Merge branch 'new-storage' into plugin

This commit is contained in:
SengokuCola
2025-05-16 21:14:16 +08:00
63 changed files with 2397 additions and 2008 deletions

View File

@@ -1,18 +1,18 @@
# 麦麦MaiCore-MaiMBot (编辑中) # 麦麦MaiCore-MaiMBot (编辑中)
<br /> <br />
<div style="text-align: center"> <div align="center">
![Python Version](https://img.shields.io/badge/Python-3.10+-blue)
![License](https://img.shields.io/github/license/SengokuCola/MaiMBot?label=协议)
![Status](https://img.shields.io/badge/状态-开发中-yellow)
![Contributors](https://img.shields.io/github/contributors/MaiM-with-u/MaiBot.svg?style=flat&label=贡献者)
![forks](https://img.shields.io/github/forks/MaiM-with-u/MaiBot.svg?style=flat&label=分支数)
![stars](https://img.shields.io/github/stars/MaiM-with-u/MaiBot?style=flat&label=星标数)
![issues](https://img.shields.io/github/issues/MaiM-with-u/MaiBot)
![Python Version](https://img.shields.io/badge/Python-3.10+-blue)
![License](https://img.shields.io/github/license/SengokuCola/MaiMBot?label=协议)
![Status](https://img.shields.io/badge/状态-开发中-yellow)
![Contributors](https://img.shields.io/github/contributors/MaiM-with-u/MaiBot.svg?style=flat&label=贡献者)
![forks](https://img.shields.io/github/forks/MaiM-with-u/MaiBot.svg?style=flat&label=分支数)
![stars](https://img.shields.io/github/stars/MaiM-with-u/MaiBot?style=flat&label=星标数)
![issues](https://img.shields.io/github/issues/MaiM-with-u/MaiBot)
[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/DrSmoothl/MaiBot)
</div> </div>
<p style="text-align: center"> <p align="center">
<a href="https://github.com/MaiM-with-u/MaiBot/"> <a href="https://github.com/MaiM-with-u/MaiBot/">
<img src="depends-data/maimai.png" alt="Logo" style="width: 200px"> <img src="depends-data/maimai.png" alt="Logo" style="width: 200px">
</a> </a>
@@ -21,8 +21,8 @@
画师略nd 画师略nd
</a> </a>
<h3 style="text-align: center">MaiBot(麦麦)</h3> <h3 align="center">MaiBot(麦麦)</h3>
<p style="text-align: center"> <p align="center">
一款专注于<strong> 群组聊天 </strong>的赛博网友 一款专注于<strong> 群组聊天 </strong>的赛博网友
<br /> <br />
<a href="https://docs.mai-mai.org"><strong>探索本项目的文档 »</strong></a> <a href="https://docs.mai-mai.org"><strong>探索本项目的文档 »</strong></a>

View File

@@ -1,6 +1,6 @@
from fastapi import HTTPException from fastapi import HTTPException
from rich.traceback import install from rich.traceback import install
from src.config.config import BotConfig from src.config.config import Config
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
import os import os
@@ -14,8 +14,8 @@ async def reload_config():
from src.config import config as config_module from src.config import config as config_module
logger.debug("正在重载配置文件...") logger.debug("正在重载配置文件...")
bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml") bot_config_path = os.path.join(Config.get_config_dir(), "bot_config.toml")
config_module.global_config = BotConfig.load_config(config_path=bot_config_path) config_module.global_config = Config.load_config(config_path=bot_config_path)
logger.debug("配置文件重载成功") logger.debug("配置文件重载成功")
return {"status": "reloaded"} return {"status": "reloaded"}
except FileNotFoundError as e: except FileNotFoundError as e:

View File

@@ -5,12 +5,15 @@ import os
import random import random
import time import time
import traceback import traceback
from typing import Optional, Tuple from typing import Optional, Tuple, List, Any
from PIL import Image from PIL import Image
import io import io
import re import re
from ...common.database import db # from gradio_client import file
from ...common.database.database_model import Emoji
from ...common.database.database import db as peewee_db
from ...config.config import global_config from ...config.config import global_config
from ..utils.utils_image import image_path_to_base64, image_manager from ..utils.utils_image import image_path_to_base64, image_manager
from ..models.utils_model import LLMRequest from ..models.utils_model import LLMRequest
@@ -51,7 +54,7 @@ class MaiEmoji:
self.is_deleted = False # 标记是否已被删除 self.is_deleted = False # 标记是否已被删除
self.format = "" self.format = ""
async def initialize_hash_format(self): async def initialize_hash_format(self) -> Optional[bool]:
"""从文件创建表情包实例, 计算哈希值和格式""" """从文件创建表情包实例, 计算哈希值和格式"""
try: try:
# 使用 full_path 检查文件是否存在 # 使用 full_path 检查文件是否存在
@@ -104,7 +107,7 @@ class MaiEmoji:
self.is_deleted = True self.is_deleted = True
return None return None
async def register_to_db(self): async def register_to_db(self) -> bool:
""" """
注册表情包 注册表情包
将表情包对应的文件从当前路径移动到EMOJI_REGISTED_DIR目录下 将表情包对应的文件从当前路径移动到EMOJI_REGISTED_DIR目录下
@@ -143,22 +146,22 @@ class MaiEmoji:
# --- 数据库操作 --- # --- 数据库操作 ---
try: try:
# 准备数据库记录 for emoji collection # 准备数据库记录 for emoji collection
emoji_record = { emotion_str = ",".join(self.emotion) if self.emotion else ""
"filename": self.filename,
"path": self.path, # 存储目录路径
"full_path": self.full_path, # 存储完整文件路径
"embedding": self.embedding,
"description": self.description,
"emotion": self.emotion,
"hash": self.hash,
"format": self.format,
"timestamp": int(self.register_time),
"usage_count": self.usage_count,
"last_used_time": self.last_used_time,
}
# 使用upsert确保记录存在或被更新 Emoji.create(
db["emoji"].update_one({"hash": self.hash}, {"$set": emoji_record}, upsert=True) hash=self.hash,
full_path=self.full_path,
format=self.format,
description=self.description,
emotion=emotion_str, # Store as comma-separated string
query_count=0, # Default value
is_registered=True,
is_banned=False, # Default value
record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time
register_time=self.register_time,
usage_count=self.usage_count,
last_used_time=self.last_used_time,
)
logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
@@ -166,14 +169,6 @@ class MaiEmoji:
except Exception as db_error: except Exception as db_error:
logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}") logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}")
# 数据库保存失败,是否需要将文件移回?为了简化,暂时只记录错误
# 可以考虑在这里尝试删除已移动的文件,避免残留
try:
if os.path.exists(self.full_path): # full_path 此时是目标路径
os.remove(self.full_path)
logger.warning(f"[回滚] 已删除移动失败后残留的文件: {self.full_path}")
except Exception as remove_error:
logger.error(f"[错误] 回滚删除文件失败: {remove_error}")
return False return False
except Exception as e: except Exception as e:
@@ -181,7 +176,7 @@ class MaiEmoji:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return False return False
async def delete(self): async def delete(self) -> bool:
"""删除表情包 """删除表情包
删除表情包的文件和数据库记录 删除表情包的文件和数据库记录
@@ -201,10 +196,14 @@ class MaiEmoji:
# 文件删除失败,但仍然尝试删除数据库记录 # 文件删除失败,但仍然尝试删除数据库记录
# 2. 删除数据库记录 # 2. 删除数据库记录
result = db.emoji.delete_one({"hash": self.hash}) try:
deleted_in_db = result.deleted_count > 0 will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
except Emoji.DoesNotExist:
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
result = 0 # Indicate no DB record was deleted
if deleted_in_db: if result > 0:
logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})") logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})")
# 3. 标记对象已被删除 # 3. 标记对象已被删除
self.is_deleted = True self.is_deleted = True
@@ -224,7 +223,7 @@ class MaiEmoji:
return False return False
def _emoji_objects_to_readable_list(emoji_objects): def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str]:
"""将表情包对象列表转换为可读的字符串列表 """将表情包对象列表转换为可读的字符串列表
参数: 参数:
@@ -243,47 +242,48 @@ def _emoji_objects_to_readable_list(emoji_objects):
return emoji_info_list return emoji_info_list
def _to_emoji_objects(data): def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
emoji_objects = [] emoji_objects = []
load_errors = 0 load_errors = 0
# data is now an iterable of Peewee Emoji model instances
emoji_data_list = list(data) emoji_data_list = list(data)
for emoji_data in emoji_data_list: for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance
full_path = emoji_data.get("full_path") full_path = emoji_data.full_path
if not full_path: if not full_path:
logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: {emoji_data.get('_id')}") logger.warning(
f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}"
)
load_errors += 1 load_errors += 1
continue # 跳过缺少 full_path 的记录 continue
try: try:
# 使用 full_path 初始化 MaiEmoji 对象
emoji = MaiEmoji(full_path=full_path) emoji = MaiEmoji(full_path=full_path)
# 设置从数据库加载的属性 emoji.hash = emoji_data.emoji_hash
emoji.hash = emoji_data.get("hash", "")
# 如果 hash 为空,也跳过?取决于业务逻辑
if not emoji.hash: if not emoji.hash:
logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}") logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}")
load_errors += 1 load_errors += 1
continue continue
emoji.description = emoji_data.get("description", "") emoji.description = emoji_data.description
emoji.emotion = emoji_data.get("emotion", []) # Deserialize emotion string from DB to list
emoji.usage_count = emoji_data.get("usage_count", 0) emoji.emotion = emoji_data.emotion.split(",") if emoji_data.emotion else []
# 优先使用 last_used_time否则用 timestamp最后用当前时间 emoji.usage_count = emoji_data.usage_count
last_used = emoji_data.get("last_used_time")
timestamp = emoji_data.get("timestamp")
emoji.last_used_time = (
last_used if last_used is not None else (timestamp if timestamp is not None else time.time())
)
emoji.register_time = timestamp if timestamp is not None else time.time()
emoji.format = emoji_data.get("format", "") # 加载格式
# 不需要再手动设置 path 和 filename__init__ 会自动处理 db_last_used_time = emoji_data.last_used_time
db_register_time = emoji_data.register_time
# If last_used_time from DB is None, use MaiEmoji's initialized register_time or current time
emoji.last_used_time = db_last_used_time if db_last_used_time is not None else emoji.register_time
# If register_time from DB is None, use MaiEmoji's initialized register_time (which is time.time())
emoji.register_time = db_register_time if db_register_time is not None else emoji.register_time
emoji.format = emoji_data.format
emoji_objects.append(emoji) emoji_objects.append(emoji)
except ValueError as ve: # 捕获 __init__ 可能的错误 except ValueError as ve:
logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}") logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}")
load_errors += 1 load_errors += 1
except Exception as e: except Exception as e:
@@ -292,13 +292,13 @@ def _to_emoji_objects(data):
return emoji_objects, load_errors return emoji_objects, load_errors
def _ensure_emoji_dir(): def _ensure_emoji_dir() -> None:
"""确保表情存储目录存在""" """确保表情存储目录存在"""
os.makedirs(EMOJI_DIR, exist_ok=True) os.makedirs(EMOJI_DIR, exist_ok=True)
os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True) os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True)
async def clear_temp_emoji(): async def clear_temp_emoji() -> None:
"""清理临时表情包 """清理临时表情包
清理/data/emoji和/data/image目录下的所有文件 清理/data/emoji和/data/image目录下的所有文件
当目录中文件数超过100时会全部删除 当目录中文件数超过100时会全部删除
@@ -320,7 +320,7 @@ async def clear_temp_emoji():
logger.success("[清理] 完成") logger.success("[清理] 完成")
async def clean_unused_emojis(emoji_dir, emoji_objects): async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"]) -> None:
"""清理指定目录中未被 emoji_objects 追踪的表情包文件""" """清理指定目录中未被 emoji_objects 追踪的表情包文件"""
if not os.path.exists(emoji_dir): if not os.path.exists(emoji_dir):
logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}") logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}")
@@ -360,74 +360,52 @@ async def clean_unused_emojis(emoji_dir, emoji_objects):
class EmojiManager: class EmojiManager:
_instance = None _instance = None
def __new__(cls): def __new__(cls) -> "EmojiManager":
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
cls._instance._initialized = False cls._instance._initialized = False
return cls._instance return cls._instance
def __init__(self): def __init__(self) -> None:
self._initialized = None self._initialized = None
self._scan_task = None self._scan_task = None
self.vlm = LLMRequest(model=global_config.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
self.llm_emotion_judge = LLMRequest( self.llm_emotion_judge = LLMRequest(
model=global_config.llm_normal, max_tokens=600, request_type="emoji" model=global_config.model.normal, max_tokens=600, request_type="emoji"
) # 更高的温度更少的token后续可以根据情绪来调整温度 ) # 更高的温度更少的token后续可以根据情绪来调整温度
self.emoji_num = 0 self.emoji_num = 0
self.emoji_num_max = global_config.max_emoji_num self.emoji_num_max = global_config.emoji.max_reg_num
self.emoji_num_max_reach_deletion = global_config.max_reach_deletion self.emoji_num_max_reach_deletion = global_config.emoji.do_replace
self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表使用类型注解明确列表元素类型 self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表使用类型注解明确列表元素类型
logger.info("启动表情包管理器") logger.info("启动表情包管理器")
def initialize(self): def initialize(self) -> None:
"""初始化数据库连接和表情目录""" """初始化数据库连接和表情目录"""
if not self._initialized: peewee_db.connect(reuse_if_open=True)
try: if peewee_db.is_closed():
self._ensure_emoji_collection() raise RuntimeError("数据库连接失败")
_ensure_emoji_dir() _ensure_emoji_dir()
self._initialized = True Emoji.create_table(safe=True) # Ensures table exists
# 更新表情包数量
# 启动时执行一次完整性检查
# await self.check_emoji_file_integrity()
except Exception as e:
logger.exception(f"初始化表情管理器失败: {e}")
def _ensure_db(self): def _ensure_db(self) -> None:
"""确保数据库已初始化""" """确保数据库已初始化"""
if not self._initialized: if not self._initialized:
self.initialize() self.initialize()
if not self._initialized: if not self._initialized:
raise RuntimeError("EmojiManager not initialized") raise RuntimeError("EmojiManager not initialized")
@staticmethod def record_usage(self, emoji_hash: str) -> None:
def _ensure_emoji_collection():
"""确保emoji集合存在并创建索引
这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。
索引的作用是加快数据库查询速度:
- embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包
- tags字段的普通索引: 加快按标签搜索表情包的速度
- filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
"""
if "emoji" not in db.list_collection_names():
db.create_collection("emoji")
db.emoji.create_index([("embedding", "2dsphere")])
db.emoji.create_index([("filename", 1)], unique=True)
def record_usage(self, emoji_hash: str):
"""记录表情使用次数""" """记录表情使用次数"""
try: try:
db.emoji.update_one({"hash": emoji_hash}, {"$inc": {"usage_count": 1}}) emoji_update = Emoji.get(Emoji.emoji_hash == emoji_hash)
for emoji in self.emoji_objects: emoji_update.usage_count += 1
if emoji.hash == emoji_hash: emoji_update.last_used_time = time.time() # Update last used time
emoji.usage_count += 1 emoji_update.save() # Persist changes to DB
break except Emoji.DoesNotExist:
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
except Exception as e: except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}") logger.error(f"记录表情使用失败: {str(e)}")
@@ -447,7 +425,6 @@ class EmojiManager:
if not all_emojis: if not all_emojis:
logger.warning("内存中没有任何表情包对象") logger.warning("内存中没有任何表情包对象")
# 可以考虑再查一次数据库?或者依赖定期任务更新
return None return None
# 计算每个表情包与输入文本的最大情感相似度 # 计算每个表情包与输入文本的最大情感相似度
@@ -463,40 +440,38 @@ class EmojiManager:
# 计算与每个emotion标签的相似度取最大值 # 计算与每个emotion标签的相似度取最大值
max_similarity = 0 max_similarity = 0
best_matching_emotion = "" # 记录最匹配的 emotion 喵~ best_matching_emotion = ""
for emotion in emotions: for emotion in emotions:
# 使用编辑距离计算相似度 # 使用编辑距离计算相似度
distance = self._levenshtein_distance(text_emotion, emotion) distance = self._levenshtein_distance(text_emotion, emotion)
max_len = max(len(text_emotion), len(emotion)) max_len = max(len(text_emotion), len(emotion))
similarity = 1 - (distance / max_len if max_len > 0 else 0) similarity = 1 - (distance / max_len if max_len > 0 else 0)
if similarity > max_similarity: # 如果找到更相似的喵~ if similarity > max_similarity:
max_similarity = similarity max_similarity = similarity
best_matching_emotion = emotion # 就记下这个 emotion 喵~ best_matching_emotion = emotion
if best_matching_emotion: # 确保有匹配的情感才添加喵~ if best_matching_emotion:
emoji_similarities.append((emoji, max_similarity, best_matching_emotion)) # 把 emotion 也存起来喵~ emoji_similarities.append((emoji, max_similarity, best_matching_emotion))
# 按相似度降序排序 # 按相似度降序排序
emoji_similarities.sort(key=lambda x: x[1], reverse=True) emoji_similarities.sort(key=lambda x: x[1], reverse=True)
# 获取前10个最相似的表情包 # 获取前10个最相似的表情包
top_emojis = ( top_emojis = emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities
emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities
) # 改个名字,更清晰喵~
if not top_emojis: if not top_emojis:
logger.warning("未找到匹配的表情包") logger.warning("未找到匹配的表情包")
return None return None
# 从前几个中随机选择一个 # 从前几个中随机选择一个
selected_emoji, similarity, matched_emotion = random.choice(top_emojis) # 把匹配的 emotion 也拿出来喵~ selected_emoji, similarity, matched_emotion = random.choice(top_emojis)
# 更新使用次数 # 更新使用次数
self.record_usage(selected_emoji.hash) self.record_usage(selected_emoji.emoji_hash)
_time_end = time.time() _time_end = time.time()
logger.info( # 使用匹配到的 emotion 记录日志喵~ logger.info(
f"为[{text_emotion}]找到表情包: {matched_emotion} ({selected_emoji.filename}), Similarity: {similarity:.4f}" f"为[{text_emotion}]找到表情包: {matched_emotion} ({selected_emoji.filename}), Similarity: {similarity:.4f}"
) )
# 返回完整文件路径和描述 # 返回完整文件路径和描述
@@ -534,7 +509,7 @@ class EmojiManager:
return previous_row[-1] return previous_row[-1]
async def check_emoji_file_integrity(self): async def check_emoji_file_integrity(self) -> None:
"""检查表情包文件完整性 """检查表情包文件完整性
遍历self.emoji_objects中的所有对象检查文件是否存在 遍历self.emoji_objects中的所有对象检查文件是否存在
如果文件已被删除,则执行对象的删除方法并从列表中移除 如果文件已被删除,则执行对象的删除方法并从列表中移除
@@ -599,7 +574,7 @@ class EmojiManager:
logger.error(f"[错误] 检查表情包完整性失败: {str(e)}") logger.error(f"[错误] 检查表情包完整性失败: {str(e)}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
async def start_periodic_check_register(self): async def start_periodic_check_register(self) -> None:
"""定期检查表情包完整性和数量""" """定期检查表情包完整性和数量"""
await self.get_all_emoji_from_db() await self.get_all_emoji_from_db()
while True: while True:
@@ -613,18 +588,18 @@ class EmojiManager:
logger.warning(f"[警告] 表情包目录不存在: {EMOJI_DIR}") logger.warning(f"[警告] 表情包目录不存在: {EMOJI_DIR}")
os.makedirs(EMOJI_DIR, exist_ok=True) os.makedirs(EMOJI_DIR, exist_ok=True)
logger.info(f"[创建] 已创建表情包目录: {EMOJI_DIR}") logger.info(f"[创建] 已创建表情包目录: {EMOJI_DIR}")
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60) await asyncio.sleep(global_config.emoji.check_interval * 60)
continue continue
# 检查目录是否为空 # 检查目录是否为空
files = os.listdir(EMOJI_DIR) files = os.listdir(EMOJI_DIR)
if not files: if not files:
logger.warning(f"[警告] 表情包目录为空: {EMOJI_DIR}") logger.warning(f"[警告] 表情包目录为空: {EMOJI_DIR}")
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60) await asyncio.sleep(global_config.emoji.check_interval * 60)
continue continue
# 检查是否需要处理表情包(数量超过最大值或不足) # 检查是否需要处理表情包(数量超过最大值或不足)
if (self.emoji_num > self.emoji_num_max and global_config.max_reach_deletion) or ( if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or (
self.emoji_num < self.emoji_num_max self.emoji_num < self.emoji_num_max
): ):
try: try:
@@ -651,15 +626,16 @@ class EmojiManager:
except Exception as e: except Exception as e:
logger.error(f"[错误] 扫描表情包目录失败: {str(e)}") logger.error(f"[错误] 扫描表情包目录失败: {str(e)}")
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60) await asyncio.sleep(global_config.emoji.check_interval * 60)
async def get_all_emoji_from_db(self): async def get_all_emoji_from_db(self) -> None:
"""获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects""" """获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects"""
try: try:
self._ensure_db() self._ensure_db()
logger.info("[数据库] 开始加载所有表情包记录...") logger.info("[数据库] 开始加载所有表情包记录 (Peewee)...")
emoji_objects, load_errors = _to_emoji_objects(db.emoji.find()) emoji_peewee_instances = Emoji.select()
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
# 更新内存中的列表和数量 # 更新内存中的列表和数量
self.emoji_objects = emoji_objects self.emoji_objects = emoji_objects
@@ -674,7 +650,7 @@ class EmojiManager:
self.emoji_objects = [] # 加载失败则清空列表 self.emoji_objects = [] # 加载失败则清空列表
self.emoji_num = 0 self.emoji_num = 0
async def get_emoji_from_db(self, emoji_hash=None): async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List["MaiEmoji"]:
"""获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找) """获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找)
参数: 参数:
@@ -686,15 +662,16 @@ class EmojiManager:
try: try:
self._ensure_db() self._ensure_db()
query = {}
if emoji_hash: if emoji_hash:
query = {"hash": emoji_hash} query = Emoji.select().where(Emoji.emoji_hash == emoji_hash)
else: else:
logger.warning( logger.warning(
"[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。" "[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。"
) )
query = Emoji.select()
emoji_objects, load_errors = _to_emoji_objects(db.emoji.find(query)) emoji_peewee_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
if load_errors > 0: if load_errors > 0:
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。") logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
@@ -705,7 +682,7 @@ class EmojiManager:
logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}") logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}")
return [] return []
async def get_emoji_from_manager(self, emoji_hash) -> Optional[MaiEmoji]: async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]:
"""从内存中的 emoji_objects 列表获取表情包 """从内存中的 emoji_objects 列表获取表情包
参数: 参数:
@@ -758,7 +735,7 @@ class EmojiManager:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return False return False
async def replace_a_emoji(self, new_emoji: MaiEmoji): async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool:
"""替换一个表情包 """替换一个表情包
Args: Args:
@@ -788,7 +765,7 @@ class EmojiManager:
# 构建提示词 # 构建提示词
prompt = ( prompt = (
f"{global_config.BOT_NICKNAME}的表情包存储已满({self.emoji_num}/{self.emoji_num_max})" f"{global_config.bot.nickname}的表情包存储已满({self.emoji_num}/{self.emoji_num_max})"
f"需要决定是否删除一个旧表情包来为新表情包腾出空间。\n\n" f"需要决定是否删除一个旧表情包来为新表情包腾出空间。\n\n"
f"新表情包信息:\n" f"新表情包信息:\n"
f"描述: {new_emoji.description}\n\n" f"描述: {new_emoji.description}\n\n"
@@ -819,7 +796,7 @@ class EmojiManager:
# 删除选定的表情包 # 删除选定的表情包
logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}") logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}")
delete_success = await self.delete_emoji(emoji_to_delete.hash) delete_success = await self.delete_emoji(emoji_to_delete.emoji_hash)
if delete_success: if delete_success:
# 修复:等待异步注册完成 # 修复:等待异步注册完成
@@ -847,7 +824,7 @@ class EmojiManager:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return False return False
async def build_emoji_description(self, image_base64: str) -> Tuple[str, list]: async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]:
"""获取表情包描述和情感列表 """获取表情包描述和情感列表
Args: Args:
@@ -871,10 +848,10 @@ class EmojiManager:
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
# 审核表情包 # 审核表情包
if global_config.EMOJI_CHECK: if global_config.emoji.content_filtration:
prompt = f''' prompt = f'''
这是一个表情包,请对这个表情包进行审核,标准如下: 这是一个表情包,请对这个表情包进行审核,标准如下:
1. 必须符合"{global_config.EMOJI_CHECK_PROMPT}"的要求 1. 必须符合"{global_config.emoji.filtration_prompt}"的要求
2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗 2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗
3. 不能是任何形式的截图,聊天记录或视频截图 3. 不能是任何形式的截图,聊天记录或视频截图
4. 不要出现5个以上文字 4. 不要出现5个以上文字

View File

@@ -76,9 +76,10 @@ def init_prompt():
class DefaultExpressor: class DefaultExpressor:
def __init__(self, chat_id: str): def __init__(self, chat_id: str):
self.log_prefix = "expressor" self.log_prefix = "expressor"
# TODO: API-Adapter修改标记
self.express_model = LLMRequest( self.express_model = LLMRequest(
model=global_config.llm_normal, model=global_config.model.normal,
temperature=global_config.llm_normal["temp"], temperature=global_config.model.normal["temp"],
max_tokens=256, max_tokens=256,
request_type="response_heartflow", request_type="response_heartflow",
) )
@@ -102,8 +103,8 @@ class DefaultExpressor:
messageinfo = anchor_message.message_info messageinfo = anchor_message.message_info
thinking_time_point = parse_thinking_id_to_timestamp(thinking_id) thinking_time_point = parse_thinking_id_to_timestamp(thinking_id)
bot_user_info = UserInfo( bot_user_info = UserInfo(
user_id=global_config.BOT_QQ, user_id=global_config.bot.qq_account,
user_nickname=global_config.BOT_NICKNAME, user_nickname=global_config.bot.nickname,
platform=messageinfo.platform, platform=messageinfo.platform,
) )
# logger.debug(f"创建思考消息:{anchor_message}") # logger.debug(f"创建思考消息:{anchor_message}")
@@ -192,7 +193,7 @@ class DefaultExpressor:
try: try:
# 1. 获取情绪影响因子并调整模型温度 # 1. 获取情绪影响因子并调整模型温度
arousal_multiplier = mood_manager.get_arousal_multiplier() arousal_multiplier = mood_manager.get_arousal_multiplier()
current_temp = float(global_config.llm_normal["temp"]) * arousal_multiplier current_temp = float(global_config.model.normal["temp"]) * arousal_multiplier
self.express_model.params["temperature"] = current_temp # 动态调整温度 self.express_model.params["temperature"] = current_temp # 动态调整温度
# 2. 获取信息捕捉器 # 2. 获取信息捕捉器
@@ -231,6 +232,7 @@ class DefaultExpressor:
try: try:
with Timer("LLM生成", {}): # 内部计时器,可选保留 with Timer("LLM生成", {}): # 内部计时器,可选保留
# TODO: API-Adapter修改标记
# logger.info(f"{self.log_prefix}[Replier-{thinking_id}]\nPrompt:\n{prompt}\n") # logger.info(f"{self.log_prefix}[Replier-{thinking_id}]\nPrompt:\n{prompt}\n")
content, reasoning_content, model_name = await self.express_model.generate_response(prompt) content, reasoning_content, model_name = await self.express_model.generate_response(prompt)
@@ -482,8 +484,8 @@ class DefaultExpressor:
"""构建单个发送消息""" """构建单个发送消息"""
bot_user_info = UserInfo( bot_user_info = UserInfo(
user_id=global_config.BOT_QQ, user_id=global_config.bot.qq_account,
user_nickname=global_config.BOT_NICKNAME, user_nickname=global_config.bot.nickname,
platform=self.chat_stream.platform, platform=self.chat_stream.platform,
) )

View File

@@ -77,8 +77,9 @@ def init_prompt() -> None:
class ExpressionLearner: class ExpressionLearner:
def __init__(self) -> None: def __init__(self) -> None:
# TODO: API-Adapter修改标记
self.express_learn_model: LLMRequest = LLMRequest( self.express_learn_model: LLMRequest = LLMRequest(
model=global_config.llm_normal, model=global_config.model.normal,
temperature=0.1, temperature=0.1,
max_tokens=256, max_tokens=256,
request_type="response_heartflow", request_type="response_heartflow",
@@ -289,7 +290,7 @@ class ExpressionLearner:
# 构建prompt # 构建prompt
prompt = await global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
"personality_expression_prompt", "personality_expression_prompt",
personality=global_config.expression_style, personality=global_config.personality.expression_style,
) )
# logger.info(f"个性表达方式提取prompt: {prompt}") # logger.info(f"个性表达方式提取prompt: {prompt}")

View File

@@ -112,7 +112,7 @@ def _check_ban_words(text: str, chat, userinfo) -> bool:
Returns: Returns:
bool: 是否包含过滤词 bool: 是否包含过滤词
""" """
for word in global_config.ban_words: for word in global_config.chat.ban_words:
if word in text: if word in text:
chat_name = chat.group_info.group_name if chat.group_info else "私聊" chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}") logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
@@ -132,7 +132,7 @@ def _check_ban_regex(text: str, chat, userinfo) -> bool:
Returns: Returns:
bool: 是否匹配过滤正则 bool: 是否匹配过滤正则
""" """
for pattern in global_config.ban_msgs_regex: for pattern in global_config.chat.ban_msgs_regex:
if pattern.search(text): if pattern.search(text):
chat_name = chat.group_info.group_name if chat.group_info else "私聊" chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}") logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")

View File

@@ -13,6 +13,9 @@ from src.manager.mood_manager import mood_manager
from src.chat.memory_system.Hippocampus import HippocampusManager from src.chat.memory_system.Hippocampus import HippocampusManager
from src.chat.knowledge.knowledge_lib import qa_manager from src.chat.knowledge.knowledge_lib import qa_manager
import random import random
import json
import math
from src.common.database.database_model import Knowledges
logger = get_logger("prompt") logger = get_logger("prompt")
@@ -45,7 +48,7 @@ def init_prompt():
你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt}{reply_style1} 你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt}{reply_style1}
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}{prompt_ger} 尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}{prompt_ger}
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。 请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)只输出回复内容。 请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情at或 @等 )。只输出回复内容。
{moderation_prompt} {moderation_prompt}
不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。只输出回复内容""", 不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。只输出回复内容""",
"reasoning_prompt_main", "reasoning_prompt_main",
@@ -110,7 +113,7 @@ class PromptBuilder:
who_chat_in_group = get_recent_group_speaker( who_chat_in_group = get_recent_group_speaker(
chat_stream.stream_id, chat_stream.stream_id,
(chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None, (chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None,
limit=global_config.observation_context_size, limit=global_config.chat.observation_context_size,
) )
elif chat_stream.user_info: elif chat_stream.user_info:
who_chat_in_group.append( who_chat_in_group.append(
@@ -158,7 +161,7 @@ class PromptBuilder:
message_list_before_now = get_raw_msg_before_timestamp_with_chat( message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id, chat_id=chat_stream.stream_id,
timestamp=time.time(), timestamp=time.time(),
limit=global_config.observation_context_size, limit=global_config.chat.observation_context_size,
) )
chat_talking_prompt = await build_readable_messages( chat_talking_prompt = await build_readable_messages(
message_list_before_now, message_list_before_now,
@@ -170,18 +173,15 @@ class PromptBuilder:
# 关键词检测与反应 # 关键词检测与反应
keywords_reaction_prompt = "" keywords_reaction_prompt = ""
for rule in global_config.keywords_reaction_rules: for rule in global_config.keyword_reaction.rules:
if rule.get("enable", False): if rule.enable:
if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])): if any(keyword in message_txt for keyword in rule.keywords):
logger.info( logger.info(f"检测到以下关键词之一:{rule.keywords},触发反应:{rule.reaction}")
f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}" keywords_reaction_prompt += f"{rule.reaction}"
)
keywords_reaction_prompt += rule.get("reaction", "") + ""
else: else:
for pattern in rule.get("regex", []): for pattern in rule.regex:
result = pattern.search(message_txt) if result := pattern.search(message_txt):
if result: reaction = rule.reaction
reaction = rule.get("reaction", "")
for name, content in result.groupdict().items(): for name, content in result.groupdict().items():
reaction = reaction.replace(f"[{name}]", content) reaction = reaction.replace(f"[{name}]", content)
logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}") logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}")
@@ -227,8 +227,8 @@ class PromptBuilder:
chat_target_2=chat_target_2, chat_target_2=chat_target_2,
chat_talking_prompt=chat_talking_prompt, chat_talking_prompt=chat_talking_prompt,
message_txt=message_txt, message_txt=message_txt,
bot_name=global_config.BOT_NICKNAME, bot_name=global_config.bot.nickname,
bot_other_names="/".join(global_config.BOT_ALIAS_NAMES), bot_other_names="/".join(global_config.bot.alias_names),
prompt_personality=prompt_personality, prompt_personality=prompt_personality,
mood_prompt=mood_prompt, mood_prompt=mood_prompt,
reply_style1=reply_style1_chosen, reply_style1=reply_style1_chosen,
@@ -249,8 +249,8 @@ class PromptBuilder:
prompt_info=prompt_info, prompt_info=prompt_info,
chat_talking_prompt=chat_talking_prompt, chat_talking_prompt=chat_talking_prompt,
message_txt=message_txt, message_txt=message_txt,
bot_name=global_config.BOT_NICKNAME, bot_name=global_config.bot.nickname,
bot_other_names="/".join(global_config.BOT_ALIAS_NAMES), bot_other_names="/".join(global_config.bot.alias_names),
prompt_personality=prompt_personality, prompt_personality=prompt_personality,
mood_prompt=mood_prompt, mood_prompt=mood_prompt,
reply_style1=reply_style1_chosen, reply_style1=reply_style1_chosen,
@@ -269,30 +269,6 @@ class PromptBuilder:
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 1. 先从LLM获取主题类似于记忆系统的做法 # 1. 先从LLM获取主题类似于记忆系统的做法
topics = [] topics = []
# try:
# # 先尝试使用记忆系统的方法获取主题
# hippocampus = HippocampusManager.get_instance()._hippocampus
# topic_num = min(5, max(1, int(len(message) * 0.1)))
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
# # 提取关键词
# topics = re.findall(r"<([^>]+)>", topics_response[0])
# if not topics:
# topics = []
# else:
# topics = [
# topic.strip()
# for topic in ",".join(topics).replace("", ",").replace("、", ",").replace(" ", ",").split(",")
# if topic.strip()
# ]
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
# except Exception as e:
# logger.error(f"从LLM提取主题失败: {str(e)}")
# # 如果LLM提取失败使用jieba分词提取关键词作为备选
# words = jieba.cut(message)
# topics = [word for word in words if len(word) > 1][:5]
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
# 如果无法提取到主题,直接使用整个消息 # 如果无法提取到主题,直接使用整个消息
if not topics: if not topics:
@@ -402,8 +378,6 @@ class PromptBuilder:
for _i, result in enumerate(results, 1): for _i, result in enumerate(results, 1):
_similarity = result["similarity"] _similarity = result["similarity"]
content = result["content"].strip() content = result["content"].strip()
# 调试:为内容添加序号和相似度信息
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n" related_info += f"{content}\n"
related_info += "\n" related_info += "\n"
@@ -432,14 +406,14 @@ class PromptBuilder:
return related_info return related_info
else: else:
logger.debug("从LPMM知识库获取知识失败使用旧版数据库进行检索") logger.debug("从LPMM知识库获取知识失败使用旧版数据库进行检索")
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38) knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
related_info += knowledge_from_old related_info += knowledge_from_old
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return related_info return related_info
except Exception as e: except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}") logger.error(f"获取知识库内容时发生异常: {str(e)}")
try: try:
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38) knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
related_info += knowledge_from_old related_info += knowledge_from_old
logger.debug( logger.debug(
f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}" f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}"
@@ -455,70 +429,70 @@ class PromptBuilder:
) -> Union[str, list]: ) -> Union[str, list]:
if not query_embedding: if not query_embedding:
return "" if not return_raw else [] return "" if not return_raw else []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{
"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]},
]
},
]
},
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
}
},
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
{
"$match": {
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1}},
]
results = list(db.knowledges.aggregate(pipeline)) results_with_similarity = []
logger.debug(f"知识库查询结果数量: {len(results)}") try:
# Fetch all knowledge entries
# This might be inefficient for very large databases.
# Consider strategies like FAISS or other vector search libraries if performance becomes an issue.
all_knowledges = Knowledges.select()
if not results: if not all_knowledges:
return [] if return_raw else ""
query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding))
if query_embedding_magnitude == 0: # Avoid division by zero
return "" if not return_raw else []
for knowledge_item in all_knowledges:
try:
db_embedding_str = knowledge_item.embedding
db_embedding = json.loads(db_embedding_str)
if len(db_embedding) != len(query_embedding):
logger.warning(
f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping."
)
continue
# Calculate Cosine Similarity
dot_product = sum(q * d for q, d in zip(query_embedding, db_embedding))
db_embedding_magnitude = math.sqrt(sum(x * x for x in db_embedding))
if db_embedding_magnitude == 0: # Avoid division by zero
similarity = 0.0
else:
similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude)
if similarity >= threshold:
results_with_similarity.append({"content": knowledge_item.content, "similarity": similarity})
except json.JSONDecodeError:
logger.error(
f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}"
)
except Exception as e:
logger.error(f"Error processing knowledge item: {e}")
# Sort by similarity in descending order
results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True)
# Limit results
limited_results = results_with_similarity[:limit]
logger.debug(f"知识库查询结果数量 (after Peewee processing): {len(limited_results)}")
if not limited_results:
return "" if not return_raw else []
if return_raw:
return limited_results
else:
return "\n".join(str(result["content"]) for result in limited_results)
except Exception as e:
logger.error(f"Error querying Knowledges with Peewee: {e}")
return "" if not return_raw else [] return "" if not return_raw else []
if return_raw:
return results
else:
# 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results)
init_prompt() init_prompt()
prompt_builder = PromptBuilder() prompt_builder = PromptBuilder()

View File

@@ -26,8 +26,9 @@ class ChattingInfoProcessor(BaseProcessor):
def __init__(self): def __init__(self):
"""初始化观察处理器""" """初始化观察处理器"""
super().__init__() super().__init__()
# TODO: API-Adapter修改标记
self.llm_summary = LLMRequest( self.llm_summary = LLMRequest(
model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation" model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
) )
async def process_info( async def process_info(
@@ -110,12 +111,12 @@ class ChattingInfoProcessor(BaseProcessor):
"created_at": datetime.now().timestamp(), "created_at": datetime.now().timestamp(),
} }
obs.mid_memorys.append(mid_memory) obs.mid_memories.append(mid_memory)
if len(obs.mid_memorys) > obs.max_mid_memory_len: if len(obs.mid_memories) > obs.max_mid_memory_len:
obs.mid_memorys.pop(0) # 移除最旧的 obs.mid_memories.pop(0) # 移除最旧的
mid_memory_str = "之前聊天的内容概述是:\n" mid_memory_str = "之前聊天的内容概述是:\n"
for mid_memory_item in obs.mid_memorys: # 重命名循环变量以示区分 for mid_memory_item in obs.mid_memories: # 重命名循环变量以示区分
time_diff = int((datetime.now().timestamp() - mid_memory_item["created_at"]) / 60) time_diff = int((datetime.now().timestamp() - mid_memory_item["created_at"]) / 60)
mid_memory_str += ( mid_memory_str += (
f"距离现在{time_diff}分钟前(聊天记录id:{mid_memory_item['id']}){mid_memory_item['theme']}\n" f"距离现在{time_diff}分钟前(聊天记录id:{mid_memory_item['id']}){mid_memory_item['theme']}\n"

View File

@@ -71,8 +71,8 @@ class MindProcessor(BaseProcessor):
self.subheartflow_id = subheartflow_id self.subheartflow_id = subheartflow_id
self.llm_model = LLMRequest( self.llm_model = LLMRequest(
model=global_config.llm_sub_heartflow, model=global_config.model.sub_heartflow,
temperature=global_config.llm_sub_heartflow["temp"], temperature=global_config.model.sub_heartflow["temp"],
max_tokens=800, max_tokens=800,
request_type="sub_heart_flow", request_type="sub_heart_flow",
) )

View File

@@ -49,7 +49,7 @@ class ToolProcessor(BaseProcessor):
self.subheartflow_id = subheartflow_id self.subheartflow_id = subheartflow_id
self.log_prefix = f"[{subheartflow_id}:ToolExecutor] " self.log_prefix = f"[{subheartflow_id}:ToolExecutor] "
self.llm_model = LLMRequest( self.llm_model = LLMRequest(
model=global_config.llm_tool_use, model=global_config.model.tool_use,
max_tokens=500, max_tokens=500,
request_type="tool_execution", request_type="tool_execution",
) )

View File

@@ -34,8 +34,9 @@ def init_prompt():
class MemoryActivator: class MemoryActivator:
def __init__(self): def __init__(self):
# TODO: API-Adapter修改标记
self.summary_model = LLMRequest( self.summary_model = LLMRequest(
model=global_config.llm_summary, temperature=0.7, max_tokens=50, request_type="chat_observation" model=global_config.model.summary, temperature=0.7, max_tokens=50, request_type="chat_observation"
) )
self.running_memory = [] self.running_memory = []

View File

@@ -35,8 +35,9 @@ class Heartflow:
self.subheartflow_manager: SubHeartflowManager = SubHeartflowManager(self.current_state) self.subheartflow_manager: SubHeartflowManager = SubHeartflowManager(self.current_state)
# LLM模型配置 # LLM模型配置
# TODO: API-Adapter修改标记
self.llm_model = LLMRequest( self.llm_model = LLMRequest(
model=global_config.llm_heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow" model=global_config.model.heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow"
) )
# 外部依赖模块 # 外部依赖模块

View File

@@ -20,9 +20,9 @@ MAX_REPLY_PROBABILITY = 1
class InterestChatting: class InterestChatting:
def __init__( def __init__(
self, self,
decay_rate=global_config.default_decay_rate_per_second, decay_rate=global_config.focus_chat.default_decay_rate_per_second,
max_interest=MAX_INTEREST, max_interest=MAX_INTEREST,
trigger_threshold=global_config.reply_trigger_threshold, trigger_threshold=global_config.focus_chat.reply_trigger_threshold,
max_probability=MAX_REPLY_PROBABILITY, max_probability=MAX_REPLY_PROBABILITY,
): ):
# 基础属性初始化 # 基础属性初始化

View File

@@ -18,19 +18,14 @@ enable_unlimited_hfc_chat = True # 调试用:无限专注聊天
prevent_offline_state = True prevent_offline_state = True
# 目前默认不启用OFFLINE状态 # 目前默认不启用OFFLINE状态
# 不同状态下普通聊天的最大消息数 MAX_NORMAL_CHAT_NUM_PEEKING = int(global_config.chat.base_normal_chat_num / 2)
base_normal_chat_num = global_config.base_normal_chat_num MAX_NORMAL_CHAT_NUM_NORMAL = global_config.chat.base_normal_chat_num
base_focused_chat_num = global_config.base_focused_chat_num MAX_NORMAL_CHAT_NUM_FOCUSED = global_config.chat.base_normal_chat_num + 1
MAX_NORMAL_CHAT_NUM_PEEKING = int(base_normal_chat_num / 2)
MAX_NORMAL_CHAT_NUM_NORMAL = base_normal_chat_num
MAX_NORMAL_CHAT_NUM_FOCUSED = base_normal_chat_num + 1
# 不同状态下专注聊天的最大消息数 # 不同状态下专注聊天的最大消息数
MAX_FOCUSED_CHAT_NUM_PEEKING = int(base_focused_chat_num / 2) MAX_FOCUSED_CHAT_NUM_PEEKING = int(global_config.chat.base_focused_chat_num / 2)
MAX_FOCUSED_CHAT_NUM_NORMAL = base_focused_chat_num MAX_FOCUSED_CHAT_NUM_NORMAL = global_config.chat.base_focused_chat_num
MAX_FOCUSED_CHAT_NUM_FOCUSED = base_focused_chat_num + 2 MAX_FOCUSED_CHAT_NUM_FOCUSED = global_config.chat.base_focused_chat_num + 2
# -- 状态定义 -- # -- 状态定义 --

View File

@@ -55,19 +55,20 @@ class ChattingObservation(Observation):
self.talking_message = [] self.talking_message = []
self.talking_message_str = "" self.talking_message_str = ""
self.talking_message_str_truncate = "" self.talking_message_str_truncate = ""
self.name = global_config.BOT_NICKNAME self.name = global_config.bot.nickname
self.nick_name = global_config.BOT_ALIAS_NAMES self.nick_name = global_config.bot.alias_names
self.max_now_obs_len = global_config.observation_context_size self.max_now_obs_len = global_config.chat.observation_context_size
self.overlap_len = global_config.compressed_length self.overlap_len = global_config.focus_chat.compressed_length
self.mid_memorys = [] self.mid_memories = []
self.max_mid_memory_len = global_config.compress_length_limit self.max_mid_memory_len = global_config.focus_chat.compress_length_limit
self.mid_memory_info = "" self.mid_memory_info = ""
self.person_list = [] self.person_list = []
self.oldest_messages = [] self.oldest_messages = []
self.oldest_messages_str = "" self.oldest_messages_str = ""
self.compressor_prompt = "" self.compressor_prompt = ""
# TODO: API-Adapter修改标记
self.llm_summary = LLMRequest( self.llm_summary = LLMRequest(
model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation" model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
) )
async def initialize(self): async def initialize(self):
@@ -85,7 +86,7 @@ class ChattingObservation(Observation):
for id in ids: for id in ids:
print(f"id{id}") print(f"id{id}")
try: try:
for mid_memory in self.mid_memorys: for mid_memory in self.mid_memories:
if mid_memory["id"] == id: if mid_memory["id"] == id:
mid_memory_by_id = mid_memory mid_memory_by_id = mid_memory
msg_str = "" msg_str = ""
@@ -103,7 +104,7 @@ class ChattingObservation(Observation):
else: else:
mid_memory_str = "之前的聊天内容:\n" mid_memory_str = "之前的聊天内容:\n"
for mid_memory in self.mid_memorys: for mid_memory in self.mid_memories:
mid_memory_str += f"{mid_memory['theme']}\n" mid_memory_str += f"{mid_memory['theme']}\n"
return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str

View File

@@ -76,8 +76,9 @@ class SubHeartflowManager:
# 为 LLM 状态评估创建一个 LLMRequest 实例 # 为 LLM 状态评估创建一个 LLMRequest 实例
# 使用与 Heartflow 相同的模型和参数 # 使用与 Heartflow 相同的模型和参数
# TODO: API-Adapter修改标记
self.llm_state_evaluator = LLMRequest( self.llm_state_evaluator = LLMRequest(
model=global_config.llm_heartflow, # 与 Heartflow 一致 model=global_config.model.heartflow, # 与 Heartflow 一致
temperature=0.6, # 与 Heartflow 一致 temperature=0.6, # 与 Heartflow 一致
max_tokens=1000, # 与 Heartflow 一致 (虽然可能不需要这么多) max_tokens=1000, # 与 Heartflow 一致 (虽然可能不需要这么多)
request_type="subheartflow_state_eval", # 保留特定的请求类型 request_type="subheartflow_state_eval", # 保留特定的请求类型
@@ -278,7 +279,7 @@ class SubHeartflowManager:
focused_limit = current_state.get_focused_chat_max_num() focused_limit = current_state.get_focused_chat_max_num()
# --- 新增:检查是否允许进入 FOCUS 模式 --- # # --- 新增:检查是否允许进入 FOCUS 模式 --- #
if not global_config.allow_focus_mode: if not global_config.chat.allow_focus_mode:
if int(time.time()) % 60 == 0: # 每60秒输出一次日志避免刷屏 if int(time.time()) % 60 == 0: # 每60秒输出一次日志避免刷屏
logger.trace("未开启 FOCUSED 状态 (allow_focus_mode=False)") logger.trace("未开启 FOCUSED 状态 (allow_focus_mode=False)")
return # 如果不允许,直接返回 return # 如果不允许,直接返回
@@ -766,7 +767,7 @@ class SubHeartflowManager:
focused_limit = current_mai_state.get_focused_chat_max_num() focused_limit = current_mai_state.get_focused_chat_max_num()
# --- 检查是否允许 FOCUS 模式 --- # # --- 检查是否允许 FOCUS 模式 --- #
if not global_config.allow_focus_mode: if not global_config.chat.allow_focus_mode:
# Log less frequently to avoid spam # Log less frequently to avoid spam
# if int(time.time()) % 60 == 0: # if int(time.time()) % 60 == 0:
# logger.debug(f"{log_prefix_task} 配置不允许进入 FOCUSED 状态") # logger.debug(f"{log_prefix_task} 配置不允许进入 FOCUSED 状态")

View File

@@ -10,7 +10,7 @@ import jieba
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from collections import Counter from collections import Counter
from ...common.database import db from ...common.database.database import memory_db as db
from ...chat.models.utils_model import LLMRequest from ...chat.models.utils_model import LLMRequest
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器 from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
@@ -19,9 +19,10 @@ from ..utils.chat_message_builder import (
build_readable_messages, build_readable_messages,
) # 导入 build_readable_messages ) # 导入 build_readable_messages
from ..utils.utils import translate_timestamp_to_human_readable from ..utils.utils import translate_timestamp_to_human_readable
from .memory_config import MemoryConfig
from rich.traceback import install from rich.traceback import install
from ...config.config import global_config
install(extra_lines=3) install(extra_lines=3)
@@ -195,18 +196,16 @@ class Hippocampus:
self.llm_summary = None self.llm_summary = None
self.entorhinal_cortex = None self.entorhinal_cortex = None
self.parahippocampal_gyrus = None self.parahippocampal_gyrus = None
self.config = None
def initialize(self, global_config): def initialize(self):
# 使用导入的 MemoryConfig dataclass 和其 from_global_config 方法
self.config = MemoryConfig.from_global_config(global_config)
# 初始化子组件 # 初始化子组件
self.entorhinal_cortex = EntorhinalCortex(self) self.entorhinal_cortex = EntorhinalCortex(self)
self.parahippocampal_gyrus = ParahippocampalGyrus(self) self.parahippocampal_gyrus = ParahippocampalGyrus(self)
# 从数据库加载记忆图 # 从数据库加载记忆图
self.entorhinal_cortex.sync_memory_from_db() self.entorhinal_cortex.sync_memory_from_db()
self.llm_topic_judge = LLMRequest(self.config.llm_topic_judge, request_type="memory") # TODO: API-Adapter修改标记
self.llm_summary = LLMRequest(self.config.llm_summary, request_type="memory") self.llm_topic_judge = LLMRequest(global_config.model.topic_judge, request_type="memory")
self.llm_summary = LLMRequest(global_config.model.summary, request_type="memory")
def get_all_node_names(self) -> list: def get_all_node_names(self) -> list:
"""获取记忆图中所有节点的名字列表""" """获取记忆图中所有节点的名字列表"""
@@ -792,7 +791,6 @@ class EntorhinalCortex:
def __init__(self, hippocampus: Hippocampus): def __init__(self, hippocampus: Hippocampus):
self.hippocampus = hippocampus self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph self.memory_graph = hippocampus.memory_graph
self.config = hippocampus.config
def get_memory_sample(self): def get_memory_sample(self):
"""从数据库获取记忆样本""" """从数据库获取记忆样本"""
@@ -801,13 +799,13 @@ class EntorhinalCortex:
# 创建双峰分布的记忆调度器 # 创建双峰分布的记忆调度器
sample_scheduler = MemoryBuildScheduler( sample_scheduler = MemoryBuildScheduler(
n_hours1=self.config.memory_build_distribution[0], n_hours1=global_config.memory.memory_build_distribution[0],
std_hours1=self.config.memory_build_distribution[1], std_hours1=global_config.memory.memory_build_distribution[1],
weight1=self.config.memory_build_distribution[2], weight1=global_config.memory.memory_build_distribution[2],
n_hours2=self.config.memory_build_distribution[3], n_hours2=global_config.memory.memory_build_distribution[3],
std_hours2=self.config.memory_build_distribution[4], std_hours2=global_config.memory.memory_build_distribution[4],
weight2=self.config.memory_build_distribution[5], weight2=global_config.memory.memory_build_distribution[5],
total_samples=self.config.build_memory_sample_num, total_samples=global_config.memory.memory_build_sample_num,
) )
timestamps = sample_scheduler.get_timestamp_array() timestamps = sample_scheduler.get_timestamp_array()
@@ -818,7 +816,7 @@ class EntorhinalCortex:
for timestamp in timestamps: for timestamp in timestamps:
# 调用修改后的 random_get_msg_snippet # 调用修改后的 random_get_msg_snippet
messages = self.random_get_msg_snippet( messages = self.random_get_msg_snippet(
timestamp, self.config.build_memory_sample_length, max_memorized_time_per_msg timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg
) )
if messages: if messages:
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600 time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
@@ -1099,7 +1097,6 @@ class ParahippocampalGyrus:
def __init__(self, hippocampus: Hippocampus): def __init__(self, hippocampus: Hippocampus):
self.hippocampus = hippocampus self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph self.memory_graph = hippocampus.memory_graph
self.config = hippocampus.config
async def memory_compress(self, messages: list, compress_rate=0.1): async def memory_compress(self, messages: list, compress_rate=0.1):
"""压缩和总结消息内容,生成记忆主题和摘要。 """压缩和总结消息内容,生成记忆主题和摘要。
@@ -1159,7 +1156,7 @@ class ParahippocampalGyrus:
# 3. 过滤掉包含禁用关键词的topic # 3. 过滤掉包含禁用关键词的topic
filtered_topics = [ filtered_topics = [
topic for topic in topics if not any(keyword in topic for keyword in self.config.memory_ban_words) topic for topic in topics if not any(keyword in topic for keyword in global_config.memory.memory_ban_words)
] ]
logger.debug(f"过滤后话题: {filtered_topics}") logger.debug(f"过滤后话题: {filtered_topics}")
@@ -1222,7 +1219,7 @@ class ParahippocampalGyrus:
bar = "" * filled_length + "-" * (bar_length - filled_length) bar = "" * filled_length + "-" * (bar_length - filled_length)
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})") logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
compress_rate = self.config.memory_compress_rate compress_rate = global_config.memory.memory_compress_rate
try: try:
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate) compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
except Exception as e: except Exception as e:
@@ -1322,7 +1319,7 @@ class ParahippocampalGyrus:
edge_data = self.memory_graph.G[source][target] edge_data = self.memory_graph.G[source][target]
last_modified = edge_data.get("last_modified") last_modified = edge_data.get("last_modified")
if current_time - last_modified > 3600 * self.config.memory_forget_time: if current_time - last_modified > 3600 * global_config.memory.memory_forget_time:
current_strength = edge_data.get("strength", 1) current_strength = edge_data.get("strength", 1)
new_strength = current_strength - 1 new_strength = current_strength - 1
@@ -1430,8 +1427,8 @@ class ParahippocampalGyrus:
async def operation_consolidate_memory(self): async def operation_consolidate_memory(self):
"""整合记忆:合并节点内相似的记忆项""" """整合记忆:合并节点内相似的记忆项"""
start_time = time.time() start_time = time.time()
percentage = self.config.consolidate_memory_percentage percentage = global_config.memory.consolidate_memory_percentage
similarity_threshold = self.config.consolidation_similarity_threshold similarity_threshold = global_config.memory.consolidation_similarity_threshold
logger.info(f"[整合] 开始检查记忆节点... 检查比例: {percentage:.2%}, 合并阈值: {similarity_threshold}") logger.info(f"[整合] 开始检查记忆节点... 检查比例: {percentage:.2%}, 合并阈值: {similarity_threshold}")
# 获取所有至少有2条记忆项的节点 # 获取所有至少有2条记忆项的节点
@@ -1544,7 +1541,6 @@ class ParahippocampalGyrus:
class HippocampusManager: class HippocampusManager:
_instance = None _instance = None
_hippocampus = None _hippocampus = None
_global_config = None
_initialized = False _initialized = False
@classmethod @classmethod
@@ -1559,19 +1555,15 @@ class HippocampusManager:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
return cls._hippocampus return cls._hippocampus
def initialize(self, global_config): def initialize(self):
"""初始化海马体实例""" """初始化海马体实例"""
if self._initialized: if self._initialized:
return self._hippocampus return self._hippocampus
self._global_config = global_config
self._hippocampus = Hippocampus() self._hippocampus = Hippocampus()
self._hippocampus.initialize(global_config) self._hippocampus.initialize()
self._initialized = True self._initialized = True
# 输出记忆系统参数信息
config = self._hippocampus.config
# 输出记忆图统计信息 # 输出记忆图统计信息
memory_graph = self._hippocampus.memory_graph.G memory_graph = self._hippocampus.memory_graph.G
node_count = len(memory_graph.nodes()) node_count = len(memory_graph.nodes())
@@ -1579,9 +1571,9 @@ class HippocampusManager:
logger.success(f"""-------------------------------- logger.success(f"""--------------------------------
记忆系统参数配置: 记忆系统参数配置:
构建间隔: {global_config.build_memory_interval}秒|样本数: {config.build_memory_sample_num},长度: {config.build_memory_sample_length}|压缩率: {config.memory_compress_rate} 构建间隔: {global_config.memory.memory_build_interval}秒|样本数: {global_config.memory.memory_build_sample_num},长度: {global_config.memory.memory_build_sample_length}|压缩率: {global_config.memory.memory_compress_rate}
记忆构建分布: {config.memory_build_distribution} 记忆构建分布: {global_config.memory.memory_build_distribution}
遗忘间隔: {global_config.forget_memory_interval}秒|遗忘比例: {global_config.memory_forget_percentage}|遗忘: {config.memory_forget_time}小时之后 遗忘间隔: {global_config.memory.forget_memory_interval}秒|遗忘比例: {global_config.memory.memory_forget_percentage}|遗忘: {global_config.memory.memory_forget_time}小时之后
记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count} 记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count}
--------------------------------""") # noqa: E501 --------------------------------""") # noqa: E501

View File

@@ -7,7 +7,6 @@ import os
# 添加项目根目录到系统路径 # 添加项目根目录到系统路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))) sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
from src.chat.memory_system.Hippocampus import HippocampusManager from src.chat.memory_system.Hippocampus import HippocampusManager
from src.config.config import global_config
from rich.traceback import install from rich.traceback import install
install(extra_lines=3) install(extra_lines=3)
@@ -19,7 +18,7 @@ async def test_memory_system():
# 初始化记忆系统 # 初始化记忆系统
print("开始初始化记忆系统...") print("开始初始化记忆系统...")
hippocampus_manager = HippocampusManager.get_instance() hippocampus_manager = HippocampusManager.get_instance()
hippocampus_manager.initialize(global_config=global_config) hippocampus_manager.initialize()
print("记忆系统初始化完成") print("记忆系统初始化完成")
# 测试记忆构建 # 测试记忆构建

View File

@@ -34,7 +34,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path) sys.path.append(root_path)
from src.common.logger import get_module_logger # noqa E402 from src.common.logger import get_module_logger # noqa E402
from src.common.database import db # noqa E402 from common.database.database import db # noqa E402
logger = get_module_logger("mem_alter") logger = get_module_logger("mem_alter")
console = Console() console = Console()

View File

@@ -1,48 +0,0 @@
from dataclasses import dataclass
from typing import List
@dataclass
class MemoryConfig:
"""记忆系统配置类"""
# 记忆构建相关配置
memory_build_distribution: List[float] # 记忆构建的时间分布参数
build_memory_sample_num: int # 每次构建记忆的样本数量
build_memory_sample_length: int # 每个样本的消息长度
memory_compress_rate: float # 记忆压缩率
# 记忆遗忘相关配置
memory_forget_time: int # 记忆遗忘时间(小时)
# 记忆过滤相关配置
memory_ban_words: List[str] # 记忆过滤词列表
# 新增:记忆整合相关配置
consolidation_similarity_threshold: float # 相似度阈值
consolidate_memory_percentage: float # 检查节点比例
consolidate_memory_interval: int # 记忆整合间隔
llm_topic_judge: str # 话题判断模型
llm_summary: str # 话题总结模型
@classmethod
def from_global_config(cls, global_config):
"""从全局配置创建记忆系统配置"""
# 使用 getattr 提供默认值,防止全局配置缺少这些项
return cls(
memory_build_distribution=getattr(
global_config, "memory_build_distribution", (24, 12, 0.5, 168, 72, 0.5)
), # 添加默认值
build_memory_sample_num=getattr(global_config, "build_memory_sample_num", 5),
build_memory_sample_length=getattr(global_config, "build_memory_sample_length", 30),
memory_compress_rate=getattr(global_config, "memory_compress_rate", 0.1),
memory_forget_time=getattr(global_config, "memory_forget_time", 24 * 7),
memory_ban_words=getattr(global_config, "memory_ban_words", []),
# 新增加载整合配置,并提供默认值
consolidation_similarity_threshold=getattr(global_config, "consolidation_similarity_threshold", 0.7),
consolidate_memory_percentage=getattr(global_config, "consolidate_memory_percentage", 0.01),
consolidate_memory_interval=getattr(global_config, "consolidate_memory_interval", 1000),
llm_topic_judge=getattr(global_config, "llm_topic_judge", "default_judge_model"), # 添加默认模型名
llm_summary=getattr(global_config, "llm_summary", "default_summary_model"), # 添加默认模型名
)

View File

@@ -41,7 +41,7 @@ class ChatBot:
chat_id = str(message.chat_stream.stream_id) chat_id = str(message.chat_stream.stream_id)
private_name = str(message.message_info.user_info.user_nickname) private_name = str(message.message_info.user_info.user_nickname)
if global_config.enable_pfc_chatting: if global_config.experimental.enable_pfc_chatting:
await self.pfc_manager.get_or_create_conversation(chat_id, private_name) await self.pfc_manager.get_or_create_conversation(chat_id, private_name)
except Exception as e: except Exception as e:
@@ -78,19 +78,19 @@ class ChatBot:
userinfo = message.message_info.user_info userinfo = message.message_info.user_info
# 用户黑名单拦截 # 用户黑名单拦截
if userinfo.user_id in global_config.ban_user_id: if userinfo.user_id in global_config.chat_target.ban_user_id:
logger.debug(f"用户{userinfo.user_id}被禁止回复") logger.debug(f"用户{userinfo.user_id}被禁止回复")
return return
if groupinfo is None: if groupinfo is None:
logger.trace("检测到私聊消息,检查") logger.trace("检测到私聊消息,检查")
# 好友黑名单拦截 # 好友黑名单拦截
if userinfo.user_id not in global_config.talk_allowed_private: if userinfo.user_id not in global_config.experimental.talk_allowed_private:
logger.debug(f"用户{userinfo.user_id}没有私聊权限") logger.debug(f"用户{userinfo.user_id}没有私聊权限")
return return
# 群聊黑名单拦截 # 群聊黑名单拦截
if groupinfo is not None and groupinfo.group_id not in global_config.talk_allowed_groups: if groupinfo is not None and groupinfo.group_id not in global_config.chat_target.talk_allowed_groups:
logger.trace(f"{groupinfo.group_id}被禁止回复") logger.trace(f"{groupinfo.group_id}被禁止回复")
return return
@@ -112,7 +112,7 @@ class ChatBot:
if groupinfo is None: if groupinfo is None:
logger.trace("检测到私聊消息") logger.trace("检测到私聊消息")
# 是否在配置信息中开启私聊模式 # 是否在配置信息中开启私聊模式
if global_config.enable_friend_chat: if global_config.experimental.enable_friend_chat:
logger.trace("私聊模式已启用") logger.trace("私聊模式已启用")
# 是否进入PFC # 是否进入PFC
if global_config.enable_pfc_chatting: if global_config.enable_pfc_chatting:

View File

@@ -5,7 +5,8 @@ import copy
from typing import Dict, Optional from typing import Dict, Optional
from ...common.database import db from ...common.database.database import db
from ...common.database.database_model import ChatStreams # 新增导入
from maim_message import GroupInfo, UserInfo from maim_message import GroupInfo, UserInfo
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
@@ -82,7 +83,13 @@ class ChatManager:
def __init__(self): def __init__(self):
if not self._initialized: if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self._ensure_collection() try:
db.connect(reuse_if_open=True)
# 确保 ChatStreams 表存在
db.create_tables([ChatStreams], safe=True)
except Exception as e:
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
self._initialized = True self._initialized = True
# 在事件循环中启动初始化 # 在事件循环中启动初始化
# asyncio.create_task(self._initialize()) # asyncio.create_task(self._initialize())
@@ -107,15 +114,6 @@ class ChatManager:
except Exception as e: except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}") logger.error(f"聊天流自动保存失败: {str(e)}")
@staticmethod
def _ensure_collection():
"""确保数据库集合存在并创建索引"""
if "chat_streams" not in db.list_collection_names():
db.create_collection("chat_streams")
# 创建索引
db.chat_streams.create_index([("stream_id", 1)], unique=True)
db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)])
@staticmethod @staticmethod
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str: def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
"""生成聊天流唯一ID""" """生成聊天流唯一ID"""
@@ -151,16 +149,43 @@ class ChatManager:
stream = self.streams[stream_id] stream = self.streams[stream_id]
# 更新用户信息和群组信息 # 更新用户信息和群组信息
stream.update_active_time() stream.update_active_time()
stream = copy.deepcopy(stream) stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
stream.user_info = user_info stream.user_info = user_info
if group_info: if group_info:
stream.group_info = group_info stream.group_info = group_info
return stream return stream
# 检查数据库中是否存在 # 检查数据库中是否存在
data = db.chat_streams.find_one({"stream_id": stream_id}) def _db_find_stream_sync(s_id: str):
if data: return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
stream = ChatStream.from_dict(data)
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
if model_instance:
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
}
stream = ChatStream.from_dict(data_for_from_dict)
# 更新用户信息和群组信息 # 更新用户信息和群组信息
stream.user_info = user_info stream.user_info = user_info
if group_info: if group_info:
@@ -175,7 +200,7 @@ class ChatManager:
group_info=group_info, group_info=group_info,
) )
except Exception as e: except Exception as e:
logger.error(f"创建聊天流失败: {e}") logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
raise e raise e
# 保存到内存和数据库 # 保存到内存和数据库
@@ -205,15 +230,38 @@ class ChatManager:
elif stream.user_info and stream.user_info.user_nickname: elif stream.user_info and stream.user_info.user_nickname:
return f"{stream.user_info.user_nickname}的私聊" return f"{stream.user_info.user_nickname}的私聊"
else: else:
# 如果没有群名或用户昵称,返回 None 或其他默认值
return None return None
@staticmethod @staticmethod
async def _save_stream(stream: ChatStream): async def _save_stream(stream: ChatStream):
"""保存聊天流到数据库""" """保存聊天流到数据库"""
if not stream.saved: if not stream.saved:
db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True) stream_data_dict = stream.to_dict()
stream.saved = True
def _db_save_stream_sync(s_data_dict: dict):
user_info_d = s_data_dict.get("user_info")
group_info_d = s_data_dict.get("group_info")
fields_to_save = {
"platform": s_data_dict["platform"],
"create_time": s_data_dict["create_time"],
"last_active_time": s_data_dict["last_active_time"],
"user_platform": user_info_d["platform"] if user_info_d else "",
"user_id": user_info_d["user_id"] if user_info_d else "",
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
"group_platform": group_info_d["platform"] if group_info_d else "",
"group_id": group_info_d["group_id"] if group_info_d else "",
"group_name": group_info_d["group_name"] if group_info_d else "",
}
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
try:
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
stream.saved = True
except Exception as e:
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
async def _save_all_streams(self): async def _save_all_streams(self):
"""保存所有聊天流""" """保存所有聊天流"""
@@ -222,10 +270,44 @@ class ChatManager:
async def load_all_streams(self): async def load_all_streams(self):
"""从数据库加载所有聊天流""" """从数据库加载所有聊天流"""
all_streams = db.chat_streams.find({})
for data in all_streams: def _db_load_all_streams_sync():
stream = ChatStream.from_dict(data) loaded_streams_data = []
self.streams[stream.stream_id] = stream for model_instance in ChatStreams.select():
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if model_instance.group_id:
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
}
loaded_streams_data.append(data_for_from_dict)
return loaded_streams_data
try:
all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync)
self.streams.clear()
for data in all_streams_data_list:
stream = ChatStream.from_dict(data)
stream.saved = True
self.streams[stream.stream_id] = stream
except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
# 创建全局单例 # 创建全局单例

View File

@@ -38,7 +38,7 @@ class MessageBuffer:
async def start_caching_messages(self, message: MessageRecv): async def start_caching_messages(self, message: MessageRecv):
"""添加消息,启动缓冲""" """添加消息,启动缓冲"""
if not global_config.message_buffer: if not global_config.chat.message_buffer:
person_id = person_info_manager.get_person_id( person_id = person_info_manager.get_person_id(
message.message_info.user_info.platform, message.message_info.user_info.user_id message.message_info.user_info.platform, message.message_info.user_info.user_id
) )
@@ -107,7 +107,7 @@ class MessageBuffer:
async def query_buffer_result(self, message: MessageRecv) -> bool: async def query_buffer_result(self, message: MessageRecv) -> bool:
"""查询缓冲结果,并清理""" """查询缓冲结果,并清理"""
if not global_config.message_buffer: if not global_config.chat.message_buffer:
return True return True
person_id_ = self.get_person_id_( person_id_ = self.get_person_id_(
message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info

View File

@@ -279,7 +279,7 @@ class MessageManager:
) )
# 检查是否超时 # 检查是否超时
if thinking_time > global_config.thinking_timeout: if thinking_time > global_config.normal_chat.thinking_timeout:
logger.warning( logger.warning(
f"[{chat_id}] 消息思考超时 ({thinking_time:.1f}秒),移除消息 {message_earliest.message_info.message_id}" f"[{chat_id}] 消息思考超时 ({thinking_time:.1f}秒),移除消息 {message_earliest.message_info.message_id}"
) )

View File

@@ -1,9 +1,10 @@
import re import re
from typing import Union from typing import Union
from ...common.database import db # from ...common.database.database import db # db is now Peewee's SqliteDatabase instance
from .message import MessageSending, MessageRecv from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream from .chat_stream import ChatStream
from ...common.database.database_model import Messages, RecalledMessages # Import Peewee models
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
logger = get_module_logger("message_storage") logger = get_module_logger("message_storage")
@@ -29,42 +30,66 @@ class MessageStorage:
else: else:
filtered_detailed_plain_text = "" filtered_detailed_plain_text = ""
message_data = { chat_info_dict = chat_stream.to_dict()
"message_id": message.message_info.message_id, user_info_dict = message.message_info.user_info.to_dict()
"time": message.message_info.time,
"chat_id": chat_stream.stream_id, # message_id 现在是 TextField直接使用字符串值
"chat_info": chat_stream.to_dict(), msg_id = message.message_info.message_id
"user_info": message.message_info.user_info.to_dict(),
# 使用过滤后的文本 # 安全地获取 group_info, 如果为 None 则视为空字典
"processed_plain_text": filtered_processed_plain_text, group_info_from_chat = chat_info_dict.get("group_info") or {}
"detailed_plain_text": filtered_detailed_plain_text, # 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
"memorized_times": message.memorized_times, user_info_from_chat = chat_info_dict.get("user_info") or {}
}
db.messages.insert_one(message_data) Messages.create(
message_id=msg_id,
time=float(message.message_info.time),
chat_id=chat_stream.stream_id,
# Flattened chat_info
chat_info_stream_id=chat_info_dict.get("stream_id"),
chat_info_platform=chat_info_dict.get("platform"),
chat_info_user_platform=user_info_from_chat.get("platform"),
chat_info_user_id=user_info_from_chat.get("user_id"),
chat_info_user_nickname=user_info_from_chat.get("user_nickname"),
chat_info_user_cardname=user_info_from_chat.get("user_cardname"),
chat_info_group_platform=group_info_from_chat.get("platform"),
chat_info_group_id=group_info_from_chat.get("group_id"),
chat_info_group_name=group_info_from_chat.get("group_name"),
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
# Flattened user_info (message sender)
user_platform=user_info_dict.get("platform"),
user_id=user_info_dict.get("user_id"),
user_nickname=user_info_dict.get("user_nickname"),
user_cardname=user_info_dict.get("user_cardname"),
# Text content
processed_plain_text=filtered_processed_plain_text,
detailed_plain_text=filtered_detailed_plain_text,
memorized_times=message.memorized_times,
)
except Exception: except Exception:
logger.exception("存储消息失败") logger.exception("存储消息失败")
@staticmethod @staticmethod
async def store_recalled_message(message_id: str, time: str, chat_stream: ChatStream) -> None: async def store_recalled_message(message_id: str, time: str, chat_stream: ChatStream) -> None:
"""存储撤回消息到数据库""" """存储撤回消息到数据库"""
if "recalled_messages" not in db.list_collection_names(): # Table creation is handled by initialize_database in database_model.py
db.create_collection("recalled_messages") try:
else: RecalledMessages.create(
try: message_id=message_id,
message_data = { time=float(time), # Assuming time is a string representing a float timestamp
"message_id": message_id, stream_id=chat_stream.stream_id,
"time": time, )
"stream_id": chat_stream.stream_id, except Exception:
} logger.exception("存储撤回消息失败")
db.recalled_messages.insert_one(message_data)
except Exception:
logger.exception("存储撤回消息失败")
@staticmethod @staticmethod
async def remove_recalled_message(time: str) -> None: async def remove_recalled_message(time: str) -> None:
"""删除撤回消息""" """删除撤回消息"""
try: try:
db.recalled_messages.delete_many({"time": {"$lt": time - 300}}) # Assuming input 'time' is a string timestamp that can be converted to float
current_time_float = float(time)
RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute()
except Exception: except Exception:
logger.exception("删除撤回消息失败") logger.exception("删除撤回消息失败")

View File

@@ -12,7 +12,8 @@ import base64
from PIL import Image from PIL import Image
import io import io
import os import os
from ...common.database import db from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
from ...config.config import global_config from ...config.config import global_config
from rich.traceback import install from rich.traceback import install
@@ -85,8 +86,6 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64," f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}" f"{image_base64[:10]}...{image_base64[-10:]}"
) )
# if isinstance(content, str) and len(content) > 100:
# payload["messages"][0]["content"] = content[:100]
return payload return payload
@@ -111,8 +110,8 @@ class LLMRequest:
def __init__(self, model: dict, **kwargs): def __init__(self, model: dict, **kwargs):
# 将大写的配置键转换为小写并从config中获取实际值 # 将大写的配置键转换为小写并从config中获取实际值
try: try:
self.api_key = os.environ[model["key"]] self.api_key = os.environ[f"{model['provider']}_KEY"]
self.base_url = os.environ[model["base_url"]] self.base_url = os.environ[f"{model['provider']}_BASE_URL"]
except AttributeError as e: except AttributeError as e:
logger.error(f"原始 model dict 信息:{model}") logger.error(f"原始 model dict 信息:{model}")
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}") logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
@@ -134,13 +133,11 @@ class LLMRequest:
def _init_database(): def _init_database():
"""初始化数据库集合""" """初始化数据库集合"""
try: try:
# 创建llm_usage集合的索引 # 使用 Peewee 创建表safe=True 表示如果表已存在则不会抛出错误
db.llm_usage.create_index([("timestamp", 1)]) db.create_tables([LLMUsage], safe=True)
db.llm_usage.create_index([("model_name", 1)]) logger.debug("LLMUsage 表已初始化/确保存在。")
db.llm_usage.create_index([("user_id", 1)])
db.llm_usage.create_index([("request_type", 1)])
except Exception as e: except Exception as e:
logger.error(f"创建数据库索引失败: {str(e)}") logger.error(f"创建 LLMUsage 表失败: {str(e)}")
def _record_usage( def _record_usage(
self, self,
@@ -165,19 +162,19 @@ class LLMRequest:
request_type = self.request_type request_type = self.request_type
try: try:
usage_data = { # 使用 Peewee 模型创建记录
"model_name": self.model_name, LLMUsage.create(
"user_id": user_id, model_name=self.model_name,
"request_type": request_type, user_id=user_id,
"endpoint": endpoint, request_type=request_type,
"prompt_tokens": prompt_tokens, endpoint=endpoint,
"completion_tokens": completion_tokens, prompt_tokens=prompt_tokens,
"total_tokens": total_tokens, completion_tokens=completion_tokens,
"cost": self._calculate_cost(prompt_tokens, completion_tokens), total_tokens=total_tokens,
"status": "success", cost=self._calculate_cost(prompt_tokens, completion_tokens),
"timestamp": datetime.now(), status="success",
} timestamp=datetime.now(), # Peewee 会处理 DateTimeField
db.llm_usage.insert_one(usage_data) )
logger.trace( logger.trace(
f"Token使用情况 - 模型: {self.model_name}, " f"Token使用情况 - 模型: {self.model_name}, "
f"用户: {user_id}, 类型: {request_type}, " f"用户: {user_id}, 类型: {request_type}, "
@@ -500,11 +497,11 @@ class LLMRequest:
logger.warning(f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}") logger.warning(f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}")
# 对全局配置进行更新 # 对全局配置进行更新
if global_config.llm_normal.get("name") == old_model_name: if global_config.model.normal.get("name") == old_model_name:
global_config.llm_normal["name"] = self.model_name global_config.model.normal["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}") logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
if global_config.llm_reasoning.get("name") == old_model_name: if global_config.model.reasoning.get("name") == old_model_name:
global_config.llm_reasoning["name"] = self.model_name global_config.model.reasoning["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}") logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
if payload and "model" in payload: if payload and "model" in payload:
@@ -636,7 +633,7 @@ class LLMRequest:
**params_copy, **params_copy,
} }
if "max_tokens" not in payload and "max_completion_tokens" not in payload: if "max_tokens" not in payload and "max_completion_tokens" not in payload:
payload["max_tokens"] = global_config.model_max_output_length payload["max_tokens"] = global_config.model.model_max_output_length
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens") payload["max_completion_tokens"] = payload.pop("max_tokens")

View File

@@ -73,8 +73,8 @@ class NormalChat:
messageinfo = message.message_info messageinfo = message.message_info
bot_user_info = UserInfo( bot_user_info = UserInfo(
user_id=global_config.BOT_QQ, user_id=global_config.bot.qq_account,
user_nickname=global_config.BOT_NICKNAME, user_nickname=global_config.bot.nickname,
platform=messageinfo.platform, platform=messageinfo.platform,
) )
@@ -121,8 +121,8 @@ class NormalChat:
message_id=thinking_id, message_id=thinking_id,
chat_stream=self.chat_stream, # 使用 self.chat_stream chat_stream=self.chat_stream, # 使用 self.chat_stream
bot_user_info=UserInfo( bot_user_info=UserInfo(
user_id=global_config.BOT_QQ, user_id=global_config.bot.qq_account,
user_nickname=global_config.BOT_NICKNAME, user_nickname=global_config.bot.nickname,
platform=message.message_info.platform, platform=message.message_info.platform,
), ),
sender_info=message.message_info.user_info, sender_info=message.message_info.user_info,
@@ -147,7 +147,7 @@ class NormalChat:
# 改为实例方法 # 改为实例方法
async def _handle_emoji(self, message: MessageRecv, response: str): async def _handle_emoji(self, message: MessageRecv, response: str):
"""处理表情包""" """处理表情包"""
if random() < global_config.emoji_chance: if random() < global_config.normal_chat.emoji_chance:
emoji_raw = await emoji_manager.get_emoji_for_text(response) emoji_raw = await emoji_manager.get_emoji_for_text(response)
if emoji_raw: if emoji_raw:
emoji_path, description = emoji_raw emoji_path, description = emoji_raw
@@ -160,8 +160,8 @@ class NormalChat:
message_id="mt" + str(thinking_time_point), message_id="mt" + str(thinking_time_point),
chat_stream=self.chat_stream, # 使用 self.chat_stream chat_stream=self.chat_stream, # 使用 self.chat_stream
bot_user_info=UserInfo( bot_user_info=UserInfo(
user_id=global_config.BOT_QQ, user_id=global_config.bot.qq_account,
user_nickname=global_config.BOT_NICKNAME, user_nickname=global_config.bot.nickname,
platform=message.message_info.platform, platform=message.message_info.platform,
), ),
sender_info=message.message_info.user_info, sender_info=message.message_info.user_info,
@@ -186,7 +186,7 @@ class NormalChat:
label=emotion, label=emotion,
stance=stance, # 使用 self.chat_stream stance=stance, # 使用 self.chat_stream
) )
self.mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor) self.mood_manager.update_mood_from_emotion(emotion, global_config.mood.mood_intensity_factor)
async def _reply_interested_message(self) -> None: async def _reply_interested_message(self) -> None:
""" """
@@ -430,7 +430,7 @@ class NormalChat:
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool: def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
"""检查消息中是否包含过滤词""" """检查消息中是否包含过滤词"""
stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id
for word in global_config.ban_words: for word in global_config.chat.ban_words:
if word in text: if word in text:
logger.info( logger.info(
f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]" f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]"
@@ -445,7 +445,7 @@ class NormalChat:
def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool: def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
"""检查消息是否匹配过滤正则表达式""" """检查消息是否匹配过滤正则表达式"""
stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id
for pattern in global_config.ban_msgs_regex: for pattern in global_config.chat.ban_msgs_regex:
if pattern.search(text): if pattern.search(text):
logger.info( logger.info(
f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]" f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]"

View File

@@ -15,21 +15,22 @@ logger = get_logger("llm")
class NormalChatGenerator: class NormalChatGenerator:
def __init__(self): def __init__(self):
# TODO: API-Adapter修改标记
self.model_reasoning = LLMRequest( self.model_reasoning = LLMRequest(
model=global_config.llm_reasoning, model=global_config.model.reasoning,
temperature=0.7, temperature=0.7,
max_tokens=3000, max_tokens=3000,
request_type="response_reasoning", request_type="response_reasoning",
) )
self.model_normal = LLMRequest( self.model_normal = LLMRequest(
model=global_config.llm_normal, model=global_config.model.normal,
temperature=global_config.llm_normal["temp"], temperature=global_config.model.normal["temp"],
max_tokens=256, max_tokens=256,
request_type="response_reasoning", request_type="response_reasoning",
) )
self.model_sum = LLMRequest( self.model_sum = LLMRequest(
model=global_config.llm_summary, temperature=0.7, max_tokens=3000, request_type="relation" model=global_config.model.summary, temperature=0.7, max_tokens=3000, request_type="relation"
) )
self.current_model_type = "r1" # 默认使用 R1 self.current_model_type = "r1" # 默认使用 R1
self.current_model_name = "unknown model" self.current_model_name = "unknown model"
@@ -37,7 +38,7 @@ class NormalChatGenerator:
async def generate_response(self, message: MessageThinking, thinking_id: str) -> Optional[Union[str, List[str]]]: async def generate_response(self, message: MessageThinking, thinking_id: str) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数""" """根据当前模型类型选择对应的生成函数"""
# 从global_config中获取模型概率值并选择模型 # 从global_config中获取模型概率值并选择模型
if random.random() < global_config.model_reasoning_probability: if random.random() < global_config.normal_chat.reasoning_model_probability:
self.current_model_type = "深深地" self.current_model_type = "深深地"
current_model = self.model_reasoning current_model = self.model_reasoning
else: else:
@@ -51,7 +52,7 @@ class NormalChatGenerator:
model_response = await self._generate_response_with_model(message, current_model, thinking_id) model_response = await self._generate_response_with_model(message, current_model, thinking_id)
if model_response: if model_response:
logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response}") logger.info(f"{global_config.bot.nickname}的回复是:{model_response}")
model_response = await self._process_response(model_response) model_response = await self._process_response(model_response)
return model_response return model_response
@@ -113,7 +114,7 @@ class NormalChatGenerator:
- "中立":不表达明确立场或无关回应 - "中立":不表达明确立场或无关回应
2. 从"开心,愤怒,悲伤,惊讶,平静,害羞,恐惧,厌恶,困惑"中选出最匹配的1个情感标签 2. 从"开心,愤怒,悲伤,惊讶,平静,害羞,恐惧,厌恶,困惑"中选出最匹配的1个情感标签
3. 按照"立场-情绪"的格式直接输出结果,例如:"反对-愤怒" 3. 按照"立场-情绪"的格式直接输出结果,例如:"反对-愤怒"
4. 考虑回复者的人格设定为{global_config.personality_core} 4. 考虑回复者的人格设定为{global_config.personality.personality_core}
对话示例: 对话示例:
被回复「A就是笨」 被回复「A就是笨」

View File

@@ -1,18 +1,20 @@
import asyncio import asyncio
from src.config.config import global_config
from .willing_manager import BaseWillingManager from .willing_manager import BaseWillingManager
class ClassicalWillingManager(BaseWillingManager): class ClassicalWillingManager(BaseWillingManager):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self._decay_task: asyncio.Task = None self._decay_task: asyncio.Task | None = None
async def _decay_reply_willing(self): async def _decay_reply_willing(self):
"""定期衰减回复意愿""" """定期衰减回复意愿"""
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)
for chat_id in self.chat_reply_willing: for chat_id in self.chat_reply_willing:
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.9) self.chat_reply_willing[chat_id] = max(0.0, self.chat_reply_willing[chat_id] * 0.9)
async def async_task_starter(self): async def async_task_starter(self):
if self._decay_task is None: if self._decay_task is None:
@@ -23,35 +25,33 @@ class ClassicalWillingManager(BaseWillingManager):
chat_id = willing_info.chat_id chat_id = willing_info.chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0) current_willing = self.chat_reply_willing.get(chat_id, 0)
interested_rate = willing_info.interested_rate * self.global_config.response_interested_rate_amplifier interested_rate = willing_info.interested_rate * global_config.normal_chat.response_interested_rate_amplifier
if interested_rate > 0.4: if interested_rate > 0.4:
current_willing += interested_rate - 0.3 current_willing += interested_rate - 0.3
if willing_info.is_mentioned_bot and current_willing < 1.0: if willing_info.is_mentioned_bot:
current_willing += 1 current_willing += 1 if current_willing < 1.0 else 0.05
elif willing_info.is_mentioned_bot:
current_willing += 0.05
is_emoji_not_reply = False is_emoji_not_reply = False
if willing_info.is_emoji: if willing_info.is_emoji:
if self.global_config.emoji_response_penalty != 0: if global_config.normal_chat.emoji_response_penalty != 0:
current_willing *= self.global_config.emoji_response_penalty current_willing *= global_config.normal_chat.emoji_response_penalty
else: else:
is_emoji_not_reply = True is_emoji_not_reply = True
self.chat_reply_willing[chat_id] = min(current_willing, 3.0) self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
reply_probability = min( reply_probability = min(
max((current_willing - 0.5), 0.01) * self.global_config.response_willing_amplifier * 2, 1 max((current_willing - 0.5), 0.01) * global_config.normal_chat.response_willing_amplifier * 2, 1
) )
# 检查群组权限(如果是群聊) # 检查群组权限(如果是群聊)
if ( if (
willing_info.group_info willing_info.group_info
and willing_info.group_info.group_id in self.global_config.talk_frequency_down_groups and willing_info.group_info.group_id in global_config.chat_target.talk_frequency_down_groups
): ):
reply_probability = reply_probability / self.global_config.down_frequency_rate reply_probability = reply_probability / global_config.normal_chat.down_frequency_rate
if is_emoji_not_reply: if is_emoji_not_reply:
reply_probability = 0 reply_probability = 0
@@ -61,7 +61,7 @@ class ClassicalWillingManager(BaseWillingManager):
async def before_generate_reply_handle(self, message_id): async def before_generate_reply_handle(self, message_id):
chat_id = self.ongoing_messages[message_id].chat_id chat_id = self.ongoing_messages[message_id].chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0) current_willing = self.chat_reply_willing.get(chat_id, 0)
self.chat_reply_willing[chat_id] = max(0, current_willing - 1.8) self.chat_reply_willing[chat_id] = max(0.0, current_willing - 1.8)
async def after_generate_reply_handle(self, message_id): async def after_generate_reply_handle(self, message_id):
if message_id not in self.ongoing_messages: if message_id not in self.ongoing_messages:
@@ -70,7 +70,7 @@ class ClassicalWillingManager(BaseWillingManager):
chat_id = self.ongoing_messages[message_id].chat_id chat_id = self.ongoing_messages[message_id].chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0) current_willing = self.chat_reply_willing.get(chat_id, 0)
if current_willing < 1: if current_willing < 1:
self.chat_reply_willing[chat_id] = min(1, current_willing + 0.4) self.chat_reply_willing[chat_id] = min(1.0, current_willing + 0.4)
async def bombing_buffer_message_handle(self, message_id): async def bombing_buffer_message_handle(self, message_id):
return await super().bombing_buffer_message_handle(message_id) return await super().bombing_buffer_message_handle(message_id)

View File

@@ -19,6 +19,7 @@ Mxp 模式:梦溪畔独家赞助
下下策是询问一个菜鸟(@梦溪畔) 下下策是询问一个菜鸟(@梦溪畔)
""" """
from src.config.config import global_config
from .willing_manager import BaseWillingManager from .willing_manager import BaseWillingManager
from typing import Dict from typing import Dict
import asyncio import asyncio
@@ -50,8 +51,6 @@ class MxpWillingManager(BaseWillingManager):
self.mention_willing_gain = 0.6 # 提及意愿增益 self.mention_willing_gain = 0.6 # 提及意愿增益
self.interest_willing_gain = 0.3 # 兴趣意愿增益 self.interest_willing_gain = 0.3 # 兴趣意愿增益
self.emoji_response_penalty = self.global_config.emoji_response_penalty # 表情包回复惩罚
self.down_frequency_rate = self.global_config.down_frequency_rate # 降低回复频率的群组惩罚系数
self.single_chat_gain = 0.12 # 单聊增益 self.single_chat_gain = 0.12 # 单聊增益
self.fatigue_messages_triggered_num = self.expected_replies_per_min # 疲劳消息触发数量(int) self.fatigue_messages_triggered_num = self.expected_replies_per_min # 疲劳消息触发数量(int)
@@ -179,10 +178,10 @@ class MxpWillingManager(BaseWillingManager):
probability = self._willing_to_probability(current_willing) probability = self._willing_to_probability(current_willing)
if w_info.is_emoji: if w_info.is_emoji:
probability *= self.emoji_response_penalty probability *= global_config.normal_chat.emoji_response_penalty
if w_info.group_info and w_info.group_info.group_id in self.global_config.talk_frequency_down_groups: if w_info.group_info and w_info.group_info.group_id in global_config.chat_target.talk_frequency_down_groups:
probability /= self.down_frequency_rate probability /= global_config.normal_chat.down_frequency_rate
self.temporary_willing = current_willing self.temporary_willing = current_willing

View File

@@ -1,6 +1,6 @@
from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger
from dataclasses import dataclass from dataclasses import dataclass
from src.config.config import global_config, BotConfig from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.chat.person_info.person_info import person_info_manager, PersonInfoManager from src.chat.person_info.person_info import person_info_manager, PersonInfoManager
@@ -93,7 +93,6 @@ class BaseWillingManager(ABC):
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id) self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id)
self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id) self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id)
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self.global_config: BotConfig = global_config
self.logger: LoguruLogger = logger self.logger: LoguruLogger = logger
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float): def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
@@ -173,7 +172,7 @@ def init_willing_manager() -> BaseWillingManager:
Returns: Returns:
对应mode的WillingManager实例 对应mode的WillingManager实例
""" """
mode = global_config.willing_mode.lower() mode = global_config.normal_chat.willing_mode.lower()
return BaseWillingManager.create(mode) return BaseWillingManager.create(mode)

View File

@@ -1,5 +1,6 @@
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from ...common.database import db from ...common.database.database import db
from ...common.database.database_model import PersonInfo # 新增导入
import copy import copy
import hashlib import hashlib
from typing import Any, Callable, Dict from typing import Any, Callable, Dict
@@ -16,7 +17,7 @@ matplotlib.use("Agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from pathlib import Path from pathlib import Path
import pandas as pd import pandas as pd
import json import json # 新增导入
import re import re
@@ -38,47 +39,49 @@ logger = get_logger("person_info")
person_info_default = { person_info_default = {
"person_id": None, "person_id": None,
"person_name": None, "person_name": None, # 模型中已设为 null=True此默认值OK
"name_reason": None, "name_reason": None,
"platform": None, "platform": "unknown", # 提供非None的默认值
"user_id": None, "user_id": "unknown", # 提供非None的默认值
"nickname": None, "nickname": "Unknown", # 提供非None的默认值
# "age" : 0,
"relationship_value": 0, "relationship_value": 0,
# "saved" : True, "know_time": 0, # 修正拼写konw_time -> know_time
# "impression" : None,
# "gender" : Unkown,
"konw_time": 0,
"msg_interval": 2000, "msg_interval": 2000,
"msg_interval_list": [], "msg_interval_list": [], # 将作为 JSON 字符串存储在 Peewee 的 TextField
"user_cardname": None, # 添加群名片 "user_cardname": None, # 注意:此字段不在 PersonInfo Peewee 模型中
"user_avatar": None, # 添加头像信息例如URL或标识符 "user_avatar": None, # 注意:此字段不在 PersonInfo Peewee 模型中
} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项 }
class PersonInfoManager: class PersonInfoManager:
def __init__(self): def __init__(self):
self.person_name_list = {} self.person_name_list = {}
# TODO: API-Adapter修改标记
self.qv_name_llm = LLMRequest( self.qv_name_llm = LLMRequest(
model=global_config.llm_normal, model=global_config.model.normal,
max_tokens=256, max_tokens=256,
request_type="qv_name", request_type="qv_name",
) )
if "person_info" not in db.list_collection_names(): try:
db.create_collection("person_info") db.connect(reuse_if_open=True)
db.person_info.create_index("person_id", unique=True) db.create_tables([PersonInfo], safe=True)
except Exception as e:
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
# 初始化时读取所有person_name # 初始化时读取所有person_name
cursor = db.person_info.find({"person_name": {"$exists": True}}, {"person_id": 1, "person_name": 1, "_id": 0}) try:
for doc in cursor: for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where(
if doc.get("person_name"): PersonInfo.person_name.is_null(False)
self.person_name_list[doc["person_id"]] = doc["person_name"] ):
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称") if record.person_name:
self.person_name_list[record.person_id] = record.person_name
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
except Exception as e:
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
@staticmethod @staticmethod
def get_person_id(platform: str, user_id: int): def get_person_id(platform: str, user_id: int):
"""获取唯一id""" """获取唯一id"""
# 如果platform中存在-,就截取-后面的部分
if "-" in platform: if "-" in platform:
platform = platform.split("-")[1] platform = platform.split("-")[1]
@@ -86,13 +89,17 @@ class PersonInfoManager:
key = "_".join(components) key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest() return hashlib.md5(key.encode()).hexdigest()
def is_person_known(self, platform: str, user_id: int): async def is_person_known(self, platform: str, user_id: int):
"""判断是否认识某人""" """判断是否认识某人"""
person_id = self.get_person_id(platform, user_id) person_id = self.get_person_id(platform, user_id)
document = db.person_info.find_one({"person_id": person_id})
if document: def _db_check_known_sync(p_id: str):
return True return PersonInfo.get_or_none(PersonInfo.person_id == p_id) is not None
else:
try:
return await asyncio.to_thread(_db_check_known_sync, person_id)
except Exception as e:
logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}")
return False return False
def get_person_id_by_person_name(self, person_name: str): def get_person_id_by_person_name(self, person_name: str):
@@ -111,73 +118,111 @@ class PersonInfoManager:
return return
_person_info_default = copy.deepcopy(person_info_default) _person_info_default = copy.deepcopy(person_info_default)
_person_info_default["person_id"] = person_id model_fields = PersonInfo._meta.fields.keys()
final_data = {"person_id": person_id}
if data: if data:
for key in _person_info_default: for key, value in data.items():
if key != "person_id" and key in data: if key in model_fields:
_person_info_default[key] = data[key] final_data[key] = value
db.person_info.insert_one(_person_info_default) for key, default_value in _person_info_default.items():
if key in model_fields and key not in final_data:
final_data[key] = default_value
if "msg_interval_list" in final_data and isinstance(final_data["msg_interval_list"], list):
final_data["msg_interval_list"] = json.dumps(final_data["msg_interval_list"])
elif "msg_interval_list" not in final_data and "msg_interval_list" in model_fields:
final_data["msg_interval_list"] = json.dumps([])
def _db_create_sync(p_data: dict):
try:
PersonInfo.create(**p_data)
return True
except Exception as e:
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (Peewee): {e}")
return False
await asyncio.to_thread(_db_create_sync, final_data)
async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None): async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None):
"""更新某一个字段,会补全""" """更新某一个字段,会补全"""
if field_name not in person_info_default.keys(): if field_name not in PersonInfo._meta.fields:
logger.debug(f"更新'{field_name}'失败,未定义的字段") if field_name in person_info_default:
logger.debug(f"更新'{field_name}'跳过,字段存在于默认配置但不在 PersonInfo Peewee 模型中。")
return
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。")
return return
document = db.person_info.find_one({"person_id": person_id}) def _db_update_sync(p_id: str, f_name: str, val):
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
if f_name == "msg_interval_list" and isinstance(val, list):
setattr(record, f_name, json.dumps(val))
else:
setattr(record, f_name, val)
record.save()
return True, False
return False, True
if document: found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, value)
db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}})
else: if needs_creation:
data[field_name] = value logger.debug(f"更新时 {person_id} 不存在,将新建。")
logger.debug(f"更新时{person_id}不存在,已新建") creation_data = data if data is not None else {}
await self.create_person_info(person_id, data) creation_data[field_name] = value
if "platform" not in creation_data or "user_id" not in creation_data:
logger.warning(f"{person_id} 创建记录时platform/user_id 可能缺失。")
await self.create_person_info(person_id, creation_data)
@staticmethod @staticmethod
async def has_one_field(person_id: str, field_name: str): async def has_one_field(person_id: str, field_name: str):
"""判断是否存在某一个字段""" """判断是否存在某一个字段"""
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1}) if field_name not in PersonInfo._meta.fields:
if document: logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。")
return True return False
else:
def _db_has_field_sync(p_id: str, f_name: str):
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
return True
return False
try:
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
except Exception as e:
logger.error(f"检查字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
return False return False
@staticmethod @staticmethod
def _extract_json_from_text(text: str) -> dict: def _extract_json_from_text(text: str) -> dict:
"""从文本中提取JSON数据的高容错方法""" """从文本中提取JSON数据的高容错方法"""
try: try:
# 尝试直接解析
parsed_json = json.loads(text) parsed_json = json.loads(text)
# 如果解析结果是列表,尝试取第一个元素
if isinstance(parsed_json, list): if isinstance(parsed_json, list):
if parsed_json: # 检查列表是否为空 if parsed_json:
parsed_json = parsed_json[0] parsed_json = parsed_json[0]
else: # 如果列表为空,重置为 None走后续逻辑 else:
parsed_json = None parsed_json = None
# 确保解析结果是字典
if isinstance(parsed_json, dict): if isinstance(parsed_json, dict):
return parsed_json return parsed_json
except json.JSONDecodeError: except json.JSONDecodeError:
# 解析失败,继续尝试其他方法
pass pass
except Exception as e: except Exception as e:
logger.warning(f"尝试直接解析JSON时发生意外错误: {e}") logger.warning(f"尝试直接解析JSON时发生意外错误: {e}")
pass # 继续尝试其他方法 pass
# 如果直接解析失败或结果不是字典
try: try:
# 尝试找到JSON对象格式的部分
json_pattern = r"\{[^{}]*\}" json_pattern = r"\{[^{}]*\}"
matches = re.findall(json_pattern, text) matches = re.findall(json_pattern, text)
if matches: if matches:
parsed_obj = json.loads(matches[0]) parsed_obj = json.loads(matches[0])
if isinstance(parsed_obj, dict): # 确保是字典 if isinstance(parsed_obj, dict):
return parsed_obj return parsed_obj
# 如果上面都失败了,尝试提取键值对
nickname_pattern = r'"nickname"[:\s]+"([^"]+)"' nickname_pattern = r'"nickname"[:\s]+"([^"]+)"'
reason_pattern = r'"reason"[:\s]+"([^"]+)"' reason_pattern = r'"reason"[:\s]+"([^"]+)"'
@@ -192,7 +237,6 @@ class PersonInfoManager:
except Exception as e: except Exception as e:
logger.error(f"后备JSON提取失败: {str(e)}") logger.error(f"后备JSON提取失败: {str(e)}")
# 如果所有方法都失败了,返回默认字典
logger.warning(f"无法从文本中提取有效的JSON字典: {text}") logger.warning(f"无法从文本中提取有效的JSON字典: {text}")
return {"nickname": "", "reason": ""} return {"nickname": "", "reason": ""}
@@ -207,9 +251,11 @@ class PersonInfoManager:
old_name = await self.get_value(person_id, "person_name") old_name = await self.get_value(person_id, "person_name")
old_reason = await self.get_value(person_id, "name_reason") old_reason = await self.get_value(person_id, "name_reason")
max_retries = 5 # 最大重试次数 max_retries = 5
current_try = 0 current_try = 0
existing_names = "" existing_names_str = ""
current_name_set = set(self.person_name_list.values())
while current_try < max_retries: while current_try < max_retries:
individuality = Individuality.get_instance() individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(x_person=2, level=1) prompt_personality = individuality.get_prompt(x_person=2, level=1)
@@ -224,45 +270,58 @@ class PersonInfoManager:
qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason}" qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason}"
qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸" qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸"
qv_name_prompt += ( qv_name_prompt += (
"\n请根据以上用户信息想想你叫他什么比较好不要太浮夸请最好使用用户的qq昵称可以稍作修改" "\n请根据以上用户信息想想你叫他什么比较好不要太浮夸请最好使用用户的qq昵称可以稍作修改"
) )
if existing_names:
qv_name_prompt += f"\n请注意,以下名称已被使用,不要使用以下昵称:{existing_names}\n" if existing_names_str:
qv_name_prompt += f"\n请注意,以下名称已被你尝试过或已知存在,请避免:{existing_names_str}\n"
if len(current_name_set) < 50 and current_name_set:
qv_name_prompt += f"已知的其他昵称有: {', '.join(list(current_name_set)[:10])}等。\n"
qv_name_prompt += "请用json给出你的想法并给出理由示例如下" qv_name_prompt += "请用json给出你的想法并给出理由示例如下"
qv_name_prompt += """{ qv_name_prompt += """{
"nickname": "昵称", "nickname": "昵称",
"reason": "理由" "reason": "理由"
}""" }"""
# logger.debug(f"取名提示词:{qv_name_prompt}")
response = await self.qv_name_llm.generate_response(qv_name_prompt) response = await self.qv_name_llm.generate_response(qv_name_prompt)
logger.trace(f"取名提示词:{qv_name_prompt}\n取名回复:{response}") logger.trace(f"取名提示词:{qv_name_prompt}\n取名回复:{response}")
result = self._extract_json_from_text(response[0]) result = self._extract_json_from_text(response[0])
if not result["nickname"]: if not result or not result.get("nickname"):
logger.error("生成的昵称为空,重试中...") logger.error("生成的昵称为空或结果格式不正确,重试中...")
current_try += 1 current_try += 1
continue continue
# 检查生成的昵称是否已存在 generated_nickname = result["nickname"]
if result["nickname"] not in self.person_name_list.values():
# 更新数据库和内存中的列表
await self.update_one_field(person_id, "person_name", result["nickname"])
# await self.update_one_field(person_id, "nickname", user_nickname)
# await self.update_one_field(person_id, "avatar", user_avatar)
await self.update_one_field(person_id, "name_reason", result["reason"])
self.person_name_list[person_id] = result["nickname"] is_duplicate = False
# logger.debug(f"用户 {person_id} 的名称已更新为 {result['nickname']},原因:{result['reason']}") if generated_nickname in current_name_set:
is_duplicate = True
else:
def _db_check_name_exists_sync(name_to_check):
return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
is_duplicate = True
current_name_set.add(generated_nickname)
if not is_duplicate:
await self.update_one_field(person_id, "person_name", generated_nickname)
await self.update_one_field(person_id, "name_reason", result.get("reason", "未提供理由"))
self.person_name_list[person_id] = generated_nickname
return result return result
else: else:
existing_names += f"{result['nickname']}" if existing_names_str:
existing_names_str += ""
existing_names_str += generated_nickname
logger.debug(f"生成的昵称 {generated_nickname} 已存在,重试中...")
current_try += 1
logger.debug(f"生成的昵称 {result['nickname']} 已存在,重试中...") logger.error(f"{max_retries}次尝试后仍未能生成唯一昵称 for {person_id}")
current_try += 1
logger.error(f"{max_retries}次尝试后仍未能生成唯一昵称")
return None return None
@staticmethod @staticmethod
@@ -272,30 +331,56 @@ class PersonInfoManager:
logger.debug("删除失败person_id 不能为空") logger.debug("删除失败person_id 不能为空")
return return
result = db.person_info.delete_one({"person_id": person_id}) def _db_delete_sync(p_id: str):
if result.deleted_count > 0: try:
logger.debug(f"删除成功person_id={person_id}") query = PersonInfo.delete().where(PersonInfo.person_id == p_id)
deleted_count = query.execute()
return deleted_count
except Exception as e:
logger.error(f"删除 PersonInfo {p_id} 失败 (Peewee): {e}")
return 0
deleted_count = await asyncio.to_thread(_db_delete_sync, person_id)
if deleted_count > 0:
logger.debug(f"删除成功person_id={person_id} (Peewee)")
else: else:
logger.debug(f"删除失败:未找到 person_id={person_id}") logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)")
@staticmethod @staticmethod
async def get_value(person_id: str, field_name: str): async def get_value(person_id: str, field_name: str):
"""获取指定person_id文档的字段值若不存在该字段则返回该字段的全局默认值""" """获取指定person_id文档的字段值若不存在该字段则返回该字段的全局默认值"""
if not person_id: if not person_id:
logger.debug("get_value获取失败person_id不能为空") logger.debug("get_value获取失败person_id不能为空")
return person_info_default.get(field_name)
if field_name not in PersonInfo._meta.fields:
if field_name in person_info_default:
logger.trace(f"字段'{field_name}'不在Peewee模型中但存在于默认配置中。返回配置默认值。")
return copy.deepcopy(person_info_default[field_name])
logger.debug(f"get_value获取失败字段'{field_name}'未在Peewee模型和默认配置中定义。")
return None return None
if field_name not in person_info_default: def _db_get_value_sync(p_id: str, f_name: str):
logger.debug(f"get_value获取失败字段'{field_name}'未定义") record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
val = getattr(record, f_name)
if f_name == "msg_interval_list" and isinstance(val, str):
try:
return json.loads(val)
except json.JSONDecodeError:
logger.warning(f"无法解析 {p_id} 的 msg_interval_list JSON: {val}")
return copy.deepcopy(person_info_default.get(f_name, []))
return val
return None return None
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1}) value = await asyncio.to_thread(_db_get_value_sync, person_id, field_name)
if document and field_name in document: if value is not None:
return document[field_name] return value
else: else:
default_value = copy.deepcopy(person_info_default[field_name]) default_value = copy.deepcopy(person_info_default.get(field_name))
logger.trace(f"获取{person_id}{field_name}失败,已返回默认值{default_value}") logger.trace(f"获取{person_id}{field_name}失败或值为None,已返回默认值{default_value} (Peewee)")
return default_value return default_value
@staticmethod @staticmethod
@@ -305,93 +390,84 @@ class PersonInfoManager:
logger.debug("get_values获取失败person_id不能为空") logger.debug("get_values获取失败person_id不能为空")
return {} return {}
# 检查所有字段是否有效
for field in field_names:
if field not in person_info_default:
logger.debug(f"get_values获取失败字段'{field}'未定义")
return {}
# 构建查询投影(所有字段都有效才会执行到这里)
projection = {field: 1 for field in field_names}
document = db.person_info.find_one({"person_id": person_id}, projection)
result = {} result = {}
for field in field_names:
result[field] = copy.deepcopy( def _db_get_record_sync(p_id: str):
document.get(field, person_info_default[field]) if document else person_info_default[field] return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
)
record = await asyncio.to_thread(_db_get_record_sync, person_id)
for field_name in field_names:
if field_name not in PersonInfo._meta.fields:
if field_name in person_info_default:
result[field_name] = copy.deepcopy(person_info_default[field_name])
logger.trace(f"字段'{field_name}'不在Peewee模型中使用默认配置值。")
else:
logger.debug(f"get_values查询失败字段'{field_name}'未在Peewee模型和默认配置中定义。")
result[field_name] = None
continue
if record:
value = getattr(record, field_name)
if field_name == "msg_interval_list" and isinstance(value, str):
try:
result[field_name] = json.loads(value)
except json.JSONDecodeError:
logger.warning(f"无法解析 {person_id} 的 msg_interval_list JSON: {value}")
result[field_name] = copy.deepcopy(person_info_default.get(field_name, []))
elif value is not None:
result[field_name] = value
else:
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
else:
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
return result return result
@staticmethod @staticmethod
async def del_all_undefined_field(): async def del_all_undefined_field():
"""删除所有项里的未定义字段""" """删除所有项里的未定义字段 - 对于Peewee (SQL),此操作通常不适用,因为模式是固定的。"""
# 获取所有已定义的字段名 logger.info(
defined_fields = set(person_info_default.keys()) "del_all_undefined_field: 对于使用Peewee的SQL数据库此操作通常不适用或不需要因为表结构是预定义的。"
)
try: return
# 遍历集合中的所有文档
for document in db.person_info.find({}):
# 找出文档中未定义的字段
undefined_fields = set(document.keys()) - defined_fields - {"_id"}
if undefined_fields:
# 构建更新操作,使用$unset删除未定义字段
update_result = db.person_info.update_one(
{"_id": document["_id"]}, {"$unset": {field: 1 for field in undefined_fields}}
)
if update_result.modified_count > 0:
logger.debug(f"已清理文档 {document['_id']} 的未定义字段: {undefined_fields}")
return
except Exception as e:
logger.error(f"清理未定义字段时出错: {e}")
return
@staticmethod @staticmethod
async def get_specific_value_list( async def get_specific_value_list(
field_name: str, field_name: str,
way: Callable[[Any], bool], # 接受任意类型值 way: Callable[[Any], bool],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
获取满足条件的字段值字典 获取满足条件的字段值字典
Args:
field_name: 目标字段名
way: 判断函数 (value: Any) -> bool
Returns:
{person_id: value} | {}
Example:
# 查找所有nickname包含"admin"的用户
result = manager.specific_value_list(
"nickname",
lambda x: "admin" in x.lower()
)
""" """
if field_name not in person_info_default: if field_name not in PersonInfo._meta.fields:
logger.error(f"字段检查失败:'{field_name}'未定义") logger.error(f"字段检查失败:'{field_name}'在 PersonInfo Peewee 模型中定义")
return {} return {}
def _db_get_specific_sync(f_name: str):
found_results = {}
try:
for record in PersonInfo.select(PersonInfo.person_id, getattr(PersonInfo, f_name)):
value = getattr(record, f_name)
if f_name == "msg_interval_list" and isinstance(value, str):
try:
processed_value = json.loads(value)
except json.JSONDecodeError:
logger.warning(f"跳过记录 {record.person_id},无法解析 msg_interval_list: {value}")
continue
else:
processed_value = value
if way(processed_value):
found_results[record.person_id] = processed_value
except Exception as e_query:
logger.error(f"数据库查询失败 (Peewee specific_value_list for {f_name}): {str(e_query)}", exc_info=True)
return found_results
try: try:
result = {} return await asyncio.to_thread(_db_get_specific_sync, field_name)
for doc in db.person_info.find({field_name: {"$exists": True}}, {"person_id": 1, field_name: 1, "_id": 0}):
try:
value = doc[field_name]
if way(value):
result[doc["person_id"]] = value
except (KeyError, TypeError, ValueError) as e:
logger.debug(f"记录{doc.get('person_id')}处理失败: {str(e)}")
continue
return result
except Exception as e: except Exception as e:
logger.error(f"数据库查询失败: {str(e)}", exc_info=True) logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True)
return {} return {}
async def personal_habit_deduction(self): async def personal_habit_deduction(self):
@@ -399,35 +475,31 @@ class PersonInfoManager:
try: try:
while 1: while 1:
await asyncio.sleep(600) await asyncio.sleep(600)
current_time = datetime.datetime.now() current_time_dt = datetime.datetime.now()
logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}") logger.info(f"个人信息推断启动: {current_time_dt.strftime('%Y-%m-%d %H:%M:%S')}")
# "msg_interval"推断 msg_interval_map_generated = False
msg_interval_map = False msg_interval_lists_map = await self.get_specific_value_list(
msg_interval_lists = await self.get_specific_value_list(
"msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100 "msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100
) )
for person_id, msg_interval_list_ in msg_interval_lists.items():
for person_id, actual_msg_interval_list in msg_interval_lists_map.items():
await asyncio.sleep(0.3) await asyncio.sleep(0.3)
try: try:
time_interval = [] time_interval = []
for t1, t2 in zip(msg_interval_list_, msg_interval_list_[1:]): for t1, t2 in zip(actual_msg_interval_list, actual_msg_interval_list[1:]):
delta = t2 - t1 delta = t2 - t1
if delta > 0: if delta > 0:
time_interval.append(delta) time_interval.append(delta)
time_interval = [t for t in time_interval if 200 <= t <= 8000] time_interval = [t for t in time_interval if 200 <= t <= 8000]
# --- 修改后的逻辑 ---
# 数据量检查 (至少需要 30 条有效间隔,并且足够进行头尾截断)
if len(time_interval) >= 30 + 10: # 至少30条有效+头尾各5条
time_interval.sort()
# 画图(log) - 这部分保留 if len(time_interval) >= 30 + 10:
msg_interval_map = True time_interval.sort()
msg_interval_map_generated = True
log_dir = Path("logs/person_info") log_dir = Path("logs/person_info")
log_dir.mkdir(parents=True, exist_ok=True) log_dir.mkdir(parents=True, exist_ok=True)
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))
# 使用截断前的数据画图,更能反映原始分布
time_series_original = pd.Series(time_interval) time_series_original = pd.Series(time_interval)
plt.hist( plt.hist(
time_series_original, time_series_original,
@@ -449,34 +521,29 @@ class PersonInfoManager:
img_path = log_dir / f"interval_distribution_{person_id[:8]}.png" img_path = log_dir / f"interval_distribution_{person_id[:8]}.png"
plt.savefig(img_path) plt.savefig(img_path)
plt.close() plt.close()
# 画图结束
# 去掉头尾各 5 个数据点
trimmed_interval = time_interval[5:-5] trimmed_interval = time_interval[5:-5]
if trimmed_interval:
# 计算截断后数据的 37% 分位数 msg_interval_val = int(round(np.percentile(trimmed_interval, 37)))
if trimmed_interval: # 确保截断后列表不为空 await self.update_one_field(person_id, "msg_interval", msg_interval_val)
msg_interval = int(round(np.percentile(trimmed_interval, 37))) logger.trace(
# 更新数据库 f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}"
await self.update_one_field(person_id, "msg_interval", msg_interval) )
logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval}")
else: else:
logger.trace(f"用户{person_id}截断后数据为空无法计算msg_interval") logger.trace(f"用户{person_id}截断后数据为空无法计算msg_interval")
else: else:
logger.trace( logger.trace(
f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)" f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)"
) )
# --- 修改结束 --- except Exception as e_inner:
except Exception as e: logger.trace(f"用户{person_id}消息间隔计算失败: {type(e_inner).__name__}: {str(e_inner)}")
logger.trace(f"用户{person_id}消息间隔计算失败: {type(e).__name__}: {str(e)}")
continue continue
# 其他... if msg_interval_map_generated:
if msg_interval_map:
logger.trace("已保存分布图到: logs/person_info") logger.trace("已保存分布图到: logs/person_info")
current_time = datetime.datetime.now()
logger.trace(f"个人信息推断结束: {current_time.strftime('%Y-%m-%d %H:%M:%S')}") current_time_dt_end = datetime.datetime.now()
logger.trace(f"个人信息推断结束: {current_time_dt_end.strftime('%Y-%m-%d %H:%M:%S')}")
await asyncio.sleep(86400) await asyncio.sleep(86400)
except Exception as e: except Exception as e:
@@ -489,41 +556,27 @@ class PersonInfoManager:
""" """
根据 platform 和 user_id 获取 person_id。 根据 platform 和 user_id 获取 person_id。
如果对应的用户不存在,则使用提供的可选信息创建新用户。 如果对应的用户不存在,则使用提供的可选信息创建新用户。
Args:
platform: 平台标识
user_id: 用户在该平台上的ID
nickname: 用户的昵称 (可选,用于创建新用户)
user_cardname: 用户的群名片 (可选,用于创建新用户)
user_avatar: 用户的头像信息 (可选,用于创建新用户)
Returns:
对应的 person_id。
""" """
person_id = self.get_person_id(platform, user_id) person_id = self.get_person_id(platform, user_id)
# 检查用户是否已存在 def _db_check_exists_sync(p_id: str):
# 使用静态方法 get_person_id因此可以直接调用 db return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
document = db.person_info.find_one({"person_id": person_id})
if document is None: record = await asyncio.to_thread(_db_check_exists_sync, person_id)
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。")
if record is None:
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
initial_data = { initial_data = {
"platform": platform, "platform": platform,
"user_id": user_id, "user_id": str(user_id),
"nickname": nickname, "nickname": nickname,
"konw_time": int(datetime.datetime.now().timestamp()), # 添加初次认识时间 "know_time": int(datetime.datetime.now().timestamp()), # 修正拼写konw_time -> know_time
# 注意:这里没有添加 user_cardname 和 user_avatar因为它们不在 person_info_default 中
# 如果需要存储它们,需要先在 person_info_default 中定义
} }
# 过滤掉值为 None 的初始数据 model_fields = PersonInfo._meta.fields.keys()
initial_data = {k: v for k, v in initial_data.items() if v is not None} filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
# 注意create_person_info 是静态方法 await self.create_person_info(person_id, data=filtered_initial_data)
await PersonInfoManager.create_person_info(person_id, data=initial_data) logger.debug(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
# 创建后,可以考虑立即为其取名,但这可能会增加延迟
# await self.qv_person_name(person_id, nickname, user_cardname, user_avatar)
logger.debug(f"已为 {person_id} 创建新记录,初始数据: {initial_data}")
return person_id return person_id
@@ -533,35 +586,55 @@ class PersonInfoManager:
logger.debug("get_person_info_by_name 获取失败person_name 不能为空") logger.debug("get_person_info_by_name 获取失败person_name 不能为空")
return None return None
# 优先从内存缓存查找 person_id
found_person_id = None found_person_id = None
for pid, name in self.person_name_list.items(): for pid, name_in_cache in self.person_name_list.items():
if name == person_name: if name_in_cache == person_name:
found_person_id = pid found_person_id = pid
break # 找到第一个匹配就停止 break
if not found_person_id: if not found_person_id:
# 如果内存没有,尝试数据库查询(可能内存未及时更新或启动时未加载)
document = db.person_info.find_one({"person_name": person_name})
if document:
found_person_id = document.get("person_id")
else:
logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户")
return None # 数据库也找不到
# 根据找到的 person_id 获取所需信息 def _db_find_by_name_sync(p_name_to_find: str):
if found_person_id: return PersonInfo.get_or_none(PersonInfo.person_name == p_name_to_find)
required_fields = ["person_id", "platform", "user_id", "nickname", "user_cardname", "user_avatar"]
person_data = await self.get_values(found_person_id, required_fields) record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
if person_data: # 确保 get_values 成功返回 if record:
return person_data found_person_id = record.person_id
if (
found_person_id not in self.person_name_list
or self.person_name_list[found_person_id] != person_name
):
self.person_name_list[found_person_id] = person_name
else: else:
logger.warning(f"找到了 person_id '{found_person_id}' 但获取详细信息失败") logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)")
return None return None
else:
# 这理论上不应该发生,因为上面已经处理了找不到的情况 if found_person_id:
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id") required_fields = [
return None "person_id",
"platform",
"user_id",
"nickname",
"user_cardname",
"user_avatar",
"person_name",
"name_reason",
]
valid_fields_to_get = [
f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default
]
person_data = await self.get_values(found_person_id, valid_fields_to_get)
if person_data:
final_result = {key: person_data.get(key) for key in required_fields}
return final_result
else:
logger.warning(f"找到了 person_id '{found_person_id}' 但 get_values 返回空 (Peewee)")
return None
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id (Peewee)")
return None
person_info_manager = PersonInfoManager() person_info_manager = PersonInfoManager()

View File

@@ -190,8 +190,8 @@ async def _build_readable_messages_internal(
person_id = person_info_manager.get_person_id(platform, user_id) person_id = person_info_manager.get_person_id(platform, user_id)
# 根据 replace_bot_name 参数决定是否替换机器人名称 # 根据 replace_bot_name 参数决定是否替换机器人名称
if replace_bot_name and user_id == global_config.BOT_QQ: if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.BOT_NICKNAME}(你)" person_name = f"{global_config.bot.nickname}(你)"
else: else:
person_name = await person_info_manager.get_value(person_id, "person_name") person_name = await person_info_manager.get_value(person_id, "person_name")
@@ -427,7 +427,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
output_lines = [] output_lines = []
def get_anon_name(platform, user_id): def get_anon_name(platform, user_id):
if user_id == global_config.BOT_QQ: if user_id == global_config.bot.qq_account:
return "SELF" return "SELF"
person_id = person_info_manager.get_person_id(platform, user_id) person_id = person_info_manager.get_person_id(platform, user_id)
if person_id not in person_map: if person_id not in person_map:
@@ -454,7 +454,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
def reply_replacer(match, platform=platform): def reply_replacer(match, platform=platform):
# aaa = match.group(1) # aaa = match.group(1)
bbb = match.group(2) bbb = match.group(2)
anon_reply = get_anon_name(platform, bbb) anon_reply = get_anon_name(platform, bbb) # noqa
return f"回复 {anon_reply}" return f"回复 {anon_reply}"
content = re.sub(reply_pattern, reply_replacer, content, count=1) content = re.sub(reply_pattern, reply_replacer, content, count=1)
@@ -465,7 +465,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
def at_replacer(match, platform=platform): def at_replacer(match, platform=platform):
# aaa = match.group(1) # aaa = match.group(1)
bbb = match.group(2) bbb = match.group(2)
anon_at = get_anon_name(platform, bbb) anon_at = get_anon_name(platform, bbb) # noqa
return f"@{anon_at}" return f"@{anon_at}"
content = re.sub(at_pattern, at_replacer, content) content = re.sub(at_pattern, at_replacer, content)
@@ -501,7 +501,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
user_id = user_info.get("user_id") user_id = user_info.get("user_id")
# 检查必要信息是否存在 且 不是机器人自己 # 检查必要信息是否存在 且 不是机器人自己
if not all([platform, user_id]) or user_id == global_config.BOT_QQ: if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
continue continue
person_id = person_info_manager.get_person_id(platform, user_id) person_id = person_info_manager.get_person_id(platform, user_id)

View File

@@ -1,15 +1,15 @@
from src.config.config import global_config from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv, MessageSending, Message from src.chat.message_receive.message import MessageRecv, MessageSending, Message
from src.common.database import db from src.common.database.database_model import Messages, ThinkingLog
import time import time
import traceback import traceback
from typing import List from typing import List
import json
class InfoCatcher: class InfoCatcher:
def __init__(self): def __init__(self):
self.chat_history = [] # 聊天历史,长度为三倍使用的上下文喵~ self.chat_history = [] # 聊天历史,长度为三倍使用的上下文喵~
self.context_length = global_config.observation_context_size
self.chat_history_in_thinking = [] # 思考期间的聊天内容喵~ self.chat_history_in_thinking = [] # 思考期间的聊天内容喵~
self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文喵~ self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文喵~
@@ -60,8 +60,6 @@ class InfoCatcher:
def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息 def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息
self.timing_results["sub_heartflow_observe_time"] = obs_duration self.timing_results["sub_heartflow_observe_time"] = obs_duration
# def catch_shf
def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str): def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str):
self.timing_results["sub_heartflow_step_time"] = step_duration self.timing_results["sub_heartflow_step_time"] = step_duration
if len(past_mind) > 1: if len(past_mind) > 1:
@@ -72,25 +70,10 @@ class InfoCatcher:
self.heartflow_data["sub_heartflow_now"] = current_mind self.heartflow_data["sub_heartflow_now"] = current_mind
def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""): def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""):
# if self.response_mode == "heart_flow": # 条件判断不需要了喵~
# self.heartflow_data["prompt"] = prompt
# self.heartflow_data["response"] = response
# self.heartflow_data["model"] = model_name
# elif self.response_mode == "reasoning": # 条件判断不需要了喵~
# self.reasoning_data["thinking_log"] = reasoning_content
# self.reasoning_data["prompt"] = prompt
# self.reasoning_data["response"] = response
# self.reasoning_data["model"] = model_name
# 直接记录信息喵~
self.reasoning_data["thinking_log"] = reasoning_content self.reasoning_data["thinking_log"] = reasoning_content
self.reasoning_data["prompt"] = prompt self.reasoning_data["prompt"] = prompt
self.reasoning_data["response"] = response self.reasoning_data["response"] = response
self.reasoning_data["model"] = model_name self.reasoning_data["model"] = model_name
# 如果 heartflow 数据也需要通用字段,可以取消下面的注释喵~
# self.heartflow_data["prompt"] = prompt
# self.heartflow_data["response"] = response
# self.heartflow_data["model"] = model_name
self.response_text = response self.response_text = response
@@ -102,6 +85,7 @@ class InfoCatcher:
): ):
self.timing_results["make_response_time"] = response_duration self.timing_results["make_response_time"] = response_duration
self.response_time = time.time() self.response_time = time.time()
self.response_messages = []
for msg in response_message: for msg in response_message:
self.response_messages.append(msg) self.response_messages.append(msg)
@@ -112,107 +96,112 @@ class InfoCatcher:
@staticmethod @staticmethod
def get_message_from_db_between_msgs(message_start: Message, message_end: Message): def get_message_from_db_between_msgs(message_start: Message, message_end: Message):
try: try:
# 从数据库中获取消息的时间戳
time_start = message_start.message_info.time time_start = message_start.message_info.time
time_end = message_end.message_info.time time_end = message_end.message_info.time
chat_id = message_start.chat_stream.stream_id chat_id = message_start.chat_stream.stream_id
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}") print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
# 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据 messages_between_query = (
messages_between = db.messages.find( Messages.select()
{"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}} .where((Messages.chat_id == chat_id) & (Messages.time > time_start) & (Messages.time < time_end))
).sort("time", -1) .order_by(Messages.time.desc())
)
result = list(messages_between) result = list(messages_between_query)
print(f"查询结果数量: {len(result)}") print(f"查询结果数量: {len(result)}")
if result: if result:
print(f"第一条消息时间: {result[0]['time']}") print(f"第一条消息时间: {result[0].time}")
print(f"最后一条消息时间: {result[-1]['time']}") print(f"最后一条消息时间: {result[-1].time}")
return result return result
except Exception as e: except Exception as e:
print(f"获取消息时出错: {str(e)}") print(f"获取消息时出错: {str(e)}")
print(traceback.format_exc())
return [] return []
def get_message_from_db_before_msg(self, message: MessageRecv): def get_message_from_db_before_msg(self, message: MessageRecv):
# 从数据库中获取消息 message_id_val = message.message_info.message_id
message_id = message.message_info.message_id chat_id_val = message.chat_stream.stream_id
chat_id = message.chat_stream.stream_id
# 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据 messages_before_query = (
messages_before = ( Messages.select()
db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}}) .where((Messages.chat_id == chat_id_val) & (Messages.message_id < message_id_val))
.sort("time", -1) .order_by(Messages.time.desc())
.limit(self.context_length * 3) .limit(global_config.chat.observation_context_size * 3)
) # 获取更多历史信息 )
return list(messages_before) return list(messages_before_query)
def message_list_to_dict(self, message_list): def message_list_to_dict(self, message_list):
# 存储简化的聊天记录
result = [] result = []
for message in message_list: for msg_item in message_list:
if not isinstance(message, dict): processed_msg_item = msg_item
message = self.message_to_dict(message) if not isinstance(msg_item, dict):
# print(message) processed_msg_item = self.message_to_dict(msg_item)
if not processed_msg_item:
continue
lite_message = { lite_message = {
"time": message["time"], "time": processed_msg_item.get("time"),
"user_nickname": message["user_info"]["user_nickname"], "user_nickname": processed_msg_item.get("user_nickname"),
"processed_plain_text": message["processed_plain_text"], "processed_plain_text": processed_msg_item.get("processed_plain_text"),
} }
result.append(lite_message) result.append(lite_message)
return result return result
@staticmethod @staticmethod
def message_to_dict(message): def message_to_dict(msg_obj):
if not message: if not msg_obj:
return None return None
if isinstance(message, dict): if isinstance(msg_obj, dict):
return message return msg_obj
return {
# "message_id": message.message_info.message_id,
"time": message.message_info.time,
"user_id": message.message_info.user_info.user_id,
"user_nickname": message.message_info.user_info.user_nickname,
"processed_plain_text": message.processed_plain_text,
# "detailed_plain_text": message.detailed_plain_text
}
def done_catch(self): if isinstance(msg_obj, Messages):
"""将收集到的信息存储到数据库的 thinking_log 集合中喵~""" return {
try: "time": msg_obj.time,
# 将消息对象转换为可序列化的字典喵~ "user_id": msg_obj.user_id,
"user_nickname": msg_obj.user_nickname,
thinking_log_data = { "processed_plain_text": msg_obj.processed_plain_text,
"chat_id": self.chat_id,
"trigger_text": self.trigger_response_text,
"response_text": self.response_text,
"trigger_info": {
"time": self.trigger_response_time,
"message": self.message_to_dict(self.trigger_response_message),
},
"response_info": {
"time": self.response_time,
"message": self.response_messages,
},
"timing_results": self.timing_results,
"chat_history": self.message_list_to_dict(self.chat_history),
"chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking),
"chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response),
"heartflow_data": self.heartflow_data,
"reasoning_data": self.reasoning_data,
} }
# 根据不同的响应模式添加相应的数据喵~ # 现在直接都加上去好了喵~ if hasattr(msg_obj, "message_info") and hasattr(msg_obj.message_info, "user_info"):
# if self.response_mode == "heart_flow": return {
# thinking_log_data["mode_specific_data"] = self.heartflow_data "time": msg_obj.message_info.time,
# elif self.response_mode == "reasoning": "user_id": msg_obj.message_info.user_info.user_id,
# thinking_log_data["mode_specific_data"] = self.reasoning_data "user_nickname": msg_obj.message_info.user_info.user_nickname,
"processed_plain_text": msg_obj.processed_plain_text,
}
# 将数据插入到 thinking_log 集合中喵~ print(f"Warning: message_to_dict received an unhandled type: {type(msg_obj)}")
db.thinking_log.insert_one(thinking_log_data) return {}
def done_catch(self):
"""将收集到的信息存储到数据库的 thinking_log 表中喵~"""
try:
trigger_info_dict = self.message_to_dict(self.trigger_response_message)
response_info_dict = {
"time": self.response_time,
"message": self.response_messages,
}
chat_history_list = self.message_list_to_dict(self.chat_history)
chat_history_in_thinking_list = self.message_list_to_dict(self.chat_history_in_thinking)
chat_history_after_response_list = self.message_list_to_dict(self.chat_history_after_response)
log_entry = ThinkingLog(
chat_id=self.chat_id,
trigger_text=self.trigger_response_text,
response_text=self.response_text,
trigger_info_json=json.dumps(trigger_info_dict) if trigger_info_dict else None,
response_info_json=json.dumps(response_info_dict),
timing_results_json=json.dumps(self.timing_results),
chat_history_json=json.dumps(chat_history_list),
chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list),
chat_history_after_response_json=json.dumps(chat_history_after_response_list),
heartflow_data_json=json.dumps(self.heartflow_data),
reasoning_data_json=json.dumps(self.reasoning_data),
)
log_entry.save()
return True return True
except Exception as e: except Exception as e:

View File

@@ -2,10 +2,12 @@ from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List from typing import Any, Dict, Tuple, List
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.manager.async_task_manager import AsyncTask from src.manager.async_task_manager import AsyncTask
from ...common.database import db from ...common.database.database import db # This db is the Peewee database instance
from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model
from src.manager.local_store_manager import local_storage from src.manager.local_store_manager import local_storage
logger = get_module_logger("maibot_statistic") logger = get_module_logger("maibot_statistic")
@@ -39,7 +41,7 @@ class OnlineTimeRecordTask(AsyncTask):
def __init__(self): def __init__(self):
super().__init__(task_name="Online Time Record Task", run_interval=60) super().__init__(task_name="Online Time Record Task", run_interval=60)
self.record_id: str | None = None self.record_id: int | None = None # Changed to int for Peewee's default ID
"""记录ID""" """记录ID"""
self._init_database() # 初始化数据库 self._init_database() # 初始化数据库
@@ -47,49 +49,46 @@ class OnlineTimeRecordTask(AsyncTask):
@staticmethod @staticmethod
def _init_database(): def _init_database():
"""初始化数据库""" """初始化数据库"""
if "online_time" not in db.list_collection_names(): with db.atomic(): # Use atomic operations for schema changes
# 初始化数据库(在线时长) OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
db.create_collection("online_time")
# 创建索引
if ("end_timestamp", 1) not in db.online_time.list_indexes():
db.online_time.create_index([("end_timestamp", 1)])
async def run(self): async def run(self):
try: try:
current_time = datetime.now()
extended_end_time = current_time + timedelta(minutes=1)
if self.record_id: if self.record_id:
# 如果有记录,则更新结束时间 # 如果有记录,则更新结束时间
db.online_time.update_one( query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id)
{"_id": self.record_id}, updated_rows = query.execute()
{ if updated_rows == 0:
"$set": { # Record might have been deleted or ID is stale, try to find/create
"end_timestamp": datetime.now() + timedelta(minutes=1), self.record_id = None # Reset record_id to trigger find/create logic below
}
}, if not self.record_id: # Check again if record_id was reset or initially None
)
else:
# 如果没有记录,检查一分钟以内是否已有记录 # 如果没有记录,检查一分钟以内是否已有记录
current_time = datetime.now() # Look for a record whose end_timestamp is recent enough to be considered ongoing
if recent_record := db.online_time.find_one( recent_record = (
{"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}} OnlineTime.select()
): .where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1)))
.order_by(OnlineTime.end_timestamp.desc())
.first()
)
if recent_record:
# 如果有记录,则更新结束时间 # 如果有记录,则更新结束时间
self.record_id = recent_record["_id"] self.record_id = recent_record.id
db.online_time.update_one( recent_record.end_timestamp = extended_end_time
{"_id": self.record_id}, recent_record.save()
{
"$set": {
"end_timestamp": current_time + timedelta(minutes=1),
}
},
)
else: else:
# 若没有记录,则插入新的在线时间记录 # 若没有记录,则插入新的在线时间记录
self.record_id = db.online_time.insert_one( new_record = OnlineTime.create(
{ timestamp=current_time.timestamp(), # 添加此行
"start_timestamp": current_time, start_timestamp=current_time,
"end_timestamp": current_time + timedelta(minutes=1), end_timestamp=extended_end_time,
} duration=5, # 初始时长为5分钟
).inserted_id )
self.record_id = new_record.id
except Exception as e: except Exception as e:
logger.error(f"在线时间记录失败,错误信息:{e}") logger.error(f"在线时间记录失败,错误信息:{e}")
@@ -201,35 +200,28 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段 :param collect_period: 统计时间段
""" """
if len(collect_period) <= 0: if not collect_period:
return {} return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前) # 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True) collect_period.sort(key=lambda x: x[1], reverse=True)
stats = { stats = {
period_key: { period_key: {
# 总LLM请求数
TOTAL_REQ_CNT: 0, TOTAL_REQ_CNT: 0,
# 请求次数统计
REQ_CNT_BY_TYPE: defaultdict(int), REQ_CNT_BY_TYPE: defaultdict(int),
REQ_CNT_BY_USER: defaultdict(int), REQ_CNT_BY_USER: defaultdict(int),
REQ_CNT_BY_MODEL: defaultdict(int), REQ_CNT_BY_MODEL: defaultdict(int),
# 输入Token数
IN_TOK_BY_TYPE: defaultdict(int), IN_TOK_BY_TYPE: defaultdict(int),
IN_TOK_BY_USER: defaultdict(int), IN_TOK_BY_USER: defaultdict(int),
IN_TOK_BY_MODEL: defaultdict(int), IN_TOK_BY_MODEL: defaultdict(int),
# 输出Token数
OUT_TOK_BY_TYPE: defaultdict(int), OUT_TOK_BY_TYPE: defaultdict(int),
OUT_TOK_BY_USER: defaultdict(int), OUT_TOK_BY_USER: defaultdict(int),
OUT_TOK_BY_MODEL: defaultdict(int), OUT_TOK_BY_MODEL: defaultdict(int),
# 总Token数
TOTAL_TOK_BY_TYPE: defaultdict(int), TOTAL_TOK_BY_TYPE: defaultdict(int),
TOTAL_TOK_BY_USER: defaultdict(int), TOTAL_TOK_BY_USER: defaultdict(int),
TOTAL_TOK_BY_MODEL: defaultdict(int), TOTAL_TOK_BY_MODEL: defaultdict(int),
# 总开销
TOTAL_COST: 0.0, TOTAL_COST: 0.0,
# 请求开销统计
COST_BY_TYPE: defaultdict(float), COST_BY_TYPE: defaultdict(float),
COST_BY_USER: defaultdict(float), COST_BY_USER: defaultdict(float),
COST_BY_MODEL: defaultdict(float), COST_BY_MODEL: defaultdict(float),
@@ -238,26 +230,26 @@ class StatisticOutputTask(AsyncTask):
} }
# 以最早的时间戳为起始时间获取记录 # 以最早的时间戳为起始时间获取记录
for record in db.llm_usage.find({"timestamp": {"$gte": collect_period[-1][1]}}): # Assuming LLMUsage.timestamp is a DateTimeField
record_timestamp = record.get("timestamp") query_start_time = collect_period[-1][1]
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
record_timestamp = record.timestamp # This is already a datetime object
for idx, (_, period_start) in enumerate(collect_period): for idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start: if record_timestamp >= period_start:
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _ in collect_period[idx:]: for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_REQ_CNT] += 1 stats[period_key][TOTAL_REQ_CNT] += 1
request_type = record.get("request_type", "unknown") # 请求类型 request_type = record.request_type or "unknown"
user_id = str(record.get("user_id", "unknown")) # 用户ID user_id = record.user_id or "unknown" # user_id is TextField, already string
model_name = record.get("model_name", "unknown") # 模型名称 model_name = record.model_name or "unknown"
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1 stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
stats[period_key][REQ_CNT_BY_USER][user_id] += 1 stats[period_key][REQ_CNT_BY_USER][user_id] += 1
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1 stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
prompt_tokens = record.get("prompt_tokens", 0) # 输入Token数 prompt_tokens = record.prompt_tokens or 0
completion_tokens = record.get("completion_tokens", 0) # 输出Token数 completion_tokens = record.completion_tokens or 0
total_tokens = prompt_tokens + completion_tokens # Token总数 = 输入Token数 + 输出Token数 total_tokens = prompt_tokens + completion_tokens
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
@@ -271,13 +263,12 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
cost = record.get("cost", 0.0) cost = record.cost or 0.0
stats[period_key][TOTAL_COST] += cost stats[period_key][TOTAL_COST] += cost
stats[period_key][COST_BY_TYPE][request_type] += cost stats[period_key][COST_BY_TYPE][request_type] += cost
stats[period_key][COST_BY_USER][user_id] += cost stats[period_key][COST_BY_USER][user_id] += cost
stats[period_key][COST_BY_MODEL][model_name] += cost stats[period_key][COST_BY_MODEL][model_name] += cost
break # 取消更早时间段的判断 break
return stats return stats
@staticmethod @staticmethod
@@ -287,39 +278,38 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段 :param collect_period: 统计时间段
""" """
if len(collect_period) <= 0: if not collect_period:
return {} return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前) collect_period.sort(key=lambda x: x[1], reverse=True)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = { stats = {
period_key: { period_key: {
# 在线时间统计
ONLINE_TIME: 0.0, ONLINE_TIME: 0.0,
} }
for period_key, _ in collect_period for period_key, _ in collect_period
} }
# 统计在线时间 query_start_time = collect_period[-1][1]
for record in db.online_time.find({"end_timestamp": {"$gte": collect_period[-1][1]}}): # Assuming OnlineTime.end_timestamp is a DateTimeField
end_timestamp: datetime = record.get("end_timestamp") for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time):
for idx, (_, period_start) in enumerate(collect_period): # record.end_timestamp and record.start_timestamp are datetime objects
if end_timestamp >= period_start: record_end_timestamp = record.end_timestamp
# 由于end_timestamp会超前标记时间所以我们需要判断是否晚于当前时间如果是则使用当前时间作为结束时间 record_start_timestamp = record.start_timestamp
end_timestamp = min(end_timestamp, now)
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _period_start in collect_period[idx:]:
start_timestamp: datetime = record.get("start_timestamp")
if start_timestamp < _period_start:
# 如果开始时间在查询边界之前,则使用开始时间
stats[period_key][ONLINE_TIME] += (end_timestamp - _period_start).total_seconds()
else:
# 否则,使用开始时间
stats[period_key][ONLINE_TIME] += (end_timestamp - start_timestamp).total_seconds()
break # 取消更早时间段的判断
for idx, (_, period_boundary_start) in enumerate(collect_period):
if record_end_timestamp >= period_boundary_start:
# Calculate effective end time for this record in relation to 'now'
effective_end_time = min(record_end_timestamp, now)
for period_key, current_period_start_time in collect_period[idx:]:
# Determine the portion of the record that falls within this specific statistical period
overlap_start = max(record_start_timestamp, current_period_start_time)
overlap_end = effective_end_time # Already capped by 'now' and record's own end
if overlap_end > overlap_start:
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
break
return stats return stats
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
@@ -328,55 +318,57 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段 :param collect_period: 统计时间段
""" """
if len(collect_period) <= 0: if not collect_period:
return {} return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前) collect_period.sort(key=lambda x: x[1], reverse=True)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = { stats = {
period_key: { period_key: {
# 消息统计
TOTAL_MSG_CNT: 0, TOTAL_MSG_CNT: 0,
MSG_CNT_BY_CHAT: defaultdict(int), MSG_CNT_BY_CHAT: defaultdict(int),
} }
for period_key, _ in collect_period for period_key, _ in collect_period
} }
# 统计消息量 query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in db.messages.find({"time": {"$gte": collect_period[-1][1].timestamp()}}): for message in Messages.select().where(Messages.time >= query_start_timestamp):
chat_info = message.get("chat_info", None) # 聊天信息 message_time_ts = message.time # This is a float timestamp
user_info = message.get("user_info", None) # 用户信息(消息发送人)
message_time = message.get("time", 0) # 消息时间
group_info = chat_info.get("group_info") if chat_info else None # 尝试获取群聊信息 chat_id = None
if group_info is not None: chat_name = None
# 若有群聊信息
chat_id = f"g{group_info.get('group_id')}" # Logic based on Peewee model structure, aiming to replicate original intent
chat_name = group_info.get("group_name", f"{group_info.get('group_id')}") if message.chat_info_group_id:
elif user_info: chat_id = f"g{message.chat_info_group_id}"
# 若没有群聊信息,则尝试获取用户信息 chat_name = message.chat_info_group_name or f"{message.chat_info_group_id}"
chat_id = f"u{user_info['user_id']}" elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat
chat_name = user_info["user_nickname"] # This uses the message SENDER's ID as per original logic's fallback
chat_id = f"u{message.user_id}" # SENDER's user_id
chat_name = message.user_nickname # SENDER's nickname
else: else:
continue # 如果没有群组信息也没有用户信息,则跳过 # If neither group_id nor sender_id is available for chat identification
logger.warning(
f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats."
)
continue
if not chat_id: # Should not happen if above logic is correct
continue
# Update name_mapping
if chat_id in self.name_mapping: if chat_id in self.name_mapping:
if chat_name != self.name_mapping[chat_id][0] and message_time > self.name_mapping[chat_id][1]: if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
# 如果用户名称不同,且新消息时间晚于之前记录的时间,则更新用户名称 self.name_mapping[chat_id] = (chat_name, message_time_ts)
self.name_mapping[chat_id] = (chat_name, message_time)
else: else:
self.name_mapping[chat_id] = (chat_name, message_time) self.name_mapping[chat_id] = (chat_name, message_time_ts)
for idx, (_, period_start) in enumerate(collect_period): for idx, (_, period_start_dt) in enumerate(collect_period):
if message_time >= period_start.timestamp(): if message_time_ts >= period_start_dt.timestamp():
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _ in collect_period[idx:]: for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_MSG_CNT] += 1 stats[period_key][TOTAL_MSG_CNT] += 1
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1 stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
break break
return stats return stats
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:

View File

@@ -13,7 +13,7 @@ from src.manager.mood_manager import mood_manager
from ..message_receive.message import MessageRecv from ..message_receive.message import MessageRecv
from ..models.utils_model import LLMRequest from ..models.utils_model import LLMRequest
from .typo_generator import ChineseTypoGenerator from .typo_generator import ChineseTypoGenerator
from ...common.database import db from ...common.database.database import db
from ...config.config import global_config from ...config.config import global_config
logger = get_module_logger("chat_utils") logger = get_module_logger("chat_utils")
@@ -43,8 +43,8 @@ def db_message_to_str(message_dict: dict) -> str:
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
"""检查消息是否提到了机器人""" """检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME] keywords = [global_config.bot.nickname]
nicknames = global_config.BOT_ALIAS_NAMES nicknames = global_config.bot.alias_names
reply_probability = 0.0 reply_probability = 0.0
is_at = False is_at = False
is_mentioned = False is_mentioned = False
@@ -64,18 +64,18 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
) )
# 判断是否被@ # 判断是否被@
if re.search(f"@[\s\S]*?id:{global_config.BOT_QQ}", message.processed_plain_text): if re.search(f"@[\s\S]*?id:{global_config.bot.qq_account}", message.processed_plain_text):
is_at = True is_at = True
is_mentioned = True is_mentioned = True
if is_at and global_config.at_bot_inevitable_reply: if is_at and global_config.normal_chat.at_bot_inevitable_reply:
reply_probability = 1.0 reply_probability = 1.0
logger.info("被@回复概率设置为100%") logger.info("被@回复概率设置为100%")
else: else:
if not is_mentioned: if not is_mentioned:
# 判断是否被回复 # 判断是否被回复
if re.match( if re.match(
f"\[回复 [\s\S]*?\({str(global_config.BOT_QQ)}\)[\s\S]*?],说:", message.processed_plain_text f"\[回复 [\s\S]*?\({str(global_config.bot.qq_account)}\)[\s\S]*?],说:", message.processed_plain_text
): ):
is_mentioned = True is_mentioned = True
else: else:
@@ -88,7 +88,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
for nickname in nicknames: for nickname in nicknames:
if nickname in message_content: if nickname in message_content:
is_mentioned = True is_mentioned = True
if is_mentioned and global_config.mentioned_bot_inevitable_reply: if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply:
reply_probability = 1.0 reply_probability = 1.0
logger.info("被提及回复概率设置为100%") logger.info("被提及回复概率设置为100%")
return is_mentioned, reply_probability return is_mentioned, reply_probability
@@ -96,7 +96,8 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
async def get_embedding(text, request_type="embedding"): async def get_embedding(text, request_type="embedding"):
"""获取文本的embedding向量""" """获取文本的embedding向量"""
llm = LLMRequest(model=global_config.embedding, request_type=request_type) # TODO: API-Adapter修改标记
llm = LLMRequest(model=global_config.model.embedding, request_type=request_type)
# return llm.get_embedding_sync(text) # return llm.get_embedding_sync(text)
try: try:
embedding = await llm.get_embedding(text) embedding = await llm.get_embedding(text)
@@ -163,7 +164,7 @@ def get_recent_group_speaker(chat_stream_id: int, sender, limit: int = 12) -> li
user_info = UserInfo.from_dict(msg_db_data["user_info"]) user_info = UserInfo.from_dict(msg_db_data["user_info"])
if ( if (
(user_info.platform, user_info.user_id) != sender (user_info.platform, user_info.user_id) != sender
and user_info.user_id != global_config.BOT_QQ and user_info.user_id != global_config.bot.qq_account
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
and len(who_chat_in_group) < 5 and len(who_chat_in_group) < 5
): # 排除重复排除消息发送者排除bot限制加载的关系数目 ): # 排除重复排除消息发送者排除bot限制加载的关系数目
@@ -321,7 +322,7 @@ def random_remove_punctuation(text: str) -> str:
def process_llm_response(text: str) -> list[str]: def process_llm_response(text: str) -> list[str]:
# 先保护颜文字 # 先保护颜文字
if global_config.enable_kaomoji_protection: if global_config.response_splitter.enable_kaomoji_protection:
protected_text, kaomoji_mapping = protect_kaomoji(text) protected_text, kaomoji_mapping = protect_kaomoji(text)
logger.trace(f"保护颜文字后的文本: {protected_text}") logger.trace(f"保护颜文字后的文本: {protected_text}")
else: else:
@@ -340,8 +341,8 @@ def process_llm_response(text: str) -> list[str]:
logger.debug(f"{text}去除括号处理后的文本: {cleaned_text}") logger.debug(f"{text}去除括号处理后的文本: {cleaned_text}")
# 对清理后的文本进行进一步处理 # 对清理后的文本进行进一步处理
max_length = global_config.response_max_length * 2 max_length = global_config.response_splitter.max_length * 2
max_sentence_num = global_config.response_max_sentence_num max_sentence_num = global_config.response_splitter.max_sentence_num
# 如果基本上是中文,则进行长度过滤 # 如果基本上是中文,则进行长度过滤
if get_western_ratio(cleaned_text) < 0.1: if get_western_ratio(cleaned_text) < 0.1:
if len(cleaned_text) > max_length: if len(cleaned_text) > max_length:
@@ -349,20 +350,20 @@ def process_llm_response(text: str) -> list[str]:
return ["懒得说"] return ["懒得说"]
typo_generator = ChineseTypoGenerator( typo_generator = ChineseTypoGenerator(
error_rate=global_config.chinese_typo_error_rate, error_rate=global_config.chinese_typo.error_rate,
min_freq=global_config.chinese_typo_min_freq, min_freq=global_config.chinese_typo.min_freq,
tone_error_rate=global_config.chinese_typo_tone_error_rate, tone_error_rate=global_config.chinese_typo.tone_error_rate,
word_replace_rate=global_config.chinese_typo_word_replace_rate, word_replace_rate=global_config.chinese_typo.word_replace_rate,
) )
if global_config.enable_response_splitter: if global_config.response_splitter.enable:
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text) split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
else: else:
split_sentences = [cleaned_text] split_sentences = [cleaned_text]
sentences = [] sentences = []
for sentence in split_sentences: for sentence in split_sentences:
if global_config.chinese_typo_enable: if global_config.chinese_typo.enable:
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence) typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
sentences.append(typoed_text) sentences.append(typoed_text)
if typo_corrections: if typo_corrections:
@@ -372,7 +373,7 @@ def process_llm_response(text: str) -> list[str]:
if len(sentences) > max_sentence_num: if len(sentences) > max_sentence_num:
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复") logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
return [f"{global_config.BOT_NICKNAME}不知道哦"] return [f"{global_config.bot.nickname}不知道哦"]
# if extracted_contents: # if extracted_contents:
# for content in extracted_contents: # for content in extracted_contents:

View File

@@ -8,7 +8,8 @@ import io
import numpy as np import numpy as np
from ...common.database import db from ...common.database.database import db
from ...common.database.database_model import Images, ImageDescriptions
from ...config.config import global_config from ...config.config import global_config
from ..models.utils_model import LLMRequest from ..models.utils_model import LLMRequest
@@ -32,40 +33,23 @@ class ImageManager:
def __init__(self): def __init__(self):
if not self._initialized: if not self._initialized:
self._ensure_image_collection()
self._ensure_description_collection()
self._ensure_image_dir() self._ensure_image_dir()
self._initialized = True
self._llm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image")
try:
db.connect(reuse_if_open=True)
db.create_tables([Images, ImageDescriptions], safe=True)
except Exception as e:
logger.error(f"数据库连接或表创建失败: {e}")
self._initialized = True self._initialized = True
self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image")
def _ensure_image_dir(self): def _ensure_image_dir(self):
"""确保图像存储目录存在""" """确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True) os.makedirs(self.IMAGE_DIR, exist_ok=True)
@staticmethod
def _ensure_image_collection():
"""确保images集合存在并创建索引"""
if "images" not in db.list_collection_names():
db.create_collection("images")
# 删除旧索引
db.images.drop_indexes()
# 创建新的复合索引
db.images.create_index([("hash", 1), ("type", 1)], unique=True)
db.images.create_index([("url", 1)])
db.images.create_index([("path", 1)])
@staticmethod
def _ensure_description_collection():
"""确保image_descriptions集合存在并创建索引"""
if "image_descriptions" not in db.list_collection_names():
db.create_collection("image_descriptions")
# 删除旧索引
db.image_descriptions.drop_indexes()
# 创建新的复合索引
db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
@staticmethod @staticmethod
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]: def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述 """从数据库获取图片描述
@@ -77,8 +61,14 @@ class ImageManager:
Returns: Returns:
Optional[str]: 描述文本如果不存在则返回None Optional[str]: 描述文本如果不存在则返回None
""" """
result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type}) try:
return result["description"] if result else None record = ImageDescriptions.get_or_none(
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
)
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
return None
@staticmethod @staticmethod
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None: def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
@@ -90,20 +80,17 @@ class ImageManager:
description_type: 描述类型 ('emoji''image') description_type: 描述类型 ('emoji''image')
""" """
try: try:
db.image_descriptions.update_one( current_timestamp = time.time()
{"hash": image_hash, "type": description_type}, defaults = {"description": description, "timestamp": current_timestamp}
{ desc_obj, created = ImageDescriptions.get_or_create(
"$set": { hash=image_hash, type=description_type, defaults=defaults
"description": description,
"timestamp": int(time.time()),
"hash": image_hash, # 确保hash字段存在
"type": description_type, # 确保type字段存在
}
},
upsert=True,
) )
if not created: # 如果记录已存在,则更新
desc_obj.description = description
desc_obj.timestamp = current_timestamp
desc_obj.save()
except Exception as e: except Exception as e:
logger.error(f"保存描述到数据库失败: {str(e)}") logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
async def get_emoji_description(self, image_base64: str) -> str: async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述,带查重和保存功能""" """获取表情包描述,带查重和保存功能"""
@@ -116,51 +103,64 @@ class ImageManager:
# 查询缓存的描述 # 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, "emoji") cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description: if cached_description:
# logger.debug(f"缓存表情包描述: {cached_description}")
return f"[表情包,含义看起来是:{cached_description}]" return f"[表情包,含义看起来是:{cached_description}]"
# 调用AI获取描述 # 调用AI获取描述
if image_format == "gif" or image_format == "GIF": if image_format == "gif" or image_format == "GIF":
image_base64 = self.transform_gif(image_base64) image_base64_processed = self.transform_gif(image_base64)
if image_base64_processed is None:
logger.warning("GIF转换失败无法获取描述")
return "[表情包(GIF处理失败)]"
prompt = "这是一个动态图表情包每一张图代表了动态图的某一帧黑色背景代表透明使用1-2个词描述一下表情包表达的情感和内容简短一些" prompt = "这是一个动态图表情包每一张图代表了动态图的某一帧黑色背景代表透明使用1-2个词描述一下表情包表达的情感和内容简短一些"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg") description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg")
else: else:
prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些" prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
if description is None:
logger.warning("AI未能生成表情包描述")
return "[表情包(描述生成失败)]"
# 再次检查缓存,防止并发写入时重复生成
cached_description = self._get_description_from_db(image_hash, "emoji") cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description: if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}") logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
return f"[表情包,含义看起来是:{cached_description}]" return f"[表情包,含义看起来是:{cached_description}]"
# 根据配置决定是否保存图片 # 根据配置决定是否保存图片
if global_config.save_emoji: if global_config.emoji.save_emoji:
# 生成文件名和路径 # 生成文件名和路径
timestamp = int(time.time()) current_timestamp = time.time()
filename = f"{timestamp}_{image_hash[:8]}.{image_format}" filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")): emoji_dir = os.path.join(self.IMAGE_DIR, "emoji")
os.makedirs(os.path.join(self.IMAGE_DIR, "emoji")) os.makedirs(emoji_dir, exist_ok=True)
file_path = os.path.join(self.IMAGE_DIR, "emoji", filename) file_path = os.path.join(emoji_dir, filename)
try: try:
# 保存文件 # 保存文件
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(image_bytes) f.write(image_bytes)
# 保存到数据库 # 保存到数据库 (Images表)
image_doc = { try:
"hash": image_hash, img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
"path": file_path, img_obj.path = file_path
"type": "emoji", img_obj.description = description
"description": description, img_obj.timestamp = current_timestamp
"timestamp": timestamp, img_obj.save()
} except Images.DoesNotExist:
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True) Images.create(
logger.trace(f"保存表情包: {file_path}") hash=image_hash,
path=file_path,
type="emoji",
description=description,
timestamp=current_timestamp,
)
logger.trace(f"保存表情包元数据: {file_path}")
except Exception as e: except Exception as e:
logger.error(f"保存表情包文件失败: {str(e)}") logger.error(f"保存表情包文件或元数据失败: {str(e)}")
# 保存描述到数据库 # 保存描述到数据库 (ImageDescriptions表)
self._save_description_to_db(image_hash, description, "emoji") self._save_description_to_db(image_hash, description, "emoji")
return f"[表情包:{description}]" return f"[表情包:{description}]"
@@ -188,6 +188,11 @@ class ImageManager:
) )
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片(描述生成失败)]"
# 再次检查缓存
cached_description = self._get_description_from_db(image_hash, "image") cached_description = self._get_description_from_db(image_hash, "image")
if cached_description: if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}") logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}")
@@ -195,38 +200,40 @@ class ImageManager:
logger.debug(f"描述是{description}") logger.debug(f"描述是{description}")
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片]"
# 根据配置决定是否保存图片 # 根据配置决定是否保存图片
if global_config.save_pic: if global_config.emoji.save_pic:
# 生成文件名和路径 # 生成文件名和路径
timestamp = int(time.time()) current_timestamp = time.time()
filename = f"{timestamp}_{image_hash[:8]}.{image_format}" filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")): image_dir = os.path.join(self.IMAGE_DIR, "image")
os.makedirs(os.path.join(self.IMAGE_DIR, "image")) os.makedirs(image_dir, exist_ok=True)
file_path = os.path.join(self.IMAGE_DIR, "image", filename) file_path = os.path.join(image_dir, filename)
try: try:
# 保存文件 # 保存文件
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(image_bytes) f.write(image_bytes)
# 保存到数据库 # 保存到数据库 (Images表)
image_doc = { try:
"hash": image_hash, img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "image"))
"path": file_path, img_obj.path = file_path
"type": "image", img_obj.description = description
"description": description, img_obj.timestamp = current_timestamp
"timestamp": timestamp, img_obj.save()
} except Images.DoesNotExist:
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True) Images.create(
logger.trace(f"保存图片: {file_path}") hash=image_hash,
path=file_path,
type="image",
description=description,
timestamp=current_timestamp,
)
logger.trace(f"保存图片元数据: {file_path}")
except Exception as e: except Exception as e:
logger.error(f"保存图片文件失败: {str(e)}") logger.error(f"保存图片文件或元数据失败: {str(e)}")
# 保存描述到数据库 # 保存描述到数据库 (ImageDescriptions表)
self._save_description_to_db(image_hash, description, "image") self._save_description_to_db(image_hash, description, "image")
return f"[图片:{description}]" return f"[图片:{description}]"

View File

@@ -16,7 +16,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path) sys.path.append(root_path)
# 现在可以导入src模块 # 现在可以导入src模块
from src.common.database import db # noqa E402 from common.database.database import db # noqa E402
# 加载根目录下的env.edv文件 # 加载根目录下的env.edv文件

View File

@@ -1,5 +1,6 @@
import os import os
from pymongo import MongoClient from pymongo import MongoClient
from peewee import SqliteDatabase
from pymongo.database import Database from pymongo.database import Database
from rich.traceback import install from rich.traceback import install
@@ -57,4 +58,15 @@ class DBWrapper:
# 全局数据库访问点 # 全局数据库访问点
db: Database = DBWrapper() memory_db: Database = DBWrapper()
# 定义数据库文件路径
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
_DB_DIR = os.path.join(ROOT_PATH, "data")
_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db")
# 确保数据库目录存在
os.makedirs(_DB_DIR, exist_ok=True)
# 全局 Peewee SQLite 数据库访问点
db = SqliteDatabase(_DB_FILE)

View File

@@ -0,0 +1,358 @@
from peewee import Model, DoubleField, IntegerField, BooleanField, TextField, FloatField, DateTimeField
from .database import db
import datetime
from ..logger_manager import get_logger
logger = get_logger("database_model")
# 请在此处定义您的数据库实例。
# 您需要取消注释并配置适合您的数据库的部分。
# 例如,对于 SQLite:
# db = SqliteDatabase('MaiBot.db')
#
# 对于 PostgreSQL:
# db = PostgresqlDatabase('your_db_name', user='your_user', password='your_password',
# host='localhost', port=5432)
#
# 对于 MySQL:
# db = MySQLDatabase('your_db_name', user='your_user', password='your_password',
# host='localhost', port=3306)
# 定义一个基础模型是一个好习惯,所有其他模型都应继承自它。
# 这允许您在一个地方为所有模型指定数据库。
class BaseModel(Model):
class Meta:
# 将下面的 'db' 替换为您实际的数据库实例变量名。
database = db # 例如: database = my_actual_db_instance
pass # 在用户定义数据库实例之前,此处为占位符
class ChatStreams(BaseModel):
"""
用于存储流式记录数据的模型,类似于提供的 MongoDB 结构。
"""
# stream_id: "a544edeb1a9b73e3e1d77dff36e41264"
# 假设 stream_id 是唯一的,并为其创建索引以提高查询性能。
stream_id = TextField(unique=True, index=True)
# create_time: 1746096761.4490178 (时间戳精确到小数点后7位)
# DoubleField 用于存储浮点数,适合此类时间戳。
create_time = DoubleField()
# group_info 字段:
# platform: "qq"
# group_id: "941657197"
# group_name: "测试"
group_platform = TextField()
group_id = TextField()
group_name = TextField()
# last_active_time: 1746623771.4825106 (时间戳精确到小数点后7位)
last_active_time = DoubleField()
# platform: "qq" (顶层平台字段)
platform = TextField()
# user_info 字段:
# platform: "qq"
# user_id: "1787882683"
# user_nickname: "墨梓柒(IceSakurary)"
# user_cardname: ""
user_platform = TextField()
user_id = TextField()
user_nickname = TextField()
# user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。
user_cardname = TextField(null=True)
class Meta:
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# 如果不使用带有数据库实例的 BaseModel或者想覆盖它
# 请取消注释并在下面设置数据库实例:
# database = db
table_name = "chat_streams" # 可选:明确指定数据库中的表名
class LLMUsage(BaseModel):
"""
用于存储 API 使用日志数据的模型。
"""
model_name = TextField(index=True) # 添加索引
user_id = TextField(index=True) # 添加索引
request_type = TextField(index=True) # 添加索引
endpoint = TextField()
prompt_tokens = IntegerField()
completion_tokens = IntegerField()
total_tokens = IntegerField()
cost = DoubleField()
status = TextField()
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
class Meta:
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# database = db
table_name = "llm_usage"
class Emoji(BaseModel):
"""表情包"""
full_path = TextField(unique=True, index=True) # 文件的完整路径 (包括文件名)
format = TextField() # 图片格式
emoji_hash = TextField(index=True) # 表情包的哈希值
description = TextField() # 表情包的描述
query_count = IntegerField(default=0) # 查询次数(用于统计表情包被查询描述的次数)
is_registered = BooleanField(default=False) # 是否已注册
is_banned = BooleanField(default=False) # 是否被禁止注册
# emotion: list[str] # 表情包的情感标签 - 存储为文本,应用层处理序列化/反序列化
emotion = TextField(null=True)
record_time = FloatField() # 记录时间(被创建的时间)
register_time = FloatField(null=True) # 注册时间(被注册为可用表情包的时间)
usage_count = IntegerField(default=0) # 使用次数(被使用的次数)
last_used_time = FloatField(null=True) # 上次使用时间
class Meta:
# database = db # 继承自 BaseModel
table_name = "emoji"
class Messages(BaseModel):
"""
用于存储消息数据的模型。
"""
message_id = TextField(index=True) # 消息 ID (更改自 IntegerField)
time = DoubleField() # 消息时间戳
chat_id = TextField(index=True) # 对应的 ChatStreams stream_id
# 从 chat_info 扁平化而来的字段
chat_info_stream_id = TextField()
chat_info_platform = TextField()
chat_info_user_platform = TextField()
chat_info_user_id = TextField()
chat_info_user_nickname = TextField()
chat_info_user_cardname = TextField(null=True)
chat_info_group_platform = TextField(null=True) # 群聊信息可能不存在
chat_info_group_id = TextField(null=True)
chat_info_group_name = TextField(null=True)
chat_info_create_time = DoubleField()
chat_info_last_active_time = DoubleField()
# 从顶层 user_info 扁平化而来的字段 (消息发送者信息)
user_platform = TextField()
user_id = TextField()
user_nickname = TextField()
user_cardname = TextField(null=True)
processed_plain_text = TextField(null=True) # 处理后的纯文本消息
detailed_plain_text = TextField(null=True) # 详细的纯文本消息
memorized_times = IntegerField(default=0) # 被记忆的次数
class Meta:
# database = db # 继承自 BaseModel
table_name = "messages"
class Images(BaseModel):
"""
用于存储图像信息的模型。
"""
emoji_hash = TextField(index=True) # 图像的哈希值
description = TextField(null=True) # 图像的描述
path = TextField(unique=True) # 图像文件的路径
timestamp = FloatField() # 时间戳
type = TextField() # 图像类型,例如 "emoji"
class Meta:
# database = db # 继承自 BaseModel
table_name = "images"
class ImageDescriptions(BaseModel):
"""
用于存储图像描述信息的模型。
"""
type = TextField() # 类型,例如 "emoji"
image_description_hash = TextField(index=True) # 图像的哈希值
description = TextField() # 图像的描述
timestamp = FloatField() # 时间戳
class Meta:
# database = db # 继承自 BaseModel
table_name = "image_descriptions"
class OnlineTime(BaseModel):
"""
用于存储在线时长记录的模型。
"""
# timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串)
timestamp = TextField(default=datetime.datetime.now) # 时间戳
duration = IntegerField() # 时长,单位分钟
start_timestamp = DateTimeField(default=datetime.datetime.now)
end_timestamp = DateTimeField(index=True)
class Meta:
# database = db # 继承自 BaseModel
table_name = "online_time"
class PersonInfo(BaseModel):
"""
用于存储个人信息数据的模型。
"""
person_id = TextField(unique=True, index=True) # 个人唯一ID
person_name = TextField(null=True) # 个人名称 (允许为空)
name_reason = TextField(null=True) # 名称设定的原因
platform = TextField() # 平台
user_id = TextField(index=True) # 用户ID
nickname = TextField() # 用户昵称
relationship_value = IntegerField(default=0) # 关系值
know_time = FloatField() # 认识时间 (时间戳)
msg_interval = IntegerField() # 消息间隔
# msg_interval_list: 存储为 JSON 字符串的列表
msg_interval_list = TextField(null=True)
class Meta:
# database = db # 继承自 BaseModel
table_name = "person_info"
class Knowledges(BaseModel):
"""
用于存储知识库条目的模型。
"""
content = TextField() # 知识内容的文本
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
# 可以添加其他元数据字段,如 source, create_time 等
class Meta:
# database = db # 继承自 BaseModel
table_name = "knowledges"
class ThinkingLog(BaseModel):
chat_id = TextField(index=True)
trigger_text = TextField(null=True)
response_text = TextField(null=True)
# Store complex dicts/lists as JSON strings
trigger_info_json = TextField(null=True)
response_info_json = TextField(null=True)
timing_results_json = TextField(null=True)
chat_history_json = TextField(null=True)
chat_history_in_thinking_json = TextField(null=True)
chat_history_after_response_json = TextField(null=True)
heartflow_data_json = TextField(null=True)
reasoning_data_json = TextField(null=True)
# Add a timestamp for the log entry itself
# Ensure you have: from peewee import DateTimeField
# And: import datetime
created_at = DateTimeField(default=datetime.datetime.now)
class Meta:
table_name = "thinking_logs"
class RecalledMessages(BaseModel):
"""
用于存储撤回消息记录的模型。
"""
message_id = TextField(index=True) # 被撤回的消息 ID
time = DoubleField() # 撤回操作发生的时间戳
stream_id = TextField() # 对应的 ChatStreams stream_id
class Meta:
table_name = "recalled_messages"
def create_tables():
"""
创建所有在模型中定义的数据库表。
"""
with db:
db.create_tables(
[
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
OnlineTime,
PersonInfo,
Knowledges,
ThinkingLog,
RecalledMessages, # 添加新模型
]
)
def initialize_database():
"""
检查所有定义的表是否存在,如果不存在则创建它们。
检查所有表的所有字段是否存在,如果缺失则警告用户并退出程序。
"""
import sys
models = [
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
OnlineTime,
PersonInfo,
Knowledges,
ThinkingLog,
RecalledMessages, # 添加新模型
]
needs_creation = False
try:
with db: # 管理 table_exists 检查的连接
for model in models:
table_name = model._meta.table_name
if not db.table_exists(model):
logger.warning(f"'{table_name}' 未找到。")
needs_creation = True
break # 一个表丢失,无需进一步检查。
if not needs_creation:
# 检查字段
for model in models:
table_name = model._meta.table_name
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
existing_columns = {row[1] for row in cursor.fetchall()}
model_fields = model._meta.fields
for field_name in model_fields:
if field_name not in existing_columns:
logger.error(f"'{table_name}' 缺失字段 '{field_name}',请手动迁移数据库结构后重启程序。")
sys.exit(1)
except Exception as e:
logger.exception(f"检查表或字段是否存在时出错: {e}")
# 如果检查失败(例如数据库不可用),则退出
return
if needs_creation:
logger.info("正在初始化数据库:一个或多个表丢失。正在尝试创建所有定义的表...")
try:
create_tables() # 此函数有其自己的 'with db:' 上下文管理。
logger.info("数据库表创建过程完成。")
except Exception as e:
logger.exception(f"创建表期间出错: {e}")
else:
logger.info("所有数据库表及字段均已存在。")
# 模块加载时调用初始化函数
initialize_database()

View File

@@ -1,11 +1,19 @@
from src.common.database import db from src.common.database.database_model import Messages # 更改导入
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
import traceback import traceback
from typing import List, Any, Optional from typing import List, Any, Optional
from peewee import Model # 添加 Peewee Model 导入
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
def _model_to_dict(model_instance: Model) -> dict[str, Any]:
"""
将 Peewee 模型实例转换为字典。
"""
return model_instance.__data__
def find_messages( def find_messages(
message_filter: dict[str, Any], message_filter: dict[str, Any],
sort: Optional[List[tuple[str, int]]] = None, sort: Optional[List[tuple[str, int]]] = None,
@@ -16,39 +24,84 @@ def find_messages(
根据提供的过滤器、排序和限制条件查找消息。 根据提供的过滤器、排序和限制条件查找消息。
Args: Args:
message_filter: MongoDB 查询过滤器。 message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
sort: MongoDB 排序条件列表,例如 [('time', 1)]。仅在 limit 为 0 时生效。 sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。
limit: 返回的最大文档数0表示不限制。 limit: 返回的最大文档数0表示不限制。
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest' limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'
Returns: Returns:
消息文档列表,如果出错则返回空列表。 消息字典列表,如果出错则返回空列表。
""" """
try: try:
query = db.messages.find(message_filter) query = Messages.select()
# 应用过滤器
if message_filter:
conditions = []
for key, value in message_filter.items():
if hasattr(Messages, key):
field = getattr(Messages, key)
if isinstance(value, dict):
# 处理 MongoDB 风格的操作符
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(field.not_in(op_value))
else:
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
else:
# 直接相等比较
conditions.append(field == value)
else:
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)
if limit > 0: if limit > 0:
if limit_mode == "earliest": if limit_mode == "earliest":
# 获取时间最早的 limit 条记录,已经是正序 # 获取时间最早的 limit 条记录,已经是正序
query = query.sort([("time", 1)]).limit(limit) query = query.order_by(Messages.time.asc()).limit(limit)
results = list(query) peewee_results = list(query)
else: # 默认为 'latest' else: # 默认为 'latest'
# 获取时间最晚的 limit 条记录 # 获取时间最晚的 limit 条记录
query = query.sort([("time", -1)]).limit(limit) query = query.order_by(Messages.time.desc()).limit(limit)
latest_results = list(query) latest_results_peewee = list(query)
# 将结果按时间正序排列 # 将结果按时间正序排列
# 假设消息文档中总是有 'time' 字段且可排序 peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time)
results = sorted(latest_results, key=lambda msg: msg.get("time"))
else: else:
# limit 为 0 时,应用传入的 sort 参数 # limit 为 0 时,应用传入的 sort 参数
if sort: if sort:
query = query.sort(sort) peewee_sort_terms = []
results = list(query) for field_name, direction in sort:
if hasattr(Messages, field_name):
field = getattr(Messages, field_name)
if direction == 1: # ASC
peewee_sort_terms.append(field.asc())
elif direction == -1: # DESC
peewee_sort_terms.append(field.desc())
else:
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
else:
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
if peewee_sort_terms:
query = query.order_by(*peewee_sort_terms)
peewee_results = list(query)
return results return [_model_to_dict(msg) for msg in peewee_results]
except Exception as e: except Exception as e:
log_message = ( log_message = (
f"查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
+ traceback.format_exc() + traceback.format_exc()
) )
logger.error(log_message) logger.error(log_message)
@@ -60,18 +113,57 @@ def count_messages(message_filter: dict[str, Any]) -> int:
根据提供的过滤器计算消息数量。 根据提供的过滤器计算消息数量。
Args: Args:
message_filter: MongoDB 查询过滤器。 message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
Returns: Returns:
符合条件的消息数量,如果出错则返回 0。 符合条件的消息数量,如果出错则返回 0。
""" """
try: try:
count = db.messages.count_documents(message_filter) query = Messages.select()
# 应用过滤器
if message_filter:
conditions = []
for key, value in message_filter.items():
if hasattr(Messages, key):
field = getattr(Messages, key)
if isinstance(value, dict):
# 处理 MongoDB 风格的操作符
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(field.not_in(op_value))
else:
logger.warning(
f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。"
)
else:
# 直接相等比较
conditions.append(field == value)
else:
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)
count = query.count()
return count return count
except Exception as e: except Exception as e:
log_message = f"计数消息失败 (message_filter={message_filter}): {e}\n" + traceback.format_exc() log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
logger.error(log_message) logger.error(log_message)
return 0 return 0
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 # 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
# 注意:对于 Peewee插入操作通常是 Messages.create(...) 或 instance.save()。
# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。

View File

@@ -35,7 +35,7 @@ class TelemetryHeartBeatTask(AsyncTask):
info_dict = { info_dict = {
"os_type": "Unknown", "os_type": "Unknown",
"py_version": platform.python_version(), "py_version": platform.python_version(),
"mmc_version": global_config.MAI_VERSION, "mmc_version": global_config.MMC_VERSION,
} }
match platform.system(): match platform.system():
@@ -133,10 +133,9 @@ class TelemetryHeartBeatTask(AsyncTask):
async def run(self): async def run(self):
# 发送心跳 # 发送心跳
if global_config.remote_enable: if global_config.telemetry.enable:
if self.client_uuid is None: if self.client_uuid is None and not await self._req_uuid():
if not await self._req_uuid(): logger.error("获取UUID失败跳过此次心跳")
logger.error("获取UUID失败跳过此次心跳") return
return
await self._send_heartbeat() await self._send_heartbeat()

View File

@@ -1,64 +1,68 @@
import os import os
import re from dataclasses import field, dataclass
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import tomli
import tomlkit import tomlkit
import shutil import shutil
from datetime import datetime from datetime import datetime
from pathlib import Path
from packaging import version from tomlkit import TOMLDocument
from packaging.version import Version, InvalidVersion from tomlkit.items import Table
from packaging.specifiers import SpecifierSet, InvalidSpecifier
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from rich.traceback import install from rich.traceback import install
from src.config.config_base import ConfigBase
from src.config.official_configs import (
BotConfig,
ChatTargetConfig,
PersonalityConfig,
IdentityConfig,
PlatformsConfig,
ChatConfig,
NormalChatConfig,
FocusChatConfig,
EmojiConfig,
MemoryConfig,
MoodConfig,
KeywordReactionConfig,
ChineseTypoConfig,
ResponseSplitterConfig,
TelemetryConfig,
ExperimentalConfig,
ModelConfig,
)
install(extra_lines=3) install(extra_lines=3)
# 配置主程序日志格式 # 配置主程序日志格式
logger = get_logger("config") logger = get_logger("config")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 CONFIG_DIR = "config"
is_test = True TEMPLATE_DIR = "template"
mai_version_main = "0.6.4"
mai_version_fix = "snapshot-1"
if mai_version_fix: # 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
if is_test: # 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
mai_version = f"test-{mai_version_main}-{mai_version_fix}" MMC_VERSION = "0.7.0-snapshot.1"
else:
mai_version = f"{mai_version_main}-{mai_version_fix}"
else:
if is_test:
mai_version = f"test-{mai_version_main}"
else:
mai_version = mai_version_main
def update_config(): def update_config():
# 获取根目录路径 # 获取根目录路径
root_dir = Path(__file__).parent.parent.parent old_config_dir = f"{CONFIG_DIR}/old"
template_dir = root_dir / "template"
config_dir = root_dir / "config"
old_config_dir = config_dir / "old"
# 定义文件路径 # 定义文件路径
template_path = template_dir / "bot_config_template.toml" template_path = f"{TEMPLATE_DIR}/bot_config_template.toml"
old_config_path = config_dir / "bot_config.toml" old_config_path = f"{CONFIG_DIR}/bot_config.toml"
new_config_path = config_dir / "bot_config.toml" new_config_path = f"{CONFIG_DIR}/bot_config.toml"
# 检查配置文件是否存在 # 检查配置文件是否存在
if not old_config_path.exists(): if not os.path.exists(old_config_path):
logger.info("配置文件不存在,从模板创建新配置") logger.info("配置文件不存在,从模板创建新配置")
# 创建文件夹 os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
old_config_dir.mkdir(parents=True, exist_ok=True) shutil.copy2(template_path, old_config_path) # 复制模板文件
shutil.copy2(template_path, old_config_path)
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}") logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
# 如果是新创建的配置文件,直接返回 # 如果是新创建的配置文件,直接返回
return quit() quit()
# 读取旧配置文件和模板文件 # 读取旧配置文件和模板文件
with open(old_config_path, "r", encoding="utf-8") as f: with open(old_config_path, "r", encoding="utf-8") as f:
@@ -75,13 +79,15 @@ def update_config():
return return
else: else:
logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}") logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
else:
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
# 创建old目录如果不存在 # 创建old目录如果不存在
old_config_dir.mkdir(exist_ok=True) os.makedirs(old_config_dir, exist_ok=True)
# 生成带时间戳的新文件名 # 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml" old_backup_path = f"{old_config_dir}/bot_config_{timestamp}.toml"
# 移动旧配置文件到old目录 # 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path) shutil.move(old_config_path, old_backup_path)
@@ -91,24 +97,23 @@ def update_config():
shutil.copy2(template_path, new_config_path) shutil.copy2(template_path, new_config_path)
logger.info(f"已创建新配置文件: {new_config_path}") logger.info(f"已创建新配置文件: {new_config_path}")
# 递归更新配置 def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
def update_dict(target, source): """
将source字典的值更新到target字典中如果target中存在相同的键
"""
for key, value in source.items(): for key, value in source.items():
# 跳过version字段的更新 # 跳过version字段的更新
if key == "version": if key == "version":
continue continue
if key in target: if key in target:
if isinstance(value, dict) and isinstance(target[key], (dict, tomlkit.items.Table)): if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
update_dict(target[key], value) update_dict(target[key], value)
else: else:
try: try:
# 对数组类型进行特殊处理 # 对数组类型进行特殊处理
if isinstance(value, list): if isinstance(value, list):
# 如果是空数组,确保它保持为空数组 # 如果是空数组,确保它保持为空数组
if not value: target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
target[key] = tomlkit.array()
else:
target[key] = tomlkit.array(value)
else: else:
# 其他类型使用item方法创建新值 # 其他类型使用item方法创建新值
target[key] = tomlkit.item(value) target[key] = tomlkit.item(value)
@@ -123,619 +128,57 @@ def update_config():
# 保存更新后的配置(保留注释和格式) # 保存更新后的配置(保留注释和格式)
with open(new_config_path, "w", encoding="utf-8") as f: with open(new_config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config)) f.write(tomlkit.dumps(new_config))
logger.info("配置文件更新完成") logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
quit()
@dataclass @dataclass
class BotConfig: class Config(ConfigBase):
"""机器人配置类""" """配置类"""
INNER_VERSION: Version = None MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息
MAI_VERSION: str = mai_version # 硬编码的版本信息
bot: BotConfig
# bot chat_target: ChatTargetConfig
BOT_QQ: Optional[str] = "114514" personality: PersonalityConfig
BOT_NICKNAME: Optional[str] = None identity: IdentityConfig
BOT_ALIAS_NAMES: List[str] = field(default_factory=list) # 别名,可以通过这个叫它 platforms: PlatformsConfig
chat: ChatConfig
# group normal_chat: NormalChatConfig
talk_allowed_groups = set() focus_chat: FocusChatConfig
talk_frequency_down_groups = set() emoji: EmojiConfig
ban_user_id = set() memory: MemoryConfig
mood: MoodConfig
# personality keyword_reaction: KeywordReactionConfig
personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内谁再写3000字小作文敲谁脑袋 chinese_typo: ChineseTypoConfig
personality_sides: List[str] = field( response_splitter: ResponseSplitterConfig
default_factory=lambda: [ telemetry: TelemetryConfig
"用一句话或几句话描述人格的一些侧面", experimental: ExperimentalConfig
"用一句话或几句话描述人格的一些侧面", model: ModelConfig
"用一句话或几句话描述人格的一些侧面",
]
) def load_config(config_path: str) -> Config:
expression_style = "描述麦麦说话的表达风格,表达习惯" """
# identity 加载配置文件
identity_detail: List[str] = field( :param config_path: 配置文件路径
default_factory=lambda: [ :return: Config对象
"身份特点", """
"身份特点", # 读取配置文件
] with open(config_path, "r", encoding="utf-8") as f:
) config_data = tomlkit.load(f)
height: int = 170 # 身高 单位厘米
weight: int = 50 # 体重 单位千克 # 创建Config对象
age: int = 20 # 年龄 单位岁 try:
gender: str = "" # 性别 return Config.from_dict(config_data)
appearance: str = "用几句话描述外貌特征" # 外貌特征 except Exception as e:
logger.critical("配置文件解析失败")
# chat raise e
allow_focus_mode: bool = True # 是否允许专注聊天状态
base_normal_chat_num: int = 3 # 最多允许多少个群进行普通聊天
base_focused_chat_num: int = 2 # 最多允许多少个群进行专注聊天
observation_context_size: int = 12 # 心流观察到的最长上下文大小,超过这个值的上下文会被压缩
message_buffer: bool = True # 消息缓冲器
ban_words = set()
ban_msgs_regex = set()
# focus_chat
reply_trigger_threshold: float = 3.0 # 心流聊天触发阈值,越低越容易触发
default_decay_rate_per_second: float = 0.98 # 默认衰减率,越大衰减越慢
consecutive_no_reply_threshold = 3
compressed_length: int = 5 # 不能大于observation_context_size,心流上下文压缩的最短压缩长度超过心流观察到的上下文长度会压缩最短压缩长度为5
compress_length_limit: int = 5 # 最多压缩份数,超过该数值的压缩上下文会被删除
# normal_chat
model_reasoning_probability: float = 0.7 # 麦麦回答时选择推理模型(主要)模型概率
model_normal_probability: float = 0.3 # 麦麦回答时选择一般模型(次要)模型概率
emoji_chance: float = 0.2 # 发送表情包的基础概率
thinking_timeout: int = 120 # 思考时间
willing_mode: str = "classical" # 意愿模式
response_willing_amplifier: float = 1.0 # 回复意愿放大系数
response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数
down_frequency_rate: float = 3 # 降低回复频率的群组回复意愿降低系数
emoji_response_penalty: float = 0.0 # 表情包回复惩罚
mentioned_bot_inevitable_reply: bool = False # 提及 bot 必然回复
at_bot_inevitable_reply: bool = False # @bot 必然回复
# emoji
max_emoji_num: int = 200 # 表情包最大数量
max_reach_deletion: bool = True # 开启则在达到最大数量时删除表情包,关闭则不会继续收集表情包
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
save_pic: bool = False # 是否保存图片
save_emoji: bool = False # 是否保存表情包
steal_emoji: bool = True # 是否偷取表情包,让麦麦可以发送她保存的这些表情包
EMOJI_CHECK: bool = False # 是否开启过滤
EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求
# memory
build_memory_interval: int = 600 # 记忆构建间隔(秒)
memory_build_distribution: list = field(
default_factory=lambda: [4, 2, 0.6, 24, 8, 0.4]
) # 记忆构建分布参数分布1均值标准差权重分布2均值标准差权重
build_memory_sample_num: int = 10 # 记忆构建采样数量
build_memory_sample_length: int = 20 # 记忆构建采样长度
memory_compress_rate: float = 0.1 # 记忆压缩率
forget_memory_interval: int = 600 # 记忆遗忘间隔(秒)
memory_forget_time: int = 24 # 记忆遗忘时间(小时)
memory_forget_percentage: float = 0.01 # 记忆遗忘比例
consolidate_memory_interval: int = 1000 # 记忆整合间隔(秒)
consolidation_similarity_threshold: float = 0.7 # 相似度阈值
consolidate_memory_percentage: float = 0.01 # 检查节点比例
memory_ban_words: list = field(
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]
) # 添加新的配置项默认值
# mood
mood_update_interval: float = 1.0 # 情绪更新间隔 单位秒
mood_decay_rate: float = 0.95 # 情绪衰减率
mood_intensity_factor: float = 0.7 # 情绪强度因子
# keywords
keywords_reaction_rules = [] # 关键词回复规则
# chinese_typo
chinese_typo_enable = True # 是否启用中文错别字生成器
chinese_typo_error_rate = 0.03 # 单字替换概率
chinese_typo_min_freq = 7 # 最小字频阈值
chinese_typo_tone_error_rate = 0.2 # 声调错误概率
chinese_typo_word_replace_rate = 0.02 # 整词替换概率
# response_splitter
enable_kaomoji_protection = False # 是否启用颜文字保护
enable_response_splitter = True # 是否启用回复分割器
response_max_length = 100 # 回复允许的最大长度
response_max_sentence_num = 3 # 回复允许的最大句子数
model_max_output_length: int = 800 # 最大回复长度
# remote
remote_enable: bool = True # 是否启用远程控制
# experimental
enable_friend_chat: bool = False # 是否启用好友聊天
# enable_think_flow: bool = False # 是否启用思考流程
talk_allowed_private = set()
enable_pfc_chatting: bool = False # 是否启用PFC聊天
# 模型配置
llm_reasoning: dict[str, str] = field(default_factory=lambda: {})
# llm_reasoning_minor: dict[str, str] = field(default_factory=lambda: {})
llm_normal: Dict[str, str] = field(default_factory=lambda: {})
llm_topic_judge: Dict[str, str] = field(default_factory=lambda: {})
llm_summary: Dict[str, str] = field(default_factory=lambda: {})
embedding: Dict[str, str] = field(default_factory=lambda: {})
vlm: Dict[str, str] = field(default_factory=lambda: {})
moderation: Dict[str, str] = field(default_factory=lambda: {})
llm_observation: Dict[str, str] = field(default_factory=lambda: {})
llm_sub_heartflow: Dict[str, str] = field(default_factory=lambda: {})
llm_heartflow: Dict[str, str] = field(default_factory=lambda: {})
llm_tool_use: Dict[str, str] = field(default_factory=lambda: {})
llm_plan: Dict[str, str] = field(default_factory=lambda: {})
api_urls: Dict[str, str] = field(default_factory=lambda: {})
@staticmethod
def get_config_dir() -> str:
"""获取配置文件目录"""
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, "..", ".."))
config_dir = os.path.join(root_dir, "config")
if not os.path.exists(config_dir):
os.makedirs(config_dir)
return config_dir
@classmethod
def convert_to_specifierset(cls, value: str) -> SpecifierSet:
"""将 字符串 版本表达式转换成 SpecifierSet
Args:
value[str]: 版本表达式(字符串)
Returns:
SpecifierSet
"""
try:
converted = SpecifierSet(value)
except InvalidSpecifier:
logger.error(f"{value} 分类使用了错误的版本约束表达式\n", "请阅读 https://semver.org/lang/zh-CN/ 修改代码")
exit(1)
return converted
@classmethod
def get_config_version(cls, toml: dict) -> Version:
"""提取配置文件的 SpecifierSet 版本数据
Args:
toml[dict]: 输入的配置文件字典
Returns:
Version
"""
if "inner" in toml:
try:
config_version: str = toml["inner"]["version"]
except KeyError as e:
logger.error("配置文件中 inner 段 不存在, 这是错误的配置文件")
raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件") from e
else:
toml["inner"] = {"version": "0.0.0"}
config_version = toml["inner"]["version"]
try:
ver = version.parse(config_version)
except InvalidVersion as e:
logger.error(
"配置文件中 inner段 的 version 键是错误的版本描述\n"
"请阅读 https://semver.org/lang/zh-CN/ 修改配置,并参考本项目指定的模板进行修改\n"
"本项目在不同的版本下有不同的模板,请注意识别"
)
raise InvalidVersion("配置文件中 inner段 的 version 键是错误的版本描述\n") from e
return ver
@classmethod
def load_config(cls, config_path: str = None) -> "BotConfig":
"""从TOML配置文件加载配置"""
config = cls()
def personality(parent: dict):
personality_config = parent["personality"]
if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
config.personality_core = personality_config.get("personality_core", config.personality_core)
config.personality_sides = personality_config.get("personality_sides", config.personality_sides)
if config.INNER_VERSION in SpecifierSet(">=1.7.0"):
config.expression_style = personality_config.get("expression_style", config.expression_style)
def identity(parent: dict):
identity_config = parent["identity"]
if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
config.identity_detail = identity_config.get("identity_detail", config.identity_detail)
config.height = identity_config.get("height", config.height)
config.weight = identity_config.get("weight", config.weight)
config.age = identity_config.get("age", config.age)
config.gender = identity_config.get("gender", config.gender)
config.appearance = identity_config.get("appearance", config.appearance)
def emoji(parent: dict):
emoji_config = parent["emoji"]
config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL)
config.EMOJI_CHECK_PROMPT = emoji_config.get("check_prompt", config.EMOJI_CHECK_PROMPT)
config.EMOJI_CHECK = emoji_config.get("enable_check", config.EMOJI_CHECK)
if config.INNER_VERSION in SpecifierSet(">=1.1.1"):
config.max_emoji_num = emoji_config.get("max_emoji_num", config.max_emoji_num)
config.max_reach_deletion = emoji_config.get("max_reach_deletion", config.max_reach_deletion)
if config.INNER_VERSION in SpecifierSet(">=1.4.2"):
config.save_pic = emoji_config.get("save_pic", config.save_pic)
config.save_emoji = emoji_config.get("save_emoji", config.save_emoji)
config.steal_emoji = emoji_config.get("steal_emoji", config.steal_emoji)
def bot(parent: dict):
# 机器人基础配置
bot_config = parent["bot"]
bot_qq = bot_config.get("qq")
config.BOT_QQ = str(bot_qq)
config.BOT_NICKNAME = bot_config.get("nickname", config.BOT_NICKNAME)
config.BOT_ALIAS_NAMES = bot_config.get("alias_names", config.BOT_ALIAS_NAMES)
def chat(parent: dict):
chat_config = parent["chat"]
config.allow_focus_mode = chat_config.get("allow_focus_mode", config.allow_focus_mode)
config.base_normal_chat_num = chat_config.get("base_normal_chat_num", config.base_normal_chat_num)
config.base_focused_chat_num = chat_config.get("base_focused_chat_num", config.base_focused_chat_num)
config.observation_context_size = chat_config.get(
"observation_context_size", config.observation_context_size
)
config.message_buffer = chat_config.get("message_buffer", config.message_buffer)
config.ban_words = chat_config.get("ban_words", config.ban_words)
for r in chat_config.get("ban_msgs_regex", config.ban_msgs_regex):
config.ban_msgs_regex.add(re.compile(r))
def normal_chat(parent: dict):
normal_chat_config = parent["normal_chat"]
config.model_reasoning_probability = normal_chat_config.get(
"model_reasoning_probability", config.model_reasoning_probability
)
config.model_normal_probability = normal_chat_config.get(
"model_normal_probability", config.model_normal_probability
)
config.emoji_chance = normal_chat_config.get("emoji_chance", config.emoji_chance)
config.thinking_timeout = normal_chat_config.get("thinking_timeout", config.thinking_timeout)
config.willing_mode = normal_chat_config.get("willing_mode", config.willing_mode)
config.response_willing_amplifier = normal_chat_config.get(
"response_willing_amplifier", config.response_willing_amplifier
)
config.response_interested_rate_amplifier = normal_chat_config.get(
"response_interested_rate_amplifier", config.response_interested_rate_amplifier
)
config.down_frequency_rate = normal_chat_config.get("down_frequency_rate", config.down_frequency_rate)
config.emoji_response_penalty = normal_chat_config.get(
"emoji_response_penalty", config.emoji_response_penalty
)
config.mentioned_bot_inevitable_reply = normal_chat_config.get(
"mentioned_bot_inevitable_reply", config.mentioned_bot_inevitable_reply
)
config.at_bot_inevitable_reply = normal_chat_config.get(
"at_bot_inevitable_reply", config.at_bot_inevitable_reply
)
def focus_chat(parent: dict):
focus_chat_config = parent["focus_chat"]
config.compressed_length = focus_chat_config.get("compressed_length", config.compressed_length)
config.compress_length_limit = focus_chat_config.get("compress_length_limit", config.compress_length_limit)
config.reply_trigger_threshold = focus_chat_config.get(
"reply_trigger_threshold", config.reply_trigger_threshold
)
config.default_decay_rate_per_second = focus_chat_config.get(
"default_decay_rate_per_second", config.default_decay_rate_per_second
)
config.consecutive_no_reply_threshold = focus_chat_config.get(
"consecutive_no_reply_threshold", config.consecutive_no_reply_threshold
)
def model(parent: dict):
# 加载模型配置
model_config: dict = parent["model"]
config_list = [
"llm_reasoning",
# "llm_reasoning_minor",
"llm_normal",
"llm_topic_judge",
"llm_summary",
"vlm",
"embedding",
"llm_tool_use",
"llm_observation",
"llm_sub_heartflow",
"llm_plan",
"llm_heartflow",
"llm_PFC_action_planner",
"llm_PFC_chat",
"llm_PFC_reply_checker",
]
for item in config_list:
if item in model_config:
cfg_item: dict = model_config[item]
# base_url 的例子: SILICONFLOW_BASE_URL
# key 的例子: SILICONFLOW_KEY
cfg_target = {
"name": "",
"base_url": "",
"key": "",
"stream": False,
"pri_in": 0,
"pri_out": 0,
"temp": 0.7,
}
if config.INNER_VERSION in SpecifierSet("<=0.0.0"):
cfg_target = cfg_item
elif config.INNER_VERSION in SpecifierSet(">=0.0.1"):
stable_item = ["name", "pri_in", "pri_out"]
stream_item = ["stream"]
if config.INNER_VERSION in SpecifierSet(">=1.0.1"):
stable_item.append("stream")
pricing_item = ["pri_in", "pri_out"]
# 从配置中原始拷贝稳定字段
for i in stable_item:
# 如果 字段 属于计费项 且获取不到,那默认值是 0
if i in pricing_item and i not in cfg_item:
cfg_target[i] = 0
if i in stream_item and i not in cfg_item:
cfg_target[i] = False
else:
# 没有特殊情况则原样复制
try:
cfg_target[i] = cfg_item[i]
except KeyError as e:
logger.error(f"{item} 中的必要字段不存在,请检查")
raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查") from e
# 如果配置中有temp参数就使用配置中的值
if "temp" in cfg_item:
cfg_target["temp"] = cfg_item["temp"]
else:
# 如果没有temp参数就删除默认值
cfg_target.pop("temp", None)
provider = cfg_item.get("provider")
if provider is None:
logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查")
raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查")
cfg_target["base_url"] = f"{provider}_BASE_URL"
cfg_target["key"] = f"{provider}_KEY"
# 如果 列表中的项目在 model_config 中,利用反射来设置对应项目
setattr(config, item, cfg_target)
else:
logger.error(f"模型 {item} 在config中不存在请检查或尝试更新配置文件")
raise KeyError(f"模型 {item} 在config中不存在请检查或尝试更新配置文件")
def memory(parent: dict):
memory_config = parent["memory"]
config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
config.forget_memory_interval = memory_config.get("forget_memory_interval", config.forget_memory_interval)
config.memory_ban_words = set(memory_config.get("memory_ban_words", []))
config.memory_forget_time = memory_config.get("memory_forget_time", config.memory_forget_time)
config.memory_forget_percentage = memory_config.get(
"memory_forget_percentage", config.memory_forget_percentage
)
config.memory_compress_rate = memory_config.get("memory_compress_rate", config.memory_compress_rate)
if config.INNER_VERSION in SpecifierSet(">=0.0.11"):
config.memory_build_distribution = memory_config.get(
"memory_build_distribution", config.memory_build_distribution
)
config.build_memory_sample_num = memory_config.get(
"build_memory_sample_num", config.build_memory_sample_num
)
config.build_memory_sample_length = memory_config.get(
"build_memory_sample_length", config.build_memory_sample_length
)
if config.INNER_VERSION in SpecifierSet(">=1.5.1"):
config.consolidate_memory_interval = memory_config.get(
"consolidate_memory_interval", config.consolidate_memory_interval
)
config.consolidation_similarity_threshold = memory_config.get(
"consolidation_similarity_threshold", config.consolidation_similarity_threshold
)
config.consolidate_memory_percentage = memory_config.get(
"consolidate_memory_percentage", config.consolidate_memory_percentage
)
def remote(parent: dict):
remote_config = parent["remote"]
config.remote_enable = remote_config.get("enable", config.remote_enable)
def mood(parent: dict):
mood_config = parent["mood"]
config.mood_update_interval = mood_config.get("mood_update_interval", config.mood_update_interval)
config.mood_decay_rate = mood_config.get("mood_decay_rate", config.mood_decay_rate)
config.mood_intensity_factor = mood_config.get("mood_intensity_factor", config.mood_intensity_factor)
def keywords_reaction(parent: dict):
keywords_reaction_config = parent["keywords_reaction"]
if keywords_reaction_config.get("enable", False):
config.keywords_reaction_rules = keywords_reaction_config.get("rules", config.keywords_reaction_rules)
for rule in config.keywords_reaction_rules:
if rule.get("enable", False) and "regex" in rule:
rule["regex"] = [re.compile(r) for r in rule.get("regex", [])]
def chinese_typo(parent: dict):
chinese_typo_config = parent["chinese_typo"]
config.chinese_typo_enable = chinese_typo_config.get("enable", config.chinese_typo_enable)
config.chinese_typo_error_rate = chinese_typo_config.get("error_rate", config.chinese_typo_error_rate)
config.chinese_typo_min_freq = chinese_typo_config.get("min_freq", config.chinese_typo_min_freq)
config.chinese_typo_tone_error_rate = chinese_typo_config.get(
"tone_error_rate", config.chinese_typo_tone_error_rate
)
config.chinese_typo_word_replace_rate = chinese_typo_config.get(
"word_replace_rate", config.chinese_typo_word_replace_rate
)
def response_splitter(parent: dict):
response_splitter_config = parent["response_splitter"]
config.enable_response_splitter = response_splitter_config.get(
"enable_response_splitter", config.enable_response_splitter
)
config.response_max_length = response_splitter_config.get("response_max_length", config.response_max_length)
config.response_max_sentence_num = response_splitter_config.get(
"response_max_sentence_num", config.response_max_sentence_num
)
if config.INNER_VERSION in SpecifierSet(">=1.4.2"):
config.enable_kaomoji_protection = response_splitter_config.get(
"enable_kaomoji_protection", config.enable_kaomoji_protection
)
if config.INNER_VERSION in SpecifierSet(">=1.6.0"):
config.model_max_output_length = response_splitter_config.get(
"model_max_output_length", config.model_max_output_length
)
def groups(parent: dict):
groups_config = parent["groups"]
# config.talk_allowed_groups = set(groups_config.get("talk_allowed", []))
config.talk_allowed_groups = set(str(group) for group in groups_config.get("talk_allowed", []))
# config.talk_frequency_down_groups = set(groups_config.get("talk_frequency_down", []))
config.talk_frequency_down_groups = set(
str(group) for group in groups_config.get("talk_frequency_down", [])
)
# config.ban_user_id = set(groups_config.get("ban_user_id", []))
config.ban_user_id = set(str(user) for user in groups_config.get("ban_user_id", []))
def experimental(parent: dict):
experimental_config = parent["experimental"]
config.enable_friend_chat = experimental_config.get("enable_friend_chat", config.enable_friend_chat)
# config.enable_think_flow = experimental_config.get("enable_think_flow", config.enable_think_flow)
config.talk_allowed_private = set(str(user) for user in experimental_config.get("talk_allowed_private", []))
if config.INNER_VERSION in SpecifierSet(">=1.1.0"):
config.enable_pfc_chatting = experimental_config.get("pfc_chatting", config.enable_pfc_chatting)
# 版本表达式:>=1.0.0,<2.0.0
# 允许字段func: method, support: str, notice: str, necessary: bool
# 如果使用 notice 字段,在该组配置加载时,会展示该字段对用户的警示
# 例如:"notice": "personality 将在 1.3.2 后被移除",那么在有效版本中的用户就会虽然可以
# 正常执行程序,但是会看到这条自定义提示
# 版本格式:主版本号.次版本号.修订号,版本号递增规则如下:
# 主版本号:当你做了不兼容的 API 修改,
# 次版本号:当你做了向下兼容的功能性新增,
# 修订号:当你做了向下兼容的问题修正。
# 先行版本号及版本编译信息可以加到"主版本号.次版本号.修订号"的后面,作为延伸。
# 如果你做了break的修改就应该改动主版本号
# 如果做了一个兼容修改,就不应该要求这个选项是必须的!
include_configs = {
"bot": {"func": bot, "support": ">=0.0.0"},
"groups": {"func": groups, "support": ">=0.0.0"},
"personality": {"func": personality, "support": ">=0.0.0"},
"identity": {"func": identity, "support": ">=1.2.4"},
"emoji": {"func": emoji, "support": ">=0.0.0"},
"model": {"func": model, "support": ">=0.0.0"},
"memory": {"func": memory, "support": ">=0.0.0", "necessary": False},
"mood": {"func": mood, "support": ">=0.0.0"},
"remote": {"func": remote, "support": ">=0.0.10", "necessary": False},
"keywords_reaction": {"func": keywords_reaction, "support": ">=0.0.2", "necessary": False},
"chinese_typo": {"func": chinese_typo, "support": ">=0.0.3", "necessary": False},
"response_splitter": {"func": response_splitter, "support": ">=0.0.11", "necessary": False},
"experimental": {"func": experimental, "support": ">=0.0.11", "necessary": False},
"chat": {"func": chat, "support": ">=1.6.0", "necessary": False},
"normal_chat": {"func": normal_chat, "support": ">=1.6.0", "necessary": False},
"focus_chat": {"func": focus_chat, "support": ">=1.6.0", "necessary": False},
}
# 原地修改,将 字符串版本表达式 转换成 版本对象
for key in include_configs:
item_support = include_configs[key]["support"]
include_configs[key]["support"] = cls.convert_to_specifierset(item_support)
if os.path.exists(config_path):
with open(config_path, "rb") as f:
try:
toml_dict = tomli.load(f)
except tomli.TOMLDecodeError as e:
logger.critical(f"配置文件bot_config.toml填写有误请检查第{e.lineno}行第{e.colno}处:{e.msg}")
exit(1)
# 获取配置文件版本
config.INNER_VERSION = cls.get_config_version(toml_dict)
# 如果在配置中找到了需要的项,调用对应项的闭包函数处理
for key in include_configs:
if key in toml_dict:
group_specifierset: SpecifierSet = include_configs[key]["support"]
# 检查配置文件版本是否在支持范围内
if config.INNER_VERSION in group_specifierset:
# 如果版本在支持范围内,检查是否存在通知
if "notice" in include_configs[key]:
logger.warning(include_configs[key]["notice"])
include_configs[key]["func"](toml_dict)
else:
# 如果版本不在支持范围内,崩溃并提示用户
logger.error(
f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n"
f"当前程序仅支持以下版本范围: {group_specifierset}"
)
raise InvalidVersion(f"当前程序仅支持以下版本范围: {group_specifierset}")
# 如果 necessary 项目存在,而且显式声明是 False进入特殊处理
elif "necessary" in include_configs[key] and include_configs[key].get("necessary") is False:
# 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
if key == "keywords_reaction":
pass
else:
# 如果用户根本没有需要的配置项,提示缺少配置
logger.error(f"配置文件中缺少必需的字段: '{key}'")
raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
# identity_detail字段非空检查
if not config.identity_detail:
logger.error("配置文件错误:[identity] 部分的 identity_detail 不能为空字符串")
raise ValueError("配置文件错误:[identity] 部分的 identity_detail 不能为空字符串")
logger.success(f"成功加载配置文件: {config_path}")
return config
# 获取配置文件路径 # 获取配置文件路径
logger.info(f"MaiCore当前版本: {mai_version}") logger.info(f"MaiCore当前版本: {MMC_VERSION}")
update_config() update_config()
bot_config_floder_path = BotConfig.get_config_dir() logger.info("正在品鉴配置文件...")
logger.info(f"正在品鉴配置文件目录: {bot_config_floder_path}") global_config = load_config(config_path=f"{CONFIG_DIR}/bot_config.toml")
logger.info("非常的新鲜,非常的美味!")
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
if os.path.exists(bot_config_path):
# 如果开发环境配置文件不存在,则使用默认配置文件
logger.info(f"异常的新鲜,异常的美味: {bot_config_path}")
else:
# 配置文件不存在
logger.error("配置文件不存在,请检查路径: {bot_config_path}")
raise FileNotFoundError(f"配置文件不存在: {bot_config_path}")
global_config = BotConfig.load_config(config_path=bot_config_path)

116
src/config/config_base.py Normal file
View File

@@ -0,0 +1,116 @@
from dataclasses import dataclass, fields, MISSING
from typing import TypeVar, Type, Any, get_origin, get_args
T = TypeVar("T", bound="ConfigBase")
TOML_DICT_TYPE = {
int,
float,
str,
bool,
list,
dict,
}
@dataclass
class ConfigBase:
"""配置类的基类"""
@classmethod
def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
"""从字典加载配置字段"""
if not isinstance(data, dict):
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
init_args: dict[str, Any] = {}
for f in fields(cls):
field_name = f.name
if field_name.startswith("_"):
# 跳过以 _ 开头的字段
continue
if field_name not in data:
if f.default is not MISSING or f.default_factory is not MISSING:
# 跳过未提供且有默认值/默认构造方法的字段
continue
else:
raise ValueError(f"Missing required field: '{field_name}'")
value = data[field_name]
field_type = f.type
try:
init_args[field_name] = cls._convert_field(value, field_type)
except TypeError as e:
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
except Exception as e:
raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e
return cls(**init_args)
@classmethod
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
"""
转换字段值为指定类型
1. 对于嵌套的 dataclass递归调用相应的 from_dict 方法
2. 对于泛型集合类型list, set, tuple递归转换每个元素
3. 对于基础类型int, str, float, bool直接转换
4. 对于其他类型,尝试直接转换,如果失败则抛出异常
"""
# 如果是嵌套的 dataclass递归调用 from_dict 方法
if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
return field_type.from_dict(value)
# 处理泛型集合类型list, set, tuple
field_origin_type = get_origin(field_type)
field_type_args = get_args(field_type)
if field_origin_type in {list, set, tuple}:
# 检查提供的value是否为list
if not isinstance(value, list):
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
if field_origin_type is list:
return [cls._convert_field(item, field_type_args[0]) for item in value]
elif field_origin_type is set:
return {cls._convert_field(item, field_type_args[0]) for item in value}
elif field_origin_type is tuple:
# 检查提供的value长度是否与类型参数一致
if len(value) != len(field_type_args):
raise TypeError(
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
)
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args))
if field_origin_type is dict:
# 检查提供的value是否为dict
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
# 检查字典的键值类型
if len(field_type_args) != 2:
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
key_type, value_type = field_type_args
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
# 处理基础类型,例如 int, str 等
if field_type is Any or isinstance(value, field_type):
return value
# 其他类型,尝试直接转换
try:
return field_type(value)
except (ValueError, TypeError) as e:
raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e
def __str__(self):
"""返回配置类的字符串表示"""
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"

View File

@@ -0,0 +1,399 @@
from dataclasses import dataclass, field
from typing import Any
from src.config.config_base import ConfigBase
"""
须知:
1. 本文件中记录了所有的配置项
2. 所有新增的class都需要继承自ConfigBase
3. 所有新增的class都应在config.py中的Config类中添加字段
4. 对于新增的字段若为可选项则应在其后添加field()并设置default_factory或default
"""
@dataclass
class BotConfig(ConfigBase):
"""QQ机器人配置类"""
qq_account: str
"""QQ账号"""
nickname: str
"""昵称"""
alias_names: list[str] = field(default_factory=lambda: [])
"""别名列表"""
@dataclass
class ChatTargetConfig(ConfigBase):
"""
聊天目标配置类
此类中有聊天的群组和用户配置
"""
talk_allowed_groups: set[str] = field(default_factory=lambda: set())
"""允许聊天的群组列表"""
talk_frequency_down_groups: set[str] = field(default_factory=lambda: set())
"""降低聊天频率的群组列表"""
ban_user_id: set[str] = field(default_factory=lambda: set())
"""禁止聊天的用户列表"""
@dataclass
class PersonalityConfig(ConfigBase):
"""人格配置类"""
personality_core: str
"""核心人格"""
expression_style: str
"""表达风格"""
personality_sides: list[str] = field(default_factory=lambda: [])
"""人格侧写"""
@dataclass
class IdentityConfig(ConfigBase):
"""个体特征配置类"""
height: int = 170
"""身高(单位:厘米)"""
weight: float = 50
"""体重(单位:千克)"""
age: int = 18
"""年龄(单位:岁)"""
gender: str = ""
"""性别(男/女)"""
appearance: str = "可爱"
"""外貌描述"""
identity_detail: list[str] = field(default_factory=lambda: [])
"""身份特征"""
@dataclass
class PlatformsConfig(ConfigBase):
"""平台配置类"""
qq: str
"""QQ适配器连接URL配置"""
@dataclass
class ChatConfig(ConfigBase):
"""聊天配置类"""
allow_focus_mode: bool = True
"""是否允许专注聊天状态"""
base_normal_chat_num: int = 3
"""最多允许多少个群进行普通聊天"""
base_focused_chat_num: int = 2
"""最多允许多少个群进行专注聊天"""
observation_context_size: int = 12
"""可观察到的最长上下文大小,超过这个值的上下文会被压缩"""
message_buffer: bool = True
"""消息缓冲器"""
ban_words: set[str] = field(default_factory=lambda: set())
"""过滤词列表"""
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
"""过滤正则表达式列表"""
@dataclass
class NormalChatConfig(ConfigBase):
"""普通聊天配置类"""
reasoning_model_probability: float = 0.3
"""
发言时选择推理模型的概率0-1之间
选择普通模型的概率为 1 - reasoning_normal_model_probability
"""
emoji_chance: float = 0.2
"""发送表情包的基础概率"""
thinking_timeout: int = 120
"""最长思考时间"""
willing_mode: str = "classical"
"""意愿模式"""
response_willing_amplifier: float = 1.0
"""回复意愿放大系数"""
response_interested_rate_amplifier: float = 1.0
"""回复兴趣度放大系数"""
down_frequency_rate: float = 3.0
"""降低回复频率的群组回复意愿降低系数"""
emoji_response_penalty: float = 0.0
"""表情包回复惩罚系数"""
mentioned_bot_inevitable_reply: bool = False
"""提及 bot 必然回复"""
at_bot_inevitable_reply: bool = False
"""@bot 必然回复"""
@dataclass
class FocusChatConfig(ConfigBase):
"""专注聊天配置类"""
reply_trigger_threshold: float = 3.0
"""心流聊天触发阈值,越低越容易触发"""
default_decay_rate_per_second: float = 0.98
"""默认衰减率,越大衰减越快"""
consecutive_no_reply_threshold: int = 3
"""连续不回复的次数阈值"""
compressed_length: int = 5
"""心流上下文压缩的最短压缩长度超过心流观察到的上下文长度会压缩最短压缩长度为5"""
compress_length_limit: int = 5
"""最多压缩份数,超过该数值的压缩上下文会被删除"""
@dataclass
class EmojiConfig(ConfigBase):
"""表情包配置类"""
max_reg_num: int = 200
"""表情包最大注册数量"""
do_replace: bool = True
"""达到最大注册数量时替换旧表情包"""
check_interval: int = 120
"""表情包检查间隔(分钟)"""
save_pic: bool = False
"""是否保存图片"""
cache_emoji: bool = True
"""是否缓存表情包"""
steal_emoji: bool = True
"""是否偷取表情包,让麦麦可以发送她保存的这些表情包"""
content_filtration: bool = False
"""是否开启表情包过滤"""
filtration_prompt: str = "符合公序良俗"
"""表情包过滤要求"""
@dataclass
class MemoryConfig(ConfigBase):
"""记忆配置类"""
memory_build_interval: int = 600
"""记忆构建间隔(秒)"""
memory_build_distribution: tuple[
float,
float,
float,
float,
float,
float,
] = field(default_factory=lambda: (6.0, 3.0, 0.6, 32.0, 12.0, 0.4))
"""记忆构建分布参数分布1均值标准差权重分布2均值标准差权重"""
memory_build_sample_num: int = 8
"""记忆构建采样数量"""
memory_build_sample_length: int = 40
"""记忆构建采样长度"""
memory_compress_rate: float = 0.1
"""记忆压缩率"""
forget_memory_interval: int = 1000
"""记忆遗忘间隔(秒)"""
memory_forget_time: int = 24
"""记忆遗忘时间(小时)"""
memory_forget_percentage: float = 0.01
"""记忆遗忘比例"""
consolidate_memory_interval: int = 1000
"""记忆整合间隔(秒)"""
consolidation_similarity_threshold: float = 0.7
"""整合相似度阈值"""
consolidate_memory_percentage: float = 0.01
"""整合检查节点比例"""
memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"])
"""不允许记忆的词列表"""
@dataclass
class MoodConfig(ConfigBase):
"""情绪配置类"""
mood_update_interval: int = 1
"""情绪更新间隔(秒)"""
mood_decay_rate: float = 0.95
"""情绪衰减率"""
mood_intensity_factor: float = 0.7
"""情绪强度因子"""
@dataclass
class KeywordRuleConfig(ConfigBase):
"""关键词规则配置类"""
enable: bool = True
"""是否启用关键词规则"""
keywords: list[str] = field(default_factory=lambda: [])
"""关键词列表"""
regex: list[str] = field(default_factory=lambda: [])
"""正则表达式列表"""
reaction: str = ""
"""关键词触发的反应"""
@dataclass
class KeywordReactionConfig(ConfigBase):
"""关键词配置类"""
enable: bool = True
"""是否启用关键词反应"""
rules: list[KeywordRuleConfig] = field(default_factory=lambda: [])
"""关键词反应规则列表"""
@dataclass
class ChineseTypoConfig(ConfigBase):
"""中文错别字配置类"""
enable: bool = True
"""是否启用中文错别字生成器"""
error_rate: float = 0.01
"""单字替换概率"""
min_freq: int = 9
"""最小字频阈值"""
tone_error_rate: float = 0.1
"""声调错误概率"""
word_replace_rate: float = 0.006
"""整词替换概率"""
@dataclass
class ResponseSplitterConfig(ConfigBase):
"""回复分割器配置类"""
enable: bool = True
"""是否启用回复分割器"""
max_length: int = 256
"""回复允许的最大长度"""
max_sentence_num: int = 3
"""回复允许的最大句子数"""
enable_kaomoji_protection: bool = False
"""是否启用颜文字保护"""
@dataclass
class TelemetryConfig(ConfigBase):
"""遥测配置类"""
enable: bool = True
"""是否启用遥测"""
@dataclass
class ExperimentalConfig(ConfigBase):
"""实验功能配置类"""
enable_friend_chat: bool = False
"""是否启用好友聊天"""
talk_allowed_private: set[str] = field(default_factory=lambda: set())
"""允许聊天的私聊列表"""
pfc_chatting: bool = False
"""是否启用PFC"""
@dataclass
class ModelConfig(ConfigBase):
"""模型配置类"""
model_max_output_length: int = 800 # 最大回复长度
reasoning: dict[str, Any] = field(default_factory=lambda: {})
"""推理模型配置"""
normal: dict[str, Any] = field(default_factory=lambda: {})
"""普通模型配置"""
topic_judge: dict[str, Any] = field(default_factory=lambda: {})
"""主题判断模型配置"""
summary: dict[str, Any] = field(default_factory=lambda: {})
"""摘要模型配置"""
vlm: dict[str, Any] = field(default_factory=lambda: {})
"""视觉语言模型配置"""
heartflow: dict[str, Any] = field(default_factory=lambda: {})
"""心流模型配置"""
observation: dict[str, Any] = field(default_factory=lambda: {})
"""观察模型配置"""
sub_heartflow: dict[str, Any] = field(default_factory=lambda: {})
"""子心流模型配置"""
plan: dict[str, Any] = field(default_factory=lambda: {})
"""计划模型配置"""
embedding: dict[str, Any] = field(default_factory=lambda: {})
"""嵌入模型配置"""
pfc_action_planner: dict[str, Any] = field(default_factory=lambda: {})
"""PFC动作规划模型配置"""
pfc_chat: dict[str, Any] = field(default_factory=lambda: {})
"""PFC聊天模型配置"""
pfc_reply_checker: dict[str, Any] = field(default_factory=lambda: {})
"""PFC回复检查模型配置"""
tool_use: dict[str, Any] = field(default_factory=lambda: {})
"""工具使用模型配置"""

View File

@@ -114,7 +114,7 @@ class ActionPlanner:
request_type="action_planning", request_type="action_planning",
) )
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3) self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
self.name = global_config.BOT_NICKNAME self.name = global_config.bot.nickname
self.private_name = private_name self.private_name = private_name
self.chat_observer = ChatObserver.get_instance(stream_id, private_name) self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
# self.action_planner_info = ActionPlannerInfo() # 移除未使用的变量 # self.action_planner_info = ActionPlannerInfo() # 移除未使用的变量
@@ -140,7 +140,7 @@ class ActionPlanner:
# (这部分逻辑不变) # (这部分逻辑不变)
time_since_last_bot_message_info = "" time_since_last_bot_message_info = ""
try: try:
bot_id = str(global_config.BOT_QQ) bot_id = str(global_config.bot.qq_account)
if hasattr(observation_info, "chat_history") and observation_info.chat_history: if hasattr(observation_info, "chat_history") and observation_info.chat_history:
for i in range(len(observation_info.chat_history) - 1, -1, -1): for i in range(len(observation_info.chat_history) - 1, -1, -1):
msg = observation_info.chat_history[i] msg = observation_info.chat_history[i]

View File

@@ -10,7 +10,7 @@ from src.experimental.PFC.chat_states import (
create_new_message_notification, create_new_message_notification,
create_cold_chat_notification, create_cold_chat_notification,
) )
from src.experimental.PFC.message_storage import MongoDBMessageStorage from src.experimental.PFC.message_storage import PeeweeMessageStorage
from rich.traceback import install from rich.traceback import install
install(extra_lines=3) install(extra_lines=3)
@@ -53,7 +53,7 @@ class ChatObserver:
self.stream_id = stream_id self.stream_id = stream_id
self.private_name = private_name self.private_name = private_name
self.message_storage = MongoDBMessageStorage() self.message_storage = PeeweeMessageStorage()
# self.last_user_speak_time: Optional[float] = None # 对方上次发言时间 # self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
# self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间 # self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
@@ -323,7 +323,7 @@ class ChatObserver:
for msg in messages: for msg in messages:
try: try:
user_info = UserInfo.from_dict(msg.get("user_info", {})) user_info = UserInfo.from_dict(msg.get("user_info", {}))
if user_info.user_id == global_config.BOT_QQ: if user_info.user_id == global_config.bot.qq_account:
self.update_bot_speak_time(msg["time"]) self.update_bot_speak_time(msg["time"])
else: else:
self.update_user_speak_time(msg["time"]) self.update_user_speak_time(msg["time"])

View File

@@ -42,8 +42,8 @@ class DirectMessageSender:
# 获取麦麦的信息 # 获取麦麦的信息
bot_user_info = UserInfo( bot_user_info = UserInfo(
user_id=global_config.BOT_QQ, user_id=global_config.bot.qq_account,
user_nickname=global_config.BOT_NICKNAME, user_nickname=global_config.bot.nickname,
platform=chat_stream.platform, platform=chat_stream.platform,
) )

View File

@@ -1,6 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Dict, Any from typing import List, Dict, Any
from src.common.database import db
# from src.common.database.database import db # Peewee db 导入
from src.common.database.database_model import Messages # Peewee Messages 模型导入
from playhouse.shortcuts import model_to_dict # 用于将模型实例转换为字典
class MessageStorage(ABC): class MessageStorage(ABC):
@@ -47,28 +50,35 @@ class MessageStorage(ABC):
pass pass
class MongoDBMessageStorage(MessageStorage): class PeeweeMessageStorage(MessageStorage):
"""MongoDB消息存储实现""" """Peewee消息存储实现"""
async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]: async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
query = {"chat_id": chat_id, "time": {"$gt": message_time}} query = (
# print(f"storage_check_message: {message_time}") Messages.select()
.where((Messages.chat_id == chat_id) & (Messages.time > message_time))
.order_by(Messages.time.asc())
)
return list(db.messages.find(query).sort("time", 1)) # print(f"storage_check_message: {message_time}")
messages_models = list(query)
return [model_to_dict(msg) for msg in messages_models]
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
query = {"chat_id": chat_id, "time": {"$lt": time_point}} query = (
Messages.select()
messages = list(db.messages.find(query).sort("time", -1).limit(limit)) .where((Messages.chat_id == chat_id) & (Messages.time < time_point))
.order_by(Messages.time.desc())
.limit(limit)
)
messages_models = list(query)
# 将消息按时间正序排列 # 将消息按时间正序排列
messages.reverse() messages_models.reverse()
return messages return [model_to_dict(msg) for msg in messages_models]
async def has_new_messages(self, chat_id: str, after_time: float) -> bool: async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
query = {"chat_id": chat_id, "time": {"$gt": after_time}} return Messages.select().where((Messages.chat_id == chat_id) & (Messages.time > after_time)).exists()
return db.messages.find_one(query) is not None
# # 创建一个内存消息存储实现,用于测试 # # 创建一个内存消息存储实现,用于测试

View File

@@ -42,13 +42,14 @@ class GoalAnalyzer:
"""对话目标分析器""" """对话目标分析器"""
def __init__(self, stream_id: str, private_name: str): def __init__(self, stream_id: str, private_name: str):
# TODO: API-Adapter修改标记
self.llm = LLMRequest( self.llm = LLMRequest(
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal" model=global_config.model.normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
) )
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3) self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
self.name = global_config.BOT_NICKNAME self.name = global_config.bot.nickname
self.nick_name = global_config.BOT_ALIAS_NAMES self.nick_name = global_config.bot.alias_names
self.private_name = private_name self.private_name = private_name
self.chat_observer = ChatObserver.get_instance(stream_id, private_name) self.chat_observer = ChatObserver.get_instance(stream_id, private_name)

View File

@@ -14,9 +14,10 @@ class KnowledgeFetcher:
"""知识调取器""" """知识调取器"""
def __init__(self, private_name: str): def __init__(self, private_name: str):
# TODO: API-Adapter修改标记
self.llm = LLMRequest( self.llm = LLMRequest(
model=global_config.llm_normal, model=global_config.model.normal,
temperature=global_config.llm_normal["temp"], temperature=global_config.model.normal["temp"],
max_tokens=1000, max_tokens=1000,
request_type="knowledge_fetch", request_type="knowledge_fetch",
) )

View File

@@ -16,7 +16,7 @@ class ReplyChecker:
self.llm = LLMRequest( self.llm = LLMRequest(
model=global_config.llm_PFC_reply_checker, temperature=0.50, max_tokens=1000, request_type="reply_check" model=global_config.llm_PFC_reply_checker, temperature=0.50, max_tokens=1000, request_type="reply_check"
) )
self.name = global_config.BOT_NICKNAME self.name = global_config.bot.nickname
self.private_name = private_name self.private_name = private_name
self.chat_observer = ChatObserver.get_instance(stream_id, private_name) self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
self.max_retries = 3 # 最大重试次数 self.max_retries = 3 # 最大重试次数
@@ -43,7 +43,7 @@ class ReplyChecker:
bot_messages = [] bot_messages = []
for msg in reversed(chat_history): for msg in reversed(chat_history):
user_info = UserInfo.from_dict(msg.get("user_info", {})) user_info = UserInfo.from_dict(msg.get("user_info", {}))
if str(user_info.user_id) == str(global_config.BOT_QQ): # 确保比较的是字符串 if str(user_info.user_id) == str(global_config.bot.qq_account): # 确保比较的是字符串
bot_messages.append(msg.get("processed_plain_text", "")) bot_messages.append(msg.get("processed_plain_text", ""))
if len(bot_messages) >= 2: # 只和最近的两条比较 if len(bot_messages) >= 2: # 只和最近的两条比较
break break

View File

@@ -93,7 +93,7 @@ class ReplyGenerator:
request_type="reply_generation", request_type="reply_generation",
) )
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3) self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
self.name = global_config.BOT_NICKNAME self.name = global_config.bot.nickname
self.private_name = private_name self.private_name = private_name
self.chat_observer = ChatObserver.get_instance(stream_id, private_name) self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
self.reply_checker = ReplyChecker(stream_id, private_name) self.reply_checker = ReplyChecker(stream_id, private_name)

View File

@@ -19,7 +19,7 @@ class Waiter:
def __init__(self, stream_id: str, private_name: str): def __init__(self, stream_id: str, private_name: str):
self.chat_observer = ChatObserver.get_instance(stream_id, private_name) self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
self.name = global_config.BOT_NICKNAME self.name = global_config.bot.nickname
self.private_name = private_name self.private_name = private_name
# self.wait_accumulated_time = 0 # 不再需要累加计时 # self.wait_accumulated_time = 0 # 不再需要累加计时

View File

@@ -16,7 +16,7 @@ class MessageProcessor:
@staticmethod @staticmethod
def _check_ban_words(text: str, chat, userinfo) -> bool: def _check_ban_words(text: str, chat, userinfo) -> bool:
"""检查消息中是否包含过滤词""" """检查消息中是否包含过滤词"""
for word in global_config.ban_words: for word in global_config.chat.ban_words:
if word in text: if word in text:
logger.info( logger.info(
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}" f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}"
@@ -28,7 +28,7 @@ class MessageProcessor:
@staticmethod @staticmethod
def _check_ban_regex(text: str, chat, userinfo) -> bool: def _check_ban_regex(text: str, chat, userinfo) -> bool:
"""检查消息是否匹配过滤正则表达式""" """检查消息是否匹配过滤正则表达式"""
for pattern in global_config.ban_msgs_regex: for pattern in global_config.chat.ban_msgs_regex:
if pattern.search(text): if pattern.search(text):
logger.info( logger.info(
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}" f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}"

View File

@@ -40,7 +40,7 @@ class MainSystem:
async def initialize(self): async def initialize(self):
"""初始化系统组件""" """初始化系统组件"""
logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......") logger.debug(f"正在唤醒{global_config.bot.nickname}......")
# 其他初始化任务 # 其他初始化任务
await asyncio.gather(self._init_components()) await asyncio.gather(self._init_components())
@@ -84,7 +84,7 @@ class MainSystem:
asyncio.create_task(chat_manager._auto_save_task()) asyncio.create_task(chat_manager._auto_save_task())
# 使用HippocampusManager初始化海马体 # 使用HippocampusManager初始化海马体
self.hippocampus_manager.initialize(global_config=global_config) self.hippocampus_manager.initialize()
# await asyncio.sleep(0.5) #防止logger输出飞了 # await asyncio.sleep(0.5) #防止logger输出飞了
# 将bot.py中的chat_bot.message_process消息处理函数注册到api.py的消息处理基类中 # 将bot.py中的chat_bot.message_process消息处理函数注册到api.py的消息处理基类中
@@ -92,15 +92,15 @@ class MainSystem:
# 初始化个体特征 # 初始化个体特征
self.individuality.initialize( self.individuality.initialize(
bot_nickname=global_config.BOT_NICKNAME, bot_nickname=global_config.bot.nickname,
personality_core=global_config.personality_core, personality_core=global_config.personality.personality_core,
personality_sides=global_config.personality_sides, personality_sides=global_config.personality.personality_sides,
identity_detail=global_config.identity_detail, identity_detail=global_config.identity.identity_detail,
height=global_config.height, height=global_config.identity.height,
weight=global_config.weight, weight=global_config.identity.weight,
age=global_config.age, age=global_config.identity.age,
gender=global_config.gender, gender=global_config.identity.gender,
appearance=global_config.appearance, appearance=global_config.identity.appearance,
) )
logger.success("个体特征初始化成功") logger.success("个体特征初始化成功")
@@ -141,7 +141,7 @@ class MainSystem:
async def build_memory_task(): async def build_memory_task():
"""记忆构建任务""" """记忆构建任务"""
while True: while True:
await asyncio.sleep(global_config.build_memory_interval) await asyncio.sleep(global_config.memory.memory_build_interval)
logger.info("正在进行记忆构建") logger.info("正在进行记忆构建")
await HippocampusManager.get_instance().build_memory() await HippocampusManager.get_instance().build_memory()
@@ -149,16 +149,18 @@ class MainSystem:
async def forget_memory_task(): async def forget_memory_task():
"""记忆遗忘任务""" """记忆遗忘任务"""
while True: while True:
await asyncio.sleep(global_config.forget_memory_interval) await asyncio.sleep(global_config.memory.forget_memory_interval)
print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...") print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
await HippocampusManager.get_instance().forget_memory(percentage=global_config.memory_forget_percentage) await HippocampusManager.get_instance().forget_memory(
percentage=global_config.memory.memory_forget_percentage
)
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成") print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
@staticmethod @staticmethod
async def consolidate_memory_task(): async def consolidate_memory_task():
"""记忆整合任务""" """记忆整合任务"""
while True: while True:
await asyncio.sleep(global_config.consolidate_memory_interval) await asyncio.sleep(global_config.memory.consolidate_memory_interval)
print("\033[1;32m[记忆整合]\033[0m 开始整合记忆...") print("\033[1;32m[记忆整合]\033[0m 开始整合记忆...")
await HippocampusManager.get_instance().consolidate_memory() await HippocampusManager.get_instance().consolidate_memory()
print("\033[1;32m[记忆整合]\033[0m 记忆整合完成") print("\033[1;32m[记忆整合]\033[0m 记忆整合完成")

View File

@@ -34,14 +34,14 @@ class MoodUpdateTask(AsyncTask):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
task_name="Mood Update Task", task_name="Mood Update Task",
wait_before_start=global_config.mood_update_interval, wait_before_start=global_config.mood.mood_update_interval,
run_interval=global_config.mood_update_interval, run_interval=global_config.mood.mood_update_interval,
) )
# 从配置文件获取衰减率 # 从配置文件获取衰减率
self.decay_rate_valence: float = 1 - global_config.mood_decay_rate self.decay_rate_valence: float = 1 - global_config.mood.mood_decay_rate
"""愉悦度衰减率""" """愉悦度衰减率"""
self.decay_rate_arousal: float = 1 - global_config.mood_decay_rate self.decay_rate_arousal: float = 1 - global_config.mood.mood_decay_rate
"""唤醒度衰减率""" """唤醒度衰减率"""
self.last_update = time.time() self.last_update = time.time()

View File

@@ -44,7 +44,7 @@ class ChangeMoodTool(BaseTool):
_ori_response = ",".join(response_set) _ori_response = ",".join(response_set)
# _stance, emotion = await gpt._get_emotion_tags(ori_response, message_processed_plain_text) # _stance, emotion = await gpt._get_emotion_tags(ori_response, message_processed_plain_text)
emotion = "平静" emotion = "平静"
mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor) mood_manager.update_mood_from_emotion(emotion, global_config.mood.mood_intensity_factor)
return {"name": "change_mood", "content": f"你的心情刚刚变化了,现在的心情是: {emotion}"} return {"name": "change_mood", "content": f"你的心情刚刚变化了,现在的心情是: {emotion}"}
except Exception as e: except Exception as e:
logger.error(f"心情改变工具执行失败: {str(e)}") logger.error(f"心情改变工具执行失败: {str(e)}")

View File

@@ -1,8 +1,10 @@
from src.tools.tool_can_use.base_tool import BaseTool from src.tools.tool_can_use.base_tool import BaseTool
from src.chat.utils.utils import get_embedding from src.chat.utils.utils import get_embedding
from src.common.database import db from src.common.database.database_model import Knowledges # Updated import
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from typing import Any, Union from typing import Any, Union, List # Added List
import json # Added for parsing embedding
import math # Added for cosine similarity
logger = get_logger("get_knowledge_tool") logger = get_logger("get_knowledge_tool")
@@ -30,6 +32,7 @@ class SearchKnowledgeTool(BaseTool):
Returns: Returns:
dict: 工具执行结果 dict: 工具执行结果
""" """
query = "" # Initialize query to ensure it's defined in except block
try: try:
query = function_args.get("query") query = function_args.get("query")
threshold = function_args.get("threshold", 0.4) threshold = function_args.get("threshold", 0.4)
@@ -48,9 +51,19 @@ class SearchKnowledgeTool(BaseTool):
logger.error(f"知识库搜索工具执行失败: {str(e)}") logger.error(f"知识库搜索工具执行失败: {str(e)}")
return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"} return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"}
@staticmethod
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
"""计算两个向量之间的余弦相似度"""
dot_product = sum(p * q for p, q in zip(vec1, vec2))
magnitude1 = math.sqrt(sum(p * p for p in vec1))
magnitude2 = math.sqrt(sum(q * q for q in vec2))
if magnitude1 == 0 or magnitude2 == 0:
return 0.0
return dot_product / (magnitude1 * magnitude2)
@staticmethod @staticmethod
def get_info_from_db( def get_info_from_db(
query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]: ) -> Union[str, list]:
"""从数据库中获取相关信息 """从数据库中获取相关信息
@@ -66,66 +79,51 @@ class SearchKnowledgeTool(BaseTool):
if not query_embedding: if not query_embedding:
return "" if not return_raw else [] return "" if not return_raw else []
# 使用余弦相似度计算 similar_items = []
pipeline = [ try:
{ all_knowledges = Knowledges.select()
"$addFields": { for item in all_knowledges:
"dotProduct": { try:
"$reduce": { item_embedding_str = item.embedding
"input": {"$range": [0, {"$size": "$embedding"}]}, if not item_embedding_str:
"initialValue": 0, logger.warning(f"Knowledge item ID {item.id} has empty embedding string.")
"in": { continue
"$add": [ item_embedding = json.loads(item_embedding_str)
"$$value", if not isinstance(item_embedding, list) or not all(
{ isinstance(x, (int, float)) for x in item_embedding
"$multiply": [ ):
{"$arrayElemAt": ["$embedding", "$$this"]}, logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.")
{"$arrayElemAt": [query_embedding, "$$this"]}, continue
] except json.JSONDecodeError:
}, logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}")
] continue
}, except AttributeError:
} logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.")
}, continue
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
}
},
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
{
"$match": {
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1}},
]
results = list(db.knowledges.aggregate(pipeline)) similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding)
logger.debug(f"知识库查询结果数量: {len(results)}")
if similarity >= threshold:
similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item})
# 按相似度降序排序
similar_items.sort(key=lambda x: x["similarity"], reverse=True)
# 应用限制
results = similar_items[:limit]
logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}")
except Exception as e:
logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}")
return "" if not return_raw else []
if not results: if not results:
return "" if not return_raw else [] return "" if not return_raw else []
if return_raw: if return_raw:
return results # Peewee 模型实例不能直接序列化为 JSON如果需要原始模型调用者需要处理
# 这里返回包含内容和相似度的字典列表
return [{"content": r["content"], "similarity": r["similarity"]} for r in results]
else: else:
# 返回所有找到的内容,用换行分隔 # 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results) return "\n".join(str(result["content"]) for result in results)

View File

@@ -15,7 +15,7 @@ logger = get_logger("tool_use")
class ToolUser: class ToolUser:
def __init__(self): def __init__(self):
self.llm_model_tool = LLMRequest( self.llm_model_tool = LLMRequest(
model=global_config.llm_tool_use, temperature=0.2, max_tokens=1000, request_type="tool_use" model=global_config.model.tool_use, temperature=0.2, max_tokens=1000, request_type="tool_use"
) )
@staticmethod @staticmethod
@@ -37,7 +37,7 @@ class ToolUser:
# print(f"intol111111111111111111111111111111111222222222222mid_memory_info{mid_memory_info}") # print(f"intol111111111111111111111111111111111222222222222mid_memory_info{mid_memory_info}")
# 这些信息应该从调用者传入而不是从self获取 # 这些信息应该从调用者传入而不是从self获取
bot_name = global_config.BOT_NICKNAME bot_name = global_config.bot.nickname
prompt = "" prompt = ""
prompt += mid_memory_info prompt += mid_memory_info
prompt += "你正在思考如何回复群里的消息。\n" prompt += "你正在思考如何回复群里的消息。\n"

View File

@@ -1,104 +0,0 @@
[inner.version]
describe = "版本号"
important = true
can_edit = false
[bot.qq]
describe = "机器人的QQ号"
important = true
can_edit = true
[bot.nickname]
describe = "机器人的昵称"
important = true
can_edit = true
[bot.alias_names]
describe = "机器人的别名列表,该选项还在调试中,暂时未生效"
important = false
can_edit = true
[groups.talk_allowed]
describe = "可以回复消息的群号码列表"
important = true
can_edit = true
[groups.talk_frequency_down]
describe = "降低回复频率的群号码列表"
important = false
can_edit = true
[groups.ban_user_id]
describe = "禁止回复和读取消息的QQ号列表"
important = false
can_edit = true
[personality.personality_core]
describe = "用一句话或几句话描述人格的核心特点建议20字以内"
important = true
can_edit = true
[personality.personality_sides]
describe = "用一句话或几句话描述人格的一些细节条数任意不能为0该选项还在调试中"
important = false
can_edit = true
[identity.identity_detail]
describe = "身份特点列表条数任意不能为0该选项还在调试中"
important = false
can_edit = true
[identity.age]
describe = "年龄,单位岁"
important = false
can_edit = true
[identity.gender]
describe = "性别"
important = false
can_edit = true
[identity.appearance]
describe = "外貌特征描述,该选项还在调试中,暂时未生效"
important = false
can_edit = true
[platforms.nonebot-qq]
describe = "nonebot-qq适配器提供的链接"
important = true
can_edit = true
[chat.allow_focus_mode]
describe = "是否允许专注聊天状态"
important = false
can_edit = true
[chat.base_normal_chat_num]
describe = "最多允许多少个群进行普通聊天"
important = false
can_edit = true
[chat.base_focused_chat_num]
describe = "最多允许多少个群进行专注聊天"
important = false
can_edit = true
[chat.observation_context_size]
describe = "观察到的最长上下文大小建议15太短太长都会导致脑袋尖尖"
important = false
can_edit = true
[chat.message_buffer]
describe = "启用消息缓冲器,启用此项以解决消息的拆分问题,但会使麦麦的回复延迟"
important = false
can_edit = true
[chat.ban_words]
describe = "需要过滤的消息列表"
important = false
can_edit = true
[chat.ban_msgs_regex]
describe = "需要过滤的消息原始消息匹配的正则表达式匹配到的消息将被过滤支持CQ码"
important = false
can_edit = true

View File

@@ -1,18 +1,10 @@
[inner] [inner]
version = "1.7.0" version = "2.0.0"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件请在修改后将version的值进行变更 #如果你想要修改配置文件请在修改后将version的值进行变更
#如果新增项目,请在BotConfig类下新增相应的变量 #如果新增项目,请阅读src/config/official_configs.py中的说明
#1.如果你修改的是[]层级项目,例如你新增了 [memory],那么请在config.py的 load_config函数中的include_configs字典中新增"内容":{ #
#"func":memory,
#"support":">=0.0.0", #新的版本号
#"necessary":False #是否必须
#}
#2.如果你修改的是[]下的项目,例如你新增了[memory]下的 memory_ban_words ,那么请在config.py的 load_config函数中的 memory函数下新增版本判断:
# if config.INNER_VERSION in SpecifierSet(">=0.0.2"):
# config.memory_ban_words = set(memory_config.get("memory_ban_words", []))
# 版本格式:主版本号.次版本号.修订号,版本号递增规则如下: # 版本格式:主版本号.次版本号.修订号,版本号递增规则如下:
# 主版本号:当你做了不兼容的 API 修改, # 主版本号:当你做了不兼容的 API 修改,
# 次版本号:当你做了向下兼容的功能性新增, # 次版本号:当你做了向下兼容的功能性新增,
@@ -21,11 +13,11 @@ version = "1.7.0"
#----以上是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #----以上是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
[bot] [bot]
qq = 1145141919810 qq_account = 1145141919810
nickname = "麦麦" nickname = "麦麦"
alias_names = ["麦叠", "牢麦"] #该选项还在调试中,暂时未生效 alias_names = ["麦叠", "牢麦"] #该选项还在调试中,暂时未生效
[groups] [chat_target]
talk_allowed = [ talk_allowed = [
123, 123,
123, 123,
@@ -53,10 +45,13 @@ identity_detail = [
"身份特点", "身份特点",
"身份特点", "身份特点",
]# 条数任意不能为0, 该选项还在调试中 ]# 条数任意不能为0, 该选项还在调试中
#外貌特征 #外貌特征
age = 20 # 年龄 单位岁 age = 18 # 年龄 单位岁
gender = "" # 性别 gender = "" # 性别
appearance = "用几句话描述外貌特征" # 外貌特征 该选项还在调试中,暂时未生效 height = "170" # 身高单位cm
weight = "50" # 体重单位kg
appearance = "用一句或几句话描述外貌特征" # 外貌特征 该选项还在调试中,暂时未生效
[platforms] # 必填项目,填写每个平台适配器提供的链接 [platforms] # 必填项目,填写每个平台适配器提供的链接
qq="http://127.0.0.1:18002/api/message" qq="http://127.0.0.1:18002/api/message"
@@ -85,11 +80,10 @@ ban_msgs_regex = [
[normal_chat] #普通聊天 [normal_chat] #普通聊天
#一般回复参数 #一般回复参数
model_reasoning_probability = 0.7 # 麦麦回答时选择推理模型 模型的概率 reasoning_model_probability = 0.3 # 麦麦回答时选择推理模型的概率与之相对的普通模型的概率为1 - reasoning_model_probability
model_normal_probability = 0.3 # 麦麦回答时选择一般模型 模型的概率
emoji_chance = 0.2 # 麦麦一般回复时使用表情包的概率设置为1让麦麦自己决定发不发 emoji_chance = 0.2 # 麦麦一般回复时使用表情包的概率设置为1让麦麦自己决定发不发
thinking_timeout = 100 # 麦麦最长思考时间超过这个时间的思考会放弃往往是api反应太慢 thinking_timeout = 120 # 麦麦最长思考时间超过这个时间的思考会放弃往往是api反应太慢
willing_mode = "classical" # 回复意愿模式 —— 经典模式classicalmxp模式mxp自定义模式custom需要你自己实现 willing_mode = "classical" # 回复意愿模式 —— 经典模式classicalmxp模式mxp自定义模式custom需要你自己实现
response_willing_amplifier = 1 # 麦麦回复意愿放大系数一般为1 response_willing_amplifier = 1 # 麦麦回复意愿放大系数一般为1
@@ -100,8 +94,8 @@ mentioned_bot_inevitable_reply = false # 提及 bot 必然回复
at_bot_inevitable_reply = false # @bot 必然回复 at_bot_inevitable_reply = false # @bot 必然回复
[focus_chat] #专注聊天 [focus_chat] #专注聊天
reply_trigger_threshold = 3.6 # 专注聊天触发阈值,越低越容易进入专注聊天 reply_trigger_threshold = 3.0 # 专注聊天触发阈值,越低越容易进入专注聊天
default_decay_rate_per_second = 0.95 # 默认衰减率,越大衰减越快,越高越难进入专注聊天 default_decay_rate_per_second = 0.98 # 默认衰减率,越大衰减越快,越高越难进入专注聊天
consecutive_no_reply_threshold = 3 # 连续不回复的阈值,越低越容易结束专注聊天 consecutive_no_reply_threshold = 3 # 连续不回复的阈值,越低越容易结束专注聊天
# 以下选项暂时无效 # 以下选项暂时无效
@@ -110,20 +104,20 @@ compress_length_limit = 5 #最多压缩份数,超过该数值的压缩上下
[emoji] [emoji]
max_emoji_num = 40 # 表情包最大数量 max_reg_num = 40 # 表情包最大注册数量
max_reach_deletion = true # 开启则在达到最大数量时删除表情包,关闭则达到最大数量时不删除,只是不会继续收集表情包 do_replace = true # 开启则在达到最大数量时删除(替换)表情包,关闭则达到最大数量时不会继续收集表情包
check_interval = 10 # 检查表情包(注册,破损,删除)的时间间隔(分钟) check_interval = 120 # 检查表情包(注册,破损,删除)的时间间隔(分钟)
save_pic = false # 是否保存图片 save_pic = false # 是否保存图片
save_emoji = false # 是否存表情包 cache_emoji = true # 是否存表情包
steal_emoji = true # 是否偷取表情包,让麦麦可以发送她保存的这些表情包 steal_emoji = true # 是否偷取表情包,让麦麦可以发送她保存的这些表情包
enable_check = false # 是否启用表情包过滤,只有符合该要求的表情包才会被保存 content_filtration = false # 是否启用表情包过滤,只有符合该要求的表情包才会被保存
check_prompt = "符合公序良俗" # 表情包过滤要求,只有符合该要求的表情包才会被保存 filtration_prompt = "符合公序良俗" # 表情包过滤要求,只有符合该要求的表情包才会被保存
[memory] [memory]
build_memory_interval = 2000 # 记忆构建间隔 单位秒 间隔越低,麦麦学习越多,但是冗余信息也会增多 memory_build_interval = 2000 # 记忆构建间隔 单位秒 间隔越低,麦麦学习越多,但是冗余信息也会增多
build_memory_distribution = [6.0,3.0,0.6,32.0,12.0,0.4] # 记忆构建分布参数分布1均值标准差权重分布2均值标准差权重 memory_build_distribution = [6.0, 3.0, 0.6, 32.0, 12.0, 0.4] # 记忆构建分布参数分布1均值标准差权重分布2均值标准差权重
build_memory_sample_num = 8 # 采样数量,数值越高记忆采样次数越多 memory_build_sample_num = 8 # 采样数量,数值越高记忆采样次数越多
build_memory_sample_length = 40 # 采样长度,数值越高一段记忆内容越丰富 memory_build_sample_length = 40 # 采样长度,数值越高一段记忆内容越丰富
memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多 memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多
forget_memory_interval = 1000 # 记忆遗忘间隔 单位秒 间隔越低,麦麦遗忘越频繁,记忆更精简,但更难学习 forget_memory_interval = 1000 # 记忆遗忘间隔 单位秒 间隔越低,麦麦遗忘越频繁,记忆更精简,但更难学习
@@ -135,49 +129,45 @@ consolidation_similarity_threshold = 0.7 # 相似度阈值
consolidation_check_percentage = 0.01 # 检查节点比例 consolidation_check_percentage = 0.01 # 检查节点比例
#不希望记忆的词,已经记忆的不会受到影响 #不希望记忆的词,已经记忆的不会受到影响
memory_ban_words = [ memory_ban_words = [ "表情包", "图片", "回复", "聊天记录" ]
# "403","张三"
]
[mood] [mood]
mood_update_interval = 1.0 # 情绪更新间隔 单位秒 mood_update_interval = 1.0 # 情绪更新间隔 单位秒
mood_decay_rate = 0.95 # 情绪衰减率 mood_decay_rate = 0.95 # 情绪衰减率
mood_intensity_factor = 1.0 # 情绪强度因子 mood_intensity_factor = 1.0 # 情绪强度因子
[keywords_reaction] # 针对某个关键词作出反应 [keyword_reaction] # 针对某个关键词作出反应
enable = true # 关键词反应功能的总开关 enable = true # 关键词反应功能的总开关
[[keywords_reaction.rules]] # 如果想要新增多个关键词直接复制本条修改keywords和reaction即可 [[keyword_reaction.rules]] # 如果想要新增多个关键词直接复制本条修改keywords和reaction即可
enable = true # 是否启用此条为了人类在未来AI战争能更好地识别AIbushi默认开启 enable = true # 是否启用此条为了人类在未来AI战争能更好地识别AIbushi默认开启
keywords = ["人机", "bot", "机器", "入机", "robot", "机器人","ai","AI"] # 会触发反应的关键词 keywords = ["人机", "bot", "机器", "入机", "robot", "机器人","ai","AI"] # 会触发反应的关键词
reaction = "有人说你是人机或者机器人,否定这一事实,攻击或反问他。注意,只有认为你是机器人才回复,否则不要否认" # 触发之后添加的提示词 reaction = "有人说你是人机或者机器人,否定这一事实,攻击或反问他。注意,只有认为你是机器人才回复,否则不要否认" # 触发之后添加的提示词
[[keywords_reaction.rules]] # 就像这样复制 [[keyword_reaction.rules]] # 就像这样复制
enable = false # 仅作示例,不会触发 enable = false # 仅作示例,不会触发
keywords = ["测试关键词回复","test",""] keywords = ["测试关键词回复","test",""]
reaction = "回答“测试成功”" # 修复错误的引号 reaction = "回答“测试成功”" # 修复错误的引号
[[keywords_reaction.rules]] # 使用正则表达式匹配句式 [[keyword_reaction.rules]] # 使用正则表达式匹配句式
enable = false # 仅作示例,不会触发 enable = false # 仅作示例,不会触发
regex = ["^(?P<n>\\S{1,20})是这样的$"] # 将匹配到的词汇命名为n反应中对应的[n]会被替换为匹配到的内容,若不了解正则表达式请勿编写 regex = ["^(?P<n>\\S{1,20})是这样的$"] # 将匹配到的词汇命名为n反应中对应的[n]会被替换为匹配到的内容,若不了解正则表达式请勿编写
reaction = "请按照以下模板造句:[n]是这样的xx只要xx就可以可是[n]要考虑的事情就很多了比如什么时候xx什么时候xx什么时候xx。请自由发挥替换xx部分只需保持句式结构同时表达一种将[n]过度重视的反讽意味)" reaction = "请按照以下模板造句:[n]是这样的xx只要xx就可以可是[n]要考虑的事情就很多了比如什么时候xx什么时候xx什么时候xx。请自由发挥替换xx部分只需保持句式结构同时表达一种将[n]过度重视的反讽意味)"
[chinese_typo] [chinese_typo]
enable = true # 是否启用中文错别字生成器 enable = true # 是否启用中文错别字生成器
error_rate=0.001 # 单字替换概率 error_rate=0.01 # 单字替换概率
min_freq=9 # 最小字频阈值 min_freq=9 # 最小字频阈值
tone_error_rate=0.1 # 声调错误概率 tone_error_rate=0.1 # 声调错误概率
word_replace_rate=0.006 # 整词替换概率 word_replace_rate=0.006 # 整词替换概率
[response_splitter] [response_splitter]
enable_response_splitter = true # 是否启用回复分割器 enable = true # 是否启用回复分割器
response_max_length = 256 # 回复允许的最大长度 max_length = 256 # 回复允许的最大长度
response_max_sentence_num = 4 # 回复允许的最大句子数 max_sentence_num = 4 # 回复允许的最大句子数
enable_kaomoji_protection = false # 是否启用颜文字保护 enable_kaomoji_protection = false # 是否启用颜文字保护
model_max_output_length = 256 # 模型单次返回的最大token数 [telemetry] #发送统计信息,主要是看全球有多少只麦麦
[remote] #发送统计信息,主要是看全球有多少只麦麦
enable = true enable = true
[experimental] #实验性功能 [experimental] #实验性功能
@@ -194,14 +184,17 @@ pfc_chatting = false # 是否启用PFC聊天该功能仅作用于私聊
# stream = <true|false> : 用于指定模型是否是使用流式输出 # stream = <true|false> : 用于指定模型是否是使用流式输出
# 如果不指定,则该项是 False # 如果不指定,则该项是 False
[model]
model_max_output_length = 800 # 模型单次返回的最大token数
#这个模型必须是推理模型 #这个模型必须是推理模型
[model.llm_reasoning] # 一般聊天模式的推理回复模型 [model.reasoning] # 一般聊天模式的推理回复模型
name = "Pro/deepseek-ai/DeepSeek-R1" name = "Pro/deepseek-ai/DeepSeek-R1"
provider = "SILICONFLOW" provider = "SILICONFLOW"
pri_in = 1.0 #模型的输入价格(非必填,可以记录消耗) pri_in = 1.0 #模型的输入价格(非必填,可以记录消耗)
pri_out = 4.0 #模型的输出价格(非必填,可以记录消耗) pri_out = 4.0 #模型的输出价格(非必填,可以记录消耗)
[model.llm_normal] #V3 回复模型 专注和一般聊天模式共用的回复模型 [model.normal] #V3 回复模型 专注和一般聊天模式共用的回复模型
name = "Pro/deepseek-ai/DeepSeek-V3" name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW" provider = "SILICONFLOW"
pri_in = 2 #模型的输入价格(非必填,可以记录消耗) pri_in = 2 #模型的输入价格(非必填,可以记录消耗)
@@ -209,13 +202,13 @@ pri_out = 8 #模型的输出价格(非必填,可以记录消耗)
#默认temp 0.2 如果你使用的是老V3或者其他模型请自己修改temp参数 #默认temp 0.2 如果你使用的是老V3或者其他模型请自己修改temp参数
temp = 0.2 #模型的温度新V3建议0.1-0.3 temp = 0.2 #模型的温度新V3建议0.1-0.3
[model.llm_topic_judge] #主题判断模型建议使用qwen2.5 7b [model.topic_judge] #主题判断模型建议使用qwen2.5 7b
name = "Pro/Qwen/Qwen2.5-7B-Instruct" name = "Pro/Qwen/Qwen2.5-7B-Instruct"
provider = "SILICONFLOW" provider = "SILICONFLOW"
pri_in = 0.35 pri_in = 0.35
pri_out = 0.35 pri_out = 0.35
[model.llm_summary] #概括模型建议使用qwen2.5 32b 及以上 [model.summary] #概括模型建议使用qwen2.5 32b 及以上
name = "Qwen/Qwen2.5-32B-Instruct" name = "Qwen/Qwen2.5-32B-Instruct"
provider = "SILICONFLOW" provider = "SILICONFLOW"
pri_in = 1.26 pri_in = 1.26
@@ -227,27 +220,27 @@ provider = "SILICONFLOW"
pri_in = 0.35 pri_in = 0.35
pri_out = 0.35 pri_out = 0.35
[model.llm_heartflow] # 用于控制麦麦是否参与聊天的模型 [model.heartflow] # 用于控制麦麦是否参与聊天的模型
name = "Qwen/Qwen2.5-32B-Instruct" name = "Qwen/Qwen2.5-32B-Instruct"
provider = "SILICONFLOW" provider = "SILICONFLOW"
pri_in = 1.26 pri_in = 1.26
pri_out = 1.26 pri_out = 1.26
[model.llm_observation] #观察模型,压缩聊天内容,建议用免费的 [model.observation] #观察模型,压缩聊天内容,建议用免费的
# name = "Pro/Qwen/Qwen2.5-7B-Instruct" # name = "Pro/Qwen/Qwen2.5-7B-Instruct"
name = "Qwen/Qwen2.5-7B-Instruct" name = "Qwen/Qwen2.5-7B-Instruct"
provider = "SILICONFLOW" provider = "SILICONFLOW"
pri_in = 0 pri_in = 0
pri_out = 0 pri_out = 0
[model.llm_sub_heartflow] #心流:认真水群时,生成麦麦的内心想法,必须使用具有工具调用能力的模型 [model.sub_heartflow] #心流:认真水群时,生成麦麦的内心想法,必须使用具有工具调用能力的模型
name = "Pro/deepseek-ai/DeepSeek-V3" name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW" provider = "SILICONFLOW"
pri_in = 2 pri_in = 2
pri_out = 8 pri_out = 8
temp = 0.3 #模型的温度新V3建议0.1-0.3 temp = 0.3 #模型的温度新V3建议0.1-0.3
[model.llm_plan] #决策:认真水群时,负责决定麦麦该做什么 [model.plan] #决策:认真水群时,负责决定麦麦该做什么
name = "Pro/deepseek-ai/DeepSeek-V3" name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW" provider = "SILICONFLOW"
pri_in = 2 pri_in = 2
@@ -265,7 +258,7 @@ pri_out = 0
#私聊PFC需要开启PFC功能默认三个模型均为硅基流动v3如果需要支持多人同时私聊或频繁调用建议把其中的一个或两个换成官方v3或其它模型以免撞到429 #私聊PFC需要开启PFC功能默认三个模型均为硅基流动v3如果需要支持多人同时私聊或频繁调用建议把其中的一个或两个换成官方v3或其它模型以免撞到429
#PFC决策模型 #PFC决策模型
[model.llm_PFC_action_planner] [model.pfc_action_planner]
name = "Pro/deepseek-ai/DeepSeek-V3" name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW" provider = "SILICONFLOW"
temp = 0.3 temp = 0.3
@@ -273,7 +266,7 @@ pri_in = 2
pri_out = 8 pri_out = 8
#PFC聊天模型 #PFC聊天模型
[model.llm_PFC_chat] [model.pfc_chat]
name = "Pro/deepseek-ai/DeepSeek-V3" name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW" provider = "SILICONFLOW"
temp = 0.3 temp = 0.3
@@ -281,7 +274,7 @@ pri_in = 2
pri_out = 8 pri_out = 8
#PFC检查模型 #PFC检查模型
[model.llm_PFC_reply_checker] [model.pfc_reply_checker]
name = "Pro/deepseek-ai/DeepSeek-V3" name = "Pro/deepseek-ai/DeepSeek-V3"
provider = "SILICONFLOW" provider = "SILICONFLOW"
pri_in = 2 pri_in = 2
@@ -294,7 +287,7 @@ pri_out = 8
#以下模型暂时没有使用!! #以下模型暂时没有使用!!
#以下模型暂时没有使用!! #以下模型暂时没有使用!!
[model.llm_tool_use] #工具调用模型需要使用支持工具调用的模型建议使用qwen2.5 32b [model.tool_use] #工具调用模型需要使用支持工具调用的模型建议使用qwen2.5 32b
name = "Qwen/Qwen2.5-32B-Instruct" name = "Qwen/Qwen2.5-32B-Instruct"
provider = "SILICONFLOW" provider = "SILICONFLOW"
pri_in = 1.26 pri_in = 1.26

7
tests/test_config.py Normal file
View File

@@ -0,0 +1,7 @@
from src.config.config import global_config
class TestConfig:
def test_load(self):
config = global_config
print(config)