diff --git a/README.md b/README.md index f07e7d57f..17a8da37b 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,18 @@ # 麦麦!MaiCore-MaiMBot (编辑中)
-
- - ![Python Version](https://img.shields.io/badge/Python-3.10+-blue) - ![License](https://img.shields.io/github/license/SengokuCola/MaiMBot?label=协议) - ![Status](https://img.shields.io/badge/状态-开发中-yellow) - ![Contributors](https://img.shields.io/github/contributors/MaiM-with-u/MaiBot.svg?style=flat&label=贡献者) - ![forks](https://img.shields.io/github/forks/MaiM-with-u/MaiBot.svg?style=flat&label=分支数) - ![stars](https://img.shields.io/github/stars/MaiM-with-u/MaiBot?style=flat&label=星标数) - ![issues](https://img.shields.io/github/issues/MaiM-with-u/MaiBot) +
+![Python Version](https://img.shields.io/badge/Python-3.10+-blue) +![License](https://img.shields.io/github/license/SengokuCola/MaiMBot?label=协议) +![Status](https://img.shields.io/badge/状态-开发中-yellow) +![Contributors](https://img.shields.io/github/contributors/MaiM-with-u/MaiBot.svg?style=flat&label=贡献者) +![forks](https://img.shields.io/github/forks/MaiM-with-u/MaiBot.svg?style=flat&label=分支数) +![stars](https://img.shields.io/github/stars/MaiM-with-u/MaiBot?style=flat&label=星标数) +![issues](https://img.shields.io/github/issues/MaiM-with-u/MaiBot) +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/DrSmoothl/MaiBot)
-

+

Logo @@ -21,8 +21,8 @@ 画师:略nd -

MaiBot(麦麦)

-

+

MaiBot(麦麦)

+

一款专注于 群组聊天 的赛博网友
探索本项目的文档 » diff --git a/src/api/reload_config.py b/src/api/reload_config.py index a5f36e3db..1772800b6 100644 --- a/src/api/reload_config.py +++ b/src/api/reload_config.py @@ -1,6 +1,6 @@ from fastapi import HTTPException from rich.traceback import install -from src.config.config import BotConfig +from src.config.config import Config from src.common.logger_manager import get_logger import os @@ -14,8 +14,8 @@ async def reload_config(): from src.config import config as config_module logger.debug("正在重载配置文件...") - bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml") - config_module.global_config = BotConfig.load_config(config_path=bot_config_path) + bot_config_path = os.path.join(Config.get_config_dir(), "bot_config.toml") + config_module.global_config = Config.load_config(config_path=bot_config_path) logger.debug("配置文件重载成功") return {"status": "reloaded"} except FileNotFoundError as e: diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 5d800866f..fda0a63fd 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -5,12 +5,15 @@ import os import random import time import traceback -from typing import Optional, Tuple +from typing import Optional, Tuple, List, Any from PIL import Image import io import re -from ...common.database import db +# from gradio_client import file + +from ...common.database.database_model import Emoji +from ...common.database.database import db as peewee_db from ...config.config import global_config from ..utils.utils_image import image_path_to_base64, image_manager from ..models.utils_model import LLMRequest @@ -51,7 +54,7 @@ class MaiEmoji: self.is_deleted = False # 标记是否已被删除 self.format = "" - async def initialize_hash_format(self): + async def initialize_hash_format(self) -> Optional[bool]: """从文件创建表情包实例, 计算哈希值和格式""" try: # 使用 full_path 检查文件是否存在 @@ -104,7 +107,7 @@ class MaiEmoji: self.is_deleted = True return None - async def register_to_db(self): + async def register_to_db(self) -> bool: """ 注册表情包 将表情包对应的文件,从当前路径移动到EMOJI_REGISTED_DIR目录下 @@ -143,22 +146,22 @@ class MaiEmoji: # --- 数据库操作 --- try: # 准备数据库记录 for emoji collection - emoji_record = { - "filename": self.filename, - "path": self.path, # 存储目录路径 - "full_path": self.full_path, # 存储完整文件路径 - "embedding": self.embedding, - "description": self.description, - "emotion": self.emotion, - "hash": self.hash, - "format": self.format, - "timestamp": int(self.register_time), - "usage_count": self.usage_count, - "last_used_time": self.last_used_time, - } + emotion_str = ",".join(self.emotion) if self.emotion else "" - # 使用upsert确保记录存在或被更新 - db["emoji"].update_one({"hash": self.hash}, {"$set": emoji_record}, upsert=True) + Emoji.create( + hash=self.hash, + full_path=self.full_path, + format=self.format, + description=self.description, + emotion=emotion_str, # Store as comma-separated string + query_count=0, # Default value + is_registered=True, + is_banned=False, # Default value + record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time + register_time=self.register_time, + usage_count=self.usage_count, + last_used_time=self.last_used_time, + ) logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") @@ -166,14 +169,6 @@ class MaiEmoji: except Exception as db_error: logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}") - # 数据库保存失败,是否需要将文件移回?为了简化,暂时只记录错误 - # 可以考虑在这里尝试删除已移动的文件,避免残留 - try: - if os.path.exists(self.full_path): # full_path 此时是目标路径 - os.remove(self.full_path) - logger.warning(f"[回滚] 已删除移动失败后残留的文件: {self.full_path}") - except Exception as remove_error: - logger.error(f"[错误] 回滚删除文件失败: {remove_error}") return False except Exception as e: @@ -181,7 +176,7 @@ class MaiEmoji: logger.error(traceback.format_exc()) return False - async def delete(self): + async def delete(self) -> bool: """删除表情包 删除表情包的文件和数据库记录 @@ -201,10 +196,14 @@ class MaiEmoji: # 文件删除失败,但仍然尝试删除数据库记录 # 2. 删除数据库记录 - result = db.emoji.delete_one({"hash": self.hash}) - deleted_in_db = result.deleted_count > 0 + try: + will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash) + result = will_delete_emoji.delete_instance() # Returns the number of rows deleted. + except Emoji.DoesNotExist: + logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") + result = 0 # Indicate no DB record was deleted - if deleted_in_db: + if result > 0: logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})") # 3. 标记对象已被删除 self.is_deleted = True @@ -224,7 +223,7 @@ class MaiEmoji: return False -def _emoji_objects_to_readable_list(emoji_objects): +def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str]: """将表情包对象列表转换为可读的字符串列表 参数: @@ -243,47 +242,48 @@ def _emoji_objects_to_readable_list(emoji_objects): return emoji_info_list -def _to_emoji_objects(data): +def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]: emoji_objects = [] load_errors = 0 + # data is now an iterable of Peewee Emoji model instances emoji_data_list = list(data) - for emoji_data in emoji_data_list: - full_path = emoji_data.get("full_path") + for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance + full_path = emoji_data.full_path if not full_path: - logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: {emoji_data.get('_id')}") + logger.warning( + f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}" + ) load_errors += 1 - continue # 跳过缺少 full_path 的记录 + continue try: - # 使用 full_path 初始化 MaiEmoji 对象 emoji = MaiEmoji(full_path=full_path) - # 设置从数据库加载的属性 - emoji.hash = emoji_data.get("hash", "") - # 如果 hash 为空,也跳过?取决于业务逻辑 + emoji.hash = emoji_data.emoji_hash if not emoji.hash: logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}") load_errors += 1 continue - emoji.description = emoji_data.get("description", "") - emoji.emotion = emoji_data.get("emotion", []) - emoji.usage_count = emoji_data.get("usage_count", 0) - # 优先使用 last_used_time,否则用 timestamp,最后用当前时间 - last_used = emoji_data.get("last_used_time") - timestamp = emoji_data.get("timestamp") - emoji.last_used_time = ( - last_used if last_used is not None else (timestamp if timestamp is not None else time.time()) - ) - emoji.register_time = timestamp if timestamp is not None else time.time() - emoji.format = emoji_data.get("format", "") # 加载格式 + emoji.description = emoji_data.description + # Deserialize emotion string from DB to list + emoji.emotion = emoji_data.emotion.split(",") if emoji_data.emotion else [] + emoji.usage_count = emoji_data.usage_count - # 不需要再手动设置 path 和 filename,__init__ 会自动处理 + db_last_used_time = emoji_data.last_used_time + db_register_time = emoji_data.register_time + + # If last_used_time from DB is None, use MaiEmoji's initialized register_time or current time + emoji.last_used_time = db_last_used_time if db_last_used_time is not None else emoji.register_time + # If register_time from DB is None, use MaiEmoji's initialized register_time (which is time.time()) + emoji.register_time = db_register_time if db_register_time is not None else emoji.register_time + + emoji.format = emoji_data.format emoji_objects.append(emoji) - except ValueError as ve: # 捕获 __init__ 可能的错误 + except ValueError as ve: logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}") load_errors += 1 except Exception as e: @@ -292,13 +292,13 @@ def _to_emoji_objects(data): return emoji_objects, load_errors -def _ensure_emoji_dir(): +def _ensure_emoji_dir() -> None: """确保表情存储目录存在""" os.makedirs(EMOJI_DIR, exist_ok=True) os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True) -async def clear_temp_emoji(): +async def clear_temp_emoji() -> None: """清理临时表情包 清理/data/emoji和/data/image目录下的所有文件 当目录中文件数超过100时,会全部删除 @@ -320,7 +320,7 @@ async def clear_temp_emoji(): logger.success("[清理] 完成") -async def clean_unused_emojis(emoji_dir, emoji_objects): +async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"]) -> None: """清理指定目录中未被 emoji_objects 追踪的表情包文件""" if not os.path.exists(emoji_dir): logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}") @@ -360,74 +360,52 @@ async def clean_unused_emojis(emoji_dir, emoji_objects): class EmojiManager: _instance = None - def __new__(cls): + def __new__(cls) -> "EmojiManager": if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance - def __init__(self): + def __init__(self) -> None: self._initialized = None self._scan_task = None - self.vlm = LLMRequest(model=global_config.vlm, temperature=0.3, max_tokens=1000, request_type="emoji") + + self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.3, max_tokens=1000, request_type="emoji") self.llm_emotion_judge = LLMRequest( - model=global_config.llm_normal, max_tokens=600, request_type="emoji" + model=global_config.model.normal, max_tokens=600, request_type="emoji" ) # 更高的温度,更少的token(后续可以根据情绪来调整温度) self.emoji_num = 0 - self.emoji_num_max = global_config.max_emoji_num - self.emoji_num_max_reach_deletion = global_config.max_reach_deletion + self.emoji_num_max = global_config.emoji.max_reg_num + self.emoji_num_max_reach_deletion = global_config.emoji.do_replace self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表,使用类型注解明确列表元素类型 logger.info("启动表情包管理器") - def initialize(self): + def initialize(self) -> None: """初始化数据库连接和表情目录""" - if not self._initialized: - try: - self._ensure_emoji_collection() - _ensure_emoji_dir() - self._initialized = True - # 更新表情包数量 - # 启动时执行一次完整性检查 - # await self.check_emoji_file_integrity() - except Exception as e: - logger.exception(f"初始化表情管理器失败: {e}") + peewee_db.connect(reuse_if_open=True) + if peewee_db.is_closed(): + raise RuntimeError("数据库连接失败") + _ensure_emoji_dir() + Emoji.create_table(safe=True) # Ensures table exists - def _ensure_db(self): + def _ensure_db(self) -> None: """确保数据库已初始化""" if not self._initialized: self.initialize() if not self._initialized: raise RuntimeError("EmojiManager not initialized") - @staticmethod - def _ensure_emoji_collection(): - """确保emoji集合存在并创建索引 - - 这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。 - - 索引的作用是加快数据库查询速度: - - embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包 - - tags字段的普通索引: 加快按标签搜索表情包的速度 - - filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度 - - 没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。 - """ - if "emoji" not in db.list_collection_names(): - db.create_collection("emoji") - db.emoji.create_index([("embedding", "2dsphere")]) - db.emoji.create_index([("filename", 1)], unique=True) - - def record_usage(self, emoji_hash: str): + def record_usage(self, emoji_hash: str) -> None: """记录表情使用次数""" try: - db.emoji.update_one({"hash": emoji_hash}, {"$inc": {"usage_count": 1}}) - for emoji in self.emoji_objects: - if emoji.hash == emoji_hash: - emoji.usage_count += 1 - break - + emoji_update = Emoji.get(Emoji.emoji_hash == emoji_hash) + emoji_update.usage_count += 1 + emoji_update.last_used_time = time.time() # Update last used time + emoji_update.save() # Persist changes to DB + except Emoji.DoesNotExist: + logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包") except Exception as e: logger.error(f"记录表情使用失败: {str(e)}") @@ -447,7 +425,6 @@ class EmojiManager: if not all_emojis: logger.warning("内存中没有任何表情包对象") - # 可以考虑再查一次数据库?或者依赖定期任务更新 return None # 计算每个表情包与输入文本的最大情感相似度 @@ -463,40 +440,38 @@ class EmojiManager: # 计算与每个emotion标签的相似度,取最大值 max_similarity = 0 - best_matching_emotion = "" # 记录最匹配的 emotion 喵~ + best_matching_emotion = "" for emotion in emotions: # 使用编辑距离计算相似度 distance = self._levenshtein_distance(text_emotion, emotion) max_len = max(len(text_emotion), len(emotion)) similarity = 1 - (distance / max_len if max_len > 0 else 0) - if similarity > max_similarity: # 如果找到更相似的喵~ + if similarity > max_similarity: max_similarity = similarity - best_matching_emotion = emotion # 就记下这个 emotion 喵~ + best_matching_emotion = emotion - if best_matching_emotion: # 确保有匹配的情感才添加喵~ - emoji_similarities.append((emoji, max_similarity, best_matching_emotion)) # 把 emotion 也存起来喵~ + if best_matching_emotion: + emoji_similarities.append((emoji, max_similarity, best_matching_emotion)) # 按相似度降序排序 emoji_similarities.sort(key=lambda x: x[1], reverse=True) # 获取前10个最相似的表情包 - top_emojis = ( - emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities - ) # 改个名字,更清晰喵~ + top_emojis = emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities if not top_emojis: logger.warning("未找到匹配的表情包") return None # 从前几个中随机选择一个 - selected_emoji, similarity, matched_emotion = random.choice(top_emojis) # 把匹配的 emotion 也拿出来喵~ + selected_emoji, similarity, matched_emotion = random.choice(top_emojis) # 更新使用次数 - self.record_usage(selected_emoji.hash) + self.record_usage(selected_emoji.emoji_hash) _time_end = time.time() - logger.info( # 使用匹配到的 emotion 记录日志喵~ + logger.info( f"为[{text_emotion}]找到表情包: {matched_emotion} ({selected_emoji.filename}), Similarity: {similarity:.4f}" ) # 返回完整文件路径和描述 @@ -534,7 +509,7 @@ class EmojiManager: return previous_row[-1] - async def check_emoji_file_integrity(self): + async def check_emoji_file_integrity(self) -> None: """检查表情包文件完整性 遍历self.emoji_objects中的所有对象,检查文件是否存在 如果文件已被删除,则执行对象的删除方法并从列表中移除 @@ -599,7 +574,7 @@ class EmojiManager: logger.error(f"[错误] 检查表情包完整性失败: {str(e)}") logger.error(traceback.format_exc()) - async def start_periodic_check_register(self): + async def start_periodic_check_register(self) -> None: """定期检查表情包完整性和数量""" await self.get_all_emoji_from_db() while True: @@ -613,18 +588,18 @@ class EmojiManager: logger.warning(f"[警告] 表情包目录不存在: {EMOJI_DIR}") os.makedirs(EMOJI_DIR, exist_ok=True) logger.info(f"[创建] 已创建表情包目录: {EMOJI_DIR}") - await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60) + await asyncio.sleep(global_config.emoji.check_interval * 60) continue # 检查目录是否为空 files = os.listdir(EMOJI_DIR) if not files: logger.warning(f"[警告] 表情包目录为空: {EMOJI_DIR}") - await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60) + await asyncio.sleep(global_config.emoji.check_interval * 60) continue # 检查是否需要处理表情包(数量超过最大值或不足) - if (self.emoji_num > self.emoji_num_max and global_config.max_reach_deletion) or ( + if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or ( self.emoji_num < self.emoji_num_max ): try: @@ -651,15 +626,16 @@ class EmojiManager: except Exception as e: logger.error(f"[错误] 扫描表情包目录失败: {str(e)}") - await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60) + await asyncio.sleep(global_config.emoji.check_interval * 60) - async def get_all_emoji_from_db(self): + async def get_all_emoji_from_db(self) -> None: """获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects""" try: self._ensure_db() - logger.info("[数据库] 开始加载所有表情包记录...") + logger.info("[数据库] 开始加载所有表情包记录 (Peewee)...") - emoji_objects, load_errors = _to_emoji_objects(db.emoji.find()) + emoji_peewee_instances = Emoji.select() + emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances) # 更新内存中的列表和数量 self.emoji_objects = emoji_objects @@ -674,7 +650,7 @@ class EmojiManager: self.emoji_objects = [] # 加载失败则清空列表 self.emoji_num = 0 - async def get_emoji_from_db(self, emoji_hash=None): + async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List["MaiEmoji"]: """获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找) 参数: @@ -686,15 +662,16 @@ class EmojiManager: try: self._ensure_db() - query = {} if emoji_hash: - query = {"hash": emoji_hash} + query = Emoji.select().where(Emoji.emoji_hash == emoji_hash) else: logger.warning( "[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。" ) + query = Emoji.select() - emoji_objects, load_errors = _to_emoji_objects(db.emoji.find(query)) + emoji_peewee_instances = query + emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances) if load_errors > 0: logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。") @@ -705,7 +682,7 @@ class EmojiManager: logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}") return [] - async def get_emoji_from_manager(self, emoji_hash) -> Optional[MaiEmoji]: + async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]: """从内存中的 emoji_objects 列表获取表情包 参数: @@ -758,7 +735,7 @@ class EmojiManager: logger.error(traceback.format_exc()) return False - async def replace_a_emoji(self, new_emoji: MaiEmoji): + async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool: """替换一个表情包 Args: @@ -788,7 +765,7 @@ class EmojiManager: # 构建提示词 prompt = ( - f"{global_config.BOT_NICKNAME}的表情包存储已满({self.emoji_num}/{self.emoji_num_max})," + f"{global_config.bot.nickname}的表情包存储已满({self.emoji_num}/{self.emoji_num_max})," f"需要决定是否删除一个旧表情包来为新表情包腾出空间。\n\n" f"新表情包信息:\n" f"描述: {new_emoji.description}\n\n" @@ -819,7 +796,7 @@ class EmojiManager: # 删除选定的表情包 logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}") - delete_success = await self.delete_emoji(emoji_to_delete.hash) + delete_success = await self.delete_emoji(emoji_to_delete.emoji_hash) if delete_success: # 修复:等待异步注册完成 @@ -847,7 +824,7 @@ class EmojiManager: logger.error(traceback.format_exc()) return False - async def build_emoji_description(self, image_base64: str) -> Tuple[str, list]: + async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]: """获取表情包描述和情感列表 Args: @@ -871,10 +848,10 @@ class EmojiManager: description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) # 审核表情包 - if global_config.EMOJI_CHECK: + if global_config.emoji.content_filtration: prompt = f''' 这是一个表情包,请对这个表情包进行审核,标准如下: - 1. 必须符合"{global_config.EMOJI_CHECK_PROMPT}"的要求 + 1. 必须符合"{global_config.emoji.filtration_prompt}"的要求 2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗 3. 不能是任何形式的截图,聊天记录或视频截图 4. 不要出现5个以上文字 diff --git a/src/chat/focus_chat/expressors/default_expressor.py b/src/chat/focus_chat/expressors/default_expressor.py index 37b634b37..ccbc1ca56 100644 --- a/src/chat/focus_chat/expressors/default_expressor.py +++ b/src/chat/focus_chat/expressors/default_expressor.py @@ -76,9 +76,10 @@ def init_prompt(): class DefaultExpressor: def __init__(self, chat_id: str): self.log_prefix = "expressor" + # TODO: API-Adapter修改标记 self.express_model = LLMRequest( - model=global_config.llm_normal, - temperature=global_config.llm_normal["temp"], + model=global_config.model.normal, + temperature=global_config.model.normal["temp"], max_tokens=256, request_type="response_heartflow", ) @@ -102,8 +103,8 @@ class DefaultExpressor: messageinfo = anchor_message.message_info thinking_time_point = parse_thinking_id_to_timestamp(thinking_id) bot_user_info = UserInfo( - user_id=global_config.BOT_QQ, - user_nickname=global_config.BOT_NICKNAME, + user_id=global_config.bot.qq_account, + user_nickname=global_config.bot.nickname, platform=messageinfo.platform, ) # logger.debug(f"创建思考消息:{anchor_message}") @@ -192,7 +193,7 @@ class DefaultExpressor: try: # 1. 获取情绪影响因子并调整模型温度 arousal_multiplier = mood_manager.get_arousal_multiplier() - current_temp = float(global_config.llm_normal["temp"]) * arousal_multiplier + current_temp = float(global_config.model.normal["temp"]) * arousal_multiplier self.express_model.params["temperature"] = current_temp # 动态调整温度 # 2. 获取信息捕捉器 @@ -231,6 +232,7 @@ class DefaultExpressor: try: with Timer("LLM生成", {}): # 内部计时器,可选保留 + # TODO: API-Adapter修改标记 # logger.info(f"{self.log_prefix}[Replier-{thinking_id}]\nPrompt:\n{prompt}\n") content, reasoning_content, model_name = await self.express_model.generate_response(prompt) @@ -482,8 +484,8 @@ class DefaultExpressor: """构建单个发送消息""" bot_user_info = UserInfo( - user_id=global_config.BOT_QQ, - user_nickname=global_config.BOT_NICKNAME, + user_id=global_config.bot.qq_account, + user_nickname=global_config.bot.nickname, platform=self.chat_stream.platform, ) diff --git a/src/chat/focus_chat/expressors/exprssion_learner.py b/src/chat/focus_chat/expressors/exprssion_learner.py index 942162bc8..7766fde56 100644 --- a/src/chat/focus_chat/expressors/exprssion_learner.py +++ b/src/chat/focus_chat/expressors/exprssion_learner.py @@ -77,8 +77,9 @@ def init_prompt() -> None: class ExpressionLearner: def __init__(self) -> None: + # TODO: API-Adapter修改标记 self.express_learn_model: LLMRequest = LLMRequest( - model=global_config.llm_normal, + model=global_config.model.normal, temperature=0.1, max_tokens=256, request_type="response_heartflow", @@ -289,7 +290,7 @@ class ExpressionLearner: # 构建prompt prompt = await global_prompt_manager.format_prompt( "personality_expression_prompt", - personality=global_config.expression_style, + personality=global_config.personality.expression_style, ) # logger.info(f"个性表达方式提取prompt: {prompt}") diff --git a/src/chat/focus_chat/heartflow_processor.py b/src/chat/focus_chat/heartflow_processor.py index bbfa4ce46..a4cf360a5 100644 --- a/src/chat/focus_chat/heartflow_processor.py +++ b/src/chat/focus_chat/heartflow_processor.py @@ -112,7 +112,7 @@ def _check_ban_words(text: str, chat, userinfo) -> bool: Returns: bool: 是否包含过滤词 """ - for word in global_config.ban_words: + for word in global_config.chat.ban_words: if word in text: chat_name = chat.group_info.group_name if chat.group_info else "私聊" logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}") @@ -132,7 +132,7 @@ def _check_ban_regex(text: str, chat, userinfo) -> bool: Returns: bool: 是否匹配过滤正则 """ - for pattern in global_config.ban_msgs_regex: + for pattern in global_config.chat.ban_msgs_regex: if pattern.search(text): chat_name = chat.group_info.group_name if chat.group_info else "私聊" logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}") diff --git a/src/chat/focus_chat/heartflow_prompt_builder.py b/src/chat/focus_chat/heartflow_prompt_builder.py index 74bac0a1f..af526eb88 100644 --- a/src/chat/focus_chat/heartflow_prompt_builder.py +++ b/src/chat/focus_chat/heartflow_prompt_builder.py @@ -13,6 +13,9 @@ from src.manager.mood_manager import mood_manager from src.chat.memory_system.Hippocampus import HippocampusManager from src.chat.knowledge.knowledge_lib import qa_manager import random +import json +import math +from src.common.database.database_model import Knowledges logger = get_logger("prompt") @@ -45,7 +48,7 @@ def init_prompt(): 你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt},{reply_style1}, 尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}。{prompt_ger} 请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。 -请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。 +请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容。 {moderation_prompt} 不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容""", "reasoning_prompt_main", @@ -110,7 +113,7 @@ class PromptBuilder: who_chat_in_group = get_recent_group_speaker( chat_stream.stream_id, (chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None, - limit=global_config.observation_context_size, + limit=global_config.chat.observation_context_size, ) elif chat_stream.user_info: who_chat_in_group.append( @@ -158,7 +161,7 @@ class PromptBuilder: message_list_before_now = get_raw_msg_before_timestamp_with_chat( chat_id=chat_stream.stream_id, timestamp=time.time(), - limit=global_config.observation_context_size, + limit=global_config.chat.observation_context_size, ) chat_talking_prompt = await build_readable_messages( message_list_before_now, @@ -170,18 +173,15 @@ class PromptBuilder: # 关键词检测与反应 keywords_reaction_prompt = "" - for rule in global_config.keywords_reaction_rules: - if rule.get("enable", False): - if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])): - logger.info( - f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}" - ) - keywords_reaction_prompt += rule.get("reaction", "") + "," + for rule in global_config.keyword_reaction.rules: + if rule.enable: + if any(keyword in message_txt for keyword in rule.keywords): + logger.info(f"检测到以下关键词之一:{rule.keywords},触发反应:{rule.reaction}") + keywords_reaction_prompt += f"{rule.reaction}," else: - for pattern in rule.get("regex", []): - result = pattern.search(message_txt) - if result: - reaction = rule.get("reaction", "") + for pattern in rule.regex: + if result := pattern.search(message_txt): + reaction = rule.reaction for name, content in result.groupdict().items(): reaction = reaction.replace(f"[{name}]", content) logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}") @@ -227,8 +227,8 @@ class PromptBuilder: chat_target_2=chat_target_2, chat_talking_prompt=chat_talking_prompt, message_txt=message_txt, - bot_name=global_config.BOT_NICKNAME, - bot_other_names="/".join(global_config.BOT_ALIAS_NAMES), + bot_name=global_config.bot.nickname, + bot_other_names="/".join(global_config.bot.alias_names), prompt_personality=prompt_personality, mood_prompt=mood_prompt, reply_style1=reply_style1_chosen, @@ -249,8 +249,8 @@ class PromptBuilder: prompt_info=prompt_info, chat_talking_prompt=chat_talking_prompt, message_txt=message_txt, - bot_name=global_config.BOT_NICKNAME, - bot_other_names="/".join(global_config.BOT_ALIAS_NAMES), + bot_name=global_config.bot.nickname, + bot_other_names="/".join(global_config.bot.alias_names), prompt_personality=prompt_personality, mood_prompt=mood_prompt, reply_style1=reply_style1_chosen, @@ -269,30 +269,6 @@ class PromptBuilder: logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") # 1. 先从LLM获取主题,类似于记忆系统的做法 topics = [] - # try: - # # 先尝试使用记忆系统的方法获取主题 - # hippocampus = HippocampusManager.get_instance()._hippocampus - # topic_num = min(5, max(1, int(len(message) * 0.1))) - # topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num)) - - # # 提取关键词 - # topics = re.findall(r"<([^>]+)>", topics_response[0]) - # if not topics: - # topics = [] - # else: - # topics = [ - # topic.strip() - # for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") - # if topic.strip() - # ] - - # logger.info(f"从LLM提取的主题: {', '.join(topics)}") - # except Exception as e: - # logger.error(f"从LLM提取主题失败: {str(e)}") - # # 如果LLM提取失败,使用jieba分词提取关键词作为备选 - # words = jieba.cut(message) - # topics = [word for word in words if len(word) > 1][:5] - # logger.info(f"使用jieba提取的主题: {', '.join(topics)}") # 如果无法提取到主题,直接使用整个消息 if not topics: @@ -402,8 +378,6 @@ class PromptBuilder: for _i, result in enumerate(results, 1): _similarity = result["similarity"] content = result["content"].strip() - # 调试:为内容添加序号和相似度信息 - # related_info += f"{i}. [{similarity:.2f}] {content}\n" related_info += f"{content}\n" related_info += "\n" @@ -432,14 +406,14 @@ class PromptBuilder: return related_info else: logger.debug("从LPMM知识库获取知识失败,使用旧版数据库进行检索") - knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38) + knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold) related_info += knowledge_from_old logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") return related_info except Exception as e: logger.error(f"获取知识库内容时发生异常: {str(e)}") try: - knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38) + knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold) related_info += knowledge_from_old logger.debug( f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}" @@ -455,70 +429,70 @@ class PromptBuilder: ) -> Union[str, list]: if not query_embedding: return "" if not return_raw else [] - # 使用余弦相似度计算 - pipeline = [ - { - "$addFields": { - "dotProduct": { - "$reduce": { - "input": {"$range": [0, {"$size": "$embedding"}]}, - "initialValue": 0, - "in": { - "$add": [ - "$$value", - { - "$multiply": [ - {"$arrayElemAt": ["$embedding", "$$this"]}, - {"$arrayElemAt": [query_embedding, "$$this"]}, - ] - }, - ] - }, - } - }, - "magnitude1": { - "$sqrt": { - "$reduce": { - "input": "$embedding", - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - "magnitude2": { - "$sqrt": { - "$reduce": { - "input": query_embedding, - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - } - }, - {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, - { - "$match": { - "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 - } - }, - {"$sort": {"similarity": -1}}, - {"$limit": limit}, - {"$project": {"content": 1, "similarity": 1}}, - ] - results = list(db.knowledges.aggregate(pipeline)) - logger.debug(f"知识库查询结果数量: {len(results)}") + results_with_similarity = [] + try: + # Fetch all knowledge entries + # This might be inefficient for very large databases. + # Consider strategies like FAISS or other vector search libraries if performance becomes an issue. + all_knowledges = Knowledges.select() - if not results: + if not all_knowledges: + return [] if return_raw else "" + + query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding)) + if query_embedding_magnitude == 0: # Avoid division by zero + return "" if not return_raw else [] + + for knowledge_item in all_knowledges: + try: + db_embedding_str = knowledge_item.embedding + db_embedding = json.loads(db_embedding_str) + + if len(db_embedding) != len(query_embedding): + logger.warning( + f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping." + ) + continue + + # Calculate Cosine Similarity + dot_product = sum(q * d for q, d in zip(query_embedding, db_embedding)) + db_embedding_magnitude = math.sqrt(sum(x * x for x in db_embedding)) + + if db_embedding_magnitude == 0: # Avoid division by zero + similarity = 0.0 + else: + similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude) + + if similarity >= threshold: + results_with_similarity.append({"content": knowledge_item.content, "similarity": similarity}) + except json.JSONDecodeError: + logger.error( + f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}" + ) + except Exception as e: + logger.error(f"Error processing knowledge item: {e}") + + # Sort by similarity in descending order + results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True) + + # Limit results + limited_results = results_with_similarity[:limit] + + logger.debug(f"知识库查询结果数量 (after Peewee processing): {len(limited_results)}") + + if not limited_results: + return "" if not return_raw else [] + + if return_raw: + return limited_results + else: + return "\n".join(str(result["content"]) for result in limited_results) + + except Exception as e: + logger.error(f"Error querying Knowledges with Peewee: {e}") return "" if not return_raw else [] - if return_raw: - return results - else: - # 返回所有找到的内容,用换行分隔 - return "\n".join(str(result["content"]) for result in results) - init_prompt() prompt_builder = PromptBuilder() diff --git a/src/chat/focus_chat/info_processors/chattinginfo_processor.py b/src/chat/focus_chat/info_processors/chattinginfo_processor.py index bb565ee7e..8d1eb9793 100644 --- a/src/chat/focus_chat/info_processors/chattinginfo_processor.py +++ b/src/chat/focus_chat/info_processors/chattinginfo_processor.py @@ -26,8 +26,9 @@ class ChattingInfoProcessor(BaseProcessor): def __init__(self): """初始化观察处理器""" super().__init__() + # TODO: API-Adapter修改标记 self.llm_summary = LLMRequest( - model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation" + model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation" ) async def process_info( @@ -110,12 +111,12 @@ class ChattingInfoProcessor(BaseProcessor): "created_at": datetime.now().timestamp(), } - obs.mid_memorys.append(mid_memory) - if len(obs.mid_memorys) > obs.max_mid_memory_len: - obs.mid_memorys.pop(0) # 移除最旧的 + obs.mid_memories.append(mid_memory) + if len(obs.mid_memories) > obs.max_mid_memory_len: + obs.mid_memories.pop(0) # 移除最旧的 mid_memory_str = "之前聊天的内容概述是:\n" - for mid_memory_item in obs.mid_memorys: # 重命名循环变量以示区分 + for mid_memory_item in obs.mid_memories: # 重命名循环变量以示区分 time_diff = int((datetime.now().timestamp() - mid_memory_item["created_at"]) / 60) mid_memory_str += ( f"距离现在{time_diff}分钟前(聊天记录id:{mid_memory_item['id']}):{mid_memory_item['theme']}\n" diff --git a/src/chat/focus_chat/info_processors/mind_processor.py b/src/chat/focus_chat/info_processors/mind_processor.py index 09228174c..afd7921d4 100644 --- a/src/chat/focus_chat/info_processors/mind_processor.py +++ b/src/chat/focus_chat/info_processors/mind_processor.py @@ -71,8 +71,8 @@ class MindProcessor(BaseProcessor): self.subheartflow_id = subheartflow_id self.llm_model = LLMRequest( - model=global_config.llm_sub_heartflow, - temperature=global_config.llm_sub_heartflow["temp"], + model=global_config.model.sub_heartflow, + temperature=global_config.model.sub_heartflow["temp"], max_tokens=800, request_type="sub_heart_flow", ) diff --git a/src/chat/focus_chat/info_processors/tool_processor.py b/src/chat/focus_chat/info_processors/tool_processor.py index 92c1b607a..de9a9a216 100644 --- a/src/chat/focus_chat/info_processors/tool_processor.py +++ b/src/chat/focus_chat/info_processors/tool_processor.py @@ -49,7 +49,7 @@ class ToolProcessor(BaseProcessor): self.subheartflow_id = subheartflow_id self.log_prefix = f"[{subheartflow_id}:ToolExecutor] " self.llm_model = LLMRequest( - model=global_config.llm_tool_use, + model=global_config.model.tool_use, max_tokens=500, request_type="tool_execution", ) diff --git a/src/chat/focus_chat/memory_activator.py b/src/chat/focus_chat/memory_activator.py index dae310c06..4fcd37302 100644 --- a/src/chat/focus_chat/memory_activator.py +++ b/src/chat/focus_chat/memory_activator.py @@ -34,8 +34,9 @@ def init_prompt(): class MemoryActivator: def __init__(self): + # TODO: API-Adapter修改标记 self.summary_model = LLMRequest( - model=global_config.llm_summary, temperature=0.7, max_tokens=50, request_type="chat_observation" + model=global_config.model.summary, temperature=0.7, max_tokens=50, request_type="chat_observation" ) self.running_memory = [] diff --git a/src/chat/heart_flow/heartflow.py b/src/chat/heart_flow/heartflow.py index ad876bcf0..748c8331e 100644 --- a/src/chat/heart_flow/heartflow.py +++ b/src/chat/heart_flow/heartflow.py @@ -35,8 +35,9 @@ class Heartflow: self.subheartflow_manager: SubHeartflowManager = SubHeartflowManager(self.current_state) # LLM模型配置 + # TODO: API-Adapter修改标记 self.llm_model = LLMRequest( - model=global_config.llm_heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow" + model=global_config.model.heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow" ) # 外部依赖模块 diff --git a/src/chat/heart_flow/interest_chatting.py b/src/chat/heart_flow/interest_chatting.py index 45f7fe952..bce372b5c 100644 --- a/src/chat/heart_flow/interest_chatting.py +++ b/src/chat/heart_flow/interest_chatting.py @@ -20,9 +20,9 @@ MAX_REPLY_PROBABILITY = 1 class InterestChatting: def __init__( self, - decay_rate=global_config.default_decay_rate_per_second, + decay_rate=global_config.focus_chat.default_decay_rate_per_second, max_interest=MAX_INTEREST, - trigger_threshold=global_config.reply_trigger_threshold, + trigger_threshold=global_config.focus_chat.reply_trigger_threshold, max_probability=MAX_REPLY_PROBABILITY, ): # 基础属性初始化 diff --git a/src/chat/heart_flow/mai_state_manager.py b/src/chat/heart_flow/mai_state_manager.py index 7dea910e9..017656ad2 100644 --- a/src/chat/heart_flow/mai_state_manager.py +++ b/src/chat/heart_flow/mai_state_manager.py @@ -18,19 +18,14 @@ enable_unlimited_hfc_chat = True # 调试用:无限专注聊天 prevent_offline_state = True # 目前默认不启用OFFLINE状态 -# 不同状态下普通聊天的最大消息数 -base_normal_chat_num = global_config.base_normal_chat_num -base_focused_chat_num = global_config.base_focused_chat_num - - -MAX_NORMAL_CHAT_NUM_PEEKING = int(base_normal_chat_num / 2) -MAX_NORMAL_CHAT_NUM_NORMAL = base_normal_chat_num -MAX_NORMAL_CHAT_NUM_FOCUSED = base_normal_chat_num + 1 +MAX_NORMAL_CHAT_NUM_PEEKING = int(global_config.chat.base_normal_chat_num / 2) +MAX_NORMAL_CHAT_NUM_NORMAL = global_config.chat.base_normal_chat_num +MAX_NORMAL_CHAT_NUM_FOCUSED = global_config.chat.base_normal_chat_num + 1 # 不同状态下专注聊天的最大消息数 -MAX_FOCUSED_CHAT_NUM_PEEKING = int(base_focused_chat_num / 2) -MAX_FOCUSED_CHAT_NUM_NORMAL = base_focused_chat_num -MAX_FOCUSED_CHAT_NUM_FOCUSED = base_focused_chat_num + 2 +MAX_FOCUSED_CHAT_NUM_PEEKING = int(global_config.chat.base_focused_chat_num / 2) +MAX_FOCUSED_CHAT_NUM_NORMAL = global_config.chat.base_focused_chat_num +MAX_FOCUSED_CHAT_NUM_FOCUSED = global_config.chat.base_focused_chat_num + 2 # -- 状态定义 -- diff --git a/src/chat/heart_flow/observation/chatting_observation.py b/src/chat/heart_flow/observation/chatting_observation.py index 6bb72bca0..9ea18b471 100644 --- a/src/chat/heart_flow/observation/chatting_observation.py +++ b/src/chat/heart_flow/observation/chatting_observation.py @@ -55,19 +55,20 @@ class ChattingObservation(Observation): self.talking_message = [] self.talking_message_str = "" self.talking_message_str_truncate = "" - self.name = global_config.BOT_NICKNAME - self.nick_name = global_config.BOT_ALIAS_NAMES - self.max_now_obs_len = global_config.observation_context_size - self.overlap_len = global_config.compressed_length - self.mid_memorys = [] - self.max_mid_memory_len = global_config.compress_length_limit + self.name = global_config.bot.nickname + self.nick_name = global_config.bot.alias_names + self.max_now_obs_len = global_config.chat.observation_context_size + self.overlap_len = global_config.focus_chat.compressed_length + self.mid_memories = [] + self.max_mid_memory_len = global_config.focus_chat.compress_length_limit self.mid_memory_info = "" self.person_list = [] self.oldest_messages = [] self.oldest_messages_str = "" self.compressor_prompt = "" + # TODO: API-Adapter修改标记 self.llm_summary = LLMRequest( - model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation" + model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation" ) async def initialize(self): @@ -85,7 +86,7 @@ class ChattingObservation(Observation): for id in ids: print(f"id:{id}") try: - for mid_memory in self.mid_memorys: + for mid_memory in self.mid_memories: if mid_memory["id"] == id: mid_memory_by_id = mid_memory msg_str = "" @@ -103,7 +104,7 @@ class ChattingObservation(Observation): else: mid_memory_str = "之前的聊天内容:\n" - for mid_memory in self.mid_memorys: + for mid_memory in self.mid_memories: mid_memory_str += f"{mid_memory['theme']}\n" return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str diff --git a/src/chat/heart_flow/subheartflow_manager.py b/src/chat/heart_flow/subheartflow_manager.py index a4bff8338..bf4ddf7e1 100644 --- a/src/chat/heart_flow/subheartflow_manager.py +++ b/src/chat/heart_flow/subheartflow_manager.py @@ -76,8 +76,9 @@ class SubHeartflowManager: # 为 LLM 状态评估创建一个 LLMRequest 实例 # 使用与 Heartflow 相同的模型和参数 + # TODO: API-Adapter修改标记 self.llm_state_evaluator = LLMRequest( - model=global_config.llm_heartflow, # 与 Heartflow 一致 + model=global_config.model.heartflow, # 与 Heartflow 一致 temperature=0.6, # 与 Heartflow 一致 max_tokens=1000, # 与 Heartflow 一致 (虽然可能不需要这么多) request_type="subheartflow_state_eval", # 保留特定的请求类型 @@ -278,7 +279,7 @@ class SubHeartflowManager: focused_limit = current_state.get_focused_chat_max_num() # --- 新增:检查是否允许进入 FOCUS 模式 --- # - if not global_config.allow_focus_mode: + if not global_config.chat.allow_focus_mode: if int(time.time()) % 60 == 0: # 每60秒输出一次日志避免刷屏 logger.trace("未开启 FOCUSED 状态 (allow_focus_mode=False)") return # 如果不允许,直接返回 @@ -766,7 +767,7 @@ class SubHeartflowManager: focused_limit = current_mai_state.get_focused_chat_max_num() # --- 检查是否允许 FOCUS 模式 --- # - if not global_config.allow_focus_mode: + if not global_config.chat.allow_focus_mode: # Log less frequently to avoid spam # if int(time.time()) % 60 == 0: # logger.debug(f"{log_prefix_task} 配置不允许进入 FOCUSED 状态") diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 70eb679c9..2de769205 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -10,7 +10,7 @@ import jieba import networkx as nx import numpy as np from collections import Counter -from ...common.database import db +from ...common.database.database import memory_db as db from ...chat.models.utils_model import LLMRequest from src.common.logger_manager import get_logger from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器 @@ -19,9 +19,10 @@ from ..utils.chat_message_builder import ( build_readable_messages, ) # 导入 build_readable_messages from ..utils.utils import translate_timestamp_to_human_readable -from .memory_config import MemoryConfig from rich.traceback import install +from ...config.config import global_config + install(extra_lines=3) @@ -195,18 +196,16 @@ class Hippocampus: self.llm_summary = None self.entorhinal_cortex = None self.parahippocampal_gyrus = None - self.config = None - def initialize(self, global_config): - # 使用导入的 MemoryConfig dataclass 和其 from_global_config 方法 - self.config = MemoryConfig.from_global_config(global_config) + def initialize(self): # 初始化子组件 self.entorhinal_cortex = EntorhinalCortex(self) self.parahippocampal_gyrus = ParahippocampalGyrus(self) # 从数据库加载记忆图 self.entorhinal_cortex.sync_memory_from_db() - self.llm_topic_judge = LLMRequest(self.config.llm_topic_judge, request_type="memory") - self.llm_summary = LLMRequest(self.config.llm_summary, request_type="memory") + # TODO: API-Adapter修改标记 + self.llm_topic_judge = LLMRequest(global_config.model.topic_judge, request_type="memory") + self.llm_summary = LLMRequest(global_config.model.summary, request_type="memory") def get_all_node_names(self) -> list: """获取记忆图中所有节点的名字列表""" @@ -792,7 +791,6 @@ class EntorhinalCortex: def __init__(self, hippocampus: Hippocampus): self.hippocampus = hippocampus self.memory_graph = hippocampus.memory_graph - self.config = hippocampus.config def get_memory_sample(self): """从数据库获取记忆样本""" @@ -801,13 +799,13 @@ class EntorhinalCortex: # 创建双峰分布的记忆调度器 sample_scheduler = MemoryBuildScheduler( - n_hours1=self.config.memory_build_distribution[0], - std_hours1=self.config.memory_build_distribution[1], - weight1=self.config.memory_build_distribution[2], - n_hours2=self.config.memory_build_distribution[3], - std_hours2=self.config.memory_build_distribution[4], - weight2=self.config.memory_build_distribution[5], - total_samples=self.config.build_memory_sample_num, + n_hours1=global_config.memory.memory_build_distribution[0], + std_hours1=global_config.memory.memory_build_distribution[1], + weight1=global_config.memory.memory_build_distribution[2], + n_hours2=global_config.memory.memory_build_distribution[3], + std_hours2=global_config.memory.memory_build_distribution[4], + weight2=global_config.memory.memory_build_distribution[5], + total_samples=global_config.memory.memory_build_sample_num, ) timestamps = sample_scheduler.get_timestamp_array() @@ -818,7 +816,7 @@ class EntorhinalCortex: for timestamp in timestamps: # 调用修改后的 random_get_msg_snippet messages = self.random_get_msg_snippet( - timestamp, self.config.build_memory_sample_length, max_memorized_time_per_msg + timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg ) if messages: time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600 @@ -1099,7 +1097,6 @@ class ParahippocampalGyrus: def __init__(self, hippocampus: Hippocampus): self.hippocampus = hippocampus self.memory_graph = hippocampus.memory_graph - self.config = hippocampus.config async def memory_compress(self, messages: list, compress_rate=0.1): """压缩和总结消息内容,生成记忆主题和摘要。 @@ -1159,7 +1156,7 @@ class ParahippocampalGyrus: # 3. 过滤掉包含禁用关键词的topic filtered_topics = [ - topic for topic in topics if not any(keyword in topic for keyword in self.config.memory_ban_words) + topic for topic in topics if not any(keyword in topic for keyword in global_config.memory.memory_ban_words) ] logger.debug(f"过滤后话题: {filtered_topics}") @@ -1222,7 +1219,7 @@ class ParahippocampalGyrus: bar = "█" * filled_length + "-" * (bar_length - filled_length) logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})") - compress_rate = self.config.memory_compress_rate + compress_rate = global_config.memory.memory_compress_rate try: compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate) except Exception as e: @@ -1322,7 +1319,7 @@ class ParahippocampalGyrus: edge_data = self.memory_graph.G[source][target] last_modified = edge_data.get("last_modified") - if current_time - last_modified > 3600 * self.config.memory_forget_time: + if current_time - last_modified > 3600 * global_config.memory.memory_forget_time: current_strength = edge_data.get("strength", 1) new_strength = current_strength - 1 @@ -1430,8 +1427,8 @@ class ParahippocampalGyrus: async def operation_consolidate_memory(self): """整合记忆:合并节点内相似的记忆项""" start_time = time.time() - percentage = self.config.consolidate_memory_percentage - similarity_threshold = self.config.consolidation_similarity_threshold + percentage = global_config.memory.consolidate_memory_percentage + similarity_threshold = global_config.memory.consolidation_similarity_threshold logger.info(f"[整合] 开始检查记忆节点... 检查比例: {percentage:.2%}, 合并阈值: {similarity_threshold}") # 获取所有至少有2条记忆项的节点 @@ -1544,7 +1541,6 @@ class ParahippocampalGyrus: class HippocampusManager: _instance = None _hippocampus = None - _global_config = None _initialized = False @classmethod @@ -1559,19 +1555,15 @@ class HippocampusManager: raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") return cls._hippocampus - def initialize(self, global_config): + def initialize(self): """初始化海马体实例""" if self._initialized: return self._hippocampus - self._global_config = global_config self._hippocampus = Hippocampus() - self._hippocampus.initialize(global_config) + self._hippocampus.initialize() self._initialized = True - # 输出记忆系统参数信息 - config = self._hippocampus.config - # 输出记忆图统计信息 memory_graph = self._hippocampus.memory_graph.G node_count = len(memory_graph.nodes()) @@ -1579,9 +1571,9 @@ class HippocampusManager: logger.success(f"""-------------------------------- 记忆系统参数配置: - 构建间隔: {global_config.build_memory_interval}秒|样本数: {config.build_memory_sample_num},长度: {config.build_memory_sample_length}|压缩率: {config.memory_compress_rate} - 记忆构建分布: {config.memory_build_distribution} - 遗忘间隔: {global_config.forget_memory_interval}秒|遗忘比例: {global_config.memory_forget_percentage}|遗忘: {config.memory_forget_time}小时之后 + 构建间隔: {global_config.memory.memory_build_interval}秒|样本数: {global_config.memory.memory_build_sample_num},长度: {global_config.memory.memory_build_sample_length}|压缩率: {global_config.memory.memory_compress_rate} + 记忆构建分布: {global_config.memory.memory_build_distribution} + 遗忘间隔: {global_config.memory.forget_memory_interval}秒|遗忘比例: {global_config.memory.memory_forget_percentage}|遗忘: {global_config.memory.memory_forget_time}小时之后 记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count} --------------------------------""") # noqa: E501 diff --git a/src/chat/memory_system/debug_memory.py b/src/chat/memory_system/debug_memory.py index baf745409..b09e703a1 100644 --- a/src/chat/memory_system/debug_memory.py +++ b/src/chat/memory_system/debug_memory.py @@ -7,7 +7,6 @@ import os # 添加项目根目录到系统路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))) from src.chat.memory_system.Hippocampus import HippocampusManager -from src.config.config import global_config from rich.traceback import install install(extra_lines=3) @@ -19,7 +18,7 @@ async def test_memory_system(): # 初始化记忆系统 print("开始初始化记忆系统...") hippocampus_manager = HippocampusManager.get_instance() - hippocampus_manager.initialize(global_config=global_config) + hippocampus_manager.initialize() print("记忆系统初始化完成") # 测试记忆构建 diff --git a/src/chat/memory_system/manually_alter_memory.py b/src/chat/memory_system/manually_alter_memory.py index ce5abbba7..9bbf59f5b 100644 --- a/src/chat/memory_system/manually_alter_memory.py +++ b/src/chat/memory_system/manually_alter_memory.py @@ -34,7 +34,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) sys.path.append(root_path) from src.common.logger import get_module_logger # noqa E402 -from src.common.database import db # noqa E402 +from common.database.database import db # noqa E402 logger = get_module_logger("mem_alter") console = Console() diff --git a/src/chat/memory_system/memory_config.py b/src/chat/memory_system/memory_config.py deleted file mode 100644 index b82e54ec1..000000000 --- a/src/chat/memory_system/memory_config.py +++ /dev/null @@ -1,48 +0,0 @@ -from dataclasses import dataclass -from typing import List - - -@dataclass -class MemoryConfig: - """记忆系统配置类""" - - # 记忆构建相关配置 - memory_build_distribution: List[float] # 记忆构建的时间分布参数 - build_memory_sample_num: int # 每次构建记忆的样本数量 - build_memory_sample_length: int # 每个样本的消息长度 - memory_compress_rate: float # 记忆压缩率 - - # 记忆遗忘相关配置 - memory_forget_time: int # 记忆遗忘时间(小时) - - # 记忆过滤相关配置 - memory_ban_words: List[str] # 记忆过滤词列表 - - # 新增:记忆整合相关配置 - consolidation_similarity_threshold: float # 相似度阈值 - consolidate_memory_percentage: float # 检查节点比例 - consolidate_memory_interval: int # 记忆整合间隔 - - llm_topic_judge: str # 话题判断模型 - llm_summary: str # 话题总结模型 - - @classmethod - def from_global_config(cls, global_config): - """从全局配置创建记忆系统配置""" - # 使用 getattr 提供默认值,防止全局配置缺少这些项 - return cls( - memory_build_distribution=getattr( - global_config, "memory_build_distribution", (24, 12, 0.5, 168, 72, 0.5) - ), # 添加默认值 - build_memory_sample_num=getattr(global_config, "build_memory_sample_num", 5), - build_memory_sample_length=getattr(global_config, "build_memory_sample_length", 30), - memory_compress_rate=getattr(global_config, "memory_compress_rate", 0.1), - memory_forget_time=getattr(global_config, "memory_forget_time", 24 * 7), - memory_ban_words=getattr(global_config, "memory_ban_words", []), - # 新增加载整合配置,并提供默认值 - consolidation_similarity_threshold=getattr(global_config, "consolidation_similarity_threshold", 0.7), - consolidate_memory_percentage=getattr(global_config, "consolidate_memory_percentage", 0.01), - consolidate_memory_interval=getattr(global_config, "consolidate_memory_interval", 1000), - llm_topic_judge=getattr(global_config, "llm_topic_judge", "default_judge_model"), # 添加默认模型名 - llm_summary=getattr(global_config, "llm_summary", "default_summary_model"), # 添加默认模型名 - ) diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 3c9e4420c..0e35f6f6e 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -41,7 +41,7 @@ class ChatBot: chat_id = str(message.chat_stream.stream_id) private_name = str(message.message_info.user_info.user_nickname) - if global_config.enable_pfc_chatting: + if global_config.experimental.enable_pfc_chatting: await self.pfc_manager.get_or_create_conversation(chat_id, private_name) except Exception as e: @@ -78,19 +78,19 @@ class ChatBot: userinfo = message.message_info.user_info # 用户黑名单拦截 - if userinfo.user_id in global_config.ban_user_id: + if userinfo.user_id in global_config.chat_target.ban_user_id: logger.debug(f"用户{userinfo.user_id}被禁止回复") return if groupinfo is None: logger.trace("检测到私聊消息,检查") # 好友黑名单拦截 - if userinfo.user_id not in global_config.talk_allowed_private: + if userinfo.user_id not in global_config.experimental.talk_allowed_private: logger.debug(f"用户{userinfo.user_id}没有私聊权限") return # 群聊黑名单拦截 - if groupinfo is not None and groupinfo.group_id not in global_config.talk_allowed_groups: + if groupinfo is not None and groupinfo.group_id not in global_config.chat_target.talk_allowed_groups: logger.trace(f"群{groupinfo.group_id}被禁止回复") return @@ -112,7 +112,7 @@ class ChatBot: if groupinfo is None: logger.trace("检测到私聊消息") # 是否在配置信息中开启私聊模式 - if global_config.enable_friend_chat: + if global_config.experimental.enable_friend_chat: logger.trace("私聊模式已启用") # 是否进入PFC if global_config.enable_pfc_chatting: diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 53ebd5026..723d6da47 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -5,7 +5,8 @@ import copy from typing import Dict, Optional -from ...common.database import db +from ...common.database.database import db +from ...common.database.database_model import ChatStreams # 新增导入 from maim_message import GroupInfo, UserInfo from src.common.logger_manager import get_logger @@ -82,7 +83,13 @@ class ChatManager: def __init__(self): if not self._initialized: self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream - self._ensure_collection() + try: + db.connect(reuse_if_open=True) + # 确保 ChatStreams 表存在 + db.create_tables([ChatStreams], safe=True) + except Exception as e: + logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}") + self._initialized = True # 在事件循环中启动初始化 # asyncio.create_task(self._initialize()) @@ -107,15 +114,6 @@ class ChatManager: except Exception as e: logger.error(f"聊天流自动保存失败: {str(e)}") - @staticmethod - def _ensure_collection(): - """确保数据库集合存在并创建索引""" - if "chat_streams" not in db.list_collection_names(): - db.create_collection("chat_streams") - # 创建索引 - db.chat_streams.create_index([("stream_id", 1)], unique=True) - db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]) - @staticmethod def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str: """生成聊天流唯一ID""" @@ -151,16 +149,43 @@ class ChatManager: stream = self.streams[stream_id] # 更新用户信息和群组信息 stream.update_active_time() - stream = copy.deepcopy(stream) + stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存 stream.user_info = user_info if group_info: stream.group_info = group_info return stream # 检查数据库中是否存在 - data = db.chat_streams.find_one({"stream_id": stream_id}) - if data: - stream = ChatStream.from_dict(data) + def _db_find_stream_sync(s_id: str): + return ChatStreams.get_or_none(ChatStreams.stream_id == s_id) + + model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id) + + if model_instance: + # 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式 + user_info_data = { + "platform": model_instance.user_platform, + "user_id": model_instance.user_id, + "user_nickname": model_instance.user_nickname, + "user_cardname": model_instance.user_cardname or "", + } + group_info_data = None + if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息 + group_info_data = { + "platform": model_instance.group_platform, + "group_id": model_instance.group_id, + "group_name": model_instance.group_name, + } + + data_for_from_dict = { + "stream_id": model_instance.stream_id, + "platform": model_instance.platform, + "user_info": user_info_data, + "group_info": group_info_data, + "create_time": model_instance.create_time, + "last_active_time": model_instance.last_active_time, + } + stream = ChatStream.from_dict(data_for_from_dict) # 更新用户信息和群组信息 stream.user_info = user_info if group_info: @@ -175,7 +200,7 @@ class ChatManager: group_info=group_info, ) except Exception as e: - logger.error(f"创建聊天流失败: {e}") + logger.error(f"获取或创建聊天流失败: {e}", exc_info=True) raise e # 保存到内存和数据库 @@ -205,15 +230,38 @@ class ChatManager: elif stream.user_info and stream.user_info.user_nickname: return f"{stream.user_info.user_nickname}的私聊" else: - # 如果没有群名或用户昵称,返回 None 或其他默认值 return None @staticmethod async def _save_stream(stream: ChatStream): """保存聊天流到数据库""" if not stream.saved: - db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True) - stream.saved = True + stream_data_dict = stream.to_dict() + + def _db_save_stream_sync(s_data_dict: dict): + user_info_d = s_data_dict.get("user_info") + group_info_d = s_data_dict.get("group_info") + + fields_to_save = { + "platform": s_data_dict["platform"], + "create_time": s_data_dict["create_time"], + "last_active_time": s_data_dict["last_active_time"], + "user_platform": user_info_d["platform"] if user_info_d else "", + "user_id": user_info_d["user_id"] if user_info_d else "", + "user_nickname": user_info_d["user_nickname"] if user_info_d else "", + "user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None, + "group_platform": group_info_d["platform"] if group_info_d else "", + "group_id": group_info_d["group_id"] if group_info_d else "", + "group_name": group_info_d["group_name"] if group_info_d else "", + } + + ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute() + + try: + await asyncio.to_thread(_db_save_stream_sync, stream_data_dict) + stream.saved = True + except Exception as e: + logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True) async def _save_all_streams(self): """保存所有聊天流""" @@ -222,10 +270,44 @@ class ChatManager: async def load_all_streams(self): """从数据库加载所有聊天流""" - all_streams = db.chat_streams.find({}) - for data in all_streams: - stream = ChatStream.from_dict(data) - self.streams[stream.stream_id] = stream + + def _db_load_all_streams_sync(): + loaded_streams_data = [] + for model_instance in ChatStreams.select(): + user_info_data = { + "platform": model_instance.user_platform, + "user_id": model_instance.user_id, + "user_nickname": model_instance.user_nickname, + "user_cardname": model_instance.user_cardname or "", + } + group_info_data = None + if model_instance.group_id: + group_info_data = { + "platform": model_instance.group_platform, + "group_id": model_instance.group_id, + "group_name": model_instance.group_name, + } + + data_for_from_dict = { + "stream_id": model_instance.stream_id, + "platform": model_instance.platform, + "user_info": user_info_data, + "group_info": group_info_data, + "create_time": model_instance.create_time, + "last_active_time": model_instance.last_active_time, + } + loaded_streams_data.append(data_for_from_dict) + return loaded_streams_data + + try: + all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync) + self.streams.clear() + for data in all_streams_data_list: + stream = ChatStream.from_dict(data) + stream.saved = True + self.streams[stream.stream_id] = stream + except Exception as e: + logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True) # 创建全局单例 diff --git a/src/chat/message_receive/message_buffer.py b/src/chat/message_receive/message_buffer.py index f3cf63d0a..2df256ce5 100644 --- a/src/chat/message_receive/message_buffer.py +++ b/src/chat/message_receive/message_buffer.py @@ -38,7 +38,7 @@ class MessageBuffer: async def start_caching_messages(self, message: MessageRecv): """添加消息,启动缓冲""" - if not global_config.message_buffer: + if not global_config.chat.message_buffer: person_id = person_info_manager.get_person_id( message.message_info.user_info.platform, message.message_info.user_info.user_id ) @@ -107,7 +107,7 @@ class MessageBuffer: async def query_buffer_result(self, message: MessageRecv) -> bool: """查询缓冲结果,并清理""" - if not global_config.message_buffer: + if not global_config.chat.message_buffer: return True person_id_ = self.get_person_id_( message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info diff --git a/src/chat/message_receive/message_sender.py b/src/chat/message_receive/message_sender.py index 5db34fdea..cf5877989 100644 --- a/src/chat/message_receive/message_sender.py +++ b/src/chat/message_receive/message_sender.py @@ -279,7 +279,7 @@ class MessageManager: ) # 检查是否超时 - if thinking_time > global_config.thinking_timeout: + if thinking_time > global_config.normal_chat.thinking_timeout: logger.warning( f"[{chat_id}] 消息思考超时 ({thinking_time:.1f}秒),移除消息 {message_earliest.message_info.message_id}" ) diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index cae029a11..d0041cd51 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -1,9 +1,10 @@ import re from typing import Union -from ...common.database import db +# from ...common.database.database import db # db is now Peewee's SqliteDatabase instance from .message import MessageSending, MessageRecv from .chat_stream import ChatStream +from ...common.database.database_model import Messages, RecalledMessages # Import Peewee models from src.common.logger import get_module_logger logger = get_module_logger("message_storage") @@ -29,42 +30,66 @@ class MessageStorage: else: filtered_detailed_plain_text = "" - message_data = { - "message_id": message.message_info.message_id, - "time": message.message_info.time, - "chat_id": chat_stream.stream_id, - "chat_info": chat_stream.to_dict(), - "user_info": message.message_info.user_info.to_dict(), - # 使用过滤后的文本 - "processed_plain_text": filtered_processed_plain_text, - "detailed_plain_text": filtered_detailed_plain_text, - "memorized_times": message.memorized_times, - } - db.messages.insert_one(message_data) + chat_info_dict = chat_stream.to_dict() + user_info_dict = message.message_info.user_info.to_dict() + + # message_id 现在是 TextField,直接使用字符串值 + msg_id = message.message_info.message_id + + # 安全地获取 group_info, 如果为 None 则视为空字典 + group_info_from_chat = chat_info_dict.get("group_info") or {} + # 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一) + user_info_from_chat = chat_info_dict.get("user_info") or {} + + Messages.create( + message_id=msg_id, + time=float(message.message_info.time), + chat_id=chat_stream.stream_id, + # Flattened chat_info + chat_info_stream_id=chat_info_dict.get("stream_id"), + chat_info_platform=chat_info_dict.get("platform"), + chat_info_user_platform=user_info_from_chat.get("platform"), + chat_info_user_id=user_info_from_chat.get("user_id"), + chat_info_user_nickname=user_info_from_chat.get("user_nickname"), + chat_info_user_cardname=user_info_from_chat.get("user_cardname"), + chat_info_group_platform=group_info_from_chat.get("platform"), + chat_info_group_id=group_info_from_chat.get("group_id"), + chat_info_group_name=group_info_from_chat.get("group_name"), + chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)), + chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)), + # Flattened user_info (message sender) + user_platform=user_info_dict.get("platform"), + user_id=user_info_dict.get("user_id"), + user_nickname=user_info_dict.get("user_nickname"), + user_cardname=user_info_dict.get("user_cardname"), + # Text content + processed_plain_text=filtered_processed_plain_text, + detailed_plain_text=filtered_detailed_plain_text, + memorized_times=message.memorized_times, + ) except Exception: logger.exception("存储消息失败") @staticmethod async def store_recalled_message(message_id: str, time: str, chat_stream: ChatStream) -> None: """存储撤回消息到数据库""" - if "recalled_messages" not in db.list_collection_names(): - db.create_collection("recalled_messages") - else: - try: - message_data = { - "message_id": message_id, - "time": time, - "stream_id": chat_stream.stream_id, - } - db.recalled_messages.insert_one(message_data) - except Exception: - logger.exception("存储撤回消息失败") + # Table creation is handled by initialize_database in database_model.py + try: + RecalledMessages.create( + message_id=message_id, + time=float(time), # Assuming time is a string representing a float timestamp + stream_id=chat_stream.stream_id, + ) + except Exception: + logger.exception("存储撤回消息失败") @staticmethod async def remove_recalled_message(time: str) -> None: """删除撤回消息""" try: - db.recalled_messages.delete_many({"time": {"$lt": time - 300}}) + # Assuming input 'time' is a string timestamp that can be converted to float + current_time_float = float(time) + RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute() except Exception: logger.exception("删除撤回消息失败") diff --git a/src/chat/models/utils_model.py b/src/chat/models/utils_model.py index e662a8e33..f6528856d 100644 --- a/src/chat/models/utils_model.py +++ b/src/chat/models/utils_model.py @@ -12,7 +12,8 @@ import base64 from PIL import Image import io import os -from ...common.database import db +from src.common.database.database import db # 确保 db 被导入用于 create_tables +from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型 from ...config.config import global_config from rich.traceback import install @@ -85,8 +86,6 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any f"data:image/{image_format.lower() if image_format else 'jpeg'};base64," f"{image_base64[:10]}...{image_base64[-10:]}" ) - # if isinstance(content, str) and len(content) > 100: - # payload["messages"][0]["content"] = content[:100] return payload @@ -111,8 +110,8 @@ class LLMRequest: def __init__(self, model: dict, **kwargs): # 将大写的配置键转换为小写并从config中获取实际值 try: - self.api_key = os.environ[model["key"]] - self.base_url = os.environ[model["base_url"]] + self.api_key = os.environ[f"{model['provider']}_KEY"] + self.base_url = os.environ[f"{model['provider']}_BASE_URL"] except AttributeError as e: logger.error(f"原始 model dict 信息:{model}") logger.error(f"配置错误:找不到对应的配置项 - {str(e)}") @@ -134,13 +133,11 @@ class LLMRequest: def _init_database(): """初始化数据库集合""" try: - # 创建llm_usage集合的索引 - db.llm_usage.create_index([("timestamp", 1)]) - db.llm_usage.create_index([("model_name", 1)]) - db.llm_usage.create_index([("user_id", 1)]) - db.llm_usage.create_index([("request_type", 1)]) + # 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误 + db.create_tables([LLMUsage], safe=True) + logger.debug("LLMUsage 表已初始化/确保存在。") except Exception as e: - logger.error(f"创建数据库索引失败: {str(e)}") + logger.error(f"创建 LLMUsage 表失败: {str(e)}") def _record_usage( self, @@ -165,19 +162,19 @@ class LLMRequest: request_type = self.request_type try: - usage_data = { - "model_name": self.model_name, - "user_id": user_id, - "request_type": request_type, - "endpoint": endpoint, - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - "cost": self._calculate_cost(prompt_tokens, completion_tokens), - "status": "success", - "timestamp": datetime.now(), - } - db.llm_usage.insert_one(usage_data) + # 使用 Peewee 模型创建记录 + LLMUsage.create( + model_name=self.model_name, + user_id=user_id, + request_type=request_type, + endpoint=endpoint, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + cost=self._calculate_cost(prompt_tokens, completion_tokens), + status="success", + timestamp=datetime.now(), # Peewee 会处理 DateTimeField + ) logger.trace( f"Token使用情况 - 模型: {self.model_name}, " f"用户: {user_id}, 类型: {request_type}, " @@ -500,11 +497,11 @@ class LLMRequest: logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}") # 对全局配置进行更新 - if global_config.llm_normal.get("name") == old_model_name: - global_config.llm_normal["name"] = self.model_name + if global_config.model.normal.get("name") == old_model_name: + global_config.model.normal["name"] = self.model_name logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}") - if global_config.llm_reasoning.get("name") == old_model_name: - global_config.llm_reasoning["name"] = self.model_name + if global_config.model.reasoning.get("name") == old_model_name: + global_config.model.reasoning["name"] = self.model_name logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}") if payload and "model" in payload: @@ -636,7 +633,7 @@ class LLMRequest: **params_copy, } if "max_tokens" not in payload and "max_completion_tokens" not in payload: - payload["max_tokens"] = global_config.model_max_output_length + payload["max_tokens"] = global_config.model.model_max_output_length # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: payload["max_completion_tokens"] = payload.pop("max_tokens") diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py index 9dc2454ff..96cc2b8cb 100644 --- a/src/chat/normal_chat/normal_chat.py +++ b/src/chat/normal_chat/normal_chat.py @@ -73,8 +73,8 @@ class NormalChat: messageinfo = message.message_info bot_user_info = UserInfo( - user_id=global_config.BOT_QQ, - user_nickname=global_config.BOT_NICKNAME, + user_id=global_config.bot.qq_account, + user_nickname=global_config.bot.nickname, platform=messageinfo.platform, ) @@ -121,8 +121,8 @@ class NormalChat: message_id=thinking_id, chat_stream=self.chat_stream, # 使用 self.chat_stream bot_user_info=UserInfo( - user_id=global_config.BOT_QQ, - user_nickname=global_config.BOT_NICKNAME, + user_id=global_config.bot.qq_account, + user_nickname=global_config.bot.nickname, platform=message.message_info.platform, ), sender_info=message.message_info.user_info, @@ -147,7 +147,7 @@ class NormalChat: # 改为实例方法 async def _handle_emoji(self, message: MessageRecv, response: str): """处理表情包""" - if random() < global_config.emoji_chance: + if random() < global_config.normal_chat.emoji_chance: emoji_raw = await emoji_manager.get_emoji_for_text(response) if emoji_raw: emoji_path, description = emoji_raw @@ -160,8 +160,8 @@ class NormalChat: message_id="mt" + str(thinking_time_point), chat_stream=self.chat_stream, # 使用 self.chat_stream bot_user_info=UserInfo( - user_id=global_config.BOT_QQ, - user_nickname=global_config.BOT_NICKNAME, + user_id=global_config.bot.qq_account, + user_nickname=global_config.bot.nickname, platform=message.message_info.platform, ), sender_info=message.message_info.user_info, @@ -186,7 +186,7 @@ class NormalChat: label=emotion, stance=stance, # 使用 self.chat_stream ) - self.mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor) + self.mood_manager.update_mood_from_emotion(emotion, global_config.mood.mood_intensity_factor) async def _reply_interested_message(self) -> None: """ @@ -430,7 +430,7 @@ class NormalChat: def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool: """检查消息中是否包含过滤词""" stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id - for word in global_config.ban_words: + for word in global_config.chat.ban_words: if word in text: logger.info( f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]" @@ -445,7 +445,7 @@ class NormalChat: def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool: """检查消息是否匹配过滤正则表达式""" stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id - for pattern in global_config.ban_msgs_regex: + for pattern in global_config.chat.ban_msgs_regex: if pattern.search(text): logger.info( f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]" diff --git a/src/chat/normal_chat/normal_chat_generator.py b/src/chat/normal_chat/normal_chat_generator.py index aec65ed1d..631f7baa5 100644 --- a/src/chat/normal_chat/normal_chat_generator.py +++ b/src/chat/normal_chat/normal_chat_generator.py @@ -15,21 +15,22 @@ logger = get_logger("llm") class NormalChatGenerator: def __init__(self): + # TODO: API-Adapter修改标记 self.model_reasoning = LLMRequest( - model=global_config.llm_reasoning, + model=global_config.model.reasoning, temperature=0.7, max_tokens=3000, request_type="response_reasoning", ) self.model_normal = LLMRequest( - model=global_config.llm_normal, - temperature=global_config.llm_normal["temp"], + model=global_config.model.normal, + temperature=global_config.model.normal["temp"], max_tokens=256, request_type="response_reasoning", ) self.model_sum = LLMRequest( - model=global_config.llm_summary, temperature=0.7, max_tokens=3000, request_type="relation" + model=global_config.model.summary, temperature=0.7, max_tokens=3000, request_type="relation" ) self.current_model_type = "r1" # 默认使用 R1 self.current_model_name = "unknown model" @@ -37,7 +38,7 @@ class NormalChatGenerator: async def generate_response(self, message: MessageThinking, thinking_id: str) -> Optional[Union[str, List[str]]]: """根据当前模型类型选择对应的生成函数""" # 从global_config中获取模型概率值并选择模型 - if random.random() < global_config.model_reasoning_probability: + if random.random() < global_config.normal_chat.reasoning_model_probability: self.current_model_type = "深深地" current_model = self.model_reasoning else: @@ -51,7 +52,7 @@ class NormalChatGenerator: model_response = await self._generate_response_with_model(message, current_model, thinking_id) if model_response: - logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response}") + logger.info(f"{global_config.bot.nickname}的回复是:{model_response}") model_response = await self._process_response(model_response) return model_response @@ -113,7 +114,7 @@ class NormalChatGenerator: - "中立":不表达明确立场或无关回应 2. 从"开心,愤怒,悲伤,惊讶,平静,害羞,恐惧,厌恶,困惑"中选出最匹配的1个情感标签 3. 按照"立场-情绪"的格式直接输出结果,例如:"反对-愤怒" - 4. 考虑回复者的人格设定为{global_config.personality_core} + 4. 考虑回复者的人格设定为{global_config.personality.personality_core} 对话示例: 被回复:「A就是笨」 diff --git a/src/chat/normal_chat/willing/mode_classical.py b/src/chat/normal_chat/willing/mode_classical.py index e96aa77a7..a9f04273a 100644 --- a/src/chat/normal_chat/willing/mode_classical.py +++ b/src/chat/normal_chat/willing/mode_classical.py @@ -1,18 +1,20 @@ import asyncio + +from src.config.config import global_config from .willing_manager import BaseWillingManager class ClassicalWillingManager(BaseWillingManager): def __init__(self): super().__init__() - self._decay_task: asyncio.Task = None + self._decay_task: asyncio.Task | None = None async def _decay_reply_willing(self): """定期衰减回复意愿""" while True: await asyncio.sleep(1) for chat_id in self.chat_reply_willing: - self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.9) + self.chat_reply_willing[chat_id] = max(0.0, self.chat_reply_willing[chat_id] * 0.9) async def async_task_starter(self): if self._decay_task is None: @@ -23,35 +25,33 @@ class ClassicalWillingManager(BaseWillingManager): chat_id = willing_info.chat_id current_willing = self.chat_reply_willing.get(chat_id, 0) - interested_rate = willing_info.interested_rate * self.global_config.response_interested_rate_amplifier + interested_rate = willing_info.interested_rate * global_config.normal_chat.response_interested_rate_amplifier if interested_rate > 0.4: current_willing += interested_rate - 0.3 - if willing_info.is_mentioned_bot and current_willing < 1.0: - current_willing += 1 - elif willing_info.is_mentioned_bot: - current_willing += 0.05 + if willing_info.is_mentioned_bot: + current_willing += 1 if current_willing < 1.0 else 0.05 is_emoji_not_reply = False if willing_info.is_emoji: - if self.global_config.emoji_response_penalty != 0: - current_willing *= self.global_config.emoji_response_penalty + if global_config.normal_chat.emoji_response_penalty != 0: + current_willing *= global_config.normal_chat.emoji_response_penalty else: is_emoji_not_reply = True self.chat_reply_willing[chat_id] = min(current_willing, 3.0) reply_probability = min( - max((current_willing - 0.5), 0.01) * self.global_config.response_willing_amplifier * 2, 1 + max((current_willing - 0.5), 0.01) * global_config.normal_chat.response_willing_amplifier * 2, 1 ) # 检查群组权限(如果是群聊) if ( willing_info.group_info - and willing_info.group_info.group_id in self.global_config.talk_frequency_down_groups + and willing_info.group_info.group_id in global_config.chat_target.talk_frequency_down_groups ): - reply_probability = reply_probability / self.global_config.down_frequency_rate + reply_probability = reply_probability / global_config.normal_chat.down_frequency_rate if is_emoji_not_reply: reply_probability = 0 @@ -61,7 +61,7 @@ class ClassicalWillingManager(BaseWillingManager): async def before_generate_reply_handle(self, message_id): chat_id = self.ongoing_messages[message_id].chat_id current_willing = self.chat_reply_willing.get(chat_id, 0) - self.chat_reply_willing[chat_id] = max(0, current_willing - 1.8) + self.chat_reply_willing[chat_id] = max(0.0, current_willing - 1.8) async def after_generate_reply_handle(self, message_id): if message_id not in self.ongoing_messages: @@ -70,7 +70,7 @@ class ClassicalWillingManager(BaseWillingManager): chat_id = self.ongoing_messages[message_id].chat_id current_willing = self.chat_reply_willing.get(chat_id, 0) if current_willing < 1: - self.chat_reply_willing[chat_id] = min(1, current_willing + 0.4) + self.chat_reply_willing[chat_id] = min(1.0, current_willing + 0.4) async def bombing_buffer_message_handle(self, message_id): return await super().bombing_buffer_message_handle(message_id) diff --git a/src/chat/normal_chat/willing/mode_mxp.py b/src/chat/normal_chat/willing/mode_mxp.py index 78120ac53..1e7d5856d 100644 --- a/src/chat/normal_chat/willing/mode_mxp.py +++ b/src/chat/normal_chat/willing/mode_mxp.py @@ -19,6 +19,7 @@ Mxp 模式:梦溪畔独家赞助 下下策是询问一个菜鸟(@梦溪畔) """ +from src.config.config import global_config from .willing_manager import BaseWillingManager from typing import Dict import asyncio @@ -50,8 +51,6 @@ class MxpWillingManager(BaseWillingManager): self.mention_willing_gain = 0.6 # 提及意愿增益 self.interest_willing_gain = 0.3 # 兴趣意愿增益 - self.emoji_response_penalty = self.global_config.emoji_response_penalty # 表情包回复惩罚 - self.down_frequency_rate = self.global_config.down_frequency_rate # 降低回复频率的群组惩罚系数 self.single_chat_gain = 0.12 # 单聊增益 self.fatigue_messages_triggered_num = self.expected_replies_per_min # 疲劳消息触发数量(int) @@ -179,10 +178,10 @@ class MxpWillingManager(BaseWillingManager): probability = self._willing_to_probability(current_willing) if w_info.is_emoji: - probability *= self.emoji_response_penalty + probability *= global_config.normal_chat.emoji_response_penalty - if w_info.group_info and w_info.group_info.group_id in self.global_config.talk_frequency_down_groups: - probability /= self.down_frequency_rate + if w_info.group_info and w_info.group_info.group_id in global_config.chat_target.talk_frequency_down_groups: + probability /= global_config.normal_chat.down_frequency_rate self.temporary_willing = current_willing diff --git a/src/chat/normal_chat/willing/willing_manager.py b/src/chat/normal_chat/willing/willing_manager.py index 37e623d11..bbc5dcc0a 100644 --- a/src/chat/normal_chat/willing/willing_manager.py +++ b/src/chat/normal_chat/willing/willing_manager.py @@ -1,6 +1,6 @@ from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger from dataclasses import dataclass -from src.config.config import global_config, BotConfig +from src.config.config import global_config from src.chat.message_receive.chat_stream import ChatStream, GroupInfo from src.chat.message_receive.message import MessageRecv from src.chat.person_info.person_info import person_info_manager, PersonInfoManager @@ -93,7 +93,6 @@ class BaseWillingManager(ABC): self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id) self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id) self.lock = asyncio.Lock() - self.global_config: BotConfig = global_config self.logger: LoguruLogger = logger def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float): @@ -173,7 +172,7 @@ def init_willing_manager() -> BaseWillingManager: Returns: 对应mode的WillingManager实例 """ - mode = global_config.willing_mode.lower() + mode = global_config.normal_chat.willing_mode.lower() return BaseWillingManager.create(mode) diff --git a/src/chat/person_info/person_info.py b/src/chat/person_info/person_info.py index c8394a195..562cdc235 100644 --- a/src/chat/person_info/person_info.py +++ b/src/chat/person_info/person_info.py @@ -1,5 +1,6 @@ from src.common.logger_manager import get_logger -from ...common.database import db +from ...common.database.database import db +from ...common.database.database_model import PersonInfo # 新增导入 import copy import hashlib from typing import Any, Callable, Dict @@ -16,7 +17,7 @@ matplotlib.use("Agg") import matplotlib.pyplot as plt from pathlib import Path import pandas as pd -import json +import json # 新增导入 import re @@ -38,47 +39,49 @@ logger = get_logger("person_info") person_info_default = { "person_id": None, - "person_name": None, + "person_name": None, # 模型中已设为 null=True,此默认值OK "name_reason": None, - "platform": None, - "user_id": None, - "nickname": None, - # "age" : 0, + "platform": "unknown", # 提供非None的默认值 + "user_id": "unknown", # 提供非None的默认值 + "nickname": "Unknown", # 提供非None的默认值 "relationship_value": 0, - # "saved" : True, - # "impression" : None, - # "gender" : Unkown, - "konw_time": 0, + "know_time": 0, # 修正拼写:konw_time -> know_time "msg_interval": 2000, - "msg_interval_list": [], - "user_cardname": None, # 添加群名片 - "user_avatar": None, # 添加头像信息(例如URL或标识符) -} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项 + "msg_interval_list": [], # 将作为 JSON 字符串存储在 Peewee 的 TextField + "user_cardname": None, # 注意:此字段不在 PersonInfo Peewee 模型中 + "user_avatar": None, # 注意:此字段不在 PersonInfo Peewee 模型中 +} class PersonInfoManager: def __init__(self): self.person_name_list = {} + # TODO: API-Adapter修改标记 self.qv_name_llm = LLMRequest( - model=global_config.llm_normal, + model=global_config.model.normal, max_tokens=256, request_type="qv_name", ) - if "person_info" not in db.list_collection_names(): - db.create_collection("person_info") - db.person_info.create_index("person_id", unique=True) + try: + db.connect(reuse_if_open=True) + db.create_tables([PersonInfo], safe=True) + except Exception as e: + logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}") # 初始化时读取所有person_name - cursor = db.person_info.find({"person_name": {"$exists": True}}, {"person_id": 1, "person_name": 1, "_id": 0}) - for doc in cursor: - if doc.get("person_name"): - self.person_name_list[doc["person_id"]] = doc["person_name"] - logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称") + try: + for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where( + PersonInfo.person_name.is_null(False) + ): + if record.person_name: + self.person_name_list[record.person_id] = record.person_name + logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)") + except Exception as e: + logger.error(f"从 Peewee 加载 person_name_list 失败: {e}") @staticmethod def get_person_id(platform: str, user_id: int): """获取唯一id""" - # 如果platform中存在-,就截取-后面的部分 if "-" in platform: platform = platform.split("-")[1] @@ -86,13 +89,17 @@ class PersonInfoManager: key = "_".join(components) return hashlib.md5(key.encode()).hexdigest() - def is_person_known(self, platform: str, user_id: int): + async def is_person_known(self, platform: str, user_id: int): """判断是否认识某人""" person_id = self.get_person_id(platform, user_id) - document = db.person_info.find_one({"person_id": person_id}) - if document: - return True - else: + + def _db_check_known_sync(p_id: str): + return PersonInfo.get_or_none(PersonInfo.person_id == p_id) is not None + + try: + return await asyncio.to_thread(_db_check_known_sync, person_id) + except Exception as e: + logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}") return False def get_person_id_by_person_name(self, person_name: str): @@ -111,73 +118,111 @@ class PersonInfoManager: return _person_info_default = copy.deepcopy(person_info_default) - _person_info_default["person_id"] = person_id + model_fields = PersonInfo._meta.fields.keys() + + final_data = {"person_id": person_id} if data: - for key in _person_info_default: - if key != "person_id" and key in data: - _person_info_default[key] = data[key] + for key, value in data.items(): + if key in model_fields: + final_data[key] = value - db.person_info.insert_one(_person_info_default) + for key, default_value in _person_info_default.items(): + if key in model_fields and key not in final_data: + final_data[key] = default_value + + if "msg_interval_list" in final_data and isinstance(final_data["msg_interval_list"], list): + final_data["msg_interval_list"] = json.dumps(final_data["msg_interval_list"]) + elif "msg_interval_list" not in final_data and "msg_interval_list" in model_fields: + final_data["msg_interval_list"] = json.dumps([]) + + def _db_create_sync(p_data: dict): + try: + PersonInfo.create(**p_data) + return True + except Exception as e: + logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (Peewee): {e}") + return False + + await asyncio.to_thread(_db_create_sync, final_data) async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None): """更新某一个字段,会补全""" - if field_name not in person_info_default.keys(): - logger.debug(f"更新'{field_name}'失败,未定义的字段") + if field_name not in PersonInfo._meta.fields: + if field_name in person_info_default: + logger.debug(f"更新'{field_name}'跳过,字段存在于默认配置但不在 PersonInfo Peewee 模型中。") + return + logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。") return - document = db.person_info.find_one({"person_id": person_id}) + def _db_update_sync(p_id: str, f_name: str, val): + record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) + if record: + if f_name == "msg_interval_list" and isinstance(val, list): + setattr(record, f_name, json.dumps(val)) + else: + setattr(record, f_name, val) + record.save() + return True, False + return False, True - if document: - db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}}) - else: - data[field_name] = value - logger.debug(f"更新时{person_id}不存在,已新建") - await self.create_person_info(person_id, data) + found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, value) + + if needs_creation: + logger.debug(f"更新时 {person_id} 不存在,将新建。") + creation_data = data if data is not None else {} + creation_data[field_name] = value + if "platform" not in creation_data or "user_id" not in creation_data: + logger.warning(f"为 {person_id} 创建记录时,platform/user_id 可能缺失。") + + await self.create_person_info(person_id, creation_data) @staticmethod async def has_one_field(person_id: str, field_name: str): """判断是否存在某一个字段""" - document = db.person_info.find_one({"person_id": person_id}, {field_name: 1}) - if document: - return True - else: + if field_name not in PersonInfo._meta.fields: + logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。") + return False + + def _db_has_field_sync(p_id: str, f_name: str): + record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) + if record: + return True + return False + + try: + return await asyncio.to_thread(_db_has_field_sync, person_id, field_name) + except Exception as e: + logger.error(f"检查字段 {field_name} for {person_id} 时出错 (Peewee): {e}") return False @staticmethod def _extract_json_from_text(text: str) -> dict: """从文本中提取JSON数据的高容错方法""" try: - # 尝试直接解析 parsed_json = json.loads(text) - # 如果解析结果是列表,尝试取第一个元素 if isinstance(parsed_json, list): - if parsed_json: # 检查列表是否为空 + if parsed_json: parsed_json = parsed_json[0] - else: # 如果列表为空,重置为 None,走后续逻辑 + else: parsed_json = None - # 确保解析结果是字典 if isinstance(parsed_json, dict): return parsed_json except json.JSONDecodeError: - # 解析失败,继续尝试其他方法 pass except Exception as e: logger.warning(f"尝试直接解析JSON时发生意外错误: {e}") - pass # 继续尝试其他方法 + pass - # 如果直接解析失败或结果不是字典 try: - # 尝试找到JSON对象格式的部分 json_pattern = r"\{[^{}]*\}" matches = re.findall(json_pattern, text) if matches: parsed_obj = json.loads(matches[0]) - if isinstance(parsed_obj, dict): # 确保是字典 + if isinstance(parsed_obj, dict): return parsed_obj - # 如果上面都失败了,尝试提取键值对 nickname_pattern = r'"nickname"[:\s]+"([^"]+)"' reason_pattern = r'"reason"[:\s]+"([^"]+)"' @@ -192,7 +237,6 @@ class PersonInfoManager: except Exception as e: logger.error(f"后备JSON提取失败: {str(e)}") - # 如果所有方法都失败了,返回默认字典 logger.warning(f"无法从文本中提取有效的JSON字典: {text}") return {"nickname": "", "reason": ""} @@ -207,9 +251,11 @@ class PersonInfoManager: old_name = await self.get_value(person_id, "person_name") old_reason = await self.get_value(person_id, "name_reason") - max_retries = 5 # 最大重试次数 + max_retries = 5 current_try = 0 - existing_names = "" + existing_names_str = "" + current_name_set = set(self.person_name_list.values()) + while current_try < max_retries: individuality = Individuality.get_instance() prompt_personality = individuality.get_prompt(x_person=2, level=1) @@ -224,45 +270,58 @@ class PersonInfoManager: qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason}," qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸" - qv_name_prompt += ( "\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的qq昵称,可以稍作修改" ) - if existing_names: - qv_name_prompt += f"\n请注意,以下名称已被使用,不要使用以下昵称:{existing_names}。\n" + + if existing_names_str: + qv_name_prompt += f"\n请注意,以下名称已被你尝试过或已知存在,请避免:{existing_names_str}。\n" + + if len(current_name_set) < 50 and current_name_set: + qv_name_prompt += f"已知的其他昵称有: {', '.join(list(current_name_set)[:10])}等。\n" + qv_name_prompt += "请用json给出你的想法,并给出理由,示例如下:" qv_name_prompt += """{ "nickname": "昵称", "reason": "理由" }""" - # logger.debug(f"取名提示词:{qv_name_prompt}") response = await self.qv_name_llm.generate_response(qv_name_prompt) logger.trace(f"取名提示词:{qv_name_prompt}\n取名回复:{response}") result = self._extract_json_from_text(response[0]) - if not result["nickname"]: - logger.error("生成的昵称为空,重试中...") + if not result or not result.get("nickname"): + logger.error("生成的昵称为空或结果格式不正确,重试中...") current_try += 1 continue - # 检查生成的昵称是否已存在 - if result["nickname"] not in self.person_name_list.values(): - # 更新数据库和内存中的列表 - await self.update_one_field(person_id, "person_name", result["nickname"]) - # await self.update_one_field(person_id, "nickname", user_nickname) - # await self.update_one_field(person_id, "avatar", user_avatar) - await self.update_one_field(person_id, "name_reason", result["reason"]) + generated_nickname = result["nickname"] - self.person_name_list[person_id] = result["nickname"] - # logger.debug(f"用户 {person_id} 的名称已更新为 {result['nickname']},原因:{result['reason']}") + is_duplicate = False + if generated_nickname in current_name_set: + is_duplicate = True + else: + + def _db_check_name_exists_sync(name_to_check): + return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists() + + if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname): + is_duplicate = True + current_name_set.add(generated_nickname) + + if not is_duplicate: + await self.update_one_field(person_id, "person_name", generated_nickname) + await self.update_one_field(person_id, "name_reason", result.get("reason", "未提供理由")) + + self.person_name_list[person_id] = generated_nickname return result else: - existing_names += f"{result['nickname']}、" + if existing_names_str: + existing_names_str += "、" + existing_names_str += generated_nickname + logger.debug(f"生成的昵称 {generated_nickname} 已存在,重试中...") + current_try += 1 - logger.debug(f"生成的昵称 {result['nickname']} 已存在,重试中...") - current_try += 1 - - logger.error(f"在{max_retries}次尝试后仍未能生成唯一昵称") + logger.error(f"在{max_retries}次尝试后仍未能生成唯一昵称 for {person_id}") return None @staticmethod @@ -272,30 +331,56 @@ class PersonInfoManager: logger.debug("删除失败:person_id 不能为空") return - result = db.person_info.delete_one({"person_id": person_id}) - if result.deleted_count > 0: - logger.debug(f"删除成功:person_id={person_id}") + def _db_delete_sync(p_id: str): + try: + query = PersonInfo.delete().where(PersonInfo.person_id == p_id) + deleted_count = query.execute() + return deleted_count + except Exception as e: + logger.error(f"删除 PersonInfo {p_id} 失败 (Peewee): {e}") + return 0 + + deleted_count = await asyncio.to_thread(_db_delete_sync, person_id) + + if deleted_count > 0: + logger.debug(f"删除成功:person_id={person_id} (Peewee)") else: - logger.debug(f"删除失败:未找到 person_id={person_id}") + logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)") @staticmethod async def get_value(person_id: str, field_name: str): """获取指定person_id文档的字段值,若不存在该字段,则返回该字段的全局默认值""" if not person_id: logger.debug("get_value获取失败:person_id不能为空") + return person_info_default.get(field_name) + + if field_name not in PersonInfo._meta.fields: + if field_name in person_info_default: + logger.trace(f"字段'{field_name}'不在Peewee模型中,但存在于默认配置中。返回配置默认值。") + return copy.deepcopy(person_info_default[field_name]) + logger.debug(f"get_value获取失败:字段'{field_name}'未在Peewee模型和默认配置中定义。") return None - if field_name not in person_info_default: - logger.debug(f"get_value获取失败:字段'{field_name}'未定义") + def _db_get_value_sync(p_id: str, f_name: str): + record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) + if record: + val = getattr(record, f_name) + if f_name == "msg_interval_list" and isinstance(val, str): + try: + return json.loads(val) + except json.JSONDecodeError: + logger.warning(f"无法解析 {p_id} 的 msg_interval_list JSON: {val}") + return copy.deepcopy(person_info_default.get(f_name, [])) + return val return None - document = db.person_info.find_one({"person_id": person_id}, {field_name: 1}) + value = await asyncio.to_thread(_db_get_value_sync, person_id, field_name) - if document and field_name in document: - return document[field_name] + if value is not None: + return value else: - default_value = copy.deepcopy(person_info_default[field_name]) - logger.trace(f"获取{person_id}的{field_name}失败,已返回默认值{default_value}") + default_value = copy.deepcopy(person_info_default.get(field_name)) + logger.trace(f"获取{person_id}的{field_name}失败或值为None,已返回默认值{default_value} (Peewee)") return default_value @staticmethod @@ -305,93 +390,84 @@ class PersonInfoManager: logger.debug("get_values获取失败:person_id不能为空") return {} - # 检查所有字段是否有效 - for field in field_names: - if field not in person_info_default: - logger.debug(f"get_values获取失败:字段'{field}'未定义") - return {} - - # 构建查询投影(所有字段都有效才会执行到这里) - projection = {field: 1 for field in field_names} - - document = db.person_info.find_one({"person_id": person_id}, projection) - result = {} - for field in field_names: - result[field] = copy.deepcopy( - document.get(field, person_info_default[field]) if document else person_info_default[field] - ) + + def _db_get_record_sync(p_id: str): + return PersonInfo.get_or_none(PersonInfo.person_id == p_id) + + record = await asyncio.to_thread(_db_get_record_sync, person_id) + + for field_name in field_names: + if field_name not in PersonInfo._meta.fields: + if field_name in person_info_default: + result[field_name] = copy.deepcopy(person_info_default[field_name]) + logger.trace(f"字段'{field_name}'不在Peewee模型中,使用默认配置值。") + else: + logger.debug(f"get_values查询失败:字段'{field_name}'未在Peewee模型和默认配置中定义。") + result[field_name] = None + continue + + if record: + value = getattr(record, field_name) + if field_name == "msg_interval_list" and isinstance(value, str): + try: + result[field_name] = json.loads(value) + except json.JSONDecodeError: + logger.warning(f"无法解析 {person_id} 的 msg_interval_list JSON: {value}") + result[field_name] = copy.deepcopy(person_info_default.get(field_name, [])) + elif value is not None: + result[field_name] = value + else: + result[field_name] = copy.deepcopy(person_info_default.get(field_name)) + else: + result[field_name] = copy.deepcopy(person_info_default.get(field_name)) return result @staticmethod async def del_all_undefined_field(): - """删除所有项里的未定义字段""" - # 获取所有已定义的字段名 - defined_fields = set(person_info_default.keys()) - - try: - # 遍历集合中的所有文档 - for document in db.person_info.find({}): - # 找出文档中未定义的字段 - undefined_fields = set(document.keys()) - defined_fields - {"_id"} - - if undefined_fields: - # 构建更新操作,使用$unset删除未定义字段 - update_result = db.person_info.update_one( - {"_id": document["_id"]}, {"$unset": {field: 1 for field in undefined_fields}} - ) - - if update_result.modified_count > 0: - logger.debug(f"已清理文档 {document['_id']} 的未定义字段: {undefined_fields}") - - return - - except Exception as e: - logger.error(f"清理未定义字段时出错: {e}") - return + """删除所有项里的未定义字段 - 对于Peewee (SQL),此操作通常不适用,因为模式是固定的。""" + logger.info( + "del_all_undefined_field: 对于使用Peewee的SQL数据库,此操作通常不适用或不需要,因为表结构是预定义的。" + ) + return @staticmethod async def get_specific_value_list( field_name: str, - way: Callable[[Any], bool], # 接受任意类型值 + way: Callable[[Any], bool], ) -> Dict[str, Any]: """ 获取满足条件的字段值字典 - - Args: - field_name: 目标字段名 - way: 判断函数 (value: Any) -> bool - - Returns: - {person_id: value} | {} - - Example: - # 查找所有nickname包含"admin"的用户 - result = manager.specific_value_list( - "nickname", - lambda x: "admin" in x.lower() - ) """ - if field_name not in person_info_default: - logger.error(f"字段检查失败:'{field_name}'未定义") + if field_name not in PersonInfo._meta.fields: + logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo Peewee 模型中定义") return {} + def _db_get_specific_sync(f_name: str): + found_results = {} + try: + for record in PersonInfo.select(PersonInfo.person_id, getattr(PersonInfo, f_name)): + value = getattr(record, f_name) + if f_name == "msg_interval_list" and isinstance(value, str): + try: + processed_value = json.loads(value) + except json.JSONDecodeError: + logger.warning(f"跳过记录 {record.person_id},无法解析 msg_interval_list: {value}") + continue + else: + processed_value = value + + if way(processed_value): + found_results[record.person_id] = processed_value + except Exception as e_query: + logger.error(f"数据库查询失败 (Peewee specific_value_list for {f_name}): {str(e_query)}", exc_info=True) + return found_results + try: - result = {} - for doc in db.person_info.find({field_name: {"$exists": True}}, {"person_id": 1, field_name: 1, "_id": 0}): - try: - value = doc[field_name] - if way(value): - result[doc["person_id"]] = value - except (KeyError, TypeError, ValueError) as e: - logger.debug(f"记录{doc.get('person_id')}处理失败: {str(e)}") - continue - - return result - + return await asyncio.to_thread(_db_get_specific_sync, field_name) except Exception as e: - logger.error(f"数据库查询失败: {str(e)}", exc_info=True) + logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True) return {} async def personal_habit_deduction(self): @@ -399,35 +475,31 @@ class PersonInfoManager: try: while 1: await asyncio.sleep(600) - current_time = datetime.datetime.now() - logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}") + current_time_dt = datetime.datetime.now() + logger.info(f"个人信息推断启动: {current_time_dt.strftime('%Y-%m-%d %H:%M:%S')}") - # "msg_interval"推断 - msg_interval_map = False - msg_interval_lists = await self.get_specific_value_list( + msg_interval_map_generated = False + msg_interval_lists_map = await self.get_specific_value_list( "msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100 ) - for person_id, msg_interval_list_ in msg_interval_lists.items(): + + for person_id, actual_msg_interval_list in msg_interval_lists_map.items(): await asyncio.sleep(0.3) try: time_interval = [] - for t1, t2 in zip(msg_interval_list_, msg_interval_list_[1:]): + for t1, t2 in zip(actual_msg_interval_list, actual_msg_interval_list[1:]): delta = t2 - t1 if delta > 0: time_interval.append(delta) time_interval = [t for t in time_interval if 200 <= t <= 8000] - # --- 修改后的逻辑 --- - # 数据量检查 (至少需要 30 条有效间隔,并且足够进行头尾截断) - if len(time_interval) >= 30 + 10: # 至少30条有效+头尾各5条 - time_interval.sort() - # 画图(log) - 这部分保留 - msg_interval_map = True + if len(time_interval) >= 30 + 10: + time_interval.sort() + msg_interval_map_generated = True log_dir = Path("logs/person_info") log_dir.mkdir(parents=True, exist_ok=True) plt.figure(figsize=(10, 6)) - # 使用截断前的数据画图,更能反映原始分布 time_series_original = pd.Series(time_interval) plt.hist( time_series_original, @@ -449,34 +521,29 @@ class PersonInfoManager: img_path = log_dir / f"interval_distribution_{person_id[:8]}.png" plt.savefig(img_path) plt.close() - # 画图结束 - # 去掉头尾各 5 个数据点 trimmed_interval = time_interval[5:-5] - - # 计算截断后数据的 37% 分位数 - if trimmed_interval: # 确保截断后列表不为空 - msg_interval = int(round(np.percentile(trimmed_interval, 37))) - # 更新数据库 - await self.update_one_field(person_id, "msg_interval", msg_interval) - logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval}") + if trimmed_interval: + msg_interval_val = int(round(np.percentile(trimmed_interval, 37))) + await self.update_one_field(person_id, "msg_interval", msg_interval_val) + logger.trace( + f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}" + ) else: logger.trace(f"用户{person_id}截断后数据为空,无法计算msg_interval") else: logger.trace( f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)" ) - # --- 修改结束 --- - except Exception as e: - logger.trace(f"用户{person_id}消息间隔计算失败: {type(e).__name__}: {str(e)}") + except Exception as e_inner: + logger.trace(f"用户{person_id}消息间隔计算失败: {type(e_inner).__name__}: {str(e_inner)}") continue - # 其他... - - if msg_interval_map: + if msg_interval_map_generated: logger.trace("已保存分布图到: logs/person_info") - current_time = datetime.datetime.now() - logger.trace(f"个人信息推断结束: {current_time.strftime('%Y-%m-%d %H:%M:%S')}") + + current_time_dt_end = datetime.datetime.now() + logger.trace(f"个人信息推断结束: {current_time_dt_end.strftime('%Y-%m-%d %H:%M:%S')}") await asyncio.sleep(86400) except Exception as e: @@ -489,41 +556,27 @@ class PersonInfoManager: """ 根据 platform 和 user_id 获取 person_id。 如果对应的用户不存在,则使用提供的可选信息创建新用户。 - - Args: - platform: 平台标识 - user_id: 用户在该平台上的ID - nickname: 用户的昵称 (可选,用于创建新用户) - user_cardname: 用户的群名片 (可选,用于创建新用户) - user_avatar: 用户的头像信息 (可选,用于创建新用户) - - Returns: - 对应的 person_id。 """ person_id = self.get_person_id(platform, user_id) - # 检查用户是否已存在 - # 使用静态方法 get_person_id,因此可以直接调用 db - document = db.person_info.find_one({"person_id": person_id}) + def _db_check_exists_sync(p_id: str): + return PersonInfo.get_or_none(PersonInfo.person_id == p_id) - if document is None: - logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。") + record = await asyncio.to_thread(_db_check_exists_sync, person_id) + + if record is None: + logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。") initial_data = { "platform": platform, - "user_id": user_id, + "user_id": str(user_id), "nickname": nickname, - "konw_time": int(datetime.datetime.now().timestamp()), # 添加初次认识时间 - # 注意:这里没有添加 user_cardname 和 user_avatar,因为它们不在 person_info_default 中 - # 如果需要存储它们,需要先在 person_info_default 中定义 + "know_time": int(datetime.datetime.now().timestamp()), # 修正拼写:konw_time -> know_time } - # 过滤掉值为 None 的初始数据 - initial_data = {k: v for k, v in initial_data.items() if v is not None} + model_fields = PersonInfo._meta.fields.keys() + filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields} - # 注意:create_person_info 是静态方法 - await PersonInfoManager.create_person_info(person_id, data=initial_data) - # 创建后,可以考虑立即为其取名,但这可能会增加延迟 - # await self.qv_person_name(person_id, nickname, user_cardname, user_avatar) - logger.debug(f"已为 {person_id} 创建新记录,初始数据: {initial_data}") + await self.create_person_info(person_id, data=filtered_initial_data) + logger.debug(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}") return person_id @@ -533,35 +586,55 @@ class PersonInfoManager: logger.debug("get_person_info_by_name 获取失败:person_name 不能为空") return None - # 优先从内存缓存查找 person_id found_person_id = None - for pid, name in self.person_name_list.items(): - if name == person_name: + for pid, name_in_cache in self.person_name_list.items(): + if name_in_cache == person_name: found_person_id = pid - break # 找到第一个匹配就停止 + break if not found_person_id: - # 如果内存没有,尝试数据库查询(可能内存未及时更新或启动时未加载) - document = db.person_info.find_one({"person_name": person_name}) - if document: - found_person_id = document.get("person_id") - else: - logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户") - return None # 数据库也找不到 - # 根据找到的 person_id 获取所需信息 - if found_person_id: - required_fields = ["person_id", "platform", "user_id", "nickname", "user_cardname", "user_avatar"] - person_data = await self.get_values(found_person_id, required_fields) - if person_data: # 确保 get_values 成功返回 - return person_data + def _db_find_by_name_sync(p_name_to_find: str): + return PersonInfo.get_or_none(PersonInfo.person_name == p_name_to_find) + + record = await asyncio.to_thread(_db_find_by_name_sync, person_name) + if record: + found_person_id = record.person_id + if ( + found_person_id not in self.person_name_list + or self.person_name_list[found_person_id] != person_name + ): + self.person_name_list[found_person_id] = person_name else: - logger.warning(f"找到了 person_id '{found_person_id}' 但获取详细信息失败") + logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)") return None - else: - # 这理论上不应该发生,因为上面已经处理了找不到的情况 - logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id") - return None + + if found_person_id: + required_fields = [ + "person_id", + "platform", + "user_id", + "nickname", + "user_cardname", + "user_avatar", + "person_name", + "name_reason", + ] + valid_fields_to_get = [ + f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default + ] + + person_data = await self.get_values(found_person_id, valid_fields_to_get) + + if person_data: + final_result = {key: person_data.get(key) for key in required_fields} + return final_result + else: + logger.warning(f"找到了 person_id '{found_person_id}' 但 get_values 返回空 (Peewee)") + return None + + logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id (Peewee)") + return None person_info_manager = PersonInfoManager() diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index e5ccd82a7..d3a062680 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -190,8 +190,8 @@ async def _build_readable_messages_internal( person_id = person_info_manager.get_person_id(platform, user_id) # 根据 replace_bot_name 参数决定是否替换机器人名称 - if replace_bot_name and user_id == global_config.BOT_QQ: - person_name = f"{global_config.BOT_NICKNAME}(你)" + if replace_bot_name and user_id == global_config.bot.qq_account: + person_name = f"{global_config.bot.nickname}(你)" else: person_name = await person_info_manager.get_value(person_id, "person_name") @@ -427,7 +427,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: output_lines = [] def get_anon_name(platform, user_id): - if user_id == global_config.BOT_QQ: + if user_id == global_config.bot.qq_account: return "SELF" person_id = person_info_manager.get_person_id(platform, user_id) if person_id not in person_map: @@ -454,7 +454,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: def reply_replacer(match, platform=platform): # aaa = match.group(1) bbb = match.group(2) - anon_reply = get_anon_name(platform, bbb) + anon_reply = get_anon_name(platform, bbb) # noqa return f"回复 {anon_reply}" content = re.sub(reply_pattern, reply_replacer, content, count=1) @@ -465,7 +465,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: def at_replacer(match, platform=platform): # aaa = match.group(1) bbb = match.group(2) - anon_at = get_anon_name(platform, bbb) + anon_at = get_anon_name(platform, bbb) # noqa return f"@{anon_at}" content = re.sub(at_pattern, at_replacer, content) @@ -501,7 +501,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: user_id = user_info.get("user_id") # 检查必要信息是否存在 且 不是机器人自己 - if not all([platform, user_id]) or user_id == global_config.BOT_QQ: + if not all([platform, user_id]) or user_id == global_config.bot.qq_account: continue person_id = person_info_manager.get_person_id(platform, user_id) diff --git a/src/chat/utils/info_catcher.py b/src/chat/utils/info_catcher.py index 174bb5b49..93cda5113 100644 --- a/src/chat/utils/info_catcher.py +++ b/src/chat/utils/info_catcher.py @@ -1,15 +1,15 @@ from src.config.config import global_config from src.chat.message_receive.message import MessageRecv, MessageSending, Message -from src.common.database import db +from src.common.database.database_model import Messages, ThinkingLog import time import traceback from typing import List +import json class InfoCatcher: def __init__(self): self.chat_history = [] # 聊天历史,长度为三倍使用的上下文喵~ - self.context_length = global_config.observation_context_size self.chat_history_in_thinking = [] # 思考期间的聊天内容喵~ self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文喵~ @@ -60,8 +60,6 @@ class InfoCatcher: def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息 self.timing_results["sub_heartflow_observe_time"] = obs_duration - # def catch_shf - def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str): self.timing_results["sub_heartflow_step_time"] = step_duration if len(past_mind) > 1: @@ -72,25 +70,10 @@ class InfoCatcher: self.heartflow_data["sub_heartflow_now"] = current_mind def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""): - # if self.response_mode == "heart_flow": # 条件判断不需要了喵~ - # self.heartflow_data["prompt"] = prompt - # self.heartflow_data["response"] = response - # self.heartflow_data["model"] = model_name - # elif self.response_mode == "reasoning": # 条件判断不需要了喵~ - # self.reasoning_data["thinking_log"] = reasoning_content - # self.reasoning_data["prompt"] = prompt - # self.reasoning_data["response"] = response - # self.reasoning_data["model"] = model_name - - # 直接记录信息喵~ self.reasoning_data["thinking_log"] = reasoning_content self.reasoning_data["prompt"] = prompt self.reasoning_data["response"] = response self.reasoning_data["model"] = model_name - # 如果 heartflow 数据也需要通用字段,可以取消下面的注释喵~ - # self.heartflow_data["prompt"] = prompt - # self.heartflow_data["response"] = response - # self.heartflow_data["model"] = model_name self.response_text = response @@ -102,6 +85,7 @@ class InfoCatcher: ): self.timing_results["make_response_time"] = response_duration self.response_time = time.time() + self.response_messages = [] for msg in response_message: self.response_messages.append(msg) @@ -112,107 +96,112 @@ class InfoCatcher: @staticmethod def get_message_from_db_between_msgs(message_start: Message, message_end: Message): try: - # 从数据库中获取消息的时间戳 time_start = message_start.message_info.time time_end = message_end.message_info.time chat_id = message_start.chat_stream.stream_id print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}") - # 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据 - messages_between = db.messages.find( - {"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}} - ).sort("time", -1) + messages_between_query = ( + Messages.select() + .where((Messages.chat_id == chat_id) & (Messages.time > time_start) & (Messages.time < time_end)) + .order_by(Messages.time.desc()) + ) - result = list(messages_between) + result = list(messages_between_query) print(f"查询结果数量: {len(result)}") if result: - print(f"第一条消息时间: {result[0]['time']}") - print(f"最后一条消息时间: {result[-1]['time']}") + print(f"第一条消息时间: {result[0].time}") + print(f"最后一条消息时间: {result[-1].time}") return result except Exception as e: print(f"获取消息时出错: {str(e)}") + print(traceback.format_exc()) return [] def get_message_from_db_before_msg(self, message: MessageRecv): - # 从数据库中获取消息 - message_id = message.message_info.message_id - chat_id = message.chat_stream.stream_id + message_id_val = message.message_info.message_id + chat_id_val = message.chat_stream.stream_id - # 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据 - messages_before = ( - db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}}) - .sort("time", -1) - .limit(self.context_length * 3) - ) # 获取更多历史信息 + messages_before_query = ( + Messages.select() + .where((Messages.chat_id == chat_id_val) & (Messages.message_id < message_id_val)) + .order_by(Messages.time.desc()) + .limit(global_config.chat.observation_context_size * 3) + ) - return list(messages_before) + return list(messages_before_query) def message_list_to_dict(self, message_list): - # 存储简化的聊天记录 result = [] - for message in message_list: - if not isinstance(message, dict): - message = self.message_to_dict(message) - # print(message) + for msg_item in message_list: + processed_msg_item = msg_item + if not isinstance(msg_item, dict): + processed_msg_item = self.message_to_dict(msg_item) + + if not processed_msg_item: + continue lite_message = { - "time": message["time"], - "user_nickname": message["user_info"]["user_nickname"], - "processed_plain_text": message["processed_plain_text"], + "time": processed_msg_item.get("time"), + "user_nickname": processed_msg_item.get("user_nickname"), + "processed_plain_text": processed_msg_item.get("processed_plain_text"), } result.append(lite_message) - return result @staticmethod - def message_to_dict(message): - if not message: + def message_to_dict(msg_obj): + if not msg_obj: return None - if isinstance(message, dict): - return message - return { - # "message_id": message.message_info.message_id, - "time": message.message_info.time, - "user_id": message.message_info.user_info.user_id, - "user_nickname": message.message_info.user_info.user_nickname, - "processed_plain_text": message.processed_plain_text, - # "detailed_plain_text": message.detailed_plain_text - } + if isinstance(msg_obj, dict): + return msg_obj - def done_catch(self): - """将收集到的信息存储到数据库的 thinking_log 集合中喵~""" - try: - # 将消息对象转换为可序列化的字典喵~ - - thinking_log_data = { - "chat_id": self.chat_id, - "trigger_text": self.trigger_response_text, - "response_text": self.response_text, - "trigger_info": { - "time": self.trigger_response_time, - "message": self.message_to_dict(self.trigger_response_message), - }, - "response_info": { - "time": self.response_time, - "message": self.response_messages, - }, - "timing_results": self.timing_results, - "chat_history": self.message_list_to_dict(self.chat_history), - "chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking), - "chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response), - "heartflow_data": self.heartflow_data, - "reasoning_data": self.reasoning_data, + if isinstance(msg_obj, Messages): + return { + "time": msg_obj.time, + "user_id": msg_obj.user_id, + "user_nickname": msg_obj.user_nickname, + "processed_plain_text": msg_obj.processed_plain_text, } - # 根据不同的响应模式添加相应的数据喵~ # 现在直接都加上去好了喵~ - # if self.response_mode == "heart_flow": - # thinking_log_data["mode_specific_data"] = self.heartflow_data - # elif self.response_mode == "reasoning": - # thinking_log_data["mode_specific_data"] = self.reasoning_data + if hasattr(msg_obj, "message_info") and hasattr(msg_obj.message_info, "user_info"): + return { + "time": msg_obj.message_info.time, + "user_id": msg_obj.message_info.user_info.user_id, + "user_nickname": msg_obj.message_info.user_info.user_nickname, + "processed_plain_text": msg_obj.processed_plain_text, + } - # 将数据插入到 thinking_log 集合中喵~ - db.thinking_log.insert_one(thinking_log_data) + print(f"Warning: message_to_dict received an unhandled type: {type(msg_obj)}") + return {} + + def done_catch(self): + """将收集到的信息存储到数据库的 thinking_log 表中喵~""" + try: + trigger_info_dict = self.message_to_dict(self.trigger_response_message) + response_info_dict = { + "time": self.response_time, + "message": self.response_messages, + } + chat_history_list = self.message_list_to_dict(self.chat_history) + chat_history_in_thinking_list = self.message_list_to_dict(self.chat_history_in_thinking) + chat_history_after_response_list = self.message_list_to_dict(self.chat_history_after_response) + + log_entry = ThinkingLog( + chat_id=self.chat_id, + trigger_text=self.trigger_response_text, + response_text=self.response_text, + trigger_info_json=json.dumps(trigger_info_dict) if trigger_info_dict else None, + response_info_json=json.dumps(response_info_dict), + timing_results_json=json.dumps(self.timing_results), + chat_history_json=json.dumps(chat_history_list), + chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list), + chat_history_after_response_json=json.dumps(chat_history_after_response_list), + heartflow_data_json=json.dumps(self.heartflow_data), + reasoning_data_json=json.dumps(self.reasoning_data), + ) + log_entry.save() return True except Exception as e: diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 3f9832926..a657ae85b 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -2,10 +2,12 @@ from collections import defaultdict from datetime import datetime, timedelta from typing import Any, Dict, Tuple, List + from src.common.logger import get_module_logger from src.manager.async_task_manager import AsyncTask -from ...common.database import db +from ...common.database.database import db # This db is the Peewee database instance +from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model from src.manager.local_store_manager import local_storage logger = get_module_logger("maibot_statistic") @@ -39,7 +41,7 @@ class OnlineTimeRecordTask(AsyncTask): def __init__(self): super().__init__(task_name="Online Time Record Task", run_interval=60) - self.record_id: str | None = None + self.record_id: int | None = None # Changed to int for Peewee's default ID """记录ID""" self._init_database() # 初始化数据库 @@ -47,49 +49,46 @@ class OnlineTimeRecordTask(AsyncTask): @staticmethod def _init_database(): """初始化数据库""" - if "online_time" not in db.list_collection_names(): - # 初始化数据库(在线时长) - db.create_collection("online_time") - # 创建索引 - if ("end_timestamp", 1) not in db.online_time.list_indexes(): - db.online_time.create_index([("end_timestamp", 1)]) + with db.atomic(): # Use atomic operations for schema changes + OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model async def run(self): try: + current_time = datetime.now() + extended_end_time = current_time + timedelta(minutes=1) + if self.record_id: # 如果有记录,则更新结束时间 - db.online_time.update_one( - {"_id": self.record_id}, - { - "$set": { - "end_timestamp": datetime.now() + timedelta(minutes=1), - } - }, - ) - else: + query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) + updated_rows = query.execute() + if updated_rows == 0: + # Record might have been deleted or ID is stale, try to find/create + self.record_id = None # Reset record_id to trigger find/create logic below + + if not self.record_id: # Check again if record_id was reset or initially None # 如果没有记录,检查一分钟以内是否已有记录 - current_time = datetime.now() - if recent_record := db.online_time.find_one( - {"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}} - ): + # Look for a record whose end_timestamp is recent enough to be considered ongoing + recent_record = ( + OnlineTime.select() + .where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) + .order_by(OnlineTime.end_timestamp.desc()) + .first() + ) + + if recent_record: # 如果有记录,则更新结束时间 - self.record_id = recent_record["_id"] - db.online_time.update_one( - {"_id": self.record_id}, - { - "$set": { - "end_timestamp": current_time + timedelta(minutes=1), - } - }, - ) + self.record_id = recent_record.id + recent_record.end_timestamp = extended_end_time + recent_record.save() else: # 若没有记录,则插入新的在线时间记录 - self.record_id = db.online_time.insert_one( - { - "start_timestamp": current_time, - "end_timestamp": current_time + timedelta(minutes=1), - } - ).inserted_id + new_record = OnlineTime.create( + timestamp=current_time.timestamp(), # 添加此行 + start_timestamp=current_time, + end_timestamp=extended_end_time, + duration=5, # 初始时长为5分钟 + ) + self.record_id = new_record.id except Exception as e: logger.error(f"在线时间记录失败,错误信息:{e}") @@ -201,35 +200,28 @@ class StatisticOutputTask(AsyncTask): :param collect_period: 统计时间段 """ - if len(collect_period) <= 0: + if not collect_period: return {} - else: - # 排序-按照时间段开始时间降序排列(最晚的时间段在前) - collect_period.sort(key=lambda x: x[1], reverse=True) + + # 排序-按照时间段开始时间降序排列(最晚的时间段在前) + collect_period.sort(key=lambda x: x[1], reverse=True) stats = { period_key: { - # 总LLM请求数 TOTAL_REQ_CNT: 0, - # 请求次数统计 REQ_CNT_BY_TYPE: defaultdict(int), REQ_CNT_BY_USER: defaultdict(int), REQ_CNT_BY_MODEL: defaultdict(int), - # 输入Token数 IN_TOK_BY_TYPE: defaultdict(int), IN_TOK_BY_USER: defaultdict(int), IN_TOK_BY_MODEL: defaultdict(int), - # 输出Token数 OUT_TOK_BY_TYPE: defaultdict(int), OUT_TOK_BY_USER: defaultdict(int), OUT_TOK_BY_MODEL: defaultdict(int), - # 总Token数 TOTAL_TOK_BY_TYPE: defaultdict(int), TOTAL_TOK_BY_USER: defaultdict(int), TOTAL_TOK_BY_MODEL: defaultdict(int), - # 总开销 TOTAL_COST: 0.0, - # 请求开销统计 COST_BY_TYPE: defaultdict(float), COST_BY_USER: defaultdict(float), COST_BY_MODEL: defaultdict(float), @@ -238,26 +230,26 @@ class StatisticOutputTask(AsyncTask): } # 以最早的时间戳为起始时间获取记录 - for record in db.llm_usage.find({"timestamp": {"$gte": collect_period[-1][1]}}): - record_timestamp = record.get("timestamp") + # Assuming LLMUsage.timestamp is a DateTimeField + query_start_time = collect_period[-1][1] + for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): + record_timestamp = record.timestamp # This is already a datetime object for idx, (_, period_start) in enumerate(collect_period): if record_timestamp >= period_start: - # 如果记录时间在当前时间段内,则它一定在更早的时间段内 - # 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据 for period_key, _ in collect_period[idx:]: stats[period_key][TOTAL_REQ_CNT] += 1 - request_type = record.get("request_type", "unknown") # 请求类型 - user_id = str(record.get("user_id", "unknown")) # 用户ID - model_name = record.get("model_name", "unknown") # 模型名称 + request_type = record.request_type or "unknown" + user_id = record.user_id or "unknown" # user_id is TextField, already string + model_name = record.model_name or "unknown" stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1 stats[period_key][REQ_CNT_BY_USER][user_id] += 1 stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1 - prompt_tokens = record.get("prompt_tokens", 0) # 输入Token数 - completion_tokens = record.get("completion_tokens", 0) # 输出Token数 - total_tokens = prompt_tokens + completion_tokens # Token总数 = 输入Token数 + 输出Token数 + prompt_tokens = record.prompt_tokens or 0 + completion_tokens = record.completion_tokens or 0 + total_tokens = prompt_tokens + completion_tokens stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens @@ -271,13 +263,12 @@ class StatisticOutputTask(AsyncTask): stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens - cost = record.get("cost", 0.0) + cost = record.cost or 0.0 stats[period_key][TOTAL_COST] += cost stats[period_key][COST_BY_TYPE][request_type] += cost stats[period_key][COST_BY_USER][user_id] += cost stats[period_key][COST_BY_MODEL][model_name] += cost - break # 取消更早时间段的判断 - + break return stats @staticmethod @@ -287,39 +278,38 @@ class StatisticOutputTask(AsyncTask): :param collect_period: 统计时间段 """ - if len(collect_period) <= 0: + if not collect_period: return {} - else: - # 排序-按照时间段开始时间降序排列(最晚的时间段在前) - collect_period.sort(key=lambda x: x[1], reverse=True) + + collect_period.sort(key=lambda x: x[1], reverse=True) stats = { period_key: { - # 在线时间统计 ONLINE_TIME: 0.0, } for period_key, _ in collect_period } - # 统计在线时间 - for record in db.online_time.find({"end_timestamp": {"$gte": collect_period[-1][1]}}): - end_timestamp: datetime = record.get("end_timestamp") - for idx, (_, period_start) in enumerate(collect_period): - if end_timestamp >= period_start: - # 由于end_timestamp会超前标记时间,所以我们需要判断是否晚于当前时间,如果是,则使用当前时间作为结束时间 - end_timestamp = min(end_timestamp, now) - # 如果记录时间在当前时间段内,则它一定在更早的时间段内 - # 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据 - for period_key, _period_start in collect_period[idx:]: - start_timestamp: datetime = record.get("start_timestamp") - if start_timestamp < _period_start: - # 如果开始时间在查询边界之前,则使用开始时间 - stats[period_key][ONLINE_TIME] += (end_timestamp - _period_start).total_seconds() - else: - # 否则,使用开始时间 - stats[period_key][ONLINE_TIME] += (end_timestamp - start_timestamp).total_seconds() - break # 取消更早时间段的判断 + query_start_time = collect_period[-1][1] + # Assuming OnlineTime.end_timestamp is a DateTimeField + for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): + # record.end_timestamp and record.start_timestamp are datetime objects + record_end_timestamp = record.end_timestamp + record_start_timestamp = record.start_timestamp + for idx, (_, period_boundary_start) in enumerate(collect_period): + if record_end_timestamp >= period_boundary_start: + # Calculate effective end time for this record in relation to 'now' + effective_end_time = min(record_end_timestamp, now) + + for period_key, current_period_start_time in collect_period[idx:]: + # Determine the portion of the record that falls within this specific statistical period + overlap_start = max(record_start_timestamp, current_period_start_time) + overlap_end = effective_end_time # Already capped by 'now' and record's own end + + if overlap_end > overlap_start: + stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds() + break return stats def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: @@ -328,55 +318,57 @@ class StatisticOutputTask(AsyncTask): :param collect_period: 统计时间段 """ - if len(collect_period) <= 0: + if not collect_period: return {} - else: - # 排序-按照时间段开始时间降序排列(最晚的时间段在前) - collect_period.sort(key=lambda x: x[1], reverse=True) + + collect_period.sort(key=lambda x: x[1], reverse=True) stats = { period_key: { - # 消息统计 TOTAL_MSG_CNT: 0, MSG_CNT_BY_CHAT: defaultdict(int), } for period_key, _ in collect_period } - # 统计消息量 - for message in db.messages.find({"time": {"$gte": collect_period[-1][1].timestamp()}}): - chat_info = message.get("chat_info", None) # 聊天信息 - user_info = message.get("user_info", None) # 用户信息(消息发送人) - message_time = message.get("time", 0) # 消息时间 + query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) + for message in Messages.select().where(Messages.time >= query_start_timestamp): + message_time_ts = message.time # This is a float timestamp - group_info = chat_info.get("group_info") if chat_info else None # 尝试获取群聊信息 - if group_info is not None: - # 若有群聊信息 - chat_id = f"g{group_info.get('group_id')}" - chat_name = group_info.get("group_name", f"群{group_info.get('group_id')}") - elif user_info: - # 若没有群聊信息,则尝试获取用户信息 - chat_id = f"u{user_info['user_id']}" - chat_name = user_info["user_nickname"] + chat_id = None + chat_name = None + + # Logic based on Peewee model structure, aiming to replicate original intent + if message.chat_info_group_id: + chat_id = f"g{message.chat_info_group_id}" + chat_name = message.chat_info_group_name or f"群{message.chat_info_group_id}" + elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat + # This uses the message SENDER's ID as per original logic's fallback + chat_id = f"u{message.user_id}" # SENDER's user_id + chat_name = message.user_nickname # SENDER's nickname else: - continue # 如果没有群组信息也没有用户信息,则跳过 + # If neither group_id nor sender_id is available for chat identification + logger.warning( + f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats." + ) + continue + if not chat_id: # Should not happen if above logic is correct + continue + + # Update name_mapping if chat_id in self.name_mapping: - if chat_name != self.name_mapping[chat_id][0] and message_time > self.name_mapping[chat_id][1]: - # 如果用户名称不同,且新消息时间晚于之前记录的时间,则更新用户名称 - self.name_mapping[chat_id] = (chat_name, message_time) + if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]: + self.name_mapping[chat_id] = (chat_name, message_time_ts) else: - self.name_mapping[chat_id] = (chat_name, message_time) + self.name_mapping[chat_id] = (chat_name, message_time_ts) - for idx, (_, period_start) in enumerate(collect_period): - if message_time >= period_start.timestamp(): - # 如果记录时间在当前时间段内,则它一定在更早的时间段内 - # 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据 + for idx, (_, period_start_dt) in enumerate(collect_period): + if message_time_ts >= period_start_dt.timestamp(): for period_key, _ in collect_period[idx:]: stats[period_key][TOTAL_MSG_CNT] += 1 stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1 break - return stats def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 8fe8334b8..c400a9948 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -13,7 +13,7 @@ from src.manager.mood_manager import mood_manager from ..message_receive.message import MessageRecv from ..models.utils_model import LLMRequest from .typo_generator import ChineseTypoGenerator -from ...common.database import db +from ...common.database.database import db from ...config.config import global_config logger = get_module_logger("chat_utils") @@ -43,8 +43,8 @@ def db_message_to_str(message_dict: dict) -> str: def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: """检查消息是否提到了机器人""" - keywords = [global_config.BOT_NICKNAME] - nicknames = global_config.BOT_ALIAS_NAMES + keywords = [global_config.bot.nickname] + nicknames = global_config.bot.alias_names reply_probability = 0.0 is_at = False is_mentioned = False @@ -64,18 +64,18 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: ) # 判断是否被@ - if re.search(f"@[\s\S]*?(id:{global_config.BOT_QQ})", message.processed_plain_text): + if re.search(f"@[\s\S]*?(id:{global_config.bot.qq_account})", message.processed_plain_text): is_at = True is_mentioned = True - if is_at and global_config.at_bot_inevitable_reply: + if is_at and global_config.normal_chat.at_bot_inevitable_reply: reply_probability = 1.0 logger.info("被@,回复概率设置为100%") else: if not is_mentioned: # 判断是否被回复 if re.match( - f"\[回复 [\s\S]*?\({str(global_config.BOT_QQ)}\):[\s\S]*?],说:", message.processed_plain_text + f"\[回复 [\s\S]*?\({str(global_config.bot.qq_account)}\):[\s\S]*?],说:", message.processed_plain_text ): is_mentioned = True else: @@ -88,7 +88,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: for nickname in nicknames: if nickname in message_content: is_mentioned = True - if is_mentioned and global_config.mentioned_bot_inevitable_reply: + if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply: reply_probability = 1.0 logger.info("被提及,回复概率设置为100%") return is_mentioned, reply_probability @@ -96,7 +96,8 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: async def get_embedding(text, request_type="embedding"): """获取文本的embedding向量""" - llm = LLMRequest(model=global_config.embedding, request_type=request_type) + # TODO: API-Adapter修改标记 + llm = LLMRequest(model=global_config.model.embedding, request_type=request_type) # return llm.get_embedding_sync(text) try: embedding = await llm.get_embedding(text) @@ -163,7 +164,7 @@ def get_recent_group_speaker(chat_stream_id: int, sender, limit: int = 12) -> li user_info = UserInfo.from_dict(msg_db_data["user_info"]) if ( (user_info.platform, user_info.user_id) != sender - and user_info.user_id != global_config.BOT_QQ + and user_info.user_id != global_config.bot.qq_account and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group and len(who_chat_in_group) < 5 ): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目 @@ -321,7 +322,7 @@ def random_remove_punctuation(text: str) -> str: def process_llm_response(text: str) -> list[str]: # 先保护颜文字 - if global_config.enable_kaomoji_protection: + if global_config.response_splitter.enable_kaomoji_protection: protected_text, kaomoji_mapping = protect_kaomoji(text) logger.trace(f"保护颜文字后的文本: {protected_text}") else: @@ -340,8 +341,8 @@ def process_llm_response(text: str) -> list[str]: logger.debug(f"{text}去除括号处理后的文本: {cleaned_text}") # 对清理后的文本进行进一步处理 - max_length = global_config.response_max_length * 2 - max_sentence_num = global_config.response_max_sentence_num + max_length = global_config.response_splitter.max_length * 2 + max_sentence_num = global_config.response_splitter.max_sentence_num # 如果基本上是中文,则进行长度过滤 if get_western_ratio(cleaned_text) < 0.1: if len(cleaned_text) > max_length: @@ -349,20 +350,20 @@ def process_llm_response(text: str) -> list[str]: return ["懒得说"] typo_generator = ChineseTypoGenerator( - error_rate=global_config.chinese_typo_error_rate, - min_freq=global_config.chinese_typo_min_freq, - tone_error_rate=global_config.chinese_typo_tone_error_rate, - word_replace_rate=global_config.chinese_typo_word_replace_rate, + error_rate=global_config.chinese_typo.error_rate, + min_freq=global_config.chinese_typo.min_freq, + tone_error_rate=global_config.chinese_typo.tone_error_rate, + word_replace_rate=global_config.chinese_typo.word_replace_rate, ) - if global_config.enable_response_splitter: + if global_config.response_splitter.enable: split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text) else: split_sentences = [cleaned_text] sentences = [] for sentence in split_sentences: - if global_config.chinese_typo_enable: + if global_config.chinese_typo.enable: typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence) sentences.append(typoed_text) if typo_corrections: @@ -372,7 +373,7 @@ def process_llm_response(text: str) -> list[str]: if len(sentences) > max_sentence_num: logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复") - return [f"{global_config.BOT_NICKNAME}不知道哦"] + return [f"{global_config.bot.nickname}不知道哦"] # if extracted_contents: # for content in extracted_contents: diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 455038246..c317fbbd6 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -8,7 +8,8 @@ import io import numpy as np -from ...common.database import db +from ...common.database.database import db +from ...common.database.database_model import Images, ImageDescriptions from ...config.config import global_config from ..models.utils_model import LLMRequest @@ -32,40 +33,23 @@ class ImageManager: def __init__(self): if not self._initialized: - self._ensure_image_collection() - self._ensure_description_collection() self._ensure_image_dir() + + self._initialized = True + self._llm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image") + + try: + db.connect(reuse_if_open=True) + db.create_tables([Images, ImageDescriptions], safe=True) + except Exception as e: + logger.error(f"数据库连接或表创建失败: {e}") + self._initialized = True - self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image") def _ensure_image_dir(self): """确保图像存储目录存在""" os.makedirs(self.IMAGE_DIR, exist_ok=True) - @staticmethod - def _ensure_image_collection(): - """确保images集合存在并创建索引""" - if "images" not in db.list_collection_names(): - db.create_collection("images") - - # 删除旧索引 - db.images.drop_indexes() - # 创建新的复合索引 - db.images.create_index([("hash", 1), ("type", 1)], unique=True) - db.images.create_index([("url", 1)]) - db.images.create_index([("path", 1)]) - - @staticmethod - def _ensure_description_collection(): - """确保image_descriptions集合存在并创建索引""" - if "image_descriptions" not in db.list_collection_names(): - db.create_collection("image_descriptions") - - # 删除旧索引 - db.image_descriptions.drop_indexes() - # 创建新的复合索引 - db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True) - @staticmethod def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]: """从数据库获取图片描述 @@ -77,8 +61,14 @@ class ImageManager: Returns: Optional[str]: 描述文本,如果不存在则返回None """ - result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type}) - return result["description"] if result else None + try: + record = ImageDescriptions.get_or_none( + (ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type) + ) + return record.description if record else None + except Exception as e: + logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}") + return None @staticmethod def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None: @@ -90,20 +80,17 @@ class ImageManager: description_type: 描述类型 ('emoji' 或 'image') """ try: - db.image_descriptions.update_one( - {"hash": image_hash, "type": description_type}, - { - "$set": { - "description": description, - "timestamp": int(time.time()), - "hash": image_hash, # 确保hash字段存在 - "type": description_type, # 确保type字段存在 - } - }, - upsert=True, + current_timestamp = time.time() + defaults = {"description": description, "timestamp": current_timestamp} + desc_obj, created = ImageDescriptions.get_or_create( + hash=image_hash, type=description_type, defaults=defaults ) + if not created: # 如果记录已存在,则更新 + desc_obj.description = description + desc_obj.timestamp = current_timestamp + desc_obj.save() except Exception as e: - logger.error(f"保存描述到数据库失败: {str(e)}") + logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}") async def get_emoji_description(self, image_base64: str) -> str: """获取表情包描述,带查重和保存功能""" @@ -116,51 +103,64 @@ class ImageManager: # 查询缓存的描述 cached_description = self._get_description_from_db(image_hash, "emoji") if cached_description: - # logger.debug(f"缓存表情包描述: {cached_description}") return f"[表情包,含义看起来是:{cached_description}]" # 调用AI获取描述 if image_format == "gif" or image_format == "GIF": - image_base64 = self.transform_gif(image_base64) + image_base64_processed = self.transform_gif(image_base64) + if image_base64_processed is None: + logger.warning("GIF转换失败,无法获取描述") + return "[表情包(GIF处理失败)]" prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,使用1-2个词描述一下表情包表达的情感和内容,简短一些" - description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg") + description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg") else: prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些" description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) + if description is None: + logger.warning("AI未能生成表情包描述") + return "[表情包(描述生成失败)]" + + # 再次检查缓存,防止并发写入时重复生成 cached_description = self._get_description_from_db(image_hash, "emoji") if cached_description: logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}") return f"[表情包,含义看起来是:{cached_description}]" # 根据配置决定是否保存图片 - if global_config.save_emoji: + if global_config.emoji.save_emoji: # 生成文件名和路径 - timestamp = int(time.time()) - filename = f"{timestamp}_{image_hash[:8]}.{image_format}" - if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")): - os.makedirs(os.path.join(self.IMAGE_DIR, "emoji")) - file_path = os.path.join(self.IMAGE_DIR, "emoji", filename) + current_timestamp = time.time() + filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}" + emoji_dir = os.path.join(self.IMAGE_DIR, "emoji") + os.makedirs(emoji_dir, exist_ok=True) + file_path = os.path.join(emoji_dir, filename) try: # 保存文件 with open(file_path, "wb") as f: f.write(image_bytes) - # 保存到数据库 - image_doc = { - "hash": image_hash, - "path": file_path, - "type": "emoji", - "description": description, - "timestamp": timestamp, - } - db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True) - logger.trace(f"保存表情包: {file_path}") + # 保存到数据库 (Images表) + try: + img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji")) + img_obj.path = file_path + img_obj.description = description + img_obj.timestamp = current_timestamp + img_obj.save() + except Images.DoesNotExist: + Images.create( + hash=image_hash, + path=file_path, + type="emoji", + description=description, + timestamp=current_timestamp, + ) + logger.trace(f"保存表情包元数据: {file_path}") except Exception as e: - logger.error(f"保存表情包文件失败: {str(e)}") + logger.error(f"保存表情包文件或元数据失败: {str(e)}") - # 保存描述到数据库 + # 保存描述到数据库 (ImageDescriptions表) self._save_description_to_db(image_hash, description, "emoji") return f"[表情包:{description}]" @@ -188,6 +188,11 @@ class ImageManager: ) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) + if description is None: + logger.warning("AI未能生成图片描述") + return "[图片(描述生成失败)]" + + # 再次检查缓存 cached_description = self._get_description_from_db(image_hash, "image") if cached_description: logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}") @@ -195,38 +200,40 @@ class ImageManager: logger.debug(f"描述是{description}") - if description is None: - logger.warning("AI未能生成图片描述") - return "[图片]" - # 根据配置决定是否保存图片 - if global_config.save_pic: + if global_config.emoji.save_pic: # 生成文件名和路径 - timestamp = int(time.time()) - filename = f"{timestamp}_{image_hash[:8]}.{image_format}" - if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")): - os.makedirs(os.path.join(self.IMAGE_DIR, "image")) - file_path = os.path.join(self.IMAGE_DIR, "image", filename) + current_timestamp = time.time() + filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}" + image_dir = os.path.join(self.IMAGE_DIR, "image") + os.makedirs(image_dir, exist_ok=True) + file_path = os.path.join(image_dir, filename) try: # 保存文件 with open(file_path, "wb") as f: f.write(image_bytes) - # 保存到数据库 - image_doc = { - "hash": image_hash, - "path": file_path, - "type": "image", - "description": description, - "timestamp": timestamp, - } - db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True) - logger.trace(f"保存图片: {file_path}") + # 保存到数据库 (Images表) + try: + img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "image")) + img_obj.path = file_path + img_obj.description = description + img_obj.timestamp = current_timestamp + img_obj.save() + except Images.DoesNotExist: + Images.create( + hash=image_hash, + path=file_path, + type="image", + description=description, + timestamp=current_timestamp, + ) + logger.trace(f"保存图片元数据: {file_path}") except Exception as e: - logger.error(f"保存图片文件失败: {str(e)}") + logger.error(f"保存图片文件或元数据失败: {str(e)}") - # 保存描述到数据库 + # 保存描述到数据库 (ImageDescriptions表) self._save_description_to_db(image_hash, description, "image") return f"[图片:{description}]" diff --git a/src/chat/zhishi/knowledge_library.py b/src/chat/zhishi/knowledge_library.py index 6fa1d3e1a..0068a153c 100644 --- a/src/chat/zhishi/knowledge_library.py +++ b/src/chat/zhishi/knowledge_library.py @@ -16,7 +16,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) sys.path.append(root_path) # 现在可以导入src模块 -from src.common.database import db # noqa E402 +from common.database.database import db # noqa E402 # 加载根目录下的env.edv文件 diff --git a/src/common/database.py b/src/common/database/database.py similarity index 81% rename from src/common/database.py rename to src/common/database/database.py index 752f746db..a2dab739d 100644 --- a/src/common/database.py +++ b/src/common/database/database.py @@ -1,5 +1,6 @@ import os from pymongo import MongoClient +from peewee import SqliteDatabase from pymongo.database import Database from rich.traceback import install @@ -57,4 +58,15 @@ class DBWrapper: # 全局数据库访问点 -db: Database = DBWrapper() +memory_db: Database = DBWrapper() + +# 定义数据库文件路径 +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +_DB_DIR = os.path.join(ROOT_PATH, "data") +_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db") + +# 确保数据库目录存在 +os.makedirs(_DB_DIR, exist_ok=True) + +# 全局 Peewee SQLite 数据库访问点 +db = SqliteDatabase(_DB_FILE) diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py new file mode 100644 index 000000000..bd7a2d319 --- /dev/null +++ b/src/common/database/database_model.py @@ -0,0 +1,358 @@ +from peewee import Model, DoubleField, IntegerField, BooleanField, TextField, FloatField, DateTimeField +from .database import db +import datetime +from ..logger_manager import get_logger + +logger = get_logger("database_model") +# 请在此处定义您的数据库实例。 +# 您需要取消注释并配置适合您的数据库的部分。 +# 例如,对于 SQLite: +# db = SqliteDatabase('MaiBot.db') +# +# 对于 PostgreSQL: +# db = PostgresqlDatabase('your_db_name', user='your_user', password='your_password', +# host='localhost', port=5432) +# +# 对于 MySQL: +# db = MySQLDatabase('your_db_name', user='your_user', password='your_password', +# host='localhost', port=3306) + + +# 定义一个基础模型是一个好习惯,所有其他模型都应继承自它。 +# 这允许您在一个地方为所有模型指定数据库。 +class BaseModel(Model): + class Meta: + # 将下面的 'db' 替换为您实际的数据库实例变量名。 + database = db # 例如: database = my_actual_db_instance + pass # 在用户定义数据库实例之前,此处为占位符 + + +class ChatStreams(BaseModel): + """ + 用于存储流式记录数据的模型,类似于提供的 MongoDB 结构。 + """ + + # stream_id: "a544edeb1a9b73e3e1d77dff36e41264" + # 假设 stream_id 是唯一的,并为其创建索引以提高查询性能。 + stream_id = TextField(unique=True, index=True) + + # create_time: 1746096761.4490178 (时间戳,精确到小数点后7位) + # DoubleField 用于存储浮点数,适合此类时间戳。 + create_time = DoubleField() + + # group_info 字段: + # platform: "qq" + # group_id: "941657197" + # group_name: "测试" + group_platform = TextField() + group_id = TextField() + group_name = TextField() + + # last_active_time: 1746623771.4825106 (时间戳,精确到小数点后7位) + last_active_time = DoubleField() + + # platform: "qq" (顶层平台字段) + platform = TextField() + + # user_info 字段: + # platform: "qq" + # user_id: "1787882683" + # user_nickname: "墨梓柒(IceSakurary)" + # user_cardname: "" + user_platform = TextField() + user_id = TextField() + user_nickname = TextField() + # user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。 + user_cardname = TextField(null=True) + + class Meta: + # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 + # 如果不使用带有数据库实例的 BaseModel,或者想覆盖它, + # 请取消注释并在下面设置数据库实例: + # database = db + table_name = "chat_streams" # 可选:明确指定数据库中的表名 + + +class LLMUsage(BaseModel): + """ + 用于存储 API 使用日志数据的模型。 + """ + + model_name = TextField(index=True) # 添加索引 + user_id = TextField(index=True) # 添加索引 + request_type = TextField(index=True) # 添加索引 + endpoint = TextField() + prompt_tokens = IntegerField() + completion_tokens = IntegerField() + total_tokens = IntegerField() + cost = DoubleField() + status = TextField() + timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引 + + class Meta: + # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 + # database = db + table_name = "llm_usage" + + +class Emoji(BaseModel): + """表情包""" + + full_path = TextField(unique=True, index=True) # 文件的完整路径 (包括文件名) + format = TextField() # 图片格式 + emoji_hash = TextField(index=True) # 表情包的哈希值 + description = TextField() # 表情包的描述 + query_count = IntegerField(default=0) # 查询次数(用于统计表情包被查询描述的次数) + is_registered = BooleanField(default=False) # 是否已注册 + is_banned = BooleanField(default=False) # 是否被禁止注册 + # emotion: list[str] # 表情包的情感标签 - 存储为文本,应用层处理序列化/反序列化 + emotion = TextField(null=True) + record_time = FloatField() # 记录时间(被创建的时间) + register_time = FloatField(null=True) # 注册时间(被注册为可用表情包的时间) + usage_count = IntegerField(default=0) # 使用次数(被使用的次数) + last_used_time = FloatField(null=True) # 上次使用时间 + + class Meta: + # database = db # 继承自 BaseModel + table_name = "emoji" + + +class Messages(BaseModel): + """ + 用于存储消息数据的模型。 + """ + + message_id = TextField(index=True) # 消息 ID (更改自 IntegerField) + time = DoubleField() # 消息时间戳 + + chat_id = TextField(index=True) # 对应的 ChatStreams stream_id + + # 从 chat_info 扁平化而来的字段 + chat_info_stream_id = TextField() + chat_info_platform = TextField() + chat_info_user_platform = TextField() + chat_info_user_id = TextField() + chat_info_user_nickname = TextField() + chat_info_user_cardname = TextField(null=True) + chat_info_group_platform = TextField(null=True) # 群聊信息可能不存在 + chat_info_group_id = TextField(null=True) + chat_info_group_name = TextField(null=True) + chat_info_create_time = DoubleField() + chat_info_last_active_time = DoubleField() + + # 从顶层 user_info 扁平化而来的字段 (消息发送者信息) + user_platform = TextField() + user_id = TextField() + user_nickname = TextField() + user_cardname = TextField(null=True) + + processed_plain_text = TextField(null=True) # 处理后的纯文本消息 + detailed_plain_text = TextField(null=True) # 详细的纯文本消息 + memorized_times = IntegerField(default=0) # 被记忆的次数 + + class Meta: + # database = db # 继承自 BaseModel + table_name = "messages" + + +class Images(BaseModel): + """ + 用于存储图像信息的模型。 + """ + + emoji_hash = TextField(index=True) # 图像的哈希值 + description = TextField(null=True) # 图像的描述 + path = TextField(unique=True) # 图像文件的路径 + timestamp = FloatField() # 时间戳 + type = TextField() # 图像类型,例如 "emoji" + + class Meta: + # database = db # 继承自 BaseModel + table_name = "images" + + +class ImageDescriptions(BaseModel): + """ + 用于存储图像描述信息的模型。 + """ + + type = TextField() # 类型,例如 "emoji" + image_description_hash = TextField(index=True) # 图像的哈希值 + description = TextField() # 图像的描述 + timestamp = FloatField() # 时间戳 + + class Meta: + # database = db # 继承自 BaseModel + table_name = "image_descriptions" + + +class OnlineTime(BaseModel): + """ + 用于存储在线时长记录的模型。 + """ + + # timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串) + timestamp = TextField(default=datetime.datetime.now) # 时间戳 + duration = IntegerField() # 时长,单位分钟 + start_timestamp = DateTimeField(default=datetime.datetime.now) + end_timestamp = DateTimeField(index=True) + + class Meta: + # database = db # 继承自 BaseModel + table_name = "online_time" + + +class PersonInfo(BaseModel): + """ + 用于存储个人信息数据的模型。 + """ + + person_id = TextField(unique=True, index=True) # 个人唯一ID + person_name = TextField(null=True) # 个人名称 (允许为空) + name_reason = TextField(null=True) # 名称设定的原因 + platform = TextField() # 平台 + user_id = TextField(index=True) # 用户ID + nickname = TextField() # 用户昵称 + relationship_value = IntegerField(default=0) # 关系值 + know_time = FloatField() # 认识时间 (时间戳) + msg_interval = IntegerField() # 消息间隔 + # msg_interval_list: 存储为 JSON 字符串的列表 + msg_interval_list = TextField(null=True) + + class Meta: + # database = db # 继承自 BaseModel + table_name = "person_info" + + +class Knowledges(BaseModel): + """ + 用于存储知识库条目的模型。 + """ + + content = TextField() # 知识内容的文本 + embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表 + # 可以添加其他元数据字段,如 source, create_time 等 + + class Meta: + # database = db # 继承自 BaseModel + table_name = "knowledges" + + +class ThinkingLog(BaseModel): + chat_id = TextField(index=True) + trigger_text = TextField(null=True) + response_text = TextField(null=True) + + # Store complex dicts/lists as JSON strings + trigger_info_json = TextField(null=True) + response_info_json = TextField(null=True) + timing_results_json = TextField(null=True) + chat_history_json = TextField(null=True) + chat_history_in_thinking_json = TextField(null=True) + chat_history_after_response_json = TextField(null=True) + heartflow_data_json = TextField(null=True) + reasoning_data_json = TextField(null=True) + + # Add a timestamp for the log entry itself + # Ensure you have: from peewee import DateTimeField + # And: import datetime + created_at = DateTimeField(default=datetime.datetime.now) + + class Meta: + table_name = "thinking_logs" + + +class RecalledMessages(BaseModel): + """ + 用于存储撤回消息记录的模型。 + """ + + message_id = TextField(index=True) # 被撤回的消息 ID + time = DoubleField() # 撤回操作发生的时间戳 + stream_id = TextField() # 对应的 ChatStreams stream_id + + class Meta: + table_name = "recalled_messages" + + +def create_tables(): + """ + 创建所有在模型中定义的数据库表。 + """ + with db: + db.create_tables( + [ + ChatStreams, + LLMUsage, + Emoji, + Messages, + Images, + ImageDescriptions, + OnlineTime, + PersonInfo, + Knowledges, + ThinkingLog, + RecalledMessages, # 添加新模型 + ] + ) + + +def initialize_database(): + """ + 检查所有定义的表是否存在,如果不存在则创建它们。 + 检查所有表的所有字段是否存在,如果缺失则警告用户并退出程序。 + """ + import sys + + models = [ + ChatStreams, + LLMUsage, + Emoji, + Messages, + Images, + ImageDescriptions, + OnlineTime, + PersonInfo, + Knowledges, + ThinkingLog, + RecalledMessages, # 添加新模型 + ] + + needs_creation = False + try: + with db: # 管理 table_exists 检查的连接 + for model in models: + table_name = model._meta.table_name + if not db.table_exists(model): + logger.warning(f"表 '{table_name}' 未找到。") + needs_creation = True + break # 一个表丢失,无需进一步检查。 + if not needs_creation: + # 检查字段 + for model in models: + table_name = model._meta.table_name + cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')") + existing_columns = {row[1] for row in cursor.fetchall()} + model_fields = model._meta.fields + for field_name in model_fields: + if field_name not in existing_columns: + logger.error(f"表 '{table_name}' 缺失字段 '{field_name}',请手动迁移数据库结构后重启程序。") + sys.exit(1) + except Exception as e: + logger.exception(f"检查表或字段是否存在时出错: {e}") + # 如果检查失败(例如数据库不可用),则退出 + return + + if needs_creation: + logger.info("正在初始化数据库:一个或多个表丢失。正在尝试创建所有定义的表...") + try: + create_tables() # 此函数有其自己的 'with db:' 上下文管理。 + logger.info("数据库表创建过程完成。") + except Exception as e: + logger.exception(f"创建表期间出错: {e}") + else: + logger.info("所有数据库表及字段均已存在。") + + +# 模块加载时调用初始化函数 +initialize_database() diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 03f192cea..ee69b22b0 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -1,11 +1,19 @@ -from src.common.database import db +from src.common.database.database_model import Messages # 更改导入 from src.common.logger import get_module_logger import traceback from typing import List, Any, Optional +from peewee import Model # 添加 Peewee Model 导入 logger = get_module_logger(__name__) +def _model_to_dict(model_instance: Model) -> dict[str, Any]: + """ + 将 Peewee 模型实例转换为字典。 + """ + return model_instance.__data__ + + def find_messages( message_filter: dict[str, Any], sort: Optional[List[tuple[str, int]]] = None, @@ -16,39 +24,84 @@ def find_messages( 根据提供的过滤器、排序和限制条件查找消息。 Args: - message_filter: MongoDB 查询过滤器。 - sort: MongoDB 排序条件列表,例如 [('time', 1)]。仅在 limit 为 0 时生效。 + message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}). + sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。 limit: 返回的最大文档数,0表示不限制。 limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'。 Returns: - 消息文档列表,如果出错则返回空列表。 + 消息字典列表,如果出错则返回空列表。 """ try: - query = db.messages.find(message_filter) + query = Messages.select() + + # 应用过滤器 + if message_filter: + conditions = [] + for key, value in message_filter.items(): + if hasattr(Messages, key): + field = getattr(Messages, key) + if isinstance(value, dict): + # 处理 MongoDB 风格的操作符 + for op, op_value in value.items(): + if op == "$gt": + conditions.append(field > op_value) + elif op == "$lt": + conditions.append(field < op_value) + elif op == "$gte": + conditions.append(field >= op_value) + elif op == "$lte": + conditions.append(field <= op_value) + elif op == "$ne": + conditions.append(field != op_value) + elif op == "$in": + conditions.append(field.in_(op_value)) + elif op == "$nin": + conditions.append(field.not_in(op_value)) + else: + logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。") + else: + # 直接相等比较 + conditions.append(field == value) + else: + logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") + if conditions: + query = query.where(*conditions) if limit > 0: if limit_mode == "earliest": # 获取时间最早的 limit 条记录,已经是正序 - query = query.sort([("time", 1)]).limit(limit) - results = list(query) + query = query.order_by(Messages.time.asc()).limit(limit) + peewee_results = list(query) else: # 默认为 'latest' # 获取时间最晚的 limit 条记录 - query = query.sort([("time", -1)]).limit(limit) - latest_results = list(query) + query = query.order_by(Messages.time.desc()).limit(limit) + latest_results_peewee = list(query) # 将结果按时间正序排列 - # 假设消息文档中总是有 'time' 字段且可排序 - results = sorted(latest_results, key=lambda msg: msg.get("time")) + peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time) else: # limit 为 0 时,应用传入的 sort 参数 if sort: - query = query.sort(sort) - results = list(query) + peewee_sort_terms = [] + for field_name, direction in sort: + if hasattr(Messages, field_name): + field = getattr(Messages, field_name) + if direction == 1: # ASC + peewee_sort_terms.append(field.asc()) + elif direction == -1: # DESC + peewee_sort_terms.append(field.desc()) + else: + logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。") + else: + logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。") + if peewee_sort_terms: + query = query.order_by(*peewee_sort_terms) + peewee_results = list(query) - return results + return [_model_to_dict(msg) for msg in peewee_results] except Exception as e: log_message = ( - f"查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" + f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" + traceback.format_exc() ) logger.error(log_message) @@ -60,18 +113,57 @@ def count_messages(message_filter: dict[str, Any]) -> int: 根据提供的过滤器计算消息数量。 Args: - message_filter: MongoDB 查询过滤器。 + message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}). Returns: 符合条件的消息数量,如果出错则返回 0。 """ try: - count = db.messages.count_documents(message_filter) + query = Messages.select() + + # 应用过滤器 + if message_filter: + conditions = [] + for key, value in message_filter.items(): + if hasattr(Messages, key): + field = getattr(Messages, key) + if isinstance(value, dict): + # 处理 MongoDB 风格的操作符 + for op, op_value in value.items(): + if op == "$gt": + conditions.append(field > op_value) + elif op == "$lt": + conditions.append(field < op_value) + elif op == "$gte": + conditions.append(field >= op_value) + elif op == "$lte": + conditions.append(field <= op_value) + elif op == "$ne": + conditions.append(field != op_value) + elif op == "$in": + conditions.append(field.in_(op_value)) + elif op == "$nin": + conditions.append(field.not_in(op_value)) + else: + logger.warning( + f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。" + ) + else: + # 直接相等比较 + conditions.append(field == value) + else: + logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") + if conditions: + query = query.where(*conditions) + + count = query.count() return count except Exception as e: - log_message = f"计数消息失败 (message_filter={message_filter}): {e}\n" + traceback.format_exc() + log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}" logger.error(log_message) return 0 # 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 +# 注意:对于 Peewee,插入操作通常是 Messages.create(...) 或 instance.save()。 +# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。 diff --git a/src/common/remote.py b/src/common/remote.py index 1d26df01b..b1108be9c 100644 --- a/src/common/remote.py +++ b/src/common/remote.py @@ -35,7 +35,7 @@ class TelemetryHeartBeatTask(AsyncTask): info_dict = { "os_type": "Unknown", "py_version": platform.python_version(), - "mmc_version": global_config.MAI_VERSION, + "mmc_version": global_config.MMC_VERSION, } match platform.system(): @@ -133,10 +133,9 @@ class TelemetryHeartBeatTask(AsyncTask): async def run(self): # 发送心跳 - if global_config.remote_enable: - if self.client_uuid is None: - if not await self._req_uuid(): - logger.error("获取UUID失败,跳过此次心跳") - return + if global_config.telemetry.enable: + if self.client_uuid is None and not await self._req_uuid(): + logger.error("获取UUID失败,跳过此次心跳") + return await self._send_heartbeat() diff --git a/src/config/config.py b/src/config/config.py index b186f3b83..e6b7c5326 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -1,64 +1,68 @@ import os -import re -from dataclasses import dataclass, field -from typing import Dict, List, Optional +from dataclasses import field, dataclass -import tomli import tomlkit import shutil from datetime import datetime -from pathlib import Path -from packaging import version -from packaging.version import Version, InvalidVersion -from packaging.specifiers import SpecifierSet, InvalidSpecifier + +from tomlkit import TOMLDocument +from tomlkit.items import Table from src.common.logger_manager import get_logger from rich.traceback import install +from src.config.config_base import ConfigBase +from src.config.official_configs import ( + BotConfig, + ChatTargetConfig, + PersonalityConfig, + IdentityConfig, + PlatformsConfig, + ChatConfig, + NormalChatConfig, + FocusChatConfig, + EmojiConfig, + MemoryConfig, + MoodConfig, + KeywordReactionConfig, + ChineseTypoConfig, + ResponseSplitterConfig, + TelemetryConfig, + ExperimentalConfig, + ModelConfig, +) + install(extra_lines=3) # 配置主程序日志格式 logger = get_logger("config") -# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 -is_test = True -mai_version_main = "0.6.4" -mai_version_fix = "snapshot-1" +CONFIG_DIR = "config" +TEMPLATE_DIR = "template" -if mai_version_fix: - if is_test: - mai_version = f"test-{mai_version_main}-{mai_version_fix}" - else: - mai_version = f"{mai_version_main}-{mai_version_fix}" -else: - if is_test: - mai_version = f"test-{mai_version_main}" - else: - mai_version = mai_version_main +# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 +# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/ +MMC_VERSION = "0.7.0-snapshot.1" def update_config(): # 获取根目录路径 - root_dir = Path(__file__).parent.parent.parent - template_dir = root_dir / "template" - config_dir = root_dir / "config" - old_config_dir = config_dir / "old" + old_config_dir = f"{CONFIG_DIR}/old" # 定义文件路径 - template_path = template_dir / "bot_config_template.toml" - old_config_path = config_dir / "bot_config.toml" - new_config_path = config_dir / "bot_config.toml" + template_path = f"{TEMPLATE_DIR}/bot_config_template.toml" + old_config_path = f"{CONFIG_DIR}/bot_config.toml" + new_config_path = f"{CONFIG_DIR}/bot_config.toml" # 检查配置文件是否存在 - if not old_config_path.exists(): + if not os.path.exists(old_config_path): logger.info("配置文件不存在,从模板创建新配置") - # 创建文件夹 - old_config_dir.mkdir(parents=True, exist_ok=True) - shutil.copy2(template_path, old_config_path) + os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹 + shutil.copy2(template_path, old_config_path) # 复制模板文件 logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}") # 如果是新创建的配置文件,直接返回 - return quit() + quit() # 读取旧配置文件和模板文件 with open(old_config_path, "r", encoding="utf-8") as f: @@ -75,13 +79,15 @@ def update_config(): return else: logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}") + else: + logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新") # 创建old目录(如果不存在) - old_config_dir.mkdir(exist_ok=True) + os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml" + old_backup_path = f"{old_config_dir}/bot_config_{timestamp}.toml" # 移动旧配置文件到old目录 shutil.move(old_config_path, old_backup_path) @@ -91,24 +97,23 @@ def update_config(): shutil.copy2(template_path, new_config_path) logger.info(f"已创建新配置文件: {new_config_path}") - # 递归更新配置 - def update_dict(target, source): + def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict): + """ + 将source字典的值更新到target字典中(如果target中存在相同的键) + """ for key, value in source.items(): # 跳过version字段的更新 if key == "version": continue if key in target: - if isinstance(value, dict) and isinstance(target[key], (dict, tomlkit.items.Table)): + if isinstance(value, dict) and isinstance(target[key], (dict, Table)): update_dict(target[key], value) else: try: # 对数组类型进行特殊处理 if isinstance(value, list): # 如果是空数组,确保它保持为空数组 - if not value: - target[key] = tomlkit.array() - else: - target[key] = tomlkit.array(value) + target[key] = tomlkit.array(str(value)) if value else tomlkit.array() else: # 其他类型使用item方法创建新值 target[key] = tomlkit.item(value) @@ -123,619 +128,57 @@ def update_config(): # 保存更新后的配置(保留注释和格式) with open(new_config_path, "w", encoding="utf-8") as f: f.write(tomlkit.dumps(new_config)) - logger.info("配置文件更新完成") + logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息") + quit() @dataclass -class BotConfig: - """机器人配置类""" - - INNER_VERSION: Version = None - MAI_VERSION: str = mai_version # 硬编码的版本信息 - - # bot - BOT_QQ: Optional[str] = "114514" - BOT_NICKNAME: Optional[str] = None - BOT_ALIAS_NAMES: List[str] = field(default_factory=list) # 别名,可以通过这个叫它 - - # group - talk_allowed_groups = set() - talk_frequency_down_groups = set() - ban_user_id = set() - - # personality - personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内,谁再写3000字小作文敲谁脑袋 - personality_sides: List[str] = field( - default_factory=lambda: [ - "用一句话或几句话描述人格的一些侧面", - "用一句话或几句话描述人格的一些侧面", - "用一句话或几句话描述人格的一些侧面", - ] - ) - expression_style = "描述麦麦说话的表达风格,表达习惯" - # identity - identity_detail: List[str] = field( - default_factory=lambda: [ - "身份特点", - "身份特点", - ] - ) - height: int = 170 # 身高 单位厘米 - weight: int = 50 # 体重 单位千克 - age: int = 20 # 年龄 单位岁 - gender: str = "男" # 性别 - appearance: str = "用几句话描述外貌特征" # 外貌特征 - - # chat - allow_focus_mode: bool = True # 是否允许专注聊天状态 - - base_normal_chat_num: int = 3 # 最多允许多少个群进行普通聊天 - base_focused_chat_num: int = 2 # 最多允许多少个群进行专注聊天 - - observation_context_size: int = 12 # 心流观察到的最长上下文大小,超过这个值的上下文会被压缩 - - message_buffer: bool = True # 消息缓冲器 - - ban_words = set() - ban_msgs_regex = set() - - # focus_chat - reply_trigger_threshold: float = 3.0 # 心流聊天触发阈值,越低越容易触发 - default_decay_rate_per_second: float = 0.98 # 默认衰减率,越大衰减越慢 - consecutive_no_reply_threshold = 3 - - compressed_length: int = 5 # 不能大于observation_context_size,心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5 - compress_length_limit: int = 5 # 最多压缩份数,超过该数值的压缩上下文会被删除 - - # normal_chat - model_reasoning_probability: float = 0.7 # 麦麦回答时选择推理模型(主要)模型概率 - model_normal_probability: float = 0.3 # 麦麦回答时选择一般模型(次要)模型概率 - - emoji_chance: float = 0.2 # 发送表情包的基础概率 - thinking_timeout: int = 120 # 思考时间 - - willing_mode: str = "classical" # 意愿模式 - response_willing_amplifier: float = 1.0 # 回复意愿放大系数 - response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数 - down_frequency_rate: float = 3 # 降低回复频率的群组回复意愿降低系数 - emoji_response_penalty: float = 0.0 # 表情包回复惩罚 - mentioned_bot_inevitable_reply: bool = False # 提及 bot 必然回复 - at_bot_inevitable_reply: bool = False # @bot 必然回复 - - # emoji - max_emoji_num: int = 200 # 表情包最大数量 - max_reach_deletion: bool = True # 开启则在达到最大数量时删除表情包,关闭则不会继续收集表情包 - EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟) - - save_pic: bool = False # 是否保存图片 - save_emoji: bool = False # 是否保存表情包 - steal_emoji: bool = True # 是否偷取表情包,让麦麦可以发送她保存的这些表情包 - - EMOJI_CHECK: bool = False # 是否开启过滤 - EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求 - - # memory - build_memory_interval: int = 600 # 记忆构建间隔(秒) - memory_build_distribution: list = field( - default_factory=lambda: [4, 2, 0.6, 24, 8, 0.4] - ) # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重 - build_memory_sample_num: int = 10 # 记忆构建采样数量 - build_memory_sample_length: int = 20 # 记忆构建采样长度 - memory_compress_rate: float = 0.1 # 记忆压缩率 - - forget_memory_interval: int = 600 # 记忆遗忘间隔(秒) - memory_forget_time: int = 24 # 记忆遗忘时间(小时) - memory_forget_percentage: float = 0.01 # 记忆遗忘比例 - - consolidate_memory_interval: int = 1000 # 记忆整合间隔(秒) - consolidation_similarity_threshold: float = 0.7 # 相似度阈值 - consolidate_memory_percentage: float = 0.01 # 检查节点比例 - - memory_ban_words: list = field( - default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"] - ) # 添加新的配置项默认值 - - # mood - mood_update_interval: float = 1.0 # 情绪更新间隔 单位秒 - mood_decay_rate: float = 0.95 # 情绪衰减率 - mood_intensity_factor: float = 0.7 # 情绪强度因子 - - # keywords - keywords_reaction_rules = [] # 关键词回复规则 - - # chinese_typo - chinese_typo_enable = True # 是否启用中文错别字生成器 - chinese_typo_error_rate = 0.03 # 单字替换概率 - chinese_typo_min_freq = 7 # 最小字频阈值 - chinese_typo_tone_error_rate = 0.2 # 声调错误概率 - chinese_typo_word_replace_rate = 0.02 # 整词替换概率 - - # response_splitter - enable_kaomoji_protection = False # 是否启用颜文字保护 - enable_response_splitter = True # 是否启用回复分割器 - response_max_length = 100 # 回复允许的最大长度 - response_max_sentence_num = 3 # 回复允许的最大句子数 - - model_max_output_length: int = 800 # 最大回复长度 - - # remote - remote_enable: bool = True # 是否启用远程控制 - - # experimental - enable_friend_chat: bool = False # 是否启用好友聊天 - # enable_think_flow: bool = False # 是否启用思考流程 - talk_allowed_private = set() - enable_pfc_chatting: bool = False # 是否启用PFC聊天 - - # 模型配置 - llm_reasoning: dict[str, str] = field(default_factory=lambda: {}) - # llm_reasoning_minor: dict[str, str] = field(default_factory=lambda: {}) - llm_normal: Dict[str, str] = field(default_factory=lambda: {}) - llm_topic_judge: Dict[str, str] = field(default_factory=lambda: {}) - llm_summary: Dict[str, str] = field(default_factory=lambda: {}) - embedding: Dict[str, str] = field(default_factory=lambda: {}) - vlm: Dict[str, str] = field(default_factory=lambda: {}) - moderation: Dict[str, str] = field(default_factory=lambda: {}) - - llm_observation: Dict[str, str] = field(default_factory=lambda: {}) - llm_sub_heartflow: Dict[str, str] = field(default_factory=lambda: {}) - llm_heartflow: Dict[str, str] = field(default_factory=lambda: {}) - llm_tool_use: Dict[str, str] = field(default_factory=lambda: {}) - llm_plan: Dict[str, str] = field(default_factory=lambda: {}) - - api_urls: Dict[str, str] = field(default_factory=lambda: {}) - - @staticmethod - def get_config_dir() -> str: - """获取配置文件目录""" - current_dir = os.path.dirname(os.path.abspath(__file__)) - root_dir = os.path.abspath(os.path.join(current_dir, "..", "..")) - config_dir = os.path.join(root_dir, "config") - if not os.path.exists(config_dir): - os.makedirs(config_dir) - return config_dir - - @classmethod - def convert_to_specifierset(cls, value: str) -> SpecifierSet: - """将 字符串 版本表达式转换成 SpecifierSet - Args: - value[str]: 版本表达式(字符串) - Returns: - SpecifierSet - """ - - try: - converted = SpecifierSet(value) - except InvalidSpecifier: - logger.error(f"{value} 分类使用了错误的版本约束表达式\n", "请阅读 https://semver.org/lang/zh-CN/ 修改代码") - exit(1) - - return converted - - @classmethod - def get_config_version(cls, toml: dict) -> Version: - """提取配置文件的 SpecifierSet 版本数据 - Args: - toml[dict]: 输入的配置文件字典 - Returns: - Version - """ - - if "inner" in toml: - try: - config_version: str = toml["inner"]["version"] - except KeyError as e: - logger.error("配置文件中 inner 段 不存在, 这是错误的配置文件") - raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件") from e - else: - toml["inner"] = {"version": "0.0.0"} - config_version = toml["inner"]["version"] - - try: - ver = version.parse(config_version) - except InvalidVersion as e: - logger.error( - "配置文件中 inner段 的 version 键是错误的版本描述\n" - "请阅读 https://semver.org/lang/zh-CN/ 修改配置,并参考本项目指定的模板进行修改\n" - "本项目在不同的版本下有不同的模板,请注意识别" - ) - raise InvalidVersion("配置文件中 inner段 的 version 键是错误的版本描述\n") from e - - return ver - - @classmethod - def load_config(cls, config_path: str = None) -> "BotConfig": - """从TOML配置文件加载配置""" - config = cls() - - def personality(parent: dict): - personality_config = parent["personality"] - if config.INNER_VERSION in SpecifierSet(">=1.2.4"): - config.personality_core = personality_config.get("personality_core", config.personality_core) - config.personality_sides = personality_config.get("personality_sides", config.personality_sides) - if config.INNER_VERSION in SpecifierSet(">=1.7.0"): - config.expression_style = personality_config.get("expression_style", config.expression_style) - - def identity(parent: dict): - identity_config = parent["identity"] - if config.INNER_VERSION in SpecifierSet(">=1.2.4"): - config.identity_detail = identity_config.get("identity_detail", config.identity_detail) - config.height = identity_config.get("height", config.height) - config.weight = identity_config.get("weight", config.weight) - config.age = identity_config.get("age", config.age) - config.gender = identity_config.get("gender", config.gender) - config.appearance = identity_config.get("appearance", config.appearance) - - def emoji(parent: dict): - emoji_config = parent["emoji"] - config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL) - config.EMOJI_CHECK_PROMPT = emoji_config.get("check_prompt", config.EMOJI_CHECK_PROMPT) - config.EMOJI_CHECK = emoji_config.get("enable_check", config.EMOJI_CHECK) - if config.INNER_VERSION in SpecifierSet(">=1.1.1"): - config.max_emoji_num = emoji_config.get("max_emoji_num", config.max_emoji_num) - config.max_reach_deletion = emoji_config.get("max_reach_deletion", config.max_reach_deletion) - if config.INNER_VERSION in SpecifierSet(">=1.4.2"): - config.save_pic = emoji_config.get("save_pic", config.save_pic) - config.save_emoji = emoji_config.get("save_emoji", config.save_emoji) - config.steal_emoji = emoji_config.get("steal_emoji", config.steal_emoji) - - def bot(parent: dict): - # 机器人基础配置 - bot_config = parent["bot"] - bot_qq = bot_config.get("qq") - config.BOT_QQ = str(bot_qq) - config.BOT_NICKNAME = bot_config.get("nickname", config.BOT_NICKNAME) - config.BOT_ALIAS_NAMES = bot_config.get("alias_names", config.BOT_ALIAS_NAMES) - - def chat(parent: dict): - chat_config = parent["chat"] - config.allow_focus_mode = chat_config.get("allow_focus_mode", config.allow_focus_mode) - config.base_normal_chat_num = chat_config.get("base_normal_chat_num", config.base_normal_chat_num) - config.base_focused_chat_num = chat_config.get("base_focused_chat_num", config.base_focused_chat_num) - config.observation_context_size = chat_config.get( - "observation_context_size", config.observation_context_size - ) - config.message_buffer = chat_config.get("message_buffer", config.message_buffer) - config.ban_words = chat_config.get("ban_words", config.ban_words) - for r in chat_config.get("ban_msgs_regex", config.ban_msgs_regex): - config.ban_msgs_regex.add(re.compile(r)) - - def normal_chat(parent: dict): - normal_chat_config = parent["normal_chat"] - config.model_reasoning_probability = normal_chat_config.get( - "model_reasoning_probability", config.model_reasoning_probability - ) - config.model_normal_probability = normal_chat_config.get( - "model_normal_probability", config.model_normal_probability - ) - config.emoji_chance = normal_chat_config.get("emoji_chance", config.emoji_chance) - config.thinking_timeout = normal_chat_config.get("thinking_timeout", config.thinking_timeout) - - config.willing_mode = normal_chat_config.get("willing_mode", config.willing_mode) - config.response_willing_amplifier = normal_chat_config.get( - "response_willing_amplifier", config.response_willing_amplifier - ) - config.response_interested_rate_amplifier = normal_chat_config.get( - "response_interested_rate_amplifier", config.response_interested_rate_amplifier - ) - config.down_frequency_rate = normal_chat_config.get("down_frequency_rate", config.down_frequency_rate) - config.emoji_response_penalty = normal_chat_config.get( - "emoji_response_penalty", config.emoji_response_penalty - ) - - config.mentioned_bot_inevitable_reply = normal_chat_config.get( - "mentioned_bot_inevitable_reply", config.mentioned_bot_inevitable_reply - ) - config.at_bot_inevitable_reply = normal_chat_config.get( - "at_bot_inevitable_reply", config.at_bot_inevitable_reply - ) - - def focus_chat(parent: dict): - focus_chat_config = parent["focus_chat"] - config.compressed_length = focus_chat_config.get("compressed_length", config.compressed_length) - config.compress_length_limit = focus_chat_config.get("compress_length_limit", config.compress_length_limit) - config.reply_trigger_threshold = focus_chat_config.get( - "reply_trigger_threshold", config.reply_trigger_threshold - ) - config.default_decay_rate_per_second = focus_chat_config.get( - "default_decay_rate_per_second", config.default_decay_rate_per_second - ) - config.consecutive_no_reply_threshold = focus_chat_config.get( - "consecutive_no_reply_threshold", config.consecutive_no_reply_threshold - ) - - def model(parent: dict): - # 加载模型配置 - model_config: dict = parent["model"] - - config_list = [ - "llm_reasoning", - # "llm_reasoning_minor", - "llm_normal", - "llm_topic_judge", - "llm_summary", - "vlm", - "embedding", - "llm_tool_use", - "llm_observation", - "llm_sub_heartflow", - "llm_plan", - "llm_heartflow", - "llm_PFC_action_planner", - "llm_PFC_chat", - "llm_PFC_reply_checker", - ] - - for item in config_list: - if item in model_config: - cfg_item: dict = model_config[item] - - # base_url 的例子: SILICONFLOW_BASE_URL - # key 的例子: SILICONFLOW_KEY - cfg_target = { - "name": "", - "base_url": "", - "key": "", - "stream": False, - "pri_in": 0, - "pri_out": 0, - "temp": 0.7, - } - - if config.INNER_VERSION in SpecifierSet("<=0.0.0"): - cfg_target = cfg_item - - elif config.INNER_VERSION in SpecifierSet(">=0.0.1"): - stable_item = ["name", "pri_in", "pri_out"] - - stream_item = ["stream"] - if config.INNER_VERSION in SpecifierSet(">=1.0.1"): - stable_item.append("stream") - - pricing_item = ["pri_in", "pri_out"] - - # 从配置中原始拷贝稳定字段 - for i in stable_item: - # 如果 字段 属于计费项 且获取不到,那默认值是 0 - if i in pricing_item and i not in cfg_item: - cfg_target[i] = 0 - - if i in stream_item and i not in cfg_item: - cfg_target[i] = False - - else: - # 没有特殊情况则原样复制 - try: - cfg_target[i] = cfg_item[i] - except KeyError as e: - logger.error(f"{item} 中的必要字段不存在,请检查") - raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查") from e - - # 如果配置中有temp参数,就使用配置中的值 - if "temp" in cfg_item: - cfg_target["temp"] = cfg_item["temp"] - else: - # 如果没有temp参数,就删除默认值 - cfg_target.pop("temp", None) - - provider = cfg_item.get("provider") - if provider is None: - logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查") - raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查") - - cfg_target["base_url"] = f"{provider}_BASE_URL" - cfg_target["key"] = f"{provider}_KEY" - - # 如果 列表中的项目在 model_config 中,利用反射来设置对应项目 - setattr(config, item, cfg_target) - else: - logger.error(f"模型 {item} 在config中不存在,请检查,或尝试更新配置文件") - raise KeyError(f"模型 {item} 在config中不存在,请检查,或尝试更新配置文件") - - def memory(parent: dict): - memory_config = parent["memory"] - config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval) - config.forget_memory_interval = memory_config.get("forget_memory_interval", config.forget_memory_interval) - config.memory_ban_words = set(memory_config.get("memory_ban_words", [])) - config.memory_forget_time = memory_config.get("memory_forget_time", config.memory_forget_time) - config.memory_forget_percentage = memory_config.get( - "memory_forget_percentage", config.memory_forget_percentage - ) - config.memory_compress_rate = memory_config.get("memory_compress_rate", config.memory_compress_rate) - if config.INNER_VERSION in SpecifierSet(">=0.0.11"): - config.memory_build_distribution = memory_config.get( - "memory_build_distribution", config.memory_build_distribution - ) - config.build_memory_sample_num = memory_config.get( - "build_memory_sample_num", config.build_memory_sample_num - ) - config.build_memory_sample_length = memory_config.get( - "build_memory_sample_length", config.build_memory_sample_length - ) - if config.INNER_VERSION in SpecifierSet(">=1.5.1"): - config.consolidate_memory_interval = memory_config.get( - "consolidate_memory_interval", config.consolidate_memory_interval - ) - config.consolidation_similarity_threshold = memory_config.get( - "consolidation_similarity_threshold", config.consolidation_similarity_threshold - ) - config.consolidate_memory_percentage = memory_config.get( - "consolidate_memory_percentage", config.consolidate_memory_percentage - ) - - def remote(parent: dict): - remote_config = parent["remote"] - config.remote_enable = remote_config.get("enable", config.remote_enable) - - def mood(parent: dict): - mood_config = parent["mood"] - config.mood_update_interval = mood_config.get("mood_update_interval", config.mood_update_interval) - config.mood_decay_rate = mood_config.get("mood_decay_rate", config.mood_decay_rate) - config.mood_intensity_factor = mood_config.get("mood_intensity_factor", config.mood_intensity_factor) - - def keywords_reaction(parent: dict): - keywords_reaction_config = parent["keywords_reaction"] - if keywords_reaction_config.get("enable", False): - config.keywords_reaction_rules = keywords_reaction_config.get("rules", config.keywords_reaction_rules) - for rule in config.keywords_reaction_rules: - if rule.get("enable", False) and "regex" in rule: - rule["regex"] = [re.compile(r) for r in rule.get("regex", [])] - - def chinese_typo(parent: dict): - chinese_typo_config = parent["chinese_typo"] - config.chinese_typo_enable = chinese_typo_config.get("enable", config.chinese_typo_enable) - config.chinese_typo_error_rate = chinese_typo_config.get("error_rate", config.chinese_typo_error_rate) - config.chinese_typo_min_freq = chinese_typo_config.get("min_freq", config.chinese_typo_min_freq) - config.chinese_typo_tone_error_rate = chinese_typo_config.get( - "tone_error_rate", config.chinese_typo_tone_error_rate - ) - config.chinese_typo_word_replace_rate = chinese_typo_config.get( - "word_replace_rate", config.chinese_typo_word_replace_rate - ) - - def response_splitter(parent: dict): - response_splitter_config = parent["response_splitter"] - config.enable_response_splitter = response_splitter_config.get( - "enable_response_splitter", config.enable_response_splitter - ) - config.response_max_length = response_splitter_config.get("response_max_length", config.response_max_length) - config.response_max_sentence_num = response_splitter_config.get( - "response_max_sentence_num", config.response_max_sentence_num - ) - if config.INNER_VERSION in SpecifierSet(">=1.4.2"): - config.enable_kaomoji_protection = response_splitter_config.get( - "enable_kaomoji_protection", config.enable_kaomoji_protection - ) - if config.INNER_VERSION in SpecifierSet(">=1.6.0"): - config.model_max_output_length = response_splitter_config.get( - "model_max_output_length", config.model_max_output_length - ) - - def groups(parent: dict): - groups_config = parent["groups"] - # config.talk_allowed_groups = set(groups_config.get("talk_allowed", [])) - config.talk_allowed_groups = set(str(group) for group in groups_config.get("talk_allowed", [])) - # config.talk_frequency_down_groups = set(groups_config.get("talk_frequency_down", [])) - config.talk_frequency_down_groups = set( - str(group) for group in groups_config.get("talk_frequency_down", []) - ) - # config.ban_user_id = set(groups_config.get("ban_user_id", [])) - config.ban_user_id = set(str(user) for user in groups_config.get("ban_user_id", [])) - - def experimental(parent: dict): - experimental_config = parent["experimental"] - config.enable_friend_chat = experimental_config.get("enable_friend_chat", config.enable_friend_chat) - # config.enable_think_flow = experimental_config.get("enable_think_flow", config.enable_think_flow) - config.talk_allowed_private = set(str(user) for user in experimental_config.get("talk_allowed_private", [])) - if config.INNER_VERSION in SpecifierSet(">=1.1.0"): - config.enable_pfc_chatting = experimental_config.get("pfc_chatting", config.enable_pfc_chatting) - - # 版本表达式:>=1.0.0,<2.0.0 - # 允许字段:func: method, support: str, notice: str, necessary: bool - # 如果使用 notice 字段,在该组配置加载时,会展示该字段对用户的警示 - # 例如:"notice": "personality 将在 1.3.2 后被移除",那么在有效版本中的用户就会虽然可以 - # 正常执行程序,但是会看到这条自定义提示 - - # 版本格式:主版本号.次版本号.修订号,版本号递增规则如下: - # 主版本号:当你做了不兼容的 API 修改, - # 次版本号:当你做了向下兼容的功能性新增, - # 修订号:当你做了向下兼容的问题修正。 - # 先行版本号及版本编译信息可以加到"主版本号.次版本号.修订号"的后面,作为延伸。 - - # 如果你做了break的修改,就应该改动主版本号 - # 如果做了一个兼容修改,就不应该要求这个选项是必须的! - include_configs = { - "bot": {"func": bot, "support": ">=0.0.0"}, - "groups": {"func": groups, "support": ">=0.0.0"}, - "personality": {"func": personality, "support": ">=0.0.0"}, - "identity": {"func": identity, "support": ">=1.2.4"}, - "emoji": {"func": emoji, "support": ">=0.0.0"}, - "model": {"func": model, "support": ">=0.0.0"}, - "memory": {"func": memory, "support": ">=0.0.0", "necessary": False}, - "mood": {"func": mood, "support": ">=0.0.0"}, - "remote": {"func": remote, "support": ">=0.0.10", "necessary": False}, - "keywords_reaction": {"func": keywords_reaction, "support": ">=0.0.2", "necessary": False}, - "chinese_typo": {"func": chinese_typo, "support": ">=0.0.3", "necessary": False}, - "response_splitter": {"func": response_splitter, "support": ">=0.0.11", "necessary": False}, - "experimental": {"func": experimental, "support": ">=0.0.11", "necessary": False}, - "chat": {"func": chat, "support": ">=1.6.0", "necessary": False}, - "normal_chat": {"func": normal_chat, "support": ">=1.6.0", "necessary": False}, - "focus_chat": {"func": focus_chat, "support": ">=1.6.0", "necessary": False}, - } - - # 原地修改,将 字符串版本表达式 转换成 版本对象 - for key in include_configs: - item_support = include_configs[key]["support"] - include_configs[key]["support"] = cls.convert_to_specifierset(item_support) - - if os.path.exists(config_path): - with open(config_path, "rb") as f: - try: - toml_dict = tomli.load(f) - except tomli.TOMLDecodeError as e: - logger.critical(f"配置文件bot_config.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}") - exit(1) - - # 获取配置文件版本 - config.INNER_VERSION = cls.get_config_version(toml_dict) - - # 如果在配置中找到了需要的项,调用对应项的闭包函数处理 - for key in include_configs: - if key in toml_dict: - group_specifierset: SpecifierSet = include_configs[key]["support"] - - # 检查配置文件版本是否在支持范围内 - if config.INNER_VERSION in group_specifierset: - # 如果版本在支持范围内,检查是否存在通知 - if "notice" in include_configs[key]: - logger.warning(include_configs[key]["notice"]) - - include_configs[key]["func"](toml_dict) - - else: - # 如果版本不在支持范围内,崩溃并提示用户 - logger.error( - f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n" - f"当前程序仅支持以下版本范围: {group_specifierset}" - ) - raise InvalidVersion(f"当前程序仅支持以下版本范围: {group_specifierset}") - - # 如果 necessary 项目存在,而且显式声明是 False,进入特殊处理 - elif "necessary" in include_configs[key] and include_configs[key].get("necessary") is False: - # 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理 - if key == "keywords_reaction": - pass - - else: - # 如果用户根本没有需要的配置项,提示缺少配置 - logger.error(f"配置文件中缺少必需的字段: '{key}'") - raise KeyError(f"配置文件中缺少必需的字段: '{key}'") - - # identity_detail字段非空检查 - if not config.identity_detail: - logger.error("配置文件错误:[identity] 部分的 identity_detail 不能为空字符串") - raise ValueError("配置文件错误:[identity] 部分的 identity_detail 不能为空字符串") - - logger.success(f"成功加载配置文件: {config_path}") - - return config +class Config(ConfigBase): + """总配置类""" + + MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息 + + bot: BotConfig + chat_target: ChatTargetConfig + personality: PersonalityConfig + identity: IdentityConfig + platforms: PlatformsConfig + chat: ChatConfig + normal_chat: NormalChatConfig + focus_chat: FocusChatConfig + emoji: EmojiConfig + memory: MemoryConfig + mood: MoodConfig + keyword_reaction: KeywordReactionConfig + chinese_typo: ChineseTypoConfig + response_splitter: ResponseSplitterConfig + telemetry: TelemetryConfig + experimental: ExperimentalConfig + model: ModelConfig + + +def load_config(config_path: str) -> Config: + """ + 加载配置文件 + :param config_path: 配置文件路径 + :return: Config对象 + """ + # 读取配置文件 + with open(config_path, "r", encoding="utf-8") as f: + config_data = tomlkit.load(f) + + # 创建Config对象 + try: + return Config.from_dict(config_data) + except Exception as e: + logger.critical("配置文件解析失败") + raise e # 获取配置文件路径 -logger.info(f"MaiCore当前版本: {mai_version}") +logger.info(f"MaiCore当前版本: {MMC_VERSION}") update_config() -bot_config_floder_path = BotConfig.get_config_dir() -logger.info(f"正在品鉴配置文件目录: {bot_config_floder_path}") - -bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml") - -if os.path.exists(bot_config_path): - # 如果开发环境配置文件不存在,则使用默认配置文件 - logger.info(f"异常的新鲜,异常的美味: {bot_config_path}") -else: - # 配置文件不存在 - logger.error("配置文件不存在,请检查路径: {bot_config_path}") - raise FileNotFoundError(f"配置文件不存在: {bot_config_path}") - -global_config = BotConfig.load_config(config_path=bot_config_path) +logger.info("正在品鉴配置文件...") +global_config = load_config(config_path=f"{CONFIG_DIR}/bot_config.toml") +logger.info("非常的新鲜,非常的美味!") diff --git a/src/config/config_base.py b/src/config/config_base.py new file mode 100644 index 000000000..92f6cf9d4 --- /dev/null +++ b/src/config/config_base.py @@ -0,0 +1,116 @@ +from dataclasses import dataclass, fields, MISSING +from typing import TypeVar, Type, Any, get_origin, get_args + +T = TypeVar("T", bound="ConfigBase") + +TOML_DICT_TYPE = { + int, + float, + str, + bool, + list, + dict, +} + + +@dataclass +class ConfigBase: + """配置类的基类""" + + @classmethod + def from_dict(cls: Type[T], data: dict[str, Any]) -> T: + """从字典加载配置字段""" + if not isinstance(data, dict): + raise TypeError(f"Expected a dictionary, got {type(data).__name__}") + + init_args: dict[str, Any] = {} + + for f in fields(cls): + field_name = f.name + + if field_name.startswith("_"): + # 跳过以 _ 开头的字段 + continue + + if field_name not in data: + if f.default is not MISSING or f.default_factory is not MISSING: + # 跳过未提供且有默认值/默认构造方法的字段 + continue + else: + raise ValueError(f"Missing required field: '{field_name}'") + + value = data[field_name] + field_type = f.type + + try: + init_args[field_name] = cls._convert_field(value, field_type) + except TypeError as e: + raise TypeError(f"Field '{field_name}' has a type error: {e}") from e + except Exception as e: + raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e + + return cls(**init_args) + + @classmethod + def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any: + """ + 转换字段值为指定类型 + + 1. 对于嵌套的 dataclass,递归调用相应的 from_dict 方法 + 2. 对于泛型集合类型(list, set, tuple),递归转换每个元素 + 3. 对于基础类型(int, str, float, bool),直接转换 + 4. 对于其他类型,尝试直接转换,如果失败则抛出异常 + """ + + # 如果是嵌套的 dataclass,递归调用 from_dict 方法 + if isinstance(field_type, type) and issubclass(field_type, ConfigBase): + if not isinstance(value, dict): + raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}") + return field_type.from_dict(value) + + # 处理泛型集合类型(list, set, tuple) + field_origin_type = get_origin(field_type) + field_type_args = get_args(field_type) + + if field_origin_type in {list, set, tuple}: + # 检查提供的value是否为list + if not isinstance(value, list): + raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}") + + if field_origin_type is list: + return [cls._convert_field(item, field_type_args[0]) for item in value] + elif field_origin_type is set: + return {cls._convert_field(item, field_type_args[0]) for item in value} + elif field_origin_type is tuple: + # 检查提供的value长度是否与类型参数一致 + if len(value) != len(field_type_args): + raise TypeError( + f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}" + ) + return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args)) + + if field_origin_type is dict: + # 检查提供的value是否为dict + if not isinstance(value, dict): + raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}") + + # 检查字典的键值类型 + if len(field_type_args) != 2: + raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}") + key_type, value_type = field_type_args + + return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()} + + # 处理基础类型,例如 int, str 等 + if field_type is Any or isinstance(value, field_type): + return value + + # 其他类型,尝试直接转换 + try: + return field_type(value) + except (ValueError, TypeError) as e: + raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e + + def __str__(self): + """返回配置类的字符串表示""" + return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})" diff --git a/src/config/official_configs.py b/src/config/official_configs.py new file mode 100644 index 000000000..d92d925d6 --- /dev/null +++ b/src/config/official_configs.py @@ -0,0 +1,399 @@ +from dataclasses import dataclass, field +from typing import Any + +from src.config.config_base import ConfigBase + +""" +须知: +1. 本文件中记录了所有的配置项 +2. 所有新增的class都需要继承自ConfigBase +3. 所有新增的class都应在config.py中的Config类中添加字段 +4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default +""" + + +@dataclass +class BotConfig(ConfigBase): + """QQ机器人配置类""" + + qq_account: str + """QQ账号""" + + nickname: str + """昵称""" + + alias_names: list[str] = field(default_factory=lambda: []) + """别名列表""" + + +@dataclass +class ChatTargetConfig(ConfigBase): + """ + 聊天目标配置类 + 此类中有聊天的群组和用户配置 + """ + + talk_allowed_groups: set[str] = field(default_factory=lambda: set()) + """允许聊天的群组列表""" + + talk_frequency_down_groups: set[str] = field(default_factory=lambda: set()) + """降低聊天频率的群组列表""" + + ban_user_id: set[str] = field(default_factory=lambda: set()) + """禁止聊天的用户列表""" + + +@dataclass +class PersonalityConfig(ConfigBase): + """人格配置类""" + + personality_core: str + """核心人格""" + + expression_style: str + """表达风格""" + + personality_sides: list[str] = field(default_factory=lambda: []) + """人格侧写""" + + +@dataclass +class IdentityConfig(ConfigBase): + """个体特征配置类""" + + height: int = 170 + """身高(单位:厘米)""" + + weight: float = 50 + """体重(单位:千克)""" + + age: int = 18 + """年龄(单位:岁)""" + + gender: str = "女" + """性别(男/女)""" + + appearance: str = "可爱" + """外貌描述""" + + identity_detail: list[str] = field(default_factory=lambda: []) + """身份特征""" + + +@dataclass +class PlatformsConfig(ConfigBase): + """平台配置类""" + + qq: str + """QQ适配器连接URL配置""" + + +@dataclass +class ChatConfig(ConfigBase): + """聊天配置类""" + + allow_focus_mode: bool = True + """是否允许专注聊天状态""" + + base_normal_chat_num: int = 3 + """最多允许多少个群进行普通聊天""" + + base_focused_chat_num: int = 2 + """最多允许多少个群进行专注聊天""" + + observation_context_size: int = 12 + """可观察到的最长上下文大小,超过这个值的上下文会被压缩""" + + message_buffer: bool = True + """消息缓冲器""" + + ban_words: set[str] = field(default_factory=lambda: set()) + """过滤词列表""" + + ban_msgs_regex: set[str] = field(default_factory=lambda: set()) + """过滤正则表达式列表""" + + +@dataclass +class NormalChatConfig(ConfigBase): + """普通聊天配置类""" + + reasoning_model_probability: float = 0.3 + """ + 发言时选择推理模型的概率(0-1之间) + 选择普通模型的概率为 1 - reasoning_normal_model_probability + """ + + emoji_chance: float = 0.2 + """发送表情包的基础概率""" + + thinking_timeout: int = 120 + """最长思考时间""" + + willing_mode: str = "classical" + """意愿模式""" + + response_willing_amplifier: float = 1.0 + """回复意愿放大系数""" + + response_interested_rate_amplifier: float = 1.0 + """回复兴趣度放大系数""" + + down_frequency_rate: float = 3.0 + """降低回复频率的群组回复意愿降低系数""" + + emoji_response_penalty: float = 0.0 + """表情包回复惩罚系数""" + + mentioned_bot_inevitable_reply: bool = False + """提及 bot 必然回复""" + + at_bot_inevitable_reply: bool = False + """@bot 必然回复""" + + +@dataclass +class FocusChatConfig(ConfigBase): + """专注聊天配置类""" + + reply_trigger_threshold: float = 3.0 + """心流聊天触发阈值,越低越容易触发""" + + default_decay_rate_per_second: float = 0.98 + """默认衰减率,越大衰减越快""" + + consecutive_no_reply_threshold: int = 3 + """连续不回复的次数阈值""" + + compressed_length: int = 5 + """心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5""" + + compress_length_limit: int = 5 + """最多压缩份数,超过该数值的压缩上下文会被删除""" + + +@dataclass +class EmojiConfig(ConfigBase): + """表情包配置类""" + + max_reg_num: int = 200 + """表情包最大注册数量""" + + do_replace: bool = True + """达到最大注册数量时替换旧表情包""" + + check_interval: int = 120 + """表情包检查间隔(分钟)""" + + save_pic: bool = False + """是否保存图片""" + + cache_emoji: bool = True + """是否缓存表情包""" + + steal_emoji: bool = True + """是否偷取表情包,让麦麦可以发送她保存的这些表情包""" + + content_filtration: bool = False + """是否开启表情包过滤""" + + filtration_prompt: str = "符合公序良俗" + """表情包过滤要求""" + + +@dataclass +class MemoryConfig(ConfigBase): + """记忆配置类""" + + memory_build_interval: int = 600 + """记忆构建间隔(秒)""" + + memory_build_distribution: tuple[ + float, + float, + float, + float, + float, + float, + ] = field(default_factory=lambda: (6.0, 3.0, 0.6, 32.0, 12.0, 0.4)) + """记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重""" + + memory_build_sample_num: int = 8 + """记忆构建采样数量""" + + memory_build_sample_length: int = 40 + """记忆构建采样长度""" + + memory_compress_rate: float = 0.1 + """记忆压缩率""" + + forget_memory_interval: int = 1000 + """记忆遗忘间隔(秒)""" + + memory_forget_time: int = 24 + """记忆遗忘时间(小时)""" + + memory_forget_percentage: float = 0.01 + """记忆遗忘比例""" + + consolidate_memory_interval: int = 1000 + """记忆整合间隔(秒)""" + + consolidation_similarity_threshold: float = 0.7 + """整合相似度阈值""" + + consolidate_memory_percentage: float = 0.01 + """整合检查节点比例""" + + memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]) + """不允许记忆的词列表""" + + +@dataclass +class MoodConfig(ConfigBase): + """情绪配置类""" + + mood_update_interval: int = 1 + """情绪更新间隔(秒)""" + + mood_decay_rate: float = 0.95 + """情绪衰减率""" + + mood_intensity_factor: float = 0.7 + """情绪强度因子""" + + +@dataclass +class KeywordRuleConfig(ConfigBase): + """关键词规则配置类""" + + enable: bool = True + """是否启用关键词规则""" + + keywords: list[str] = field(default_factory=lambda: []) + """关键词列表""" + + regex: list[str] = field(default_factory=lambda: []) + """正则表达式列表""" + + reaction: str = "" + """关键词触发的反应""" + + +@dataclass +class KeywordReactionConfig(ConfigBase): + """关键词配置类""" + + enable: bool = True + """是否启用关键词反应""" + + rules: list[KeywordRuleConfig] = field(default_factory=lambda: []) + """关键词反应规则列表""" + + +@dataclass +class ChineseTypoConfig(ConfigBase): + """中文错别字配置类""" + + enable: bool = True + """是否启用中文错别字生成器""" + + error_rate: float = 0.01 + """单字替换概率""" + + min_freq: int = 9 + """最小字频阈值""" + + tone_error_rate: float = 0.1 + """声调错误概率""" + + word_replace_rate: float = 0.006 + """整词替换概率""" + + +@dataclass +class ResponseSplitterConfig(ConfigBase): + """回复分割器配置类""" + + enable: bool = True + """是否启用回复分割器""" + + max_length: int = 256 + """回复允许的最大长度""" + + max_sentence_num: int = 3 + """回复允许的最大句子数""" + + enable_kaomoji_protection: bool = False + """是否启用颜文字保护""" + + +@dataclass +class TelemetryConfig(ConfigBase): + """遥测配置类""" + + enable: bool = True + """是否启用遥测""" + + +@dataclass +class ExperimentalConfig(ConfigBase): + """实验功能配置类""" + + enable_friend_chat: bool = False + """是否启用好友聊天""" + + talk_allowed_private: set[str] = field(default_factory=lambda: set()) + """允许聊天的私聊列表""" + + pfc_chatting: bool = False + """是否启用PFC""" + + +@dataclass +class ModelConfig(ConfigBase): + """模型配置类""" + + model_max_output_length: int = 800 # 最大回复长度 + + reasoning: dict[str, Any] = field(default_factory=lambda: {}) + """推理模型配置""" + + normal: dict[str, Any] = field(default_factory=lambda: {}) + """普通模型配置""" + + topic_judge: dict[str, Any] = field(default_factory=lambda: {}) + """主题判断模型配置""" + + summary: dict[str, Any] = field(default_factory=lambda: {}) + """摘要模型配置""" + + vlm: dict[str, Any] = field(default_factory=lambda: {}) + """视觉语言模型配置""" + + heartflow: dict[str, Any] = field(default_factory=lambda: {}) + """心流模型配置""" + + observation: dict[str, Any] = field(default_factory=lambda: {}) + """观察模型配置""" + + sub_heartflow: dict[str, Any] = field(default_factory=lambda: {}) + """子心流模型配置""" + + plan: dict[str, Any] = field(default_factory=lambda: {}) + """计划模型配置""" + + embedding: dict[str, Any] = field(default_factory=lambda: {}) + """嵌入模型配置""" + + pfc_action_planner: dict[str, Any] = field(default_factory=lambda: {}) + """PFC动作规划模型配置""" + + pfc_chat: dict[str, Any] = field(default_factory=lambda: {}) + """PFC聊天模型配置""" + + pfc_reply_checker: dict[str, Any] = field(default_factory=lambda: {}) + """PFC回复检查模型配置""" + + tool_use: dict[str, Any] = field(default_factory=lambda: {}) + """工具使用模型配置""" diff --git a/src/experimental/PFC/action_planner.py b/src/experimental/PFC/action_planner.py index b4182c9aa..c0bff5887 100644 --- a/src/experimental/PFC/action_planner.py +++ b/src/experimental/PFC/action_planner.py @@ -114,7 +114,7 @@ class ActionPlanner: request_type="action_planning", ) self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3) - self.name = global_config.BOT_NICKNAME + self.name = global_config.bot.nickname self.private_name = private_name self.chat_observer = ChatObserver.get_instance(stream_id, private_name) # self.action_planner_info = ActionPlannerInfo() # 移除未使用的变量 @@ -140,7 +140,7 @@ class ActionPlanner: # (这部分逻辑不变) time_since_last_bot_message_info = "" try: - bot_id = str(global_config.BOT_QQ) + bot_id = str(global_config.bot.qq_account) if hasattr(observation_info, "chat_history") and observation_info.chat_history: for i in range(len(observation_info.chat_history) - 1, -1, -1): msg = observation_info.chat_history[i] diff --git a/src/experimental/PFC/chat_observer.py b/src/experimental/PFC/chat_observer.py index 704eeb330..55914d800 100644 --- a/src/experimental/PFC/chat_observer.py +++ b/src/experimental/PFC/chat_observer.py @@ -10,7 +10,7 @@ from src.experimental.PFC.chat_states import ( create_new_message_notification, create_cold_chat_notification, ) -from src.experimental.PFC.message_storage import MongoDBMessageStorage +from src.experimental.PFC.message_storage import PeeweeMessageStorage from rich.traceback import install install(extra_lines=3) @@ -53,7 +53,7 @@ class ChatObserver: self.stream_id = stream_id self.private_name = private_name - self.message_storage = MongoDBMessageStorage() + self.message_storage = PeeweeMessageStorage() # self.last_user_speak_time: Optional[float] = None # 对方上次发言时间 # self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间 @@ -323,7 +323,7 @@ class ChatObserver: for msg in messages: try: user_info = UserInfo.from_dict(msg.get("user_info", {})) - if user_info.user_id == global_config.BOT_QQ: + if user_info.user_id == global_config.bot.qq_account: self.update_bot_speak_time(msg["time"]) else: self.update_user_speak_time(msg["time"]) diff --git a/src/experimental/PFC/message_sender.py b/src/experimental/PFC/message_sender.py index 181bf171b..4b193a41d 100644 --- a/src/experimental/PFC/message_sender.py +++ b/src/experimental/PFC/message_sender.py @@ -42,8 +42,8 @@ class DirectMessageSender: # 获取麦麦的信息 bot_user_info = UserInfo( - user_id=global_config.BOT_QQ, - user_nickname=global_config.BOT_NICKNAME, + user_id=global_config.bot.qq_account, + user_nickname=global_config.bot.nickname, platform=chat_stream.platform, ) diff --git a/src/experimental/PFC/message_storage.py b/src/experimental/PFC/message_storage.py index cd6a01e34..e2e1dd052 100644 --- a/src/experimental/PFC/message_storage.py +++ b/src/experimental/PFC/message_storage.py @@ -1,6 +1,9 @@ from abc import ABC, abstractmethod from typing import List, Dict, Any -from src.common.database import db + +# from src.common.database.database import db # Peewee db 导入 +from src.common.database.database_model import Messages # Peewee Messages 模型导入 +from playhouse.shortcuts import model_to_dict # 用于将模型实例转换为字典 class MessageStorage(ABC): @@ -47,28 +50,35 @@ class MessageStorage(ABC): pass -class MongoDBMessageStorage(MessageStorage): - """MongoDB消息存储实现""" +class PeeweeMessageStorage(MessageStorage): + """Peewee消息存储实现""" async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]: - query = {"chat_id": chat_id, "time": {"$gt": message_time}} - # print(f"storage_check_message: {message_time}") + query = ( + Messages.select() + .where((Messages.chat_id == chat_id) & (Messages.time > message_time)) + .order_by(Messages.time.asc()) + ) - return list(db.messages.find(query).sort("time", 1)) + # print(f"storage_check_message: {message_time}") + messages_models = list(query) + return [model_to_dict(msg) for msg in messages_models] async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: - query = {"chat_id": chat_id, "time": {"$lt": time_point}} - - messages = list(db.messages.find(query).sort("time", -1).limit(limit)) + query = ( + Messages.select() + .where((Messages.chat_id == chat_id) & (Messages.time < time_point)) + .order_by(Messages.time.desc()) + .limit(limit) + ) + messages_models = list(query) # 将消息按时间正序排列 - messages.reverse() - return messages + messages_models.reverse() + return [model_to_dict(msg) for msg in messages_models] async def has_new_messages(self, chat_id: str, after_time: float) -> bool: - query = {"chat_id": chat_id, "time": {"$gt": after_time}} - - return db.messages.find_one(query) is not None + return Messages.select().where((Messages.chat_id == chat_id) & (Messages.time > after_time)).exists() # # 创建一个内存消息存储实现,用于测试 diff --git a/src/experimental/PFC/pfc.py b/src/experimental/PFC/pfc.py index 84fb9f8dc..686d4af49 100644 --- a/src/experimental/PFC/pfc.py +++ b/src/experimental/PFC/pfc.py @@ -42,13 +42,14 @@ class GoalAnalyzer: """对话目标分析器""" def __init__(self, stream_id: str, private_name: str): + # TODO: API-Adapter修改标记 self.llm = LLMRequest( - model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal" + model=global_config.model.normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal" ) self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3) - self.name = global_config.BOT_NICKNAME - self.nick_name = global_config.BOT_ALIAS_NAMES + self.name = global_config.bot.nickname + self.nick_name = global_config.bot.alias_names self.private_name = private_name self.chat_observer = ChatObserver.get_instance(stream_id, private_name) diff --git a/src/experimental/PFC/pfc_KnowledgeFetcher.py b/src/experimental/PFC/pfc_KnowledgeFetcher.py index 8ebc307e2..4c1d8c759 100644 --- a/src/experimental/PFC/pfc_KnowledgeFetcher.py +++ b/src/experimental/PFC/pfc_KnowledgeFetcher.py @@ -14,9 +14,10 @@ class KnowledgeFetcher: """知识调取器""" def __init__(self, private_name: str): + # TODO: API-Adapter修改标记 self.llm = LLMRequest( - model=global_config.llm_normal, - temperature=global_config.llm_normal["temp"], + model=global_config.model.normal, + temperature=global_config.model.normal["temp"], max_tokens=1000, request_type="knowledge_fetch", ) diff --git a/src/experimental/PFC/reply_checker.py b/src/experimental/PFC/reply_checker.py index a76e8a0da..5bca9d601 100644 --- a/src/experimental/PFC/reply_checker.py +++ b/src/experimental/PFC/reply_checker.py @@ -16,7 +16,7 @@ class ReplyChecker: self.llm = LLMRequest( model=global_config.llm_PFC_reply_checker, temperature=0.50, max_tokens=1000, request_type="reply_check" ) - self.name = global_config.BOT_NICKNAME + self.name = global_config.bot.nickname self.private_name = private_name self.chat_observer = ChatObserver.get_instance(stream_id, private_name) self.max_retries = 3 # 最大重试次数 @@ -43,7 +43,7 @@ class ReplyChecker: bot_messages = [] for msg in reversed(chat_history): user_info = UserInfo.from_dict(msg.get("user_info", {})) - if str(user_info.user_id) == str(global_config.BOT_QQ): # 确保比较的是字符串 + if str(user_info.user_id) == str(global_config.bot.qq_account): # 确保比较的是字符串 bot_messages.append(msg.get("processed_plain_text", "")) if len(bot_messages) >= 2: # 只和最近的两条比较 break diff --git a/src/experimental/PFC/reply_generator.py b/src/experimental/PFC/reply_generator.py index 6dcda69af..bac8a769f 100644 --- a/src/experimental/PFC/reply_generator.py +++ b/src/experimental/PFC/reply_generator.py @@ -93,7 +93,7 @@ class ReplyGenerator: request_type="reply_generation", ) self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3) - self.name = global_config.BOT_NICKNAME + self.name = global_config.bot.nickname self.private_name = private_name self.chat_observer = ChatObserver.get_instance(stream_id, private_name) self.reply_checker = ReplyChecker(stream_id, private_name) diff --git a/src/experimental/PFC/waiter.py b/src/experimental/PFC/waiter.py index af5cf7ad0..452446589 100644 --- a/src/experimental/PFC/waiter.py +++ b/src/experimental/PFC/waiter.py @@ -19,7 +19,7 @@ class Waiter: def __init__(self, stream_id: str, private_name: str): self.chat_observer = ChatObserver.get_instance(stream_id, private_name) - self.name = global_config.BOT_NICKNAME + self.name = global_config.bot.nickname self.private_name = private_name # self.wait_accumulated_time = 0 # 不再需要累加计时 diff --git a/src/experimental/only_message_process.py b/src/experimental/only_message_process.py index 3d1432703..62f73c700 100644 --- a/src/experimental/only_message_process.py +++ b/src/experimental/only_message_process.py @@ -16,7 +16,7 @@ class MessageProcessor: @staticmethod def _check_ban_words(text: str, chat, userinfo) -> bool: """检查消息中是否包含过滤词""" - for word in global_config.ban_words: + for word in global_config.chat.ban_words: if word in text: logger.info( f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}" @@ -28,7 +28,7 @@ class MessageProcessor: @staticmethod def _check_ban_regex(text: str, chat, userinfo) -> bool: """检查消息是否匹配过滤正则表达式""" - for pattern in global_config.ban_msgs_regex: + for pattern in global_config.chat.ban_msgs_regex: if pattern.search(text): logger.info( f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}" diff --git a/src/main.py b/src/main.py index 34b7eda3d..4f8af28ef 100644 --- a/src/main.py +++ b/src/main.py @@ -40,7 +40,7 @@ class MainSystem: async def initialize(self): """初始化系统组件""" - logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......") + logger.debug(f"正在唤醒{global_config.bot.nickname}......") # 其他初始化任务 await asyncio.gather(self._init_components()) @@ -84,7 +84,7 @@ class MainSystem: asyncio.create_task(chat_manager._auto_save_task()) # 使用HippocampusManager初始化海马体 - self.hippocampus_manager.initialize(global_config=global_config) + self.hippocampus_manager.initialize() # await asyncio.sleep(0.5) #防止logger输出飞了 # 将bot.py中的chat_bot.message_process消息处理函数注册到api.py的消息处理基类中 @@ -92,15 +92,15 @@ class MainSystem: # 初始化个体特征 self.individuality.initialize( - bot_nickname=global_config.BOT_NICKNAME, - personality_core=global_config.personality_core, - personality_sides=global_config.personality_sides, - identity_detail=global_config.identity_detail, - height=global_config.height, - weight=global_config.weight, - age=global_config.age, - gender=global_config.gender, - appearance=global_config.appearance, + bot_nickname=global_config.bot.nickname, + personality_core=global_config.personality.personality_core, + personality_sides=global_config.personality.personality_sides, + identity_detail=global_config.identity.identity_detail, + height=global_config.identity.height, + weight=global_config.identity.weight, + age=global_config.identity.age, + gender=global_config.identity.gender, + appearance=global_config.identity.appearance, ) logger.success("个体特征初始化成功") @@ -141,7 +141,7 @@ class MainSystem: async def build_memory_task(): """记忆构建任务""" while True: - await asyncio.sleep(global_config.build_memory_interval) + await asyncio.sleep(global_config.memory.memory_build_interval) logger.info("正在进行记忆构建") await HippocampusManager.get_instance().build_memory() @@ -149,16 +149,18 @@ class MainSystem: async def forget_memory_task(): """记忆遗忘任务""" while True: - await asyncio.sleep(global_config.forget_memory_interval) + await asyncio.sleep(global_config.memory.forget_memory_interval) print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...") - await HippocampusManager.get_instance().forget_memory(percentage=global_config.memory_forget_percentage) + await HippocampusManager.get_instance().forget_memory( + percentage=global_config.memory.memory_forget_percentage + ) print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成") @staticmethod async def consolidate_memory_task(): """记忆整合任务""" while True: - await asyncio.sleep(global_config.consolidate_memory_interval) + await asyncio.sleep(global_config.memory.consolidate_memory_interval) print("\033[1;32m[记忆整合]\033[0m 开始整合记忆...") await HippocampusManager.get_instance().consolidate_memory() print("\033[1;32m[记忆整合]\033[0m 记忆整合完成") diff --git a/src/manager/mood_manager.py b/src/manager/mood_manager.py index 42677d4e1..c83fbeb7c 100644 --- a/src/manager/mood_manager.py +++ b/src/manager/mood_manager.py @@ -34,14 +34,14 @@ class MoodUpdateTask(AsyncTask): def __init__(self): super().__init__( task_name="Mood Update Task", - wait_before_start=global_config.mood_update_interval, - run_interval=global_config.mood_update_interval, + wait_before_start=global_config.mood.mood_update_interval, + run_interval=global_config.mood.mood_update_interval, ) # 从配置文件获取衰减率 - self.decay_rate_valence: float = 1 - global_config.mood_decay_rate + self.decay_rate_valence: float = 1 - global_config.mood.mood_decay_rate """愉悦度衰减率""" - self.decay_rate_arousal: float = 1 - global_config.mood_decay_rate + self.decay_rate_arousal: float = 1 - global_config.mood.mood_decay_rate """唤醒度衰减率""" self.last_update = time.time() diff --git a/src/tools/not_used/change_mood.py b/src/tools/not_used/change_mood.py index c34bebb93..69fc3bb78 100644 --- a/src/tools/not_used/change_mood.py +++ b/src/tools/not_used/change_mood.py @@ -44,7 +44,7 @@ class ChangeMoodTool(BaseTool): _ori_response = ",".join(response_set) # _stance, emotion = await gpt._get_emotion_tags(ori_response, message_processed_plain_text) emotion = "平静" - mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor) + mood_manager.update_mood_from_emotion(emotion, global_config.mood.mood_intensity_factor) return {"name": "change_mood", "content": f"你的心情刚刚变化了,现在的心情是: {emotion}"} except Exception as e: logger.error(f"心情改变工具执行失败: {str(e)}") diff --git a/src/tools/tool_can_use/get_knowledge.py b/src/tools/tool_can_use/get_knowledge.py index 65acd55c0..fd37f11e7 100644 --- a/src/tools/tool_can_use/get_knowledge.py +++ b/src/tools/tool_can_use/get_knowledge.py @@ -1,8 +1,10 @@ from src.tools.tool_can_use.base_tool import BaseTool from src.chat.utils.utils import get_embedding -from src.common.database import db +from src.common.database.database_model import Knowledges # Updated import from src.common.logger_manager import get_logger -from typing import Any, Union +from typing import Any, Union, List # Added List +import json # Added for parsing embedding +import math # Added for cosine similarity logger = get_logger("get_knowledge_tool") @@ -30,6 +32,7 @@ class SearchKnowledgeTool(BaseTool): Returns: dict: 工具执行结果 """ + query = "" # Initialize query to ensure it's defined in except block try: query = function_args.get("query") threshold = function_args.get("threshold", 0.4) @@ -48,9 +51,19 @@ class SearchKnowledgeTool(BaseTool): logger.error(f"知识库搜索工具执行失败: {str(e)}") return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"} + @staticmethod + def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float: + """计算两个向量之间的余弦相似度""" + dot_product = sum(p * q for p, q in zip(vec1, vec2)) + magnitude1 = math.sqrt(sum(p * p for p in vec1)) + magnitude2 = math.sqrt(sum(q * q for q in vec2)) + if magnitude1 == 0 or magnitude2 == 0: + return 0.0 + return dot_product / (magnitude1 * magnitude2) + @staticmethod def get_info_from_db( - query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False + query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False ) -> Union[str, list]: """从数据库中获取相关信息 @@ -66,66 +79,51 @@ class SearchKnowledgeTool(BaseTool): if not query_embedding: return "" if not return_raw else [] - # 使用余弦相似度计算 - pipeline = [ - { - "$addFields": { - "dotProduct": { - "$reduce": { - "input": {"$range": [0, {"$size": "$embedding"}]}, - "initialValue": 0, - "in": { - "$add": [ - "$$value", - { - "$multiply": [ - {"$arrayElemAt": ["$embedding", "$$this"]}, - {"$arrayElemAt": [query_embedding, "$$this"]}, - ] - }, - ] - }, - } - }, - "magnitude1": { - "$sqrt": { - "$reduce": { - "input": "$embedding", - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - "magnitude2": { - "$sqrt": { - "$reduce": { - "input": query_embedding, - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - } - }, - {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, - { - "$match": { - "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 - } - }, - {"$sort": {"similarity": -1}}, - {"$limit": limit}, - {"$project": {"content": 1, "similarity": 1}}, - ] + similar_items = [] + try: + all_knowledges = Knowledges.select() + for item in all_knowledges: + try: + item_embedding_str = item.embedding + if not item_embedding_str: + logger.warning(f"Knowledge item ID {item.id} has empty embedding string.") + continue + item_embedding = json.loads(item_embedding_str) + if not isinstance(item_embedding, list) or not all( + isinstance(x, (int, float)) for x in item_embedding + ): + logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.") + continue + except json.JSONDecodeError: + logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}") + continue + except AttributeError: + logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.") + continue - results = list(db.knowledges.aggregate(pipeline)) - logger.debug(f"知识库查询结果数量: {len(results)}") + similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding) + + if similarity >= threshold: + similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item}) + + # 按相似度降序排序 + similar_items.sort(key=lambda x: x["similarity"], reverse=True) + + # 应用限制 + results = similar_items[:limit] + logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}") + + except Exception as e: + logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}") + return "" if not return_raw else [] if not results: return "" if not return_raw else [] if return_raw: - return results + # Peewee 模型实例不能直接序列化为 JSON,如果需要原始模型,调用者需要处理 + # 这里返回包含内容和相似度的字典列表 + return [{"content": r["content"], "similarity": r["similarity"]} for r in results] else: # 返回所有找到的内容,用换行分隔 return "\n".join(str(result["content"]) for result in results) diff --git a/src/tools/tool_use.py b/src/tools/tool_use.py index c55170b88..ff36085d6 100644 --- a/src/tools/tool_use.py +++ b/src/tools/tool_use.py @@ -15,7 +15,7 @@ logger = get_logger("tool_use") class ToolUser: def __init__(self): self.llm_model_tool = LLMRequest( - model=global_config.llm_tool_use, temperature=0.2, max_tokens=1000, request_type="tool_use" + model=global_config.model.tool_use, temperature=0.2, max_tokens=1000, request_type="tool_use" ) @staticmethod @@ -37,7 +37,7 @@ class ToolUser: # print(f"intol111111111111111111111111111111111222222222222mid_memory_info:{mid_memory_info}") # 这些信息应该从调用者传入,而不是从self获取 - bot_name = global_config.BOT_NICKNAME + bot_name = global_config.bot.nickname prompt = "" prompt += mid_memory_info prompt += "你正在思考如何回复群里的消息。\n" diff --git a/template/bot_config_meta.toml b/template/bot_config_meta.toml deleted file mode 100644 index c3541baad..000000000 --- a/template/bot_config_meta.toml +++ /dev/null @@ -1,104 +0,0 @@ -[inner.version] -describe = "版本号" -important = true -can_edit = false - -[bot.qq] -describe = "机器人的QQ号" -important = true -can_edit = true - -[bot.nickname] -describe = "机器人的昵称" -important = true -can_edit = true - -[bot.alias_names] -describe = "机器人的别名列表,该选项还在调试中,暂时未生效" -important = false -can_edit = true - -[groups.talk_allowed] -describe = "可以回复消息的群号码列表" -important = true -can_edit = true - -[groups.talk_frequency_down] -describe = "降低回复频率的群号码列表" -important = false -can_edit = true - -[groups.ban_user_id] -describe = "禁止回复和读取消息的QQ号列表" -important = false -can_edit = true - -[personality.personality_core] -describe = "用一句话或几句话描述人格的核心特点,建议20字以内" -important = true -can_edit = true - -[personality.personality_sides] -describe = "用一句话或几句话描述人格的一些细节,条数任意,不能为0,该选项还在调试中" -important = false -can_edit = true - -[identity.identity_detail] -describe = "身份特点列表,条数任意,不能为0,该选项还在调试中" -important = false -can_edit = true - -[identity.age] -describe = "年龄,单位岁" -important = false -can_edit = true - -[identity.gender] -describe = "性别" -important = false -can_edit = true - -[identity.appearance] -describe = "外貌特征描述,该选项还在调试中,暂时未生效" -important = false -can_edit = true - -[platforms.nonebot-qq] -describe = "nonebot-qq适配器提供的链接" -important = true -can_edit = true - -[chat.allow_focus_mode] -describe = "是否允许专注聊天状态" -important = false -can_edit = true - -[chat.base_normal_chat_num] -describe = "最多允许多少个群进行普通聊天" -important = false -can_edit = true - -[chat.base_focused_chat_num] -describe = "最多允许多少个群进行专注聊天" -important = false -can_edit = true - -[chat.observation_context_size] -describe = "观察到的最长上下文大小,建议15,太短太长都会导致脑袋尖尖" -important = false -can_edit = true - -[chat.message_buffer] -describe = "启用消息缓冲器,启用此项以解决消息的拆分问题,但会使麦麦的回复延迟" -important = false -can_edit = true - -[chat.ban_words] -describe = "需要过滤的消息列表" -important = false -can_edit = true - -[chat.ban_msgs_regex] -describe = "需要过滤的消息(原始消息)匹配的正则表达式,匹配到的消息将被过滤(支持CQ码)" -important = false -can_edit = true \ No newline at end of file diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 931afe2ed..64e51da77 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,18 +1,10 @@ [inner] -version = "1.7.0" +version = "2.0.0" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请在修改后将version的值进行变更 -#如果新增项目,请在BotConfig类下新增相应的变量 -#1.如果你修改的是[]层级项目,例如你新增了 [memory],那么请在config.py的 load_config函数中的include_configs字典中新增"内容":{ -#"func":memory, -#"support":">=0.0.0", #新的版本号 -#"necessary":False #是否必须 -#} -#2.如果你修改的是[]下的项目,例如你新增了[memory]下的 memory_ban_words ,那么请在config.py的 load_config函数中的 memory函数下新增版本判断: - # if config.INNER_VERSION in SpecifierSet(">=0.0.2"): - # config.memory_ban_words = set(memory_config.get("memory_ban_words", [])) - +#如果新增项目,请阅读src/config/official_configs.py中的说明 +# # 版本格式:主版本号.次版本号.修订号,版本号递增规则如下: # 主版本号:当你做了不兼容的 API 修改, # 次版本号:当你做了向下兼容的功能性新增, @@ -21,11 +13,11 @@ version = "1.7.0" #----以上是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- [bot] -qq = 1145141919810 +qq_account = 1145141919810 nickname = "麦麦" alias_names = ["麦叠", "牢麦"] #该选项还在调试中,暂时未生效 -[groups] +[chat_target] talk_allowed = [ 123, 123, @@ -53,10 +45,13 @@ identity_detail = [ "身份特点", "身份特点", ]# 条数任意,不能为0, 该选项还在调试中 + #外貌特征 -age = 20 # 年龄 单位岁 -gender = "男" # 性别 -appearance = "用几句话描述外貌特征" # 外貌特征 该选项还在调试中,暂时未生效 +age = 18 # 年龄 单位岁 +gender = "女" # 性别 +height = "170" # 身高(单位cm) +weight = "50" # 体重(单位kg) +appearance = "用一句或几句话描述外貌特征" # 外貌特征 该选项还在调试中,暂时未生效 [platforms] # 必填项目,填写每个平台适配器提供的链接 qq="http://127.0.0.1:18002/api/message" @@ -85,11 +80,10 @@ ban_msgs_regex = [ [normal_chat] #普通聊天 #一般回复参数 -model_reasoning_probability = 0.7 # 麦麦回答时选择推理模型 模型的概率 -model_normal_probability = 0.3 # 麦麦回答时选择一般模型 模型的概率 +reasoning_model_probability = 0.3 # 麦麦回答时选择推理模型的概率(与之相对的,普通模型的概率为1 - reasoning_model_probability) emoji_chance = 0.2 # 麦麦一般回复时使用表情包的概率,设置为1让麦麦自己决定发不发 -thinking_timeout = 100 # 麦麦最长思考时间,超过这个时间的思考会放弃(往往是api反应太慢) +thinking_timeout = 120 # 麦麦最长思考时间,超过这个时间的思考会放弃(往往是api反应太慢) willing_mode = "classical" # 回复意愿模式 —— 经典模式:classical,mxp模式:mxp,自定义模式:custom(需要你自己实现) response_willing_amplifier = 1 # 麦麦回复意愿放大系数,一般为1 @@ -100,8 +94,8 @@ mentioned_bot_inevitable_reply = false # 提及 bot 必然回复 at_bot_inevitable_reply = false # @bot 必然回复 [focus_chat] #专注聊天 -reply_trigger_threshold = 3.6 # 专注聊天触发阈值,越低越容易进入专注聊天 -default_decay_rate_per_second = 0.95 # 默认衰减率,越大衰减越快,越高越难进入专注聊天 +reply_trigger_threshold = 3.0 # 专注聊天触发阈值,越低越容易进入专注聊天 +default_decay_rate_per_second = 0.98 # 默认衰减率,越大衰减越快,越高越难进入专注聊天 consecutive_no_reply_threshold = 3 # 连续不回复的阈值,越低越容易结束专注聊天 # 以下选项暂时无效 @@ -110,20 +104,20 @@ compress_length_limit = 5 #最多压缩份数,超过该数值的压缩上下 [emoji] -max_emoji_num = 40 # 表情包最大数量 -max_reach_deletion = true # 开启则在达到最大数量时删除表情包,关闭则达到最大数量时不删除,只是不会继续收集表情包 -check_interval = 10 # 检查表情包(注册,破损,删除)的时间间隔(分钟) +max_reg_num = 40 # 表情包最大注册数量 +do_replace = true # 开启则在达到最大数量时删除(替换)表情包,关闭则达到最大数量时不会继续收集表情包 +check_interval = 120 # 检查表情包(注册,破损,删除)的时间间隔(分钟) save_pic = false # 是否保存图片 -save_emoji = false # 是否保存表情包 +cache_emoji = true # 是否缓存表情包 steal_emoji = true # 是否偷取表情包,让麦麦可以发送她保存的这些表情包 -enable_check = false # 是否启用表情包过滤,只有符合该要求的表情包才会被保存 -check_prompt = "符合公序良俗" # 表情包过滤要求,只有符合该要求的表情包才会被保存 +content_filtration = false # 是否启用表情包过滤,只有符合该要求的表情包才会被保存 +filtration_prompt = "符合公序良俗" # 表情包过滤要求,只有符合该要求的表情包才会被保存 [memory] -build_memory_interval = 2000 # 记忆构建间隔 单位秒 间隔越低,麦麦学习越多,但是冗余信息也会增多 -build_memory_distribution = [6.0,3.0,0.6,32.0,12.0,0.4] # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重 -build_memory_sample_num = 8 # 采样数量,数值越高记忆采样次数越多 -build_memory_sample_length = 40 # 采样长度,数值越高一段记忆内容越丰富 +memory_build_interval = 2000 # 记忆构建间隔 单位秒 间隔越低,麦麦学习越多,但是冗余信息也会增多 +memory_build_distribution = [6.0, 3.0, 0.6, 32.0, 12.0, 0.4] # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重 +memory_build_sample_num = 8 # 采样数量,数值越高记忆采样次数越多 +memory_build_sample_length = 40 # 采样长度,数值越高一段记忆内容越丰富 memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多 forget_memory_interval = 1000 # 记忆遗忘间隔 单位秒 间隔越低,麦麦遗忘越频繁,记忆更精简,但更难学习 @@ -135,49 +129,45 @@ consolidation_similarity_threshold = 0.7 # 相似度阈值 consolidation_check_percentage = 0.01 # 检查节点比例 #不希望记忆的词,已经记忆的不会受到影响 -memory_ban_words = [ - # "403","张三" -] +memory_ban_words = [ "表情包", "图片", "回复", "聊天记录" ] [mood] mood_update_interval = 1.0 # 情绪更新间隔 单位秒 mood_decay_rate = 0.95 # 情绪衰减率 mood_intensity_factor = 1.0 # 情绪强度因子 -[keywords_reaction] # 针对某个关键词作出反应 +[keyword_reaction] # 针对某个关键词作出反应 enable = true # 关键词反应功能的总开关 -[[keywords_reaction.rules]] # 如果想要新增多个关键词,直接复制本条,修改keywords和reaction即可 +[[keyword_reaction.rules]] # 如果想要新增多个关键词,直接复制本条,修改keywords和reaction即可 enable = true # 是否启用此条(为了人类在未来AI战争能更好地识别AI(bushi),默认开启) keywords = ["人机", "bot", "机器", "入机", "robot", "机器人","ai","AI"] # 会触发反应的关键词 reaction = "有人说你是人机或者机器人,否定这一事实,攻击或反问他。注意,只有认为你是机器人才回复,否则不要否认" # 触发之后添加的提示词 -[[keywords_reaction.rules]] # 就像这样复制 +[[keyword_reaction.rules]] # 就像这样复制 enable = false # 仅作示例,不会触发 keywords = ["测试关键词回复","test",""] reaction = "回答“测试成功”" # 修复错误的引号 -[[keywords_reaction.rules]] # 使用正则表达式匹配句式 +[[keyword_reaction.rules]] # 使用正则表达式匹配句式 enable = false # 仅作示例,不会触发 regex = ["^(?P\\S{1,20})是这样的$"] # 将匹配到的词汇命名为n,反应中对应的[n]会被替换为匹配到的内容,若不了解正则表达式请勿编写 reaction = "请按照以下模板造句:[n]是这样的,xx只要xx就可以,可是[n]要考虑的事情就很多了,比如什么时候xx,什么时候xx,什么时候xx。(请自由发挥替换xx部分,只需保持句式结构,同时表达一种将[n]过度重视的反讽意味)" [chinese_typo] enable = true # 是否启用中文错别字生成器 -error_rate=0.001 # 单字替换概率 +error_rate=0.01 # 单字替换概率 min_freq=9 # 最小字频阈值 tone_error_rate=0.1 # 声调错误概率 word_replace_rate=0.006 # 整词替换概率 [response_splitter] -enable_response_splitter = true # 是否启用回复分割器 -response_max_length = 256 # 回复允许的最大长度 -response_max_sentence_num = 4 # 回复允许的最大句子数 +enable = true # 是否启用回复分割器 +max_length = 256 # 回复允许的最大长度 +max_sentence_num = 4 # 回复允许的最大句子数 enable_kaomoji_protection = false # 是否启用颜文字保护 -model_max_output_length = 256 # 模型单次返回的最大token数 - -[remote] #发送统计信息,主要是看全球有多少只麦麦 +[telemetry] #发送统计信息,主要是看全球有多少只麦麦 enable = true [experimental] #实验性功能 @@ -194,14 +184,17 @@ pfc_chatting = false # 是否启用PFC聊天,该功能仅作用于私聊,与 # stream = : 用于指定模型是否是使用流式输出 # 如果不指定,则该项是 False +[model] +model_max_output_length = 800 # 模型单次返回的最大token数 + #这个模型必须是推理模型 -[model.llm_reasoning] # 一般聊天模式的推理回复模型 +[model.reasoning] # 一般聊天模式的推理回复模型 name = "Pro/deepseek-ai/DeepSeek-R1" provider = "SILICONFLOW" pri_in = 1.0 #模型的输入价格(非必填,可以记录消耗) pri_out = 4.0 #模型的输出价格(非必填,可以记录消耗) -[model.llm_normal] #V3 回复模型 专注和一般聊天模式共用的回复模型 +[model.normal] #V3 回复模型 专注和一般聊天模式共用的回复模型 name = "Pro/deepseek-ai/DeepSeek-V3" provider = "SILICONFLOW" pri_in = 2 #模型的输入价格(非必填,可以记录消耗) @@ -209,13 +202,13 @@ pri_out = 8 #模型的输出价格(非必填,可以记录消耗) #默认temp 0.2 如果你使用的是老V3或者其他模型,请自己修改temp参数 temp = 0.2 #模型的温度,新V3建议0.1-0.3 -[model.llm_topic_judge] #主题判断模型:建议使用qwen2.5 7b +[model.topic_judge] #主题判断模型:建议使用qwen2.5 7b name = "Pro/Qwen/Qwen2.5-7B-Instruct" provider = "SILICONFLOW" pri_in = 0.35 pri_out = 0.35 -[model.llm_summary] #概括模型,建议使用qwen2.5 32b 及以上 +[model.summary] #概括模型,建议使用qwen2.5 32b 及以上 name = "Qwen/Qwen2.5-32B-Instruct" provider = "SILICONFLOW" pri_in = 1.26 @@ -227,27 +220,27 @@ provider = "SILICONFLOW" pri_in = 0.35 pri_out = 0.35 -[model.llm_heartflow] # 用于控制麦麦是否参与聊天的模型 +[model.heartflow] # 用于控制麦麦是否参与聊天的模型 name = "Qwen/Qwen2.5-32B-Instruct" provider = "SILICONFLOW" pri_in = 1.26 pri_out = 1.26 -[model.llm_observation] #观察模型,压缩聊天内容,建议用免费的 +[model.observation] #观察模型,压缩聊天内容,建议用免费的 # name = "Pro/Qwen/Qwen2.5-7B-Instruct" name = "Qwen/Qwen2.5-7B-Instruct" provider = "SILICONFLOW" pri_in = 0 pri_out = 0 -[model.llm_sub_heartflow] #心流:认真水群时,生成麦麦的内心想法,必须使用具有工具调用能力的模型 +[model.sub_heartflow] #心流:认真水群时,生成麦麦的内心想法,必须使用具有工具调用能力的模型 name = "Pro/deepseek-ai/DeepSeek-V3" provider = "SILICONFLOW" pri_in = 2 pri_out = 8 temp = 0.3 #模型的温度,新V3建议0.1-0.3 -[model.llm_plan] #决策:认真水群时,负责决定麦麦该做什么 +[model.plan] #决策:认真水群时,负责决定麦麦该做什么 name = "Pro/deepseek-ai/DeepSeek-V3" provider = "SILICONFLOW" pri_in = 2 @@ -265,7 +258,7 @@ pri_out = 0 #私聊PFC:需要开启PFC功能,默认三个模型均为硅基流动v3,如果需要支持多人同时私聊或频繁调用,建议把其中的一个或两个换成官方v3或其它模型,以免撞到429 #PFC决策模型 -[model.llm_PFC_action_planner] +[model.pfc_action_planner] name = "Pro/deepseek-ai/DeepSeek-V3" provider = "SILICONFLOW" temp = 0.3 @@ -273,7 +266,7 @@ pri_in = 2 pri_out = 8 #PFC聊天模型 -[model.llm_PFC_chat] +[model.pfc_chat] name = "Pro/deepseek-ai/DeepSeek-V3" provider = "SILICONFLOW" temp = 0.3 @@ -281,7 +274,7 @@ pri_in = 2 pri_out = 8 #PFC检查模型 -[model.llm_PFC_reply_checker] +[model.pfc_reply_checker] name = "Pro/deepseek-ai/DeepSeek-V3" provider = "SILICONFLOW" pri_in = 2 @@ -294,7 +287,7 @@ pri_out = 8 #以下模型暂时没有使用!! #以下模型暂时没有使用!! -[model.llm_tool_use] #工具调用模型,需要使用支持工具调用的模型,建议使用qwen2.5 32b +[model.tool_use] #工具调用模型,需要使用支持工具调用的模型,建议使用qwen2.5 32b name = "Qwen/Qwen2.5-32B-Instruct" provider = "SILICONFLOW" pri_in = 1.26 diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 000000000..1a1239601 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,7 @@ +from src.config.config import global_config + + +class TestConfig: + def test_load(self): + config = global_config + print(config)