From b84cc9240aab39e9ff5862f3ac8e773dd6b7b703 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 14 May 2025 22:53:21 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E4=BA=A4=E4=BA=92=E4=BB=A5=E4=BD=BF=E7=94=A8=20Peewee=20ORM?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 更新数据库连接和模型定义,以便使用 Peewee for SQLite。 - 在消息存储和检索功能中,用 Peewee ORM 查询替换 MongoDB 查询。 - 为 Messages、ThinkingLog 和 OnlineTime 引入了新的模型,以方便结构化数据存储。 - 增强了数据库操作的错误处理和日志记录。 - 删除了过时的 MongoDB 集合管理代码。 - 通过利用 Peewee 内置的查询和数据操作方法来提升性能。 --- src/chat/emoji_system/emoji_manager.py | 188 ++++--- .../focus_chat/heartflow_prompt_builder.py | 157 +++--- src/chat/memory_system/Hippocampus.py | 2 +- src/chat/message_receive/chat_stream.py | 126 ++++- src/chat/models/utils_model.py | 41 +- src/chat/person_info/person_info.py | 531 ++++++++++-------- src/chat/utils/info_catcher.py | 160 +++--- src/chat/utils/statistic.py | 66 +-- src/chat/utils/utils_image.py | 175 +++--- src/common/database/database.py | 14 +- src/common/database/database_model.py | 54 +- src/common/message_repository.py | 92 ++- src/experimental/PFC/chat_observer.py | 4 +- src/experimental/PFC/message_storage.py | 37 +- src/tools/tool_can_use/get_knowledge.py | 110 ++-- 15 files changed, 999 insertions(+), 758 deletions(-) diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 076dbf5a4..68fa5de44 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -10,7 +10,10 @@ from PIL import Image import io import re -from ...common.database.database import db +# from gradio_client import file + +from ...common.database.database_model import Emoji +from ...common.database.database import db as peewee_db from ...config.config import global_config from ..utils.utils_image import image_path_to_base64, image_manager from ..models.utils_model import LLMRequest @@ -143,37 +146,28 @@ class MaiEmoji: # --- 数据库操作 --- try: # 准备数据库记录 for emoji collection - emoji_record = { - "filename": self.filename, - "path": self.path, # 存储目录路径 - "full_path": self.full_path, # 存储完整文件路径 - "embedding": self.embedding, - "description": self.description, - "emotion": self.emotion, - "hash": self.hash, - "format": self.format, - "timestamp": int(self.register_time), - "usage_count": self.usage_count, - "last_used_time": self.last_used_time, - } - - # 使用upsert确保记录存在或被更新 - db["emoji"].update_one({"hash": self.hash}, {"$set": emoji_record}, upsert=True) + emotion_str = ",".join(self.emotion) if self.emotion else "" + Emoji.create(hash=self.hash, + full_path=self.full_path, + format=self.format, + description=self.description, + emotion=emotion_str, # Store as comma-separated string + query_count=0, # Default value + is_registered=True, + is_banned=False, # Default value + record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time + register_time=self.register_time, + usage_count=self.usage_count, + last_used_time=self.last_used_time, + ) + logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") return True except Exception as db_error: logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}") - # 数据库保存失败,是否需要将文件移回?为了简化,暂时只记录错误 - # 可以考虑在这里尝试删除已移动的文件,避免残留 - try: - if os.path.exists(self.full_path): # full_path 此时是目标路径 - os.remove(self.full_path) - logger.warning(f"[回滚] 已删除移动失败后残留的文件: {self.full_path}") - except Exception as remove_error: - logger.error(f"[错误] 回滚删除文件失败: {remove_error}") return False except Exception as e: @@ -201,10 +195,14 @@ class MaiEmoji: # 文件删除失败,但仍然尝试删除数据库记录 # 2. 删除数据库记录 - result = db.emoji.delete_one({"hash": self.hash}) - deleted_in_db = result.deleted_count > 0 + try: + will_delete_emoji = Emoji.get(Emoji.hash == self.hash) + result = will_delete_emoji.delete_instance() # Returns the number of rows deleted. + except Emoji.DoesNotExist: + logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") + result = 0 # Indicate no DB record was deleted - if deleted_in_db: + if result > 0: logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})") # 3. 标记对象已被删除 self.is_deleted = True @@ -246,44 +244,43 @@ def _emoji_objects_to_readable_list(emoji_objects): def _to_emoji_objects(data): emoji_objects = [] load_errors = 0 - emoji_data_list = list(data) + # data is now an iterable of Peewee Emoji model instances + emoji_data_list = list(data) - for emoji_data in emoji_data_list: - full_path = emoji_data.get("full_path") + for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance + full_path = emoji_data.full_path if not full_path: - logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: {emoji_data.get('_id')}") + logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}") load_errors += 1 - continue # 跳过缺少 full_path 的记录 + continue try: - # 使用 full_path 初始化 MaiEmoji 对象 emoji = MaiEmoji(full_path=full_path) - # 设置从数据库加载的属性 - emoji.hash = emoji_data.get("hash", "") - # 如果 hash 为空,也跳过?取决于业务逻辑 + emoji.hash = emoji_data.hash if not emoji.hash: logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}") load_errors += 1 continue - emoji.description = emoji_data.get("description", "") - emoji.emotion = emoji_data.get("emotion", []) - emoji.usage_count = emoji_data.get("usage_count", 0) - # 优先使用 last_used_time,否则用 timestamp,最后用当前时间 - last_used = emoji_data.get("last_used_time") - timestamp = emoji_data.get("timestamp") - emoji.last_used_time = ( - last_used if last_used is not None else (timestamp if timestamp is not None else time.time()) - ) - emoji.register_time = timestamp if timestamp is not None else time.time() - emoji.format = emoji_data.get("format", "") # 加载格式 + emoji.description = emoji_data.description + # Deserialize emotion string from DB to list + emoji.emotion = emoji_data.emotion.split(',') if emoji_data.emotion else [] + emoji.usage_count = emoji_data.usage_count + + db_last_used_time = emoji_data.last_used_time + db_register_time = emoji_data.register_time - # 不需要再手动设置 path 和 filename,__init__ 会自动处理 + # If last_used_time from DB is None, use MaiEmoji's initialized register_time or current time + emoji.last_used_time = db_last_used_time if db_last_used_time is not None else emoji.register_time + # If register_time from DB is None, use MaiEmoji's initialized register_time (which is time.time()) + emoji.register_time = db_register_time if db_register_time is not None else emoji.register_time + + emoji.format = emoji_data.format emoji_objects.append(emoji) - except ValueError as ve: # 捕获 __init__ 可能的错误 + except ValueError as ve: logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}") load_errors += 1 except Exception as e: @@ -385,12 +382,13 @@ class EmojiManager: """初始化数据库连接和表情目录""" if not self._initialized: try: - self._ensure_emoji_collection() + # Ensure Peewee database connection is up and tables are created + if not peewee_db.is_closed(): + peewee_db.connect(reuse_if_open=True) + Emoji.create_table(safe=True) # Ensures table exists + _ensure_emoji_dir() self._initialized = True - # 更新表情包数量 - # 启动时执行一次完整性检查 - # await self.check_emoji_file_integrity() except Exception as e: logger.exception(f"初始化表情管理器失败: {e}") @@ -401,33 +399,15 @@ class EmojiManager: if not self._initialized: raise RuntimeError("EmojiManager not initialized") - @staticmethod - def _ensure_emoji_collection(): - """确保emoji集合存在并创建索引 - - 这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。 - - 索引的作用是加快数据库查询速度: - - embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包 - - tags字段的普通索引: 加快按标签搜索表情包的速度 - - filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度 - - 没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。 - """ - if "emoji" not in db.list_collection_names(): - db.create_collection("emoji") - db.emoji.create_index([("embedding", "2dsphere")]) - db.emoji.create_index([("filename", 1)], unique=True) - def record_usage(self, emoji_hash: str): """记录表情使用次数""" try: - db.emoji.update_one({"hash": emoji_hash}, {"$inc": {"usage_count": 1}}) - for emoji in self.emoji_objects: - if emoji.hash == emoji_hash: - emoji.usage_count += 1 - break - + emoji_update = Emoji.get(Emoji.hash == emoji_hash) + emoji_update.usage_count += 1 + emoji_update.last_used_time = time.time() # Update last used time + emoji_update.save() # Persist changes to DB + except Emoji.DoesNotExist: + logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包") except Exception as e: logger.error(f"记录表情使用失败: {str(e)}") @@ -657,9 +637,10 @@ class EmojiManager: """获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects""" try: self._ensure_db() - logger.info("[数据库] 开始加载所有表情包记录...") + logger.info("[数据库] 开始加载所有表情包记录 (Peewee)...") - emoji_objects, load_errors = _to_emoji_objects(db.emoji.find()) + emoji_peewee_instances = Emoji.select() + emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances) # 更新内存中的列表和数量 self.emoji_objects = emoji_objects @@ -686,15 +667,16 @@ class EmojiManager: try: self._ensure_db() - query = {} if emoji_hash: - query = {"hash": emoji_hash} + query = Emoji.select().where(Emoji.hash == emoji_hash) else: logger.warning( "[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。" ) - - emoji_objects, load_errors = _to_emoji_objects(db.emoji.find(query)) + query = Emoji.select() + + emoji_peewee_instances = query + emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances) if load_errors > 0: logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。") @@ -908,6 +890,44 @@ class EmojiManager: logger.error(f"获取表情包描述失败: {str(e)}") return "", [] + # async def register_emoji_by_filename(self, filename: str) -> bool: + # if global_config.EMOJI_CHECK: + # prompt = f''' + # 这是一个表情包,请对这个表情包进行审核,标准如下: + # 1. 必须符合"{global_config.EMOJI_CHECK_PROMPT}"的要求 + # 2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗 + # 3. 不能是任何形式的截图,聊天记录或视频截图 + # 4. 不要出现5个以上文字 + # 请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容 + # ''' + # content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) + # if content == "否": + # return "", [] + + # # 分析情感含义 + # emotion_prompt = f""" + # 请你识别这个表情包的含义和适用场景,给我简短的描述,每个描述不要超过15个字 + # 这是一个基于这个表情包的描述:'{description}' + # 你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析 + # 请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔 + # """ + # emotions_text, _ = await self.llm_emotion_judge.generate_response_async(emotion_prompt, temperature=0.7) + + # # 处理情感列表 + # emotions = [e.strip() for e in emotions_text.split(",") if e.strip()] + + # # 根据情感标签数量随机选择喵~超过5个选3个,超过2个选2个 + # if len(emotions) > 5: + # emotions = random.sample(emotions, 3) + # elif len(emotions) > 2: + # emotions = random.sample(emotions, 2) + + # return f"[表情包:{description}]", emotions + + # except Exception as e: + # logger.error(f"获取表情包描述失败: {str(e)}") + # return "", [] + async def register_emoji_by_filename(self, filename: str) -> bool: """读取指定文件名的表情包图片,分析并注册到数据库 diff --git a/src/chat/focus_chat/heartflow_prompt_builder.py b/src/chat/focus_chat/heartflow_prompt_builder.py index 1acef540e..141d850ab 100644 --- a/src/chat/focus_chat/heartflow_prompt_builder.py +++ b/src/chat/focus_chat/heartflow_prompt_builder.py @@ -7,7 +7,7 @@ from src.chat.person_info.relationship_manager import relationship_manager from src.chat.utils.utils import get_embedding import time from typing import Union, Optional -from common.database.database import db +# from common.database.database import db from src.chat.utils.utils import get_recent_group_speaker from src.manager.mood_manager import mood_manager from src.chat.memory_system.Hippocampus import HippocampusManager @@ -15,6 +15,9 @@ from src.chat.knowledge.knowledge_lib import qa_manager from src.chat.focus_chat.expressors.exprssion_learner import expression_learner # import traceback import random +import json +import math +from src.common.database.database_model import Knowledges logger = get_logger("prompt") @@ -69,7 +72,7 @@ def init_prompt(): 你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt},{reply_style1}, 尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}。{prompt_ger} 请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。 -请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。 +请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容。 {moderation_prompt} 不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容""", "reasoning_prompt_main", @@ -439,30 +442,6 @@ class PromptBuilder: logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") # 1. 先从LLM获取主题,类似于记忆系统的做法 topics = [] - # try: - # # 先尝试使用记忆系统的方法获取主题 - # hippocampus = HippocampusManager.get_instance()._hippocampus - # topic_num = min(5, max(1, int(len(message) * 0.1))) - # topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num)) - - # # 提取关键词 - # topics = re.findall(r"<([^>]+)>", topics_response[0]) - # if not topics: - # topics = [] - # else: - # topics = [ - # topic.strip() - # for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") - # if topic.strip() - # ] - - # logger.info(f"从LLM提取的主题: {', '.join(topics)}") - # except Exception as e: - # logger.error(f"从LLM提取主题失败: {str(e)}") - # # 如果LLM提取失败,使用jieba分词提取关键词作为备选 - # words = jieba.cut(message) - # topics = [word for word in words if len(word) > 1][:5] - # logger.info(f"使用jieba提取的主题: {', '.join(topics)}") # 如果无法提取到主题,直接使用整个消息 if not topics: @@ -572,8 +551,6 @@ class PromptBuilder: for _i, result in enumerate(results, 1): _similarity = result["similarity"] content = result["content"].strip() - # 调试:为内容添加序号和相似度信息 - # related_info += f"{i}. [{similarity:.2f}] {content}\n" related_info += f"{content}\n" related_info += "\n" @@ -602,14 +579,14 @@ class PromptBuilder: return related_info else: logger.debug("从LPMM知识库获取知识失败,使用旧版数据库进行检索") - knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38) + knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold) related_info += knowledge_from_old logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") return related_info except Exception as e: logger.error(f"获取知识库内容时发生异常: {str(e)}") try: - knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38) + knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold) related_info += knowledge_from_old logger.debug( f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}" @@ -625,70 +602,70 @@ class PromptBuilder: ) -> Union[str, list]: if not query_embedding: return "" if not return_raw else [] - # 使用余弦相似度计算 - pipeline = [ - { - "$addFields": { - "dotProduct": { - "$reduce": { - "input": {"$range": [0, {"$size": "$embedding"}]}, - "initialValue": 0, - "in": { - "$add": [ - "$$value", - { - "$multiply": [ - {"$arrayElemAt": ["$embedding", "$$this"]}, - {"$arrayElemAt": [query_embedding, "$$this"]}, - ] - }, - ] - }, - } - }, - "magnitude1": { - "$sqrt": { - "$reduce": { - "input": "$embedding", - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - "magnitude2": { - "$sqrt": { - "$reduce": { - "input": query_embedding, - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - } - }, - {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, - { - "$match": { - "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 - } - }, - {"$sort": {"similarity": -1}}, - {"$limit": limit}, - {"$project": {"content": 1, "similarity": 1}}, - ] - results = list(db.knowledges.aggregate(pipeline)) - logger.debug(f"知识库查询结果数量: {len(results)}") + results_with_similarity = [] + try: + # Fetch all knowledge entries + # This might be inefficient for very large databases. + # Consider strategies like FAISS or other vector search libraries if performance becomes an issue. + all_knowledges = Knowledges.select() - if not results: + if not all_knowledges: + return "" if not return_raw else [] + + query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding)) + if query_embedding_magnitude == 0: # Avoid division by zero + return "" if not return_raw else [] + + for knowledge_item in all_knowledges: + try: + db_embedding_str = knowledge_item.embedding + db_embedding = json.loads(db_embedding_str) + + if len(db_embedding) != len(query_embedding): + logger.warning(f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping.") + continue + + # Calculate Cosine Similarity + dot_product = sum(q * d for q, d in zip(query_embedding, db_embedding)) + db_embedding_magnitude = math.sqrt(sum(x * x for x in db_embedding)) + + if db_embedding_magnitude == 0: # Avoid division by zero + similarity = 0.0 + else: + similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude) + + if similarity >= threshold: + results_with_similarity.append({ + "content": knowledge_item.content, + "similarity": similarity + }) + except json.JSONDecodeError: + logger.error(f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}") + except Exception as e: + logger.error(f"Error processing knowledge item: {e}") + + + # Sort by similarity in descending order + results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True) + + # Limit results + limited_results = results_with_similarity[:limit] + + logger.debug(f"知识库查询结果数量 (after Peewee processing): {len(limited_results)}") + + if not limited_results: + return "" if not return_raw else [] + + if return_raw: + return limited_results + else: + return "\n".join(str(result["content"]) for result in limited_results) + + except Exception as e: + logger.error(f"Error querying Knowledges with Peewee: {e}") return "" if not return_raw else [] - if return_raw: - return results - else: - # 返回所有找到的内容,用换行分隔 - return "\n".join(str(result["content"]) for result in results) - def weighted_sample_no_replacement(items, weights, k) -> list: """ diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index e64475126..78616d824 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -10,7 +10,7 @@ import jieba import networkx as nx import numpy as np from collections import Counter -from ...common.database.database import db +from ...common.database.database import memory_db as db from ...chat.models.utils_model import LLMRequest from src.common.logger_manager import get_logger from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器 diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 7f41ac96b..723d6da47 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -6,6 +6,7 @@ from typing import Dict, Optional from ...common.database.database import db +from ...common.database.database_model import ChatStreams # 新增导入 from maim_message import GroupInfo, UserInfo from src.common.logger_manager import get_logger @@ -82,7 +83,13 @@ class ChatManager: def __init__(self): if not self._initialized: self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream - self._ensure_collection() + try: + db.connect(reuse_if_open=True) + # 确保 ChatStreams 表存在 + db.create_tables([ChatStreams], safe=True) + except Exception as e: + logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}") + self._initialized = True # 在事件循环中启动初始化 # asyncio.create_task(self._initialize()) @@ -107,15 +114,6 @@ class ChatManager: except Exception as e: logger.error(f"聊天流自动保存失败: {str(e)}") - @staticmethod - def _ensure_collection(): - """确保数据库集合存在并创建索引""" - if "chat_streams" not in db.list_collection_names(): - db.create_collection("chat_streams") - # 创建索引 - db.chat_streams.create_index([("stream_id", 1)], unique=True) - db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]) - @staticmethod def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str: """生成聊天流唯一ID""" @@ -151,16 +149,43 @@ class ChatManager: stream = self.streams[stream_id] # 更新用户信息和群组信息 stream.update_active_time() - stream = copy.deepcopy(stream) + stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存 stream.user_info = user_info if group_info: stream.group_info = group_info return stream # 检查数据库中是否存在 - data = db.chat_streams.find_one({"stream_id": stream_id}) - if data: - stream = ChatStream.from_dict(data) + def _db_find_stream_sync(s_id: str): + return ChatStreams.get_or_none(ChatStreams.stream_id == s_id) + + model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id) + + if model_instance: + # 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式 + user_info_data = { + "platform": model_instance.user_platform, + "user_id": model_instance.user_id, + "user_nickname": model_instance.user_nickname, + "user_cardname": model_instance.user_cardname or "", + } + group_info_data = None + if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息 + group_info_data = { + "platform": model_instance.group_platform, + "group_id": model_instance.group_id, + "group_name": model_instance.group_name, + } + + data_for_from_dict = { + "stream_id": model_instance.stream_id, + "platform": model_instance.platform, + "user_info": user_info_data, + "group_info": group_info_data, + "create_time": model_instance.create_time, + "last_active_time": model_instance.last_active_time, + } + stream = ChatStream.from_dict(data_for_from_dict) # 更新用户信息和群组信息 stream.user_info = user_info if group_info: @@ -175,7 +200,7 @@ class ChatManager: group_info=group_info, ) except Exception as e: - logger.error(f"创建聊天流失败: {e}") + logger.error(f"获取或创建聊天流失败: {e}", exc_info=True) raise e # 保存到内存和数据库 @@ -205,15 +230,38 @@ class ChatManager: elif stream.user_info and stream.user_info.user_nickname: return f"{stream.user_info.user_nickname}的私聊" else: - # 如果没有群名或用户昵称,返回 None 或其他默认值 return None @staticmethod async def _save_stream(stream: ChatStream): """保存聊天流到数据库""" if not stream.saved: - db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True) - stream.saved = True + stream_data_dict = stream.to_dict() + + def _db_save_stream_sync(s_data_dict: dict): + user_info_d = s_data_dict.get("user_info") + group_info_d = s_data_dict.get("group_info") + + fields_to_save = { + "platform": s_data_dict["platform"], + "create_time": s_data_dict["create_time"], + "last_active_time": s_data_dict["last_active_time"], + "user_platform": user_info_d["platform"] if user_info_d else "", + "user_id": user_info_d["user_id"] if user_info_d else "", + "user_nickname": user_info_d["user_nickname"] if user_info_d else "", + "user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None, + "group_platform": group_info_d["platform"] if group_info_d else "", + "group_id": group_info_d["group_id"] if group_info_d else "", + "group_name": group_info_d["group_name"] if group_info_d else "", + } + + ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute() + + try: + await asyncio.to_thread(_db_save_stream_sync, stream_data_dict) + stream.saved = True + except Exception as e: + logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True) async def _save_all_streams(self): """保存所有聊天流""" @@ -222,10 +270,44 @@ class ChatManager: async def load_all_streams(self): """从数据库加载所有聊天流""" - all_streams = db.chat_streams.find({}) - for data in all_streams: - stream = ChatStream.from_dict(data) - self.streams[stream.stream_id] = stream + + def _db_load_all_streams_sync(): + loaded_streams_data = [] + for model_instance in ChatStreams.select(): + user_info_data = { + "platform": model_instance.user_platform, + "user_id": model_instance.user_id, + "user_nickname": model_instance.user_nickname, + "user_cardname": model_instance.user_cardname or "", + } + group_info_data = None + if model_instance.group_id: + group_info_data = { + "platform": model_instance.group_platform, + "group_id": model_instance.group_id, + "group_name": model_instance.group_name, + } + + data_for_from_dict = { + "stream_id": model_instance.stream_id, + "platform": model_instance.platform, + "user_info": user_info_data, + "group_info": group_info_data, + "create_time": model_instance.create_time, + "last_active_time": model_instance.last_active_time, + } + loaded_streams_data.append(data_for_from_dict) + return loaded_streams_data + + try: + all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync) + self.streams.clear() + for data in all_streams_data_list: + stream = ChatStream.from_dict(data) + stream.saved = True + self.streams[stream.stream_id] = stream + except Exception as e: + logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True) # 创建全局单例 diff --git a/src/chat/models/utils_model.py b/src/chat/models/utils_model.py index 9ca4e56d0..986036e86 100644 --- a/src/chat/models/utils_model.py +++ b/src/chat/models/utils_model.py @@ -12,7 +12,8 @@ import base64 from PIL import Image import io import os -from ...common.database.database import db +from src.common.database.database import db # 确保 db 被导入用于 create_tables +from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型 from ...config.config import global_config from rich.traceback import install @@ -85,8 +86,6 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any f"data:image/{image_format.lower() if image_format else 'jpeg'};base64," f"{image_base64[:10]}...{image_base64[-10:]}" ) - # if isinstance(content, str) and len(content) > 100: - # payload["messages"][0]["content"] = content[:100] return payload @@ -134,13 +133,11 @@ class LLMRequest: def _init_database(): """初始化数据库集合""" try: - # 创建llm_usage集合的索引 - db.llm_usage.create_index([("timestamp", 1)]) - db.llm_usage.create_index([("model_name", 1)]) - db.llm_usage.create_index([("user_id", 1)]) - db.llm_usage.create_index([("request_type", 1)]) + # 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误 + db.create_tables([LLMUsage], safe=True) + logger.info("LLMUsage 表已初始化/确保存在。") except Exception as e: - logger.error(f"创建数据库索引失败: {str(e)}") + logger.error(f"创建 LLMUsage 表失败: {str(e)}") def _record_usage( self, @@ -165,19 +162,19 @@ class LLMRequest: request_type = self.request_type try: - usage_data = { - "model_name": self.model_name, - "user_id": user_id, - "request_type": request_type, - "endpoint": endpoint, - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - "cost": self._calculate_cost(prompt_tokens, completion_tokens), - "status": "success", - "timestamp": datetime.now(), - } - db.llm_usage.insert_one(usage_data) + # 使用 Peewee 模型创建记录 + LLMUsage.create( + model_name=self.model_name, + user_id=user_id, + request_type=request_type, + endpoint=endpoint, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + cost=self._calculate_cost(prompt_tokens, completion_tokens), + status="success", + timestamp=datetime.now(), # Peewee 会处理 DateTimeField + ) logger.trace( f"Token使用情况 - 模型: {self.model_name}, " f"用户: {user_id}, 类型: {request_type}, " diff --git a/src/chat/person_info/person_info.py b/src/chat/person_info/person_info.py index 00cbe86f1..cd9034d6f 100644 --- a/src/chat/person_info/person_info.py +++ b/src/chat/person_info/person_info.py @@ -1,5 +1,6 @@ from src.common.logger_manager import get_logger from ...common.database.database import db +from ...common.database.database_model import PersonInfo # 新增导入 import copy import hashlib from typing import Any, Callable, Dict @@ -16,7 +17,7 @@ matplotlib.use("Agg") import matplotlib.pyplot as plt from pathlib import Path import pandas as pd -import json +import json # 新增导入 import re @@ -43,17 +44,13 @@ person_info_default = { "platform": None, "user_id": None, "nickname": None, - # "age" : 0, "relationship_value": 0, - # "saved" : True, - # "impression" : None, - # "gender" : Unkown, "konw_time": 0, "msg_interval": 2000, - "msg_interval_list": [], - "user_cardname": None, # 添加群名片 - "user_avatar": None, # 添加头像信息(例如URL或标识符) -} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项 + "msg_interval_list": [], # 将作为 JSON 字符串存储在 Peewee 的 TextField + "user_cardname": None, # 注意:此字段不在 PersonInfo Peewee 模型中 + "user_avatar": None, # 注意:此字段不在 PersonInfo Peewee 模型中 +} class PersonInfoManager: @@ -64,21 +61,26 @@ class PersonInfoManager: max_tokens=256, request_type="qv_name", ) - if "person_info" not in db.list_collection_names(): - db.create_collection("person_info") - db.person_info.create_index("person_id", unique=True) + try: + db.connect(reuse_if_open=True) + db.create_tables([PersonInfo], safe=True) + except Exception as e: + logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}") # 初始化时读取所有person_name - cursor = db.person_info.find({"person_name": {"$exists": True}}, {"person_id": 1, "person_name": 1, "_id": 0}) - for doc in cursor: - if doc.get("person_name"): - self.person_name_list[doc["person_id"]] = doc["person_name"] - logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称") + try: + for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where( + PersonInfo.person_name.is_null(False) + ): + if record.person_name: + self.person_name_list[record.person_id] = record.person_name + logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)") + except Exception as e: + logger.error(f"从 Peewee 加载 person_name_list 失败: {e}") @staticmethod def get_person_id(platform: str, user_id: int): """获取唯一id""" - # 如果platform中存在-,就截取-后面的部分 if "-" in platform: platform = platform.split("-")[1] @@ -86,13 +88,17 @@ class PersonInfoManager: key = "_".join(components) return hashlib.md5(key.encode()).hexdigest() - def is_person_known(self, platform: str, user_id: int): + async def is_person_known(self, platform: str, user_id: int): """判断是否认识某人""" person_id = self.get_person_id(platform, user_id) - document = db.person_info.find_one({"person_id": person_id}) - if document: - return True - else: + + def _db_check_known_sync(p_id: str): + return PersonInfo.get_or_none(PersonInfo.person_id == p_id) is not None + + try: + return await asyncio.to_thread(_db_check_known_sync, person_id) + except Exception as e: + logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}") return False @staticmethod @@ -103,73 +109,111 @@ class PersonInfoManager: return _person_info_default = copy.deepcopy(person_info_default) - _person_info_default["person_id"] = person_id + model_fields = PersonInfo._meta.fields.keys() + + final_data = {"person_id": person_id} if data: - for key in _person_info_default: - if key != "person_id" and key in data: - _person_info_default[key] = data[key] + for key, value in data.items(): + if key in model_fields: + final_data[key] = value - db.person_info.insert_one(_person_info_default) + for key, default_value in _person_info_default.items(): + if key in model_fields and key not in final_data: + final_data[key] = default_value + + if "msg_interval_list" in final_data and isinstance(final_data["msg_interval_list"], list): + final_data["msg_interval_list"] = json.dumps(final_data["msg_interval_list"]) + elif "msg_interval_list" not in final_data and "msg_interval_list" in model_fields: + final_data["msg_interval_list"] = json.dumps([]) + + def _db_create_sync(p_data: dict): + try: + PersonInfo.create(**p_data) + return True + except Exception as e: + logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (Peewee): {e}") + return False + + await asyncio.to_thread(_db_create_sync, final_data) async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None): """更新某一个字段,会补全""" - if field_name not in person_info_default.keys(): - logger.debug(f"更新'{field_name}'失败,未定义的字段") + if field_name not in PersonInfo._meta.fields: + if field_name in person_info_default: + logger.debug(f"更新'{field_name}'跳过,字段存在于默认配置但不在 PersonInfo Peewee 模型中。") + return + logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。") return - document = db.person_info.find_one({"person_id": person_id}) + def _db_update_sync(p_id: str, f_name: str, val): + record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) + if record: + if f_name == "msg_interval_list" and isinstance(val, list): + setattr(record, f_name, json.dumps(val)) + else: + setattr(record, f_name, val) + record.save() + return True, False + return False, True - if document: - db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}}) - else: - data[field_name] = value - logger.debug(f"更新时{person_id}不存在,已新建") - await self.create_person_info(person_id, data) + found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, value) + + if needs_creation: + logger.debug(f"更新时 {person_id} 不存在,将新建。") + creation_data = data if data is not None else {} + creation_data[field_name] = value + if "platform" not in creation_data or "user_id" not in creation_data: + logger.warning(f"为 {person_id} 创建记录时,platform/user_id 可能缺失。") + + await self.create_person_info(person_id, creation_data) @staticmethod async def has_one_field(person_id: str, field_name: str): """判断是否存在某一个字段""" - document = db.person_info.find_one({"person_id": person_id}, {field_name: 1}) - if document: - return True - else: + if field_name not in PersonInfo._meta.fields: + logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。") + return False + + def _db_has_field_sync(p_id: str, f_name: str): + record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) + if record: + return True + return False + + try: + return await asyncio.to_thread(_db_has_field_sync, person_id, field_name) + except Exception as e: + logger.error(f"检查字段 {field_name} for {person_id} 时出错 (Peewee): {e}") return False @staticmethod def _extract_json_from_text(text: str) -> dict: """从文本中提取JSON数据的高容错方法""" try: - # 尝试直接解析 parsed_json = json.loads(text) - # 如果解析结果是列表,尝试取第一个元素 if isinstance(parsed_json, list): - if parsed_json: # 检查列表是否为空 + if parsed_json: parsed_json = parsed_json[0] - else: # 如果列表为空,重置为 None,走后续逻辑 + else: parsed_json = None - # 确保解析结果是字典 if isinstance(parsed_json, dict): return parsed_json except json.JSONDecodeError: - # 解析失败,继续尝试其他方法 pass except Exception as e: logger.warning(f"尝试直接解析JSON时发生意外错误: {e}") - pass # 继续尝试其他方法 + pass - # 如果直接解析失败或结果不是字典 try: - # 尝试找到JSON对象格式的部分 json_pattern = r"\{[^{}]*\}" matches = re.findall(json_pattern, text) if matches: parsed_obj = json.loads(matches[0]) - if isinstance(parsed_obj, dict): # 确保是字典 + if isinstance(parsed_obj, dict): return parsed_obj - # 如果上面都失败了,尝试提取键值对 nickname_pattern = r'"nickname"[:\s]+"([^"]+)"' reason_pattern = r'"reason"[:\s]+"([^"]+)"' @@ -184,7 +228,6 @@ class PersonInfoManager: except Exception as e: logger.error(f"后备JSON提取失败: {str(e)}") - # 如果所有方法都失败了,返回默认字典 logger.warning(f"无法从文本中提取有效的JSON字典: {text}") return {"nickname": "", "reason": ""} @@ -199,9 +242,11 @@ class PersonInfoManager: old_name = await self.get_value(person_id, "person_name") old_reason = await self.get_value(person_id, "name_reason") - max_retries = 5 # 最大重试次数 + max_retries = 5 current_try = 0 - existing_names = "" + existing_names_str = "" + current_name_set = set(self.person_name_list.values()) + while current_try < max_retries: individuality = Individuality.get_instance() prompt_personality = individuality.get_prompt(x_person=2, level=1) @@ -216,45 +261,55 @@ class PersonInfoManager: qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason}," qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸" + qv_name_prompt += "\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的qq昵称,可以稍作修改" + + if existing_names_str: + qv_name_prompt += f"\n请注意,以下名称已被你尝试过或已知存在,请避免:{existing_names_str}。\n" + + if len(current_name_set) < 50 and current_name_set: + qv_name_prompt += f"已知的其他昵称有: {', '.join(list(current_name_set)[:10])}等。\n" - qv_name_prompt += ( - "\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的qq昵称,可以稍作修改" - ) - if existing_names: - qv_name_prompt += f"\n请注意,以下名称已被使用,不要使用以下昵称:{existing_names}。\n" qv_name_prompt += "请用json给出你的想法,并给出理由,示例如下:" qv_name_prompt += """{ "nickname": "昵称", "reason": "理由" }""" - # logger.debug(f"取名提示词:{qv_name_prompt}") response = await self.qv_name_llm.generate_response(qv_name_prompt) logger.trace(f"取名提示词:{qv_name_prompt}\n取名回复:{response}") result = self._extract_json_from_text(response[0]) - if not result["nickname"]: - logger.error("生成的昵称为空,重试中...") + if not result or not result.get("nickname"): + logger.error("生成的昵称为空或结果格式不正确,重试中...") current_try += 1 continue - # 检查生成的昵称是否已存在 - if result["nickname"] not in self.person_name_list.values(): - # 更新数据库和内存中的列表 - await self.update_one_field(person_id, "person_name", result["nickname"]) - # await self.update_one_field(person_id, "nickname", user_nickname) - # await self.update_one_field(person_id, "avatar", user_avatar) - await self.update_one_field(person_id, "name_reason", result["reason"]) + generated_nickname = result["nickname"] - self.person_name_list[person_id] = result["nickname"] - # logger.debug(f"用户 {person_id} 的名称已更新为 {result['nickname']},原因:{result['reason']}") + is_duplicate = False + if generated_nickname in current_name_set: + is_duplicate = True + else: + def _db_check_name_exists_sync(name_to_check): + return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists() + + if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname): + is_duplicate = True + current_name_set.add(generated_nickname) + + if not is_duplicate: + await self.update_one_field(person_id, "person_name", generated_nickname) + await self.update_one_field(person_id, "name_reason", result.get("reason", "未提供理由")) + + self.person_name_list[person_id] = generated_nickname return result else: - existing_names += f"{result['nickname']}、" + if existing_names_str: + existing_names_str += "、" + existing_names_str += generated_nickname + logger.debug(f"生成的昵称 {generated_nickname} 已存在,重试中...") + current_try += 1 - logger.debug(f"生成的昵称 {result['nickname']} 已存在,重试中...") - current_try += 1 - - logger.error(f"在{max_retries}次尝试后仍未能生成唯一昵称") + logger.error(f"在{max_retries}次尝试后仍未能生成唯一昵称 for {person_id}") return None @staticmethod @@ -264,30 +319,56 @@ class PersonInfoManager: logger.debug("删除失败:person_id 不能为空") return - result = db.person_info.delete_one({"person_id": person_id}) - if result.deleted_count > 0: - logger.debug(f"删除成功:person_id={person_id}") + def _db_delete_sync(p_id: str): + try: + query = PersonInfo.delete().where(PersonInfo.person_id == p_id) + deleted_count = query.execute() + return deleted_count + except Exception as e: + logger.error(f"删除 PersonInfo {p_id} 失败 (Peewee): {e}") + return 0 + + deleted_count = await asyncio.to_thread(_db_delete_sync, person_id) + + if deleted_count > 0: + logger.debug(f"删除成功:person_id={person_id} (Peewee)") else: - logger.debug(f"删除失败:未找到 person_id={person_id}") + logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)") @staticmethod async def get_value(person_id: str, field_name: str): """获取指定person_id文档的字段值,若不存在该字段,则返回该字段的全局默认值""" if not person_id: logger.debug("get_value获取失败:person_id不能为空") + return person_info_default.get(field_name) + + if field_name not in PersonInfo._meta.fields: + if field_name in person_info_default: + logger.trace(f"字段'{field_name}'不在Peewee模型中,但存在于默认配置中。返回配置默认值。") + return copy.deepcopy(person_info_default[field_name]) + logger.debug(f"get_value获取失败:字段'{field_name}'未在Peewee模型和默认配置中定义。") return None - if field_name not in person_info_default: - logger.debug(f"get_value获取失败:字段'{field_name}'未定义") + def _db_get_value_sync(p_id: str, f_name: str): + record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) + if record: + val = getattr(record, f_name) + if f_name == "msg_interval_list" and isinstance(val, str): + try: + return json.loads(val) + except json.JSONDecodeError: + logger.warning(f"无法解析 {p_id} 的 msg_interval_list JSON: {val}") + return copy.deepcopy(person_info_default.get(f_name, [])) + return val return None - document = db.person_info.find_one({"person_id": person_id}, {field_name: 1}) + value = await asyncio.to_thread(_db_get_value_sync, person_id, field_name) - if document and field_name in document: - return document[field_name] + if value is not None: + return value else: - default_value = copy.deepcopy(person_info_default[field_name]) - logger.trace(f"获取{person_id}的{field_name}失败,已返回默认值{default_value}") + default_value = copy.deepcopy(person_info_default.get(field_name)) + logger.trace(f"获取{person_id}的{field_name}失败或值为None,已返回默认值{default_value} (Peewee)") return default_value @staticmethod @@ -297,93 +378,82 @@ class PersonInfoManager: logger.debug("get_values获取失败:person_id不能为空") return {} - # 检查所有字段是否有效 - for field in field_names: - if field not in person_info_default: - logger.debug(f"get_values获取失败:字段'{field}'未定义") - return {} - - # 构建查询投影(所有字段都有效才会执行到这里) - projection = {field: 1 for field in field_names} - - document = db.person_info.find_one({"person_id": person_id}, projection) - result = {} - for field in field_names: - result[field] = copy.deepcopy( - document.get(field, person_info_default[field]) if document else person_info_default[field] - ) + + def _db_get_record_sync(p_id: str): + return PersonInfo.get_or_none(PersonInfo.person_id == p_id) + + record = await asyncio.to_thread(_db_get_record_sync, person_id) + + for field_name in field_names: + if field_name not in PersonInfo._meta.fields: + if field_name in person_info_default: + result[field_name] = copy.deepcopy(person_info_default[field_name]) + logger.trace(f"字段'{field_name}'不在Peewee模型中,使用默认配置值。") + else: + logger.debug(f"get_values查询失败:字段'{field_name}'未在Peewee模型和默认配置中定义。") + result[field_name] = None + continue + + if record: + value = getattr(record, field_name) + if field_name == "msg_interval_list" and isinstance(value, str): + try: + result[field_name] = json.loads(value) + except json.JSONDecodeError: + logger.warning(f"无法解析 {person_id} 的 msg_interval_list JSON: {value}") + result[field_name] = copy.deepcopy(person_info_default.get(field_name, [])) + elif value is not None: + result[field_name] = value + else: + result[field_name] = copy.deepcopy(person_info_default.get(field_name)) + else: + result[field_name] = copy.deepcopy(person_info_default.get(field_name)) return result @staticmethod async def del_all_undefined_field(): - """删除所有项里的未定义字段""" - # 获取所有已定义的字段名 - defined_fields = set(person_info_default.keys()) - - try: - # 遍历集合中的所有文档 - for document in db.person_info.find({}): - # 找出文档中未定义的字段 - undefined_fields = set(document.keys()) - defined_fields - {"_id"} - - if undefined_fields: - # 构建更新操作,使用$unset删除未定义字段 - update_result = db.person_info.update_one( - {"_id": document["_id"]}, {"$unset": {field: 1 for field in undefined_fields}} - ) - - if update_result.modified_count > 0: - logger.debug(f"已清理文档 {document['_id']} 的未定义字段: {undefined_fields}") - - return - - except Exception as e: - logger.error(f"清理未定义字段时出错: {e}") - return + """删除所有项里的未定义字段 - 对于Peewee (SQL),此操作通常不适用,因为模式是固定的。""" + logger.info("del_all_undefined_field: 对于使用Peewee的SQL数据库,此操作通常不适用或不需要,因为表结构是预定义的。") + return @staticmethod async def get_specific_value_list( field_name: str, - way: Callable[[Any], bool], # 接受任意类型值 + way: Callable[[Any], bool], ) -> Dict[str, Any]: """ 获取满足条件的字段值字典 - - Args: - field_name: 目标字段名 - way: 判断函数 (value: Any) -> bool - - Returns: - {person_id: value} | {} - - Example: - # 查找所有nickname包含"admin"的用户 - result = manager.specific_value_list( - "nickname", - lambda x: "admin" in x.lower() - ) """ - if field_name not in person_info_default: - logger.error(f"字段检查失败:'{field_name}'未定义") + if field_name not in PersonInfo._meta.fields: + logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo Peewee 模型中定义") return {} + def _db_get_specific_sync(f_name: str): + found_results = {} + try: + for record in PersonInfo.select(PersonInfo.person_id, getattr(PersonInfo, f_name)): + value = getattr(record, f_name) + if f_name == "msg_interval_list" and isinstance(value, str): + try: + processed_value = json.loads(value) + except json.JSONDecodeError: + logger.warning(f"跳过记录 {record.person_id},无法解析 msg_interval_list: {value}") + continue + else: + processed_value = value + + if way(processed_value): + found_results[record.person_id] = processed_value + except Exception as e_query: + logger.error(f"数据库查询失败 (Peewee specific_value_list for {f_name}): {str(e_query)}", exc_info=True) + return found_results + try: - result = {} - for doc in db.person_info.find({field_name: {"$exists": True}}, {"person_id": 1, field_name: 1, "_id": 0}): - try: - value = doc[field_name] - if way(value): - result[doc["person_id"]] = value - except (KeyError, TypeError, ValueError) as e: - logger.debug(f"记录{doc.get('person_id')}处理失败: {str(e)}") - continue - - return result - + return await asyncio.to_thread(_db_get_specific_sync, field_name) except Exception as e: - logger.error(f"数据库查询失败: {str(e)}", exc_info=True) + logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True) return {} async def personal_habit_deduction(self): @@ -391,35 +461,31 @@ class PersonInfoManager: try: while 1: await asyncio.sleep(600) - current_time = datetime.datetime.now() - logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}") + current_time_dt = datetime.datetime.now() + logger.info(f"个人信息推断启动: {current_time_dt.strftime('%Y-%m-%d %H:%M:%S')}") - # "msg_interval"推断 - msg_interval_map = False - msg_interval_lists = await self.get_specific_value_list( + msg_interval_map_generated = False + msg_interval_lists_map = await self.get_specific_value_list( "msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100 ) - for person_id, msg_interval_list_ in msg_interval_lists.items(): + + for person_id, actual_msg_interval_list in msg_interval_lists_map.items(): await asyncio.sleep(0.3) try: time_interval = [] - for t1, t2 in zip(msg_interval_list_, msg_interval_list_[1:]): + for t1, t2 in zip(actual_msg_interval_list, actual_msg_interval_list[1:]): delta = t2 - t1 if delta > 0: time_interval.append(delta) time_interval = [t for t in time_interval if 200 <= t <= 8000] - # --- 修改后的逻辑 --- - # 数据量检查 (至少需要 30 条有效间隔,并且足够进行头尾截断) - if len(time_interval) >= 30 + 10: # 至少30条有效+头尾各5条 - time_interval.sort() - # 画图(log) - 这部分保留 - msg_interval_map = True + if len(time_interval) >= 30 + 10: + time_interval.sort() + msg_interval_map_generated = True log_dir = Path("logs/person_info") log_dir.mkdir(parents=True, exist_ok=True) plt.figure(figsize=(10, 6)) - # 使用截断前的数据画图,更能反映原始分布 time_series_original = pd.Series(time_interval) plt.hist( time_series_original, @@ -441,34 +507,27 @@ class PersonInfoManager: img_path = log_dir / f"interval_distribution_{person_id[:8]}.png" plt.savefig(img_path) plt.close() - # 画图结束 - # 去掉头尾各 5 个数据点 trimmed_interval = time_interval[5:-5] - - # 计算截断后数据的 37% 分位数 - if trimmed_interval: # 确保截断后列表不为空 - msg_interval = int(round(np.percentile(trimmed_interval, 37))) - # 更新数据库 - await self.update_one_field(person_id, "msg_interval", msg_interval) - logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval}") + if trimmed_interval: + msg_interval_val = int(round(np.percentile(trimmed_interval, 37))) + await self.update_one_field(person_id, "msg_interval", msg_interval_val) + logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}") else: logger.trace(f"用户{person_id}截断后数据为空,无法计算msg_interval") else: logger.trace( f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)" ) - # --- 修改结束 --- - except Exception as e: - logger.trace(f"用户{person_id}消息间隔计算失败: {type(e).__name__}: {str(e)}") + except Exception as e_inner: + logger.trace(f"用户{person_id}消息间隔计算失败: {type(e_inner).__name__}: {str(e_inner)}") continue - # 其他... - - if msg_interval_map: + if msg_interval_map_generated: logger.trace("已保存分布图到: logs/person_info") - current_time = datetime.datetime.now() - logger.trace(f"个人信息推断结束: {current_time.strftime('%Y-%m-%d %H:%M:%S')}") + + current_time_dt_end = datetime.datetime.now() + logger.trace(f"个人信息推断结束: {current_time_dt_end.strftime('%Y-%m-%d %H:%M:%S')}") await asyncio.sleep(86400) except Exception as e: @@ -481,41 +540,27 @@ class PersonInfoManager: """ 根据 platform 和 user_id 获取 person_id。 如果对应的用户不存在,则使用提供的可选信息创建新用户。 - - Args: - platform: 平台标识 - user_id: 用户在该平台上的ID - nickname: 用户的昵称 (可选,用于创建新用户) - user_cardname: 用户的群名片 (可选,用于创建新用户) - user_avatar: 用户的头像信息 (可选,用于创建新用户) - - Returns: - 对应的 person_id。 """ person_id = self.get_person_id(platform, user_id) - # 检查用户是否已存在 - # 使用静态方法 get_person_id,因此可以直接调用 db - document = db.person_info.find_one({"person_id": person_id}) + def _db_check_exists_sync(p_id: str): + return PersonInfo.get_or_none(PersonInfo.person_id == p_id) - if document is None: - logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。") + record = await asyncio.to_thread(_db_check_exists_sync, person_id) + + if record is None: + logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。") initial_data = { "platform": platform, - "user_id": user_id, + "user_id": str(user_id), "nickname": nickname, - "konw_time": int(datetime.datetime.now().timestamp()), # 添加初次认识时间 - # 注意:这里没有添加 user_cardname 和 user_avatar,因为它们不在 person_info_default 中 - # 如果需要存储它们,需要先在 person_info_default 中定义 + "konw_time": int(datetime.datetime.now().timestamp()), } - # 过滤掉值为 None 的初始数据 - initial_data = {k: v for k, v in initial_data.items() if v is not None} + model_fields = PersonInfo._meta.fields.keys() + filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields} - # 注意:create_person_info 是静态方法 - await PersonInfoManager.create_person_info(person_id, data=initial_data) - # 创建后,可以考虑立即为其取名,但这可能会增加延迟 - # await self.qv_person_name(person_id, nickname, user_cardname, user_avatar) - logger.debug(f"已为 {person_id} 创建新记录,初始数据: {initial_data}") + await self.create_person_info(person_id, data=filtered_initial_data) + logger.debug(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}") return person_id @@ -525,35 +570,49 @@ class PersonInfoManager: logger.debug("get_person_info_by_name 获取失败:person_name 不能为空") return None - # 优先从内存缓存查找 person_id found_person_id = None - for pid, name in self.person_name_list.items(): - if name == person_name: + for pid, name_in_cache in self.person_name_list.items(): + if name_in_cache == person_name: found_person_id = pid - break # 找到第一个匹配就停止 + break if not found_person_id: - # 如果内存没有,尝试数据库查询(可能内存未及时更新或启动时未加载) - document = db.person_info.find_one({"person_name": person_name}) - if document: - found_person_id = document.get("person_id") - else: - logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户") - return None # 数据库也找不到 + def _db_find_by_name_sync(p_name_to_find: str): + return PersonInfo.get_or_none(PersonInfo.person_name == p_name_to_find) - # 根据找到的 person_id 获取所需信息 - if found_person_id: - required_fields = ["person_id", "platform", "user_id", "nickname", "user_cardname", "user_avatar"] - person_data = await self.get_values(found_person_id, required_fields) - if person_data: # 确保 get_values 成功返回 - return person_data + record = await asyncio.to_thread(_db_find_by_name_sync, person_name) + if record: + found_person_id = record.person_id + if found_person_id not in self.person_name_list or self.person_name_list[found_person_id] != person_name: + self.person_name_list[found_person_id] = person_name else: - logger.warning(f"找到了 person_id '{found_person_id}' 但获取详细信息失败") + logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)") return None - else: - # 这理论上不应该发生,因为上面已经处理了找不到的情况 - logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id") - return None + + if found_person_id: + required_fields = [ + "person_id", + "platform", + "user_id", + "nickname", + "user_cardname", + "user_avatar", + "person_name", + "name_reason", + ] + valid_fields_to_get = [f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default] + + person_data = await self.get_values(found_person_id, valid_fields_to_get) + + if person_data: + final_result = {key: person_data.get(key) for key in required_fields} + return final_result + else: + logger.warning(f"找到了 person_id '{found_person_id}' 但 get_values 返回空 (Peewee)") + return None + + logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id (Peewee)") + return None person_info_manager = PersonInfoManager() diff --git a/src/chat/utils/info_catcher.py b/src/chat/utils/info_catcher.py index b7f59c661..fb8182973 100644 --- a/src/chat/utils/info_catcher.py +++ b/src/chat/utils/info_catcher.py @@ -1,9 +1,10 @@ from src.config.config import global_config from src.chat.message_receive.message import MessageRecv, MessageSending, Message -from common.database.database import db +from src.common.database.database_model import Messages, ThinkingLog import time import traceback from typing import List +import json class InfoCatcher: @@ -60,8 +61,6 @@ class InfoCatcher: def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息 self.timing_results["sub_heartflow_observe_time"] = obs_duration - # def catch_shf - def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str): self.timing_results["sub_heartflow_step_time"] = step_duration if len(past_mind) > 1: @@ -72,25 +71,10 @@ class InfoCatcher: self.heartflow_data["sub_heartflow_now"] = current_mind def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""): - # if self.response_mode == "heart_flow": # 条件判断不需要了喵~ - # self.heartflow_data["prompt"] = prompt - # self.heartflow_data["response"] = response - # self.heartflow_data["model"] = model_name - # elif self.response_mode == "reasoning": # 条件判断不需要了喵~ - # self.reasoning_data["thinking_log"] = reasoning_content - # self.reasoning_data["prompt"] = prompt - # self.reasoning_data["response"] = response - # self.reasoning_data["model"] = model_name - - # 直接记录信息喵~ self.reasoning_data["thinking_log"] = reasoning_content self.reasoning_data["prompt"] = prompt self.reasoning_data["response"] = response self.reasoning_data["model"] = model_name - # 如果 heartflow 数据也需要通用字段,可以取消下面的注释喵~ - # self.heartflow_data["prompt"] = prompt - # self.heartflow_data["response"] = response - # self.heartflow_data["model"] = model_name self.response_text = response @@ -102,6 +86,7 @@ class InfoCatcher: ): self.timing_results["make_response_time"] = response_duration self.response_time = time.time() + self.response_messages = [] for msg in response_message: self.response_messages.append(msg) @@ -112,107 +97,110 @@ class InfoCatcher: @staticmethod def get_message_from_db_between_msgs(message_start: Message, message_end: Message): try: - # 从数据库中获取消息的时间戳 time_start = message_start.message_info.time time_end = message_end.message_info.time chat_id = message_start.chat_stream.stream_id print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}") - # 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据 - messages_between = db.messages.find( - {"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}} - ).sort("time", -1) + messages_between_query = Messages.select().where( + (Messages.chat_id == chat_id) & + (Messages.time > time_start) & + (Messages.time < time_end) + ).order_by(Messages.time.desc()) - result = list(messages_between) + result = list(messages_between_query) print(f"查询结果数量: {len(result)}") if result: - print(f"第一条消息时间: {result[0]['time']}") - print(f"最后一条消息时间: {result[-1]['time']}") + print(f"第一条消息时间: {result[0].time}") + print(f"最后一条消息时间: {result[-1].time}") return result except Exception as e: print(f"获取消息时出错: {str(e)}") + print(traceback.format_exc()) return [] def get_message_from_db_before_msg(self, message: MessageRecv): - # 从数据库中获取消息 - message_id = message.message_info.message_id - chat_id = message.chat_stream.stream_id + message_id_val = message.message_info.message_id + chat_id_val = message.chat_stream.stream_id - # 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据 - messages_before = ( - db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}}) - .sort("time", -1) - .limit(self.context_length * 3) - ) # 获取更多历史信息 + messages_before_query = Messages.select().where( + (Messages.chat_id == chat_id_val) & + (Messages.message_id < message_id_val) + ).order_by(Messages.time.desc()).limit(self.context_length * 3) - return list(messages_before) + return list(messages_before_query) def message_list_to_dict(self, message_list): - # 存储简化的聊天记录 result = [] - for message in message_list: - if not isinstance(message, dict): - message = self.message_to_dict(message) - # print(message) + for msg_item in message_list: + processed_msg_item = msg_item + if not isinstance(msg_item, dict): + processed_msg_item = self.message_to_dict(msg_item) + + if not processed_msg_item: + continue lite_message = { - "time": message["time"], - "user_nickname": message["user_info"]["user_nickname"], - "processed_plain_text": message["processed_plain_text"], + "time": processed_msg_item.get("time"), + "user_nickname": processed_msg_item.get("user_nickname"), + "processed_plain_text": processed_msg_item.get("processed_plain_text"), } result.append(lite_message) - return result @staticmethod - def message_to_dict(message): - if not message: + def message_to_dict(msg_obj): + if not msg_obj: return None - if isinstance(message, dict): - return message - return { - # "message_id": message.message_info.message_id, - "time": message.message_info.time, - "user_id": message.message_info.user_info.user_id, - "user_nickname": message.message_info.user_info.user_nickname, - "processed_plain_text": message.processed_plain_text, - # "detailed_plain_text": message.detailed_plain_text - } + if isinstance(msg_obj, dict): + return msg_obj + + if isinstance(msg_obj, Messages): + return { + "time": msg_obj.time, + "user_id": msg_obj.user_id, + "user_nickname": msg_obj.user_nickname, + "processed_plain_text": msg_obj.processed_plain_text, + } + + if hasattr(msg_obj, 'message_info') and hasattr(msg_obj.message_info, 'user_info'): + return { + "time": msg_obj.message_info.time, + "user_id": msg_obj.message_info.user_info.user_id, + "user_nickname": msg_obj.message_info.user_info.user_nickname, + "processed_plain_text": msg_obj.processed_plain_text, + } + + print(f"Warning: message_to_dict received an unhandled type: {type(msg_obj)}") + return {} def done_catch(self): - """将收集到的信息存储到数据库的 thinking_log 集合中喵~""" + """将收集到的信息存储到数据库的 thinking_log 表中喵~""" try: - # 将消息对象转换为可序列化的字典喵~ - - thinking_log_data = { - "chat_id": self.chat_id, - "trigger_text": self.trigger_response_text, - "response_text": self.response_text, - "trigger_info": { - "time": self.trigger_response_time, - "message": self.message_to_dict(self.trigger_response_message), - }, - "response_info": { - "time": self.response_time, - "message": self.response_messages, - }, - "timing_results": self.timing_results, - "chat_history": self.message_list_to_dict(self.chat_history), - "chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking), - "chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response), - "heartflow_data": self.heartflow_data, - "reasoning_data": self.reasoning_data, + trigger_info_dict = self.message_to_dict(self.trigger_response_message) + response_info_dict = { + "time": self.response_time, + "message": self.response_messages, } + chat_history_list = self.message_list_to_dict(self.chat_history) + chat_history_in_thinking_list = self.message_list_to_dict(self.chat_history_in_thinking) + chat_history_after_response_list = self.message_list_to_dict(self.chat_history_after_response) - # 根据不同的响应模式添加相应的数据喵~ # 现在直接都加上去好了喵~ - # if self.response_mode == "heart_flow": - # thinking_log_data["mode_specific_data"] = self.heartflow_data - # elif self.response_mode == "reasoning": - # thinking_log_data["mode_specific_data"] = self.reasoning_data - - # 将数据插入到 thinking_log 集合中喵~ - db.thinking_log.insert_one(thinking_log_data) + log_entry = ThinkingLog( + chat_id=self.chat_id, + trigger_text=self.trigger_response_text, + response_text=self.response_text, + trigger_info_json=json.dumps(trigger_info_dict) if trigger_info_dict else None, + response_info_json=json.dumps(response_info_dict), + timing_results_json=json.dumps(self.timing_results), + chat_history_json=json.dumps(chat_history_list), + chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list), + chat_history_after_response_json=json.dumps(chat_history_after_response_list), + heartflow_data_json=json.dumps(self.heartflow_data), + reasoning_data_json=json.dumps(self.reasoning_data) + ) + log_entry.save() return True except Exception as e: diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 4bcf6fea0..9a0131f74 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -5,7 +5,8 @@ from typing import Any, Dict, Tuple, List from src.common.logger import get_module_logger from src.manager.async_task_manager import AsyncTask -from ...common.database.database import db +from ...common.database.database import db # This db is the Peewee database instance +from ...common.database.database_model import OnlineTime # Import the Peewee model from src.manager.local_store_manager import local_storage logger = get_module_logger("maibot_statistic") @@ -39,7 +40,7 @@ class OnlineTimeRecordTask(AsyncTask): def __init__(self): super().__init__(task_name="Online Time Record Task", run_interval=60) - self.record_id: str | None = None + self.record_id: int | None = None # Changed to int for Peewee's default ID """记录ID""" self._init_database() # 初始化数据库 @@ -47,53 +48,46 @@ class OnlineTimeRecordTask(AsyncTask): @staticmethod def _init_database(): """初始化数据库""" - if "online_time" not in db.list_collection_names(): - # 初始化数据库(在线时长) - db.create_collection("online_time") - # 创建索引 - if ("end_timestamp", 1) not in db.online_time.list_indexes(): - db.online_time.create_index([("end_timestamp", 1)]) + with db.atomic(): # Use atomic operations for schema changes + OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model async def run(self): try: + current_time = datetime.now() + extended_end_time = current_time + timedelta(minutes=1) + if self.record_id: # 如果有记录,则更新结束时间 - db.online_time.update_one( - {"_id": self.record_id}, - { - "$set": { - "end_timestamp": datetime.now() + timedelta(minutes=1), - } - }, - ) - else: + query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) + updated_rows = query.execute() + if updated_rows == 0: + # Record might have been deleted or ID is stale, try to find/create + self.record_id = None # Reset record_id to trigger find/create logic below + + if not self.record_id: # Check again if record_id was reset or initially None # 如果没有记录,检查一分钟以内是否已有记录 - current_time = datetime.now() - if recent_record := db.online_time.find_one( - {"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}} - ): + # Look for a record whose end_timestamp is recent enough to be considered ongoing + recent_record = OnlineTime.select().where( + OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1)) + ).order_by(OnlineTime.end_timestamp.desc()).first() + + if recent_record: # 如果有记录,则更新结束时间 - self.record_id = recent_record["_id"] - db.online_time.update_one( - {"_id": self.record_id}, - { - "$set": { - "end_timestamp": current_time + timedelta(minutes=1), - } - }, - ) + self.record_id = recent_record.id + recent_record.end_timestamp = extended_end_time + recent_record.save() else: # 若没有记录,则插入新的在线时间记录 - self.record_id = db.online_time.insert_one( - { - "start_timestamp": current_time, - "end_timestamp": current_time + timedelta(minutes=1), - } - ).inserted_id + new_record = OnlineTime.create( + start_timestamp=current_time, + end_timestamp=extended_end_time, + ) + self.record_id = new_record.id except Exception as e: logger.error(f"在线时间记录失败,错误信息:{e}") + def _format_online_time(online_seconds: int) -> str: """ 格式化在线时间 diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 6fbafc905..11e7bed06 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -9,6 +9,7 @@ import numpy as np from ...common.database.database import db +from ...common.database.database_model import Images, ImageDescriptions from ...config.config import global_config from ..models.utils_model import LLMRequest @@ -32,40 +33,21 @@ class ImageManager: def __init__(self): if not self._initialized: - self._ensure_image_collection() - self._ensure_description_collection() self._ensure_image_dir() - self._initialized = True self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image") + + try: + db.connect(reuse_if_open=True) + db.create_tables([Images, ImageDescriptions], safe=True) + except Exception as e: + logger.error(f"数据库连接或表创建失败: {e}") + + self._initialized = True def _ensure_image_dir(self): """确保图像存储目录存在""" os.makedirs(self.IMAGE_DIR, exist_ok=True) - @staticmethod - def _ensure_image_collection(): - """确保images集合存在并创建索引""" - if "images" not in db.list_collection_names(): - db.create_collection("images") - - # 删除旧索引 - db.images.drop_indexes() - # 创建新的复合索引 - db.images.create_index([("hash", 1), ("type", 1)], unique=True) - db.images.create_index([("url", 1)]) - db.images.create_index([("path", 1)]) - - @staticmethod - def _ensure_description_collection(): - """确保image_descriptions集合存在并创建索引""" - if "image_descriptions" not in db.list_collection_names(): - db.create_collection("image_descriptions") - - # 删除旧索引 - db.image_descriptions.drop_indexes() - # 创建新的复合索引 - db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True) - @staticmethod def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]: """从数据库获取图片描述 @@ -77,8 +59,15 @@ class ImageManager: Returns: Optional[str]: 描述文本,如果不存在则返回None """ - result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type}) - return result["description"] if result else None + try: + record = ImageDescriptions.get_or_none( + (ImageDescriptions.hash == image_hash) & + (ImageDescriptions.type == description_type) + ) + return record.description if record else None + except Exception as e: + logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}") + return None @staticmethod def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None: @@ -90,20 +79,22 @@ class ImageManager: description_type: 描述类型 ('emoji' 或 'image') """ try: - db.image_descriptions.update_one( - {"hash": image_hash, "type": description_type}, - { - "$set": { - "description": description, - "timestamp": int(time.time()), - "hash": image_hash, # 确保hash字段存在 - "type": description_type, # 确保type字段存在 - } - }, - upsert=True, + current_timestamp = time.time() + defaults = { + 'description': description, + 'timestamp': current_timestamp + } + desc_obj, created = ImageDescriptions.get_or_create( + hash=image_hash, + type=description_type, + defaults=defaults ) + if not created: # 如果记录已存在,则更新 + desc_obj.description = description + desc_obj.timestamp = current_timestamp + desc_obj.save() except Exception as e: - logger.error(f"保存描述到数据库失败: {str(e)}") + logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}") async def get_emoji_description(self, image_base64: str) -> str: """获取表情包描述,带查重和保存功能""" @@ -116,18 +107,25 @@ class ImageManager: # 查询缓存的描述 cached_description = self._get_description_from_db(image_hash, "emoji") if cached_description: - # logger.debug(f"缓存表情包描述: {cached_description}") return f"[表情包,含义看起来是:{cached_description}]" # 调用AI获取描述 if image_format == "gif" or image_format == "GIF": - image_base64 = self.transform_gif(image_base64) + image_base64_processed = self.transform_gif(image_base64) + if image_base64_processed is None: + logger.warning("GIF转换失败,无法获取描述") + return "[表情包(GIF处理失败)]" prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,使用1-2个词描述一下表情包表达的情感和内容,简短一些" - description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg") + description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg") else: prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些" description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) + + if description is None: + logger.warning("AI未能生成表情包描述") + return "[表情包(描述生成失败)]" + # 再次检查缓存,防止并发写入时重复生成 cached_description = self._get_description_from_db(image_hash, "emoji") if cached_description: logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}") @@ -136,31 +134,37 @@ class ImageManager: # 根据配置决定是否保存图片 if global_config.save_emoji: # 生成文件名和路径 - timestamp = int(time.time()) - filename = f"{timestamp}_{image_hash[:8]}.{image_format}" - if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")): - os.makedirs(os.path.join(self.IMAGE_DIR, "emoji")) - file_path = os.path.join(self.IMAGE_DIR, "emoji", filename) + current_timestamp = time.time() + filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}" + emoji_dir = os.path.join(self.IMAGE_DIR, "emoji") + os.makedirs(emoji_dir, exist_ok=True) + file_path = os.path.join(emoji_dir, filename) try: # 保存文件 with open(file_path, "wb") as f: f.write(image_bytes) - # 保存到数据库 - image_doc = { - "hash": image_hash, - "path": file_path, - "type": "emoji", - "description": description, - "timestamp": timestamp, - } - db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True) - logger.trace(f"保存表情包: {file_path}") + # 保存到数据库 (Images表) + try: + img_obj = Images.get((Images.hash == image_hash) & (Images.type == "emoji")) + img_obj.path = file_path + img_obj.description = description + img_obj.timestamp = current_timestamp + img_obj.save() + except Images.DoesNotExist: + Images.create( + hash=image_hash, + path=file_path, + type="emoji", + description=description, + timestamp=current_timestamp, + ) + logger.trace(f"保存表情包元数据: {file_path}") except Exception as e: - logger.error(f"保存表情包文件失败: {str(e)}") + logger.error(f"保存表情包文件或元数据失败: {str(e)}") - # 保存描述到数据库 + # 保存描述到数据库 (ImageDescriptions表) self._save_description_to_db(image_hash, description, "emoji") return f"[表情包:{description}]" @@ -187,7 +191,12 @@ class ImageManager: "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多100个字。" ) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) + + if description is None: + logger.warning("AI未能生成图片描述") + return "[图片(描述生成失败)]" + # 再次检查缓存 cached_description = self._get_description_from_db(image_hash, "image") if cached_description: logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}") @@ -195,38 +204,40 @@ class ImageManager: logger.debug(f"描述是{description}") - if description is None: - logger.warning("AI未能生成图片描述") - return "[图片]" - # 根据配置决定是否保存图片 if global_config.save_pic: # 生成文件名和路径 - timestamp = int(time.time()) - filename = f"{timestamp}_{image_hash[:8]}.{image_format}" - if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")): - os.makedirs(os.path.join(self.IMAGE_DIR, "image")) - file_path = os.path.join(self.IMAGE_DIR, "image", filename) + current_timestamp = time.time() + filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}" + image_dir = os.path.join(self.IMAGE_DIR, "image") + os.makedirs(image_dir, exist_ok=True) + file_path = os.path.join(image_dir, filename) try: # 保存文件 with open(file_path, "wb") as f: f.write(image_bytes) - # 保存到数据库 - image_doc = { - "hash": image_hash, - "path": file_path, - "type": "image", - "description": description, - "timestamp": timestamp, - } - db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True) - logger.trace(f"保存图片: {file_path}") + # 保存到数据库 (Images表) + try: + img_obj = Images.get((Images.hash == image_hash) & (Images.type == "image")) + img_obj.path = file_path + img_obj.description = description + img_obj.timestamp = current_timestamp + img_obj.save() + except Images.DoesNotExist: + Images.create( + hash=image_hash, + path=file_path, + type="image", + description=description, + timestamp=current_timestamp, + ) + logger.trace(f"保存图片元数据: {file_path}") except Exception as e: - logger.error(f"保存图片文件失败: {str(e)}") + logger.error(f"保存图片文件或元数据失败: {str(e)}") - # 保存描述到数据库 + # 保存描述到数据库 (ImageDescriptions表) self._save_description_to_db(image_hash, description, "image") return f"[图片:{description}]" diff --git a/src/common/database/database.py b/src/common/database/database.py index 752f746db..a2dab739d 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -1,5 +1,6 @@ import os from pymongo import MongoClient +from peewee import SqliteDatabase from pymongo.database import Database from rich.traceback import install @@ -57,4 +58,15 @@ class DBWrapper: # 全局数据库访问点 -db: Database = DBWrapper() +memory_db: Database = DBWrapper() + +# 定义数据库文件路径 +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +_DB_DIR = os.path.join(ROOT_PATH, "data") +_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db") + +# 确保数据库目录存在 +os.makedirs(_DB_DIR, exist_ok=True) + +# 全局 Peewee SQLite 数据库访问点 +db = SqliteDatabase(_DB_FILE) diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index c1135a33d..b46cace9f 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -1,9 +1,10 @@ -from peewee import Model, DoubleField, IntegerField, SqliteDatabase, BooleanField, TextField, FloatField - +from peewee import Model, DoubleField, IntegerField, BooleanField, TextField, FloatField, DateTimeField +from .database import db +import datetime # 请在此处定义您的数据库实例。 # 您需要取消注释并配置适合您的数据库的部分。 # 例如,对于 SQLite: -db = SqliteDatabase('my_application.db') +# db = SqliteDatabase('MaiBot.db') # # 对于 PostgreSQL: # db = PostgresqlDatabase('your_db_name', user='your_user', password='your_password', @@ -69,17 +70,16 @@ class LLMUsage(BaseModel): """ 用于存储 API 使用日志数据的模型。 """ - model_name = TextField() - user_id = TextField() - request_type = TextField() + model_name = TextField(index=True) # 添加索引 + user_id = TextField(index=True) # 添加索引 + request_type = TextField(index=True) # 添加索引 endpoint = TextField() prompt_tokens = IntegerField() completion_tokens = IntegerField() total_tokens = IntegerField() cost = DoubleField() status = TextField() - # timestamp: "$date": "2025-05-01T18:52:50.870Z" (存储为字符串) - timestamp = TextField() + timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引 class Meta: # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 @@ -177,6 +177,8 @@ class OnlineTime(BaseModel): # timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串) timestamp = TextField() duration = IntegerField() # 时长,单位分钟 + start_timestamp = DateTimeField(default=datetime.datetime.now) + end_timestamp = DateTimeField(index=True) class Meta: # database = db # 继承自 BaseModel @@ -202,3 +204,39 @@ class PersonInfo(BaseModel): # database = db # 继承自 BaseModel table_name = 'person_info' +class Knowledges(BaseModel): + """ + 用于存储知识库条目的模型。 + """ + content = TextField() # 知识内容的文本 + embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表 + # 可以添加其他元数据字段,如 source, create_time 等 + + class Meta: + # database = db # 继承自 BaseModel + table_name = 'knowledges' + + +class ThinkingLog(BaseModel): + chat_id = TextField(index=True) + trigger_text = TextField(null=True) + response_text = TextField(null=True) + + # Store complex dicts/lists as JSON strings + trigger_info_json = TextField(null=True) + response_info_json = TextField(null=True) + timing_results_json = TextField(null=True) + chat_history_json = TextField(null=True) + chat_history_in_thinking_json = TextField(null=True) + chat_history_after_response_json = TextField(null=True) + heartflow_data_json = TextField(null=True) + reasoning_data_json = TextField(null=True) + + # Add a timestamp for the log entry itself + # Ensure you have: from peewee import DateTimeField + # And: import datetime + created_at = DateTimeField(default=datetime.datetime.now) + + class Meta: + table_name = 'thinking_logs' + diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 03eaba332..7d987ace9 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -1,11 +1,19 @@ -from common.database.database import db +from src.common.database.database_model import Messages # 更改导入 from src.common.logger import get_module_logger import traceback from typing import List, Any, Optional +from peewee import Model # 添加 Peewee Model 导入 logger = get_module_logger(__name__) +def _model_to_dict(model_instance: Model) -> dict[str, Any]: + """ + 将 Peewee 模型实例转换为字典。 + """ + return model_instance.__data__ + + def find_messages( message_filter: dict[str, Any], sort: Optional[List[tuple[str, int]]] = None, @@ -16,39 +24,72 @@ def find_messages( 根据提供的过滤器、排序和限制条件查找消息。 Args: - message_filter: MongoDB 查询过滤器。 - sort: MongoDB 排序条件列表,例如 [('time', 1)]。仅在 limit 为 0 时生效。 + message_filter: 查询过滤器字典,键为模型字段名,值为期望值。 + sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。 limit: 返回的最大文档数,0表示不限制。 limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'。 Returns: - 消息文档列表,如果出错则返回空列表。 + 消息字典列表,如果出错则返回空列表。 """ try: - query = db.messages.find(message_filter) + query = Messages.select() + + # 应用过滤器 + if message_filter: + conditions = [] + for key, value in message_filter.items(): + if hasattr(Messages, key): + conditions.append(getattr(Messages, key) == value) + else: + logger.warning( + f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。" + ) + if conditions: + # 使用 *conditions 将所有条件以 AND 连接 + query = query.where(*conditions) if limit > 0: if limit_mode == "earliest": # 获取时间最早的 limit 条记录,已经是正序 - query = query.sort([("time", 1)]).limit(limit) - results = list(query) + query = query.order_by(Messages.time.asc()).limit(limit) + peewee_results = list(query) else: # 默认为 'latest' # 获取时间最晚的 limit 条记录 - query = query.sort([("time", -1)]).limit(limit) - latest_results = list(query) + query = query.order_by(Messages.time.desc()).limit(limit) + latest_results_peewee = list(query) # 将结果按时间正序排列 - # 假设消息文档中总是有 'time' 字段且可排序 - results = sorted(latest_results, key=lambda msg: msg.get("time")) + peewee_results = sorted( + latest_results_peewee, key=lambda msg: msg.time + ) else: # limit 为 0 时,应用传入的 sort 参数 if sort: - query = query.sort(sort) - results = list(query) + peewee_sort_terms = [] + for field_name, direction in sort: + if hasattr(Messages, field_name): + field = getattr(Messages, field_name) + if direction == 1: # ASC + peewee_sort_terms.append(field.asc()) + elif direction == -1: # DESC + peewee_sort_terms.append(field.desc()) + else: + logger.warning( + f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。" + ) + else: + logger.warning( + f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。" + ) + if peewee_sort_terms: + query = query.order_by(*peewee_sort_terms) + peewee_results = list(query) + results = [_model_to_dict(msg) for msg in peewee_results] return results except Exception as e: log_message = ( - f"查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" + f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" + traceback.format_exc() ) logger.error(log_message) @@ -60,18 +101,35 @@ def count_messages(message_filter: dict[str, Any]) -> int: 根据提供的过滤器计算消息数量。 Args: - message_filter: MongoDB 查询过滤器。 + message_filter: 查询过滤器字典,键为模型字段名,值为期望值。 Returns: 符合条件的消息数量,如果出错则返回 0。 """ try: - count = db.messages.count_documents(message_filter) + query = Messages.select() + + # 应用过滤器 + if message_filter: + conditions = [] + for key, value in message_filter.items(): + if hasattr(Messages, key): + conditions.append(getattr(Messages, key) == value) + else: + logger.warning( + f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。" + ) + if conditions: + query = query.where(*conditions) + + count = query.count() return count except Exception as e: - log_message = f"计数消息失败 (message_filter={message_filter}): {e}\n" + traceback.format_exc() + log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n" + traceback.format_exc() logger.error(log_message) return 0 # 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 +# 注意:对于 Peewee,插入操作通常是 Messages.create(...) 或 instance.save()。 +# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。 diff --git a/src/experimental/PFC/chat_observer.py b/src/experimental/PFC/chat_observer.py index 704eeb330..e9e64053f 100644 --- a/src/experimental/PFC/chat_observer.py +++ b/src/experimental/PFC/chat_observer.py @@ -10,7 +10,7 @@ from src.experimental.PFC.chat_states import ( create_new_message_notification, create_cold_chat_notification, ) -from src.experimental.PFC.message_storage import MongoDBMessageStorage +from src.experimental.PFC.message_storage import PeeweeMessageStorage from rich.traceback import install install(extra_lines=3) @@ -53,7 +53,7 @@ class ChatObserver: self.stream_id = stream_id self.private_name = private_name - self.message_storage = MongoDBMessageStorage() + self.message_storage = PeeweeMessageStorage() # self.last_user_speak_time: Optional[float] = None # 对方上次发言时间 # self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间 diff --git a/src/experimental/PFC/message_storage.py b/src/experimental/PFC/message_storage.py index 24866e38c..6e109fac3 100644 --- a/src/experimental/PFC/message_storage.py +++ b/src/experimental/PFC/message_storage.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from typing import List, Dict, Any -from common.database.database import db +# from src.common.database.database import db # Peewee db 导入 +from src.common.database.database_model import Messages # Peewee Messages 模型导入 +from playhouse.shortcuts import model_to_dict # 用于将模型实例转换为字典 class MessageStorage(ABC): @@ -47,28 +49,35 @@ class MessageStorage(ABC): pass -class MongoDBMessageStorage(MessageStorage): - """MongoDB消息存储实现""" +class PeeweeMessageStorage(MessageStorage): + """Peewee消息存储实现""" async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]: - query = {"chat_id": chat_id, "time": {"$gt": message_time}} - # print(f"storage_check_message: {message_time}") + query = Messages.select().where( + (Messages.chat_id == chat_id) & + (Messages.time > message_time) + ).order_by(Messages.time.asc()) - return list(db.messages.find(query).sort("time", 1)) + # print(f"storage_check_message: {message_time}") + messages_models = list(query) + return [model_to_dict(msg) for msg in messages_models] async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: - query = {"chat_id": chat_id, "time": {"$lt": time_point}} - - messages = list(db.messages.find(query).sort("time", -1).limit(limit)) + query = Messages.select().where( + (Messages.chat_id == chat_id) & + (Messages.time < time_point) + ).order_by(Messages.time.desc()).limit(limit) + messages_models = list(query) # 将消息按时间正序排列 - messages.reverse() - return messages + messages_models.reverse() + return [model_to_dict(msg) for msg in messages_models] async def has_new_messages(self, chat_id: str, after_time: float) -> bool: - query = {"chat_id": chat_id, "time": {"$gt": after_time}} - - return db.messages.find_one(query) is not None + return Messages.select().where( + (Messages.chat_id == chat_id) & + (Messages.time > after_time) + ).exists() # # 创建一个内存消息存储实现,用于测试 diff --git a/src/tools/tool_can_use/get_knowledge.py b/src/tools/tool_can_use/get_knowledge.py index 2a4922f9f..4ff62b7c2 100644 --- a/src/tools/tool_can_use/get_knowledge.py +++ b/src/tools/tool_can_use/get_knowledge.py @@ -1,8 +1,10 @@ from src.tools.tool_can_use.base_tool import BaseTool from src.chat.utils.utils import get_embedding -from common.database.database import db +from src.common.database.database_model import Knowledges # Updated import from src.common.logger_manager import get_logger -from typing import Any, Union +from typing import Any, Union, List # Added List +import json # Added for parsing embedding +import math # Added for cosine similarity logger = get_logger("get_knowledge_tool") @@ -30,6 +32,7 @@ class SearchKnowledgeTool(BaseTool): Returns: dict: 工具执行结果 """ + query = "" # Initialize query to ensure it's defined in except block try: query = function_args.get("query") threshold = function_args.get("threshold", 0.4) @@ -48,9 +51,19 @@ class SearchKnowledgeTool(BaseTool): logger.error(f"知识库搜索工具执行失败: {str(e)}") return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"} + @staticmethod + def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float: + """计算两个向量之间的余弦相似度""" + dot_product = sum(p * q for p, q in zip(vec1, vec2)) + magnitude1 = math.sqrt(sum(p * p for p in vec1)) + magnitude2 = math.sqrt(sum(q * q for q in vec2)) + if magnitude1 == 0 or magnitude2 == 0: + return 0.0 + return dot_product / (magnitude1 * magnitude2) + @staticmethod def get_info_from_db( - query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False + query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False ) -> Union[str, list]: """从数据库中获取相关信息 @@ -66,66 +79,49 @@ class SearchKnowledgeTool(BaseTool): if not query_embedding: return "" if not return_raw else [] - # 使用余弦相似度计算 - pipeline = [ - { - "$addFields": { - "dotProduct": { - "$reduce": { - "input": {"$range": [0, {"$size": "$embedding"}]}, - "initialValue": 0, - "in": { - "$add": [ - "$$value", - { - "$multiply": [ - {"$arrayElemAt": ["$embedding", "$$this"]}, - {"$arrayElemAt": [query_embedding, "$$this"]}, - ] - }, - ] - }, - } - }, - "magnitude1": { - "$sqrt": { - "$reduce": { - "input": "$embedding", - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - "magnitude2": { - "$sqrt": { - "$reduce": { - "input": query_embedding, - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - } - }, - {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, - { - "$match": { - "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 - } - }, - {"$sort": {"similarity": -1}}, - {"$limit": limit}, - {"$project": {"content": 1, "similarity": 1}}, - ] + similar_items = [] + try: + all_knowledges = Knowledges.select() + for item in all_knowledges: + try: + item_embedding_str = item.embedding + if not item_embedding_str: + logger.warning(f"Knowledge item ID {item.id} has empty embedding string.") + continue + item_embedding = json.loads(item_embedding_str) + if not isinstance(item_embedding, list) or not all(isinstance(x, (int, float)) for x in item_embedding): + logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.") + continue + except json.JSONDecodeError: + logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}") + continue + except AttributeError: + logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.") + continue - results = list(db.knowledges.aggregate(pipeline)) - logger.debug(f"知识库查询结果数量: {len(results)}") + similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding) + + if similarity >= threshold: + similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item}) + + # 按相似度降序排序 + similar_items.sort(key=lambda x: x["similarity"], reverse=True) + + # 应用限制 + results = similar_items[:limit] + logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}") + + except Exception as e: + logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}") + return "" if not return_raw else [] if not results: return "" if not return_raw else [] if return_raw: - return results + # Peewee 模型实例不能直接序列化为 JSON,如果需要原始模型,调用者需要处理 + # 这里返回包含内容和相似度的字典列表 + return [{"content": r["content"], "similarity": r["similarity"]} for r in results] else: # 返回所有找到的内容,用换行分隔 return "\n".join(str(result["content"]) for result in results)