重构数据库交互以使用 Peewee ORM

- 更新数据库连接和模型定义,以便使用 Peewee for SQLite。
- 在消息存储和检索功能中,用 Peewee ORM 查询替换 MongoDB 查询。
- 为 Messages、ThinkingLog 和 OnlineTime 引入了新的模型,以方便结构化数据存储。
- 增强了数据库操作的错误处理和日志记录。
- 删除了过时的 MongoDB 集合管理代码。
- 通过利用 Peewee 内置的查询和数据操作方法来提升性能。
This commit is contained in:
墨梓柒
2025-05-14 22:53:21 +08:00
parent df897a0f42
commit b84cc9240a
15 changed files with 999 additions and 758 deletions

View File

@@ -10,7 +10,10 @@ from PIL import Image
import io import io
import re import re
from ...common.database.database import db # from gradio_client import file
from ...common.database.database_model import Emoji
from ...common.database.database import db as peewee_db
from ...config.config import global_config from ...config.config import global_config
from ..utils.utils_image import image_path_to_base64, image_manager from ..utils.utils_image import image_path_to_base64, image_manager
from ..models.utils_model import LLMRequest from ..models.utils_model import LLMRequest
@@ -143,22 +146,21 @@ class MaiEmoji:
# --- 数据库操作 --- # --- 数据库操作 ---
try: try:
# 准备数据库记录 for emoji collection # 准备数据库记录 for emoji collection
emoji_record = { emotion_str = ",".join(self.emotion) if self.emotion else ""
"filename": self.filename,
"path": self.path, # 存储目录路径
"full_path": self.full_path, # 存储完整文件路径
"embedding": self.embedding,
"description": self.description,
"emotion": self.emotion,
"hash": self.hash,
"format": self.format,
"timestamp": int(self.register_time),
"usage_count": self.usage_count,
"last_used_time": self.last_used_time,
}
# 使用upsert确保记录存在或被更新 Emoji.create(hash=self.hash,
db["emoji"].update_one({"hash": self.hash}, {"$set": emoji_record}, upsert=True) full_path=self.full_path,
format=self.format,
description=self.description,
emotion=emotion_str, # Store as comma-separated string
query_count=0, # Default value
is_registered=True,
is_banned=False, # Default value
record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time
register_time=self.register_time,
usage_count=self.usage_count,
last_used_time=self.last_used_time,
)
logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
@@ -166,14 +168,6 @@ class MaiEmoji:
except Exception as db_error: except Exception as db_error:
logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}") logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}")
# 数据库保存失败,是否需要将文件移回?为了简化,暂时只记录错误
# 可以考虑在这里尝试删除已移动的文件,避免残留
try:
if os.path.exists(self.full_path): # full_path 此时是目标路径
os.remove(self.full_path)
logger.warning(f"[回滚] 已删除移动失败后残留的文件: {self.full_path}")
except Exception as remove_error:
logger.error(f"[错误] 回滚删除文件失败: {remove_error}")
return False return False
except Exception as e: except Exception as e:
@@ -201,10 +195,14 @@ class MaiEmoji:
# 文件删除失败,但仍然尝试删除数据库记录 # 文件删除失败,但仍然尝试删除数据库记录
# 2. 删除数据库记录 # 2. 删除数据库记录
result = db.emoji.delete_one({"hash": self.hash}) try:
deleted_in_db = result.deleted_count > 0 will_delete_emoji = Emoji.get(Emoji.hash == self.hash)
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
except Emoji.DoesNotExist:
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
result = 0 # Indicate no DB record was deleted
if deleted_in_db: if result > 0:
logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})") logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})")
# 3. 标记对象已被删除 # 3. 标记对象已被删除
self.is_deleted = True self.is_deleted = True
@@ -246,44 +244,43 @@ def _emoji_objects_to_readable_list(emoji_objects):
def _to_emoji_objects(data): def _to_emoji_objects(data):
emoji_objects = [] emoji_objects = []
load_errors = 0 load_errors = 0
# data is now an iterable of Peewee Emoji model instances
emoji_data_list = list(data) emoji_data_list = list(data)
for emoji_data in emoji_data_list: for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance
full_path = emoji_data.get("full_path") full_path = emoji_data.full_path
if not full_path: if not full_path:
logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: {emoji_data.get('_id')}") logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}")
load_errors += 1 load_errors += 1
continue # 跳过缺少 full_path 的记录 continue
try: try:
# 使用 full_path 初始化 MaiEmoji 对象
emoji = MaiEmoji(full_path=full_path) emoji = MaiEmoji(full_path=full_path)
# 设置从数据库加载的属性 emoji.hash = emoji_data.hash
emoji.hash = emoji_data.get("hash", "")
# 如果 hash 为空,也跳过?取决于业务逻辑
if not emoji.hash: if not emoji.hash:
logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}") logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}")
load_errors += 1 load_errors += 1
continue continue
emoji.description = emoji_data.get("description", "") emoji.description = emoji_data.description
emoji.emotion = emoji_data.get("emotion", []) # Deserialize emotion string from DB to list
emoji.usage_count = emoji_data.get("usage_count", 0) emoji.emotion = emoji_data.emotion.split(',') if emoji_data.emotion else []
# 优先使用 last_used_time否则用 timestamp最后用当前时间 emoji.usage_count = emoji_data.usage_count
last_used = emoji_data.get("last_used_time")
timestamp = emoji_data.get("timestamp")
emoji.last_used_time = (
last_used if last_used is not None else (timestamp if timestamp is not None else time.time())
)
emoji.register_time = timestamp if timestamp is not None else time.time()
emoji.format = emoji_data.get("format", "") # 加载格式
# 不需要再手动设置 path 和 filename__init__ 会自动处理 db_last_used_time = emoji_data.last_used_time
db_register_time = emoji_data.register_time
# If last_used_time from DB is None, use MaiEmoji's initialized register_time or current time
emoji.last_used_time = db_last_used_time if db_last_used_time is not None else emoji.register_time
# If register_time from DB is None, use MaiEmoji's initialized register_time (which is time.time())
emoji.register_time = db_register_time if db_register_time is not None else emoji.register_time
emoji.format = emoji_data.format
emoji_objects.append(emoji) emoji_objects.append(emoji)
except ValueError as ve: # 捕获 __init__ 可能的错误 except ValueError as ve:
logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}") logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}")
load_errors += 1 load_errors += 1
except Exception as e: except Exception as e:
@@ -385,12 +382,13 @@ class EmojiManager:
"""初始化数据库连接和表情目录""" """初始化数据库连接和表情目录"""
if not self._initialized: if not self._initialized:
try: try:
self._ensure_emoji_collection() # Ensure Peewee database connection is up and tables are created
if not peewee_db.is_closed():
peewee_db.connect(reuse_if_open=True)
Emoji.create_table(safe=True) # Ensures table exists
_ensure_emoji_dir() _ensure_emoji_dir()
self._initialized = True self._initialized = True
# 更新表情包数量
# 启动时执行一次完整性检查
# await self.check_emoji_file_integrity()
except Exception as e: except Exception as e:
logger.exception(f"初始化表情管理器失败: {e}") logger.exception(f"初始化表情管理器失败: {e}")
@@ -401,33 +399,15 @@ class EmojiManager:
if not self._initialized: if not self._initialized:
raise RuntimeError("EmojiManager not initialized") raise RuntimeError("EmojiManager not initialized")
@staticmethod
def _ensure_emoji_collection():
"""确保emoji集合存在并创建索引
这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。
索引的作用是加快数据库查询速度:
- embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包
- tags字段的普通索引: 加快按标签搜索表情包的速度
- filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
"""
if "emoji" not in db.list_collection_names():
db.create_collection("emoji")
db.emoji.create_index([("embedding", "2dsphere")])
db.emoji.create_index([("filename", 1)], unique=True)
def record_usage(self, emoji_hash: str): def record_usage(self, emoji_hash: str):
"""记录表情使用次数""" """记录表情使用次数"""
try: try:
db.emoji.update_one({"hash": emoji_hash}, {"$inc": {"usage_count": 1}}) emoji_update = Emoji.get(Emoji.hash == emoji_hash)
for emoji in self.emoji_objects: emoji_update.usage_count += 1
if emoji.hash == emoji_hash: emoji_update.last_used_time = time.time() # Update last used time
emoji.usage_count += 1 emoji_update.save() # Persist changes to DB
break except Emoji.DoesNotExist:
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
except Exception as e: except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}") logger.error(f"记录表情使用失败: {str(e)}")
@@ -657,9 +637,10 @@ class EmojiManager:
"""获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects""" """获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects"""
try: try:
self._ensure_db() self._ensure_db()
logger.info("[数据库] 开始加载所有表情包记录...") logger.info("[数据库] 开始加载所有表情包记录 (Peewee)...")
emoji_objects, load_errors = _to_emoji_objects(db.emoji.find()) emoji_peewee_instances = Emoji.select()
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
# 更新内存中的列表和数量 # 更新内存中的列表和数量
self.emoji_objects = emoji_objects self.emoji_objects = emoji_objects
@@ -686,15 +667,16 @@ class EmojiManager:
try: try:
self._ensure_db() self._ensure_db()
query = {}
if emoji_hash: if emoji_hash:
query = {"hash": emoji_hash} query = Emoji.select().where(Emoji.hash == emoji_hash)
else: else:
logger.warning( logger.warning(
"[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。" "[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。"
) )
query = Emoji.select()
emoji_objects, load_errors = _to_emoji_objects(db.emoji.find(query)) emoji_peewee_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
if load_errors > 0: if load_errors > 0:
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。") logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
@@ -908,6 +890,44 @@ class EmojiManager:
logger.error(f"获取表情包描述失败: {str(e)}") logger.error(f"获取表情包描述失败: {str(e)}")
return "", [] return "", []
# async def register_emoji_by_filename(self, filename: str) -> bool:
# if global_config.EMOJI_CHECK:
# prompt = f'''
# 这是一个表情包,请对这个表情包进行审核,标准如下:
# 1. 必须符合"{global_config.EMOJI_CHECK_PROMPT}"的要求
# 2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗
# 3. 不能是任何形式的截图,聊天记录或视频截图
# 4. 不要出现5个以上文字
# 请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容
# '''
# content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
# if content == "否":
# return "", []
# # 分析情感含义
# emotion_prompt = f"""
# 请你识别这个表情包的含义和适用场景给我简短的描述每个描述不要超过15个字
# 这是一个基于这个表情包的描述:'{description}'
# 你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析
# 请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔
# """
# emotions_text, _ = await self.llm_emotion_judge.generate_response_async(emotion_prompt, temperature=0.7)
# # 处理情感列表
# emotions = [e.strip() for e in emotions_text.split(",") if e.strip()]
# # 根据情感标签数量随机选择喵~超过5个选3个超过2个选2个
# if len(emotions) > 5:
# emotions = random.sample(emotions, 3)
# elif len(emotions) > 2:
# emotions = random.sample(emotions, 2)
# return f"[表情包:{description}]", emotions
# except Exception as e:
# logger.error(f"获取表情包描述失败: {str(e)}")
# return "", []
async def register_emoji_by_filename(self, filename: str) -> bool: async def register_emoji_by_filename(self, filename: str) -> bool:
"""读取指定文件名的表情包图片,分析并注册到数据库 """读取指定文件名的表情包图片,分析并注册到数据库

View File

@@ -7,7 +7,7 @@ from src.chat.person_info.relationship_manager import relationship_manager
from src.chat.utils.utils import get_embedding from src.chat.utils.utils import get_embedding
import time import time
from typing import Union, Optional from typing import Union, Optional
from common.database.database import db # from common.database.database import db
from src.chat.utils.utils import get_recent_group_speaker from src.chat.utils.utils import get_recent_group_speaker
from src.manager.mood_manager import mood_manager from src.manager.mood_manager import mood_manager
from src.chat.memory_system.Hippocampus import HippocampusManager from src.chat.memory_system.Hippocampus import HippocampusManager
@@ -15,6 +15,9 @@ from src.chat.knowledge.knowledge_lib import qa_manager
from src.chat.focus_chat.expressors.exprssion_learner import expression_learner from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
# import traceback # import traceback
import random import random
import json
import math
from src.common.database.database_model import Knowledges
logger = get_logger("prompt") logger = get_logger("prompt")
@@ -69,7 +72,7 @@ def init_prompt():
你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt}{reply_style1} 你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt}{reply_style1}
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}{prompt_ger} 尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}{prompt_ger}
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。 请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)只输出回复内容。 请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情at或 @等 )。只输出回复内容。
{moderation_prompt} {moderation_prompt}
不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。只输出回复内容""", 不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。只输出回复内容""",
"reasoning_prompt_main", "reasoning_prompt_main",
@@ -439,30 +442,6 @@ class PromptBuilder:
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 1. 先从LLM获取主题类似于记忆系统的做法 # 1. 先从LLM获取主题类似于记忆系统的做法
topics = [] topics = []
# try:
# # 先尝试使用记忆系统的方法获取主题
# hippocampus = HippocampusManager.get_instance()._hippocampus
# topic_num = min(5, max(1, int(len(message) * 0.1)))
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
# # 提取关键词
# topics = re.findall(r"<([^>]+)>", topics_response[0])
# if not topics:
# topics = []
# else:
# topics = [
# topic.strip()
# for topic in ",".join(topics).replace("", ",").replace("、", ",").replace(" ", ",").split(",")
# if topic.strip()
# ]
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
# except Exception as e:
# logger.error(f"从LLM提取主题失败: {str(e)}")
# # 如果LLM提取失败使用jieba分词提取关键词作为备选
# words = jieba.cut(message)
# topics = [word for word in words if len(word) > 1][:5]
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
# 如果无法提取到主题,直接使用整个消息 # 如果无法提取到主题,直接使用整个消息
if not topics: if not topics:
@@ -572,8 +551,6 @@ class PromptBuilder:
for _i, result in enumerate(results, 1): for _i, result in enumerate(results, 1):
_similarity = result["similarity"] _similarity = result["similarity"]
content = result["content"].strip() content = result["content"].strip()
# 调试:为内容添加序号和相似度信息
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n" related_info += f"{content}\n"
related_info += "\n" related_info += "\n"
@@ -602,14 +579,14 @@ class PromptBuilder:
return related_info return related_info
else: else:
logger.debug("从LPMM知识库获取知识失败使用旧版数据库进行检索") logger.debug("从LPMM知识库获取知识失败使用旧版数据库进行检索")
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38) knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
related_info += knowledge_from_old related_info += knowledge_from_old
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return related_info return related_info
except Exception as e: except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}") logger.error(f"获取知识库内容时发生异常: {str(e)}")
try: try:
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38) knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
related_info += knowledge_from_old related_info += knowledge_from_old
logger.debug( logger.debug(
f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}" f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}"
@@ -625,69 +602,69 @@ class PromptBuilder:
) -> Union[str, list]: ) -> Union[str, list]:
if not query_embedding: if not query_embedding:
return "" if not return_raw else [] return "" if not return_raw else []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{
"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]},
]
},
]
},
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
}
},
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
{
"$match": {
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1}},
]
results = list(db.knowledges.aggregate(pipeline)) results_with_similarity = []
logger.debug(f"知识库查询结果数量: {len(results)}") try:
# Fetch all knowledge entries
# This might be inefficient for very large databases.
# Consider strategies like FAISS or other vector search libraries if performance becomes an issue.
all_knowledges = Knowledges.select()
if not results: if not all_knowledges:
return "" if not return_raw else []
query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding))
if query_embedding_magnitude == 0: # Avoid division by zero
return "" if not return_raw else []
for knowledge_item in all_knowledges:
try:
db_embedding_str = knowledge_item.embedding
db_embedding = json.loads(db_embedding_str)
if len(db_embedding) != len(query_embedding):
logger.warning(f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping.")
continue
# Calculate Cosine Similarity
dot_product = sum(q * d for q, d in zip(query_embedding, db_embedding))
db_embedding_magnitude = math.sqrt(sum(x * x for x in db_embedding))
if db_embedding_magnitude == 0: # Avoid division by zero
similarity = 0.0
else:
similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude)
if similarity >= threshold:
results_with_similarity.append({
"content": knowledge_item.content,
"similarity": similarity
})
except json.JSONDecodeError:
logger.error(f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}")
except Exception as e:
logger.error(f"Error processing knowledge item: {e}")
# Sort by similarity in descending order
results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True)
# Limit results
limited_results = results_with_similarity[:limit]
logger.debug(f"知识库查询结果数量 (after Peewee processing): {len(limited_results)}")
if not limited_results:
return "" if not return_raw else [] return "" if not return_raw else []
if return_raw: if return_raw:
return results return limited_results
else: else:
# 返回所有找到的内容,用换行分隔 return "\n".join(str(result["content"]) for result in limited_results)
return "\n".join(str(result["content"]) for result in results)
except Exception as e:
logger.error(f"Error querying Knowledges with Peewee: {e}")
return "" if not return_raw else []
def weighted_sample_no_replacement(items, weights, k) -> list: def weighted_sample_no_replacement(items, weights, k) -> list:

View File

@@ -10,7 +10,7 @@ import jieba
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from collections import Counter from collections import Counter
from ...common.database.database import db from ...common.database.database import memory_db as db
from ...chat.models.utils_model import LLMRequest from ...chat.models.utils_model import LLMRequest
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器 from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器

View File

@@ -6,6 +6,7 @@ from typing import Dict, Optional
from ...common.database.database import db from ...common.database.database import db
from ...common.database.database_model import ChatStreams # 新增导入
from maim_message import GroupInfo, UserInfo from maim_message import GroupInfo, UserInfo
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
@@ -82,7 +83,13 @@ class ChatManager:
def __init__(self): def __init__(self):
if not self._initialized: if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self._ensure_collection() try:
db.connect(reuse_if_open=True)
# 确保 ChatStreams 表存在
db.create_tables([ChatStreams], safe=True)
except Exception as e:
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
self._initialized = True self._initialized = True
# 在事件循环中启动初始化 # 在事件循环中启动初始化
# asyncio.create_task(self._initialize()) # asyncio.create_task(self._initialize())
@@ -107,15 +114,6 @@ class ChatManager:
except Exception as e: except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}") logger.error(f"聊天流自动保存失败: {str(e)}")
@staticmethod
def _ensure_collection():
"""确保数据库集合存在并创建索引"""
if "chat_streams" not in db.list_collection_names():
db.create_collection("chat_streams")
# 创建索引
db.chat_streams.create_index([("stream_id", 1)], unique=True)
db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)])
@staticmethod @staticmethod
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str: def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
"""生成聊天流唯一ID""" """生成聊天流唯一ID"""
@@ -151,16 +149,43 @@ class ChatManager:
stream = self.streams[stream_id] stream = self.streams[stream_id]
# 更新用户信息和群组信息 # 更新用户信息和群组信息
stream.update_active_time() stream.update_active_time()
stream = copy.deepcopy(stream) stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
stream.user_info = user_info stream.user_info = user_info
if group_info: if group_info:
stream.group_info = group_info stream.group_info = group_info
return stream return stream
# 检查数据库中是否存在 # 检查数据库中是否存在
data = db.chat_streams.find_one({"stream_id": stream_id}) def _db_find_stream_sync(s_id: str):
if data: return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
stream = ChatStream.from_dict(data)
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
if model_instance:
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
}
stream = ChatStream.from_dict(data_for_from_dict)
# 更新用户信息和群组信息 # 更新用户信息和群组信息
stream.user_info = user_info stream.user_info = user_info
if group_info: if group_info:
@@ -175,7 +200,7 @@ class ChatManager:
group_info=group_info, group_info=group_info,
) )
except Exception as e: except Exception as e:
logger.error(f"创建聊天流失败: {e}") logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
raise e raise e
# 保存到内存和数据库 # 保存到内存和数据库
@@ -205,15 +230,38 @@ class ChatManager:
elif stream.user_info and stream.user_info.user_nickname: elif stream.user_info and stream.user_info.user_nickname:
return f"{stream.user_info.user_nickname}的私聊" return f"{stream.user_info.user_nickname}的私聊"
else: else:
# 如果没有群名或用户昵称,返回 None 或其他默认值
return None return None
@staticmethod @staticmethod
async def _save_stream(stream: ChatStream): async def _save_stream(stream: ChatStream):
"""保存聊天流到数据库""" """保存聊天流到数据库"""
if not stream.saved: if not stream.saved:
db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True) stream_data_dict = stream.to_dict()
def _db_save_stream_sync(s_data_dict: dict):
user_info_d = s_data_dict.get("user_info")
group_info_d = s_data_dict.get("group_info")
fields_to_save = {
"platform": s_data_dict["platform"],
"create_time": s_data_dict["create_time"],
"last_active_time": s_data_dict["last_active_time"],
"user_platform": user_info_d["platform"] if user_info_d else "",
"user_id": user_info_d["user_id"] if user_info_d else "",
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
"group_platform": group_info_d["platform"] if group_info_d else "",
"group_id": group_info_d["group_id"] if group_info_d else "",
"group_name": group_info_d["group_name"] if group_info_d else "",
}
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
try:
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
stream.saved = True stream.saved = True
except Exception as e:
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
async def _save_all_streams(self): async def _save_all_streams(self):
"""保存所有聊天流""" """保存所有聊天流"""
@@ -222,10 +270,44 @@ class ChatManager:
async def load_all_streams(self): async def load_all_streams(self):
"""从数据库加载所有聊天流""" """从数据库加载所有聊天流"""
all_streams = db.chat_streams.find({})
for data in all_streams: def _db_load_all_streams_sync():
loaded_streams_data = []
for model_instance in ChatStreams.select():
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if model_instance.group_id:
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
}
loaded_streams_data.append(data_for_from_dict)
return loaded_streams_data
try:
all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync)
self.streams.clear()
for data in all_streams_data_list:
stream = ChatStream.from_dict(data) stream = ChatStream.from_dict(data)
stream.saved = True
self.streams[stream.stream_id] = stream self.streams[stream.stream_id] = stream
except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
# 创建全局单例 # 创建全局单例

View File

@@ -12,7 +12,8 @@ import base64
from PIL import Image from PIL import Image
import io import io
import os import os
from ...common.database.database import db from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
from ...config.config import global_config from ...config.config import global_config
from rich.traceback import install from rich.traceback import install
@@ -85,8 +86,6 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64," f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}" f"{image_base64[:10]}...{image_base64[-10:]}"
) )
# if isinstance(content, str) and len(content) > 100:
# payload["messages"][0]["content"] = content[:100]
return payload return payload
@@ -134,13 +133,11 @@ class LLMRequest:
def _init_database(): def _init_database():
"""初始化数据库集合""" """初始化数据库集合"""
try: try:
# 创建llm_usage集合的索引 # 使用 Peewee 创建表safe=True 表示如果表已存在则不会抛出错误
db.llm_usage.create_index([("timestamp", 1)]) db.create_tables([LLMUsage], safe=True)
db.llm_usage.create_index([("model_name", 1)]) logger.info("LLMUsage 表已初始化/确保存在。")
db.llm_usage.create_index([("user_id", 1)])
db.llm_usage.create_index([("request_type", 1)])
except Exception as e: except Exception as e:
logger.error(f"创建数据库索引失败: {str(e)}") logger.error(f"创建 LLMUsage 表失败: {str(e)}")
def _record_usage( def _record_usage(
self, self,
@@ -165,19 +162,19 @@ class LLMRequest:
request_type = self.request_type request_type = self.request_type
try: try:
usage_data = { # 使用 Peewee 模型创建记录
"model_name": self.model_name, LLMUsage.create(
"user_id": user_id, model_name=self.model_name,
"request_type": request_type, user_id=user_id,
"endpoint": endpoint, request_type=request_type,
"prompt_tokens": prompt_tokens, endpoint=endpoint,
"completion_tokens": completion_tokens, prompt_tokens=prompt_tokens,
"total_tokens": total_tokens, completion_tokens=completion_tokens,
"cost": self._calculate_cost(prompt_tokens, completion_tokens), total_tokens=total_tokens,
"status": "success", cost=self._calculate_cost(prompt_tokens, completion_tokens),
"timestamp": datetime.now(), status="success",
} timestamp=datetime.now(), # Peewee 会处理 DateTimeField
db.llm_usage.insert_one(usage_data) )
logger.trace( logger.trace(
f"Token使用情况 - 模型: {self.model_name}, " f"Token使用情况 - 模型: {self.model_name}, "
f"用户: {user_id}, 类型: {request_type}, " f"用户: {user_id}, 类型: {request_type}, "

View File

@@ -1,5 +1,6 @@
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from ...common.database.database import db from ...common.database.database import db
from ...common.database.database_model import PersonInfo # 新增导入
import copy import copy
import hashlib import hashlib
from typing import Any, Callable, Dict from typing import Any, Callable, Dict
@@ -16,7 +17,7 @@ matplotlib.use("Agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from pathlib import Path from pathlib import Path
import pandas as pd import pandas as pd
import json import json # 新增导入
import re import re
@@ -43,17 +44,13 @@ person_info_default = {
"platform": None, "platform": None,
"user_id": None, "user_id": None,
"nickname": None, "nickname": None,
# "age" : 0,
"relationship_value": 0, "relationship_value": 0,
# "saved" : True,
# "impression" : None,
# "gender" : Unkown,
"konw_time": 0, "konw_time": 0,
"msg_interval": 2000, "msg_interval": 2000,
"msg_interval_list": [], "msg_interval_list": [], # 将作为 JSON 字符串存储在 Peewee 的 TextField
"user_cardname": None, # 添加群名片 "user_cardname": None, # 注意:此字段不在 PersonInfo Peewee 模型中
"user_avatar": None, # 添加头像信息例如URL或标识符 "user_avatar": None, # 注意:此字段不在 PersonInfo Peewee 模型中
} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项 }
class PersonInfoManager: class PersonInfoManager:
@@ -64,21 +61,26 @@ class PersonInfoManager:
max_tokens=256, max_tokens=256,
request_type="qv_name", request_type="qv_name",
) )
if "person_info" not in db.list_collection_names(): try:
db.create_collection("person_info") db.connect(reuse_if_open=True)
db.person_info.create_index("person_id", unique=True) db.create_tables([PersonInfo], safe=True)
except Exception as e:
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
# 初始化时读取所有person_name # 初始化时读取所有person_name
cursor = db.person_info.find({"person_name": {"$exists": True}}, {"person_id": 1, "person_name": 1, "_id": 0}) try:
for doc in cursor: for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where(
if doc.get("person_name"): PersonInfo.person_name.is_null(False)
self.person_name_list[doc["person_id"]] = doc["person_name"] ):
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称") if record.person_name:
self.person_name_list[record.person_id] = record.person_name
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
except Exception as e:
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
@staticmethod @staticmethod
def get_person_id(platform: str, user_id: int): def get_person_id(platform: str, user_id: int):
"""获取唯一id""" """获取唯一id"""
# 如果platform中存在-,就截取-后面的部分
if "-" in platform: if "-" in platform:
platform = platform.split("-")[1] platform = platform.split("-")[1]
@@ -86,13 +88,17 @@ class PersonInfoManager:
key = "_".join(components) key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest() return hashlib.md5(key.encode()).hexdigest()
def is_person_known(self, platform: str, user_id: int): async def is_person_known(self, platform: str, user_id: int):
"""判断是否认识某人""" """判断是否认识某人"""
person_id = self.get_person_id(platform, user_id) person_id = self.get_person_id(platform, user_id)
document = db.person_info.find_one({"person_id": person_id})
if document: def _db_check_known_sync(p_id: str):
return True return PersonInfo.get_or_none(PersonInfo.person_id == p_id) is not None
else:
try:
return await asyncio.to_thread(_db_check_known_sync, person_id)
except Exception as e:
logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}")
return False return False
@staticmethod @staticmethod
@@ -103,73 +109,111 @@ class PersonInfoManager:
return return
_person_info_default = copy.deepcopy(person_info_default) _person_info_default = copy.deepcopy(person_info_default)
_person_info_default["person_id"] = person_id model_fields = PersonInfo._meta.fields.keys()
final_data = {"person_id": person_id}
if data: if data:
for key in _person_info_default: for key, value in data.items():
if key != "person_id" and key in data: if key in model_fields:
_person_info_default[key] = data[key] final_data[key] = value
db.person_info.insert_one(_person_info_default) for key, default_value in _person_info_default.items():
if key in model_fields and key not in final_data:
final_data[key] = default_value
if "msg_interval_list" in final_data and isinstance(final_data["msg_interval_list"], list):
final_data["msg_interval_list"] = json.dumps(final_data["msg_interval_list"])
elif "msg_interval_list" not in final_data and "msg_interval_list" in model_fields:
final_data["msg_interval_list"] = json.dumps([])
def _db_create_sync(p_data: dict):
try:
PersonInfo.create(**p_data)
return True
except Exception as e:
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (Peewee): {e}")
return False
await asyncio.to_thread(_db_create_sync, final_data)
async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None): async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None):
"""更新某一个字段,会补全""" """更新某一个字段,会补全"""
if field_name not in person_info_default.keys(): if field_name not in PersonInfo._meta.fields:
logger.debug(f"更新'{field_name}'失败,未定义的字段") if field_name in person_info_default:
logger.debug(f"更新'{field_name}'跳过,字段存在于默认配置但不在 PersonInfo Peewee 模型中。")
return
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。")
return return
document = db.person_info.find_one({"person_id": person_id}) def _db_update_sync(p_id: str, f_name: str, val):
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if document: if record:
db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}}) if f_name == "msg_interval_list" and isinstance(val, list):
setattr(record, f_name, json.dumps(val))
else: else:
data[field_name] = value setattr(record, f_name, val)
logger.debug(f"更新时{person_id}不存在,已新建") record.save()
await self.create_person_info(person_id, data) return True, False
return False, True
found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, value)
if needs_creation:
logger.debug(f"更新时 {person_id} 不存在,将新建。")
creation_data = data if data is not None else {}
creation_data[field_name] = value
if "platform" not in creation_data or "user_id" not in creation_data:
logger.warning(f"{person_id} 创建记录时platform/user_id 可能缺失。")
await self.create_person_info(person_id, creation_data)
@staticmethod @staticmethod
async def has_one_field(person_id: str, field_name: str): async def has_one_field(person_id: str, field_name: str):
"""判断是否存在某一个字段""" """判断是否存在某一个字段"""
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1}) if field_name not in PersonInfo._meta.fields:
if document: logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。")
return False
def _db_has_field_sync(p_id: str, f_name: str):
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
return True return True
else: return False
try:
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
except Exception as e:
logger.error(f"检查字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
return False return False
@staticmethod @staticmethod
def _extract_json_from_text(text: str) -> dict: def _extract_json_from_text(text: str) -> dict:
"""从文本中提取JSON数据的高容错方法""" """从文本中提取JSON数据的高容错方法"""
try: try:
# 尝试直接解析
parsed_json = json.loads(text) parsed_json = json.loads(text)
# 如果解析结果是列表,尝试取第一个元素
if isinstance(parsed_json, list): if isinstance(parsed_json, list):
if parsed_json: # 检查列表是否为空 if parsed_json:
parsed_json = parsed_json[0] parsed_json = parsed_json[0]
else: # 如果列表为空,重置为 None走后续逻辑 else:
parsed_json = None parsed_json = None
# 确保解析结果是字典
if isinstance(parsed_json, dict): if isinstance(parsed_json, dict):
return parsed_json return parsed_json
except json.JSONDecodeError: except json.JSONDecodeError:
# 解析失败,继续尝试其他方法
pass pass
except Exception as e: except Exception as e:
logger.warning(f"尝试直接解析JSON时发生意外错误: {e}") logger.warning(f"尝试直接解析JSON时发生意外错误: {e}")
pass # 继续尝试其他方法 pass
# 如果直接解析失败或结果不是字典
try: try:
# 尝试找到JSON对象格式的部分
json_pattern = r"\{[^{}]*\}" json_pattern = r"\{[^{}]*\}"
matches = re.findall(json_pattern, text) matches = re.findall(json_pattern, text)
if matches: if matches:
parsed_obj = json.loads(matches[0]) parsed_obj = json.loads(matches[0])
if isinstance(parsed_obj, dict): # 确保是字典 if isinstance(parsed_obj, dict):
return parsed_obj return parsed_obj
# 如果上面都失败了,尝试提取键值对
nickname_pattern = r'"nickname"[:\s]+"([^"]+)"' nickname_pattern = r'"nickname"[:\s]+"([^"]+)"'
reason_pattern = r'"reason"[:\s]+"([^"]+)"' reason_pattern = r'"reason"[:\s]+"([^"]+)"'
@@ -184,7 +228,6 @@ class PersonInfoManager:
except Exception as e: except Exception as e:
logger.error(f"后备JSON提取失败: {str(e)}") logger.error(f"后备JSON提取失败: {str(e)}")
# 如果所有方法都失败了,返回默认字典
logger.warning(f"无法从文本中提取有效的JSON字典: {text}") logger.warning(f"无法从文本中提取有效的JSON字典: {text}")
return {"nickname": "", "reason": ""} return {"nickname": "", "reason": ""}
@@ -199,9 +242,11 @@ class PersonInfoManager:
old_name = await self.get_value(person_id, "person_name") old_name = await self.get_value(person_id, "person_name")
old_reason = await self.get_value(person_id, "name_reason") old_reason = await self.get_value(person_id, "name_reason")
max_retries = 5 # 最大重试次数 max_retries = 5
current_try = 0 current_try = 0
existing_names = "" existing_names_str = ""
current_name_set = set(self.person_name_list.values())
while current_try < max_retries: while current_try < max_retries:
individuality = Individuality.get_instance() individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(x_person=2, level=1) prompt_personality = individuality.get_prompt(x_person=2, level=1)
@@ -216,45 +261,55 @@ class PersonInfoManager:
qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason}" qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason}"
qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸" qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸"
qv_name_prompt += "\n请根据以上用户信息想想你叫他什么比较好不要太浮夸请最好使用用户的qq昵称可以稍作修改"
if existing_names_str:
qv_name_prompt += f"\n请注意,以下名称已被你尝试过或已知存在,请避免:{existing_names_str}\n"
if len(current_name_set) < 50 and current_name_set:
qv_name_prompt += f"已知的其他昵称有: {', '.join(list(current_name_set)[:10])}等。\n"
qv_name_prompt += (
"\n请根据以上用户信息想想你叫他什么比较好不要太浮夸请最好使用用户的qq昵称可以稍作修改"
)
if existing_names:
qv_name_prompt += f"\n请注意,以下名称已被使用,不要使用以下昵称:{existing_names}\n"
qv_name_prompt += "请用json给出你的想法并给出理由示例如下" qv_name_prompt += "请用json给出你的想法并给出理由示例如下"
qv_name_prompt += """{ qv_name_prompt += """{
"nickname": "昵称", "nickname": "昵称",
"reason": "理由" "reason": "理由"
}""" }"""
# logger.debug(f"取名提示词:{qv_name_prompt}")
response = await self.qv_name_llm.generate_response(qv_name_prompt) response = await self.qv_name_llm.generate_response(qv_name_prompt)
logger.trace(f"取名提示词:{qv_name_prompt}\n取名回复:{response}") logger.trace(f"取名提示词:{qv_name_prompt}\n取名回复:{response}")
result = self._extract_json_from_text(response[0]) result = self._extract_json_from_text(response[0])
if not result["nickname"]: if not result or not result.get("nickname"):
logger.error("生成的昵称为空,重试中...") logger.error("生成的昵称为空或结果格式不正确,重试中...")
current_try += 1 current_try += 1
continue continue
# 检查生成的昵称是否已存在 generated_nickname = result["nickname"]
if result["nickname"] not in self.person_name_list.values():
# 更新数据库和内存中的列表
await self.update_one_field(person_id, "person_name", result["nickname"])
# await self.update_one_field(person_id, "nickname", user_nickname)
# await self.update_one_field(person_id, "avatar", user_avatar)
await self.update_one_field(person_id, "name_reason", result["reason"])
self.person_name_list[person_id] = result["nickname"] is_duplicate = False
# logger.debug(f"用户 {person_id} 的名称已更新为 {result['nickname']},原因:{result['reason']}") if generated_nickname in current_name_set:
is_duplicate = True
else:
def _db_check_name_exists_sync(name_to_check):
return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
is_duplicate = True
current_name_set.add(generated_nickname)
if not is_duplicate:
await self.update_one_field(person_id, "person_name", generated_nickname)
await self.update_one_field(person_id, "name_reason", result.get("reason", "未提供理由"))
self.person_name_list[person_id] = generated_nickname
return result return result
else: else:
existing_names += f"{result['nickname']}" if existing_names_str:
existing_names_str += ""
logger.debug(f"生成的昵称 {result['nickname']} 已存在,重试中...") existing_names_str += generated_nickname
logger.debug(f"生成的昵称 {generated_nickname} 已存在,重试中...")
current_try += 1 current_try += 1
logger.error(f"{max_retries}次尝试后仍未能生成唯一昵称") logger.error(f"{max_retries}次尝试后仍未能生成唯一昵称 for {person_id}")
return None return None
@staticmethod @staticmethod
@@ -264,30 +319,56 @@ class PersonInfoManager:
logger.debug("删除失败person_id 不能为空") logger.debug("删除失败person_id 不能为空")
return return
result = db.person_info.delete_one({"person_id": person_id}) def _db_delete_sync(p_id: str):
if result.deleted_count > 0: try:
logger.debug(f"删除成功person_id={person_id}") query = PersonInfo.delete().where(PersonInfo.person_id == p_id)
deleted_count = query.execute()
return deleted_count
except Exception as e:
logger.error(f"删除 PersonInfo {p_id} 失败 (Peewee): {e}")
return 0
deleted_count = await asyncio.to_thread(_db_delete_sync, person_id)
if deleted_count > 0:
logger.debug(f"删除成功person_id={person_id} (Peewee)")
else: else:
logger.debug(f"删除失败:未找到 person_id={person_id}") logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)")
@staticmethod @staticmethod
async def get_value(person_id: str, field_name: str): async def get_value(person_id: str, field_name: str):
"""获取指定person_id文档的字段值若不存在该字段则返回该字段的全局默认值""" """获取指定person_id文档的字段值若不存在该字段则返回该字段的全局默认值"""
if not person_id: if not person_id:
logger.debug("get_value获取失败person_id不能为空") logger.debug("get_value获取失败person_id不能为空")
return person_info_default.get(field_name)
if field_name not in PersonInfo._meta.fields:
if field_name in person_info_default:
logger.trace(f"字段'{field_name}'不在Peewee模型中但存在于默认配置中。返回配置默认值。")
return copy.deepcopy(person_info_default[field_name])
logger.debug(f"get_value获取失败字段'{field_name}'未在Peewee模型和默认配置中定义。")
return None return None
if field_name not in person_info_default: def _db_get_value_sync(p_id: str, f_name: str):
logger.debug(f"get_value获取失败字段'{field_name}'未定义") record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
val = getattr(record, f_name)
if f_name == "msg_interval_list" and isinstance(val, str):
try:
return json.loads(val)
except json.JSONDecodeError:
logger.warning(f"无法解析 {p_id} 的 msg_interval_list JSON: {val}")
return copy.deepcopy(person_info_default.get(f_name, []))
return val
return None return None
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1}) value = await asyncio.to_thread(_db_get_value_sync, person_id, field_name)
if document and field_name in document: if value is not None:
return document[field_name] return value
else: else:
default_value = copy.deepcopy(person_info_default[field_name]) default_value = copy.deepcopy(person_info_default.get(field_name))
logger.trace(f"获取{person_id}{field_name}失败,已返回默认值{default_value}") logger.trace(f"获取{person_id}{field_name}失败或值为None,已返回默认值{default_value} (Peewee)")
return default_value return default_value
@staticmethod @staticmethod
@@ -297,93 +378,82 @@ class PersonInfoManager:
logger.debug("get_values获取失败person_id不能为空") logger.debug("get_values获取失败person_id不能为空")
return {} return {}
# 检查所有字段是否有效
for field in field_names:
if field not in person_info_default:
logger.debug(f"get_values获取失败字段'{field}'未定义")
return {}
# 构建查询投影(所有字段都有效才会执行到这里)
projection = {field: 1 for field in field_names}
document = db.person_info.find_one({"person_id": person_id}, projection)
result = {} result = {}
for field in field_names:
result[field] = copy.deepcopy( def _db_get_record_sync(p_id: str):
document.get(field, person_info_default[field]) if document else person_info_default[field] return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
)
record = await asyncio.to_thread(_db_get_record_sync, person_id)
for field_name in field_names:
if field_name not in PersonInfo._meta.fields:
if field_name in person_info_default:
result[field_name] = copy.deepcopy(person_info_default[field_name])
logger.trace(f"字段'{field_name}'不在Peewee模型中使用默认配置值。")
else:
logger.debug(f"get_values查询失败字段'{field_name}'未在Peewee模型和默认配置中定义。")
result[field_name] = None
continue
if record:
value = getattr(record, field_name)
if field_name == "msg_interval_list" and isinstance(value, str):
try:
result[field_name] = json.loads(value)
except json.JSONDecodeError:
logger.warning(f"无法解析 {person_id} 的 msg_interval_list JSON: {value}")
result[field_name] = copy.deepcopy(person_info_default.get(field_name, []))
elif value is not None:
result[field_name] = value
else:
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
else:
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
return result return result
@staticmethod @staticmethod
async def del_all_undefined_field(): async def del_all_undefined_field():
"""删除所有项里的未定义字段""" """删除所有项里的未定义字段 - 对于Peewee (SQL),此操作通常不适用,因为模式是固定的。"""
# 获取所有已定义的字段名 logger.info("del_all_undefined_field: 对于使用Peewee的SQL数据库此操作通常不适用或不需要因为表结构是预定义的。")
defined_fields = set(person_info_default.keys())
try:
# 遍历集合中的所有文档
for document in db.person_info.find({}):
# 找出文档中未定义的字段
undefined_fields = set(document.keys()) - defined_fields - {"_id"}
if undefined_fields:
# 构建更新操作,使用$unset删除未定义字段
update_result = db.person_info.update_one(
{"_id": document["_id"]}, {"$unset": {field: 1 for field in undefined_fields}}
)
if update_result.modified_count > 0:
logger.debug(f"已清理文档 {document['_id']} 的未定义字段: {undefined_fields}")
return
except Exception as e:
logger.error(f"清理未定义字段时出错: {e}")
return return
@staticmethod @staticmethod
async def get_specific_value_list( async def get_specific_value_list(
field_name: str, field_name: str,
way: Callable[[Any], bool], # 接受任意类型值 way: Callable[[Any], bool],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
获取满足条件的字段值字典 获取满足条件的字段值字典
Args:
field_name: 目标字段名
way: 判断函数 (value: Any) -> bool
Returns:
{person_id: value} | {}
Example:
# 查找所有nickname包含"admin"的用户
result = manager.specific_value_list(
"nickname",
lambda x: "admin" in x.lower()
)
""" """
if field_name not in person_info_default: if field_name not in PersonInfo._meta.fields:
logger.error(f"字段检查失败:'{field_name}'未定义") logger.error(f"字段检查失败:'{field_name}'在 PersonInfo Peewee 模型中定义")
return {} return {}
def _db_get_specific_sync(f_name: str):
found_results = {}
try: try:
result = {} for record in PersonInfo.select(PersonInfo.person_id, getattr(PersonInfo, f_name)):
for doc in db.person_info.find({field_name: {"$exists": True}}, {"person_id": 1, field_name: 1, "_id": 0}): value = getattr(record, f_name)
if f_name == "msg_interval_list" and isinstance(value, str):
try: try:
value = doc[field_name] processed_value = json.loads(value)
if way(value): except json.JSONDecodeError:
result[doc["person_id"]] = value logger.warning(f"跳过记录 {record.person_id},无法解析 msg_interval_list: {value}")
except (KeyError, TypeError, ValueError) as e:
logger.debug(f"记录{doc.get('person_id')}处理失败: {str(e)}")
continue continue
else:
processed_value = value
return result if way(processed_value):
found_results[record.person_id] = processed_value
except Exception as e_query:
logger.error(f"数据库查询失败 (Peewee specific_value_list for {f_name}): {str(e_query)}", exc_info=True)
return found_results
try:
return await asyncio.to_thread(_db_get_specific_sync, field_name)
except Exception as e: except Exception as e:
logger.error(f"数据库查询失败: {str(e)}", exc_info=True) logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True)
return {} return {}
async def personal_habit_deduction(self): async def personal_habit_deduction(self):
@@ -391,35 +461,31 @@ class PersonInfoManager:
try: try:
while 1: while 1:
await asyncio.sleep(600) await asyncio.sleep(600)
current_time = datetime.datetime.now() current_time_dt = datetime.datetime.now()
logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}") logger.info(f"个人信息推断启动: {current_time_dt.strftime('%Y-%m-%d %H:%M:%S')}")
# "msg_interval"推断 msg_interval_map_generated = False
msg_interval_map = False msg_interval_lists_map = await self.get_specific_value_list(
msg_interval_lists = await self.get_specific_value_list(
"msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100 "msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100
) )
for person_id, msg_interval_list_ in msg_interval_lists.items():
for person_id, actual_msg_interval_list in msg_interval_lists_map.items():
await asyncio.sleep(0.3) await asyncio.sleep(0.3)
try: try:
time_interval = [] time_interval = []
for t1, t2 in zip(msg_interval_list_, msg_interval_list_[1:]): for t1, t2 in zip(actual_msg_interval_list, actual_msg_interval_list[1:]):
delta = t2 - t1 delta = t2 - t1
if delta > 0: if delta > 0:
time_interval.append(delta) time_interval.append(delta)
time_interval = [t for t in time_interval if 200 <= t <= 8000] time_interval = [t for t in time_interval if 200 <= t <= 8000]
# --- 修改后的逻辑 ---
# 数据量检查 (至少需要 30 条有效间隔,并且足够进行头尾截断)
if len(time_interval) >= 30 + 10: # 至少30条有效+头尾各5条
time_interval.sort()
# 画图(log) - 这部分保留 if len(time_interval) >= 30 + 10:
msg_interval_map = True time_interval.sort()
msg_interval_map_generated = True
log_dir = Path("logs/person_info") log_dir = Path("logs/person_info")
log_dir.mkdir(parents=True, exist_ok=True) log_dir.mkdir(parents=True, exist_ok=True)
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))
# 使用截断前的数据画图,更能反映原始分布
time_series_original = pd.Series(time_interval) time_series_original = pd.Series(time_interval)
plt.hist( plt.hist(
time_series_original, time_series_original,
@@ -441,34 +507,27 @@ class PersonInfoManager:
img_path = log_dir / f"interval_distribution_{person_id[:8]}.png" img_path = log_dir / f"interval_distribution_{person_id[:8]}.png"
plt.savefig(img_path) plt.savefig(img_path)
plt.close() plt.close()
# 画图结束
# 去掉头尾各 5 个数据点
trimmed_interval = time_interval[5:-5] trimmed_interval = time_interval[5:-5]
if trimmed_interval:
# 计算截断后数据的 37% 分位数 msg_interval_val = int(round(np.percentile(trimmed_interval, 37)))
if trimmed_interval: # 确保截断后列表不为空 await self.update_one_field(person_id, "msg_interval", msg_interval_val)
msg_interval = int(round(np.percentile(trimmed_interval, 37))) logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}")
# 更新数据库
await self.update_one_field(person_id, "msg_interval", msg_interval)
logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval}")
else: else:
logger.trace(f"用户{person_id}截断后数据为空无法计算msg_interval") logger.trace(f"用户{person_id}截断后数据为空无法计算msg_interval")
else: else:
logger.trace( logger.trace(
f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)" f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)"
) )
# --- 修改结束 --- except Exception as e_inner:
except Exception as e: logger.trace(f"用户{person_id}消息间隔计算失败: {type(e_inner).__name__}: {str(e_inner)}")
logger.trace(f"用户{person_id}消息间隔计算失败: {type(e).__name__}: {str(e)}")
continue continue
# 其他... if msg_interval_map_generated:
if msg_interval_map:
logger.trace("已保存分布图到: logs/person_info") logger.trace("已保存分布图到: logs/person_info")
current_time = datetime.datetime.now()
logger.trace(f"个人信息推断结束: {current_time.strftime('%Y-%m-%d %H:%M:%S')}") current_time_dt_end = datetime.datetime.now()
logger.trace(f"个人信息推断结束: {current_time_dt_end.strftime('%Y-%m-%d %H:%M:%S')}")
await asyncio.sleep(86400) await asyncio.sleep(86400)
except Exception as e: except Exception as e:
@@ -481,41 +540,27 @@ class PersonInfoManager:
""" """
根据 platform 和 user_id 获取 person_id。 根据 platform 和 user_id 获取 person_id。
如果对应的用户不存在,则使用提供的可选信息创建新用户。 如果对应的用户不存在,则使用提供的可选信息创建新用户。
Args:
platform: 平台标识
user_id: 用户在该平台上的ID
nickname: 用户的昵称 (可选,用于创建新用户)
user_cardname: 用户的群名片 (可选,用于创建新用户)
user_avatar: 用户的头像信息 (可选,用于创建新用户)
Returns:
对应的 person_id。
""" """
person_id = self.get_person_id(platform, user_id) person_id = self.get_person_id(platform, user_id)
# 检查用户是否已存在 def _db_check_exists_sync(p_id: str):
# 使用静态方法 get_person_id因此可以直接调用 db return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
document = db.person_info.find_one({"person_id": person_id})
if document is None: record = await asyncio.to_thread(_db_check_exists_sync, person_id)
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。")
if record is None:
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
initial_data = { initial_data = {
"platform": platform, "platform": platform,
"user_id": user_id, "user_id": str(user_id),
"nickname": nickname, "nickname": nickname,
"konw_time": int(datetime.datetime.now().timestamp()), # 添加初次认识时间 "konw_time": int(datetime.datetime.now().timestamp()),
# 注意:这里没有添加 user_cardname 和 user_avatar因为它们不在 person_info_default 中
# 如果需要存储它们,需要先在 person_info_default 中定义
} }
# 过滤掉值为 None 的初始数据 model_fields = PersonInfo._meta.fields.keys()
initial_data = {k: v for k, v in initial_data.items() if v is not None} filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
# 注意create_person_info 是静态方法 await self.create_person_info(person_id, data=filtered_initial_data)
await PersonInfoManager.create_person_info(person_id, data=initial_data) logger.debug(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
# 创建后,可以考虑立即为其取名,但这可能会增加延迟
# await self.qv_person_name(person_id, nickname, user_cardname, user_avatar)
logger.debug(f"已为 {person_id} 创建新记录,初始数据: {initial_data}")
return person_id return person_id
@@ -525,34 +570,48 @@ class PersonInfoManager:
logger.debug("get_person_info_by_name 获取失败person_name 不能为空") logger.debug("get_person_info_by_name 获取失败person_name 不能为空")
return None return None
# 优先从内存缓存查找 person_id
found_person_id = None found_person_id = None
for pid, name in self.person_name_list.items(): for pid, name_in_cache in self.person_name_list.items():
if name == person_name: if name_in_cache == person_name:
found_person_id = pid found_person_id = pid
break # 找到第一个匹配就停止 break
if not found_person_id: if not found_person_id:
# 如果内存没有,尝试数据库查询(可能内存未及时更新或启动时未加载) def _db_find_by_name_sync(p_name_to_find: str):
document = db.person_info.find_one({"person_name": person_name}) return PersonInfo.get_or_none(PersonInfo.person_name == p_name_to_find)
if document:
found_person_id = document.get("person_id")
else:
logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户")
return None # 数据库也找不到
# 根据找到的 person_id 获取所需信息 record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
if found_person_id: if record:
required_fields = ["person_id", "platform", "user_id", "nickname", "user_cardname", "user_avatar"] found_person_id = record.person_id
person_data = await self.get_values(found_person_id, required_fields) if found_person_id not in self.person_name_list or self.person_name_list[found_person_id] != person_name:
if person_data: # 确保 get_values 成功返回 self.person_name_list[found_person_id] = person_name
return person_data
else: else:
logger.warning(f"找到了 person_id '{found_person_id}' 但获取详细信息失败") logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)")
return None return None
if found_person_id:
required_fields = [
"person_id",
"platform",
"user_id",
"nickname",
"user_cardname",
"user_avatar",
"person_name",
"name_reason",
]
valid_fields_to_get = [f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default]
person_data = await self.get_values(found_person_id, valid_fields_to_get)
if person_data:
final_result = {key: person_data.get(key) for key in required_fields}
return final_result
else: else:
# 这理论上不应该发生,因为上面已经处理了找不到的情况 logger.warning(f"找到了 person_id '{found_person_id}' 但 get_values 返回空 (Peewee)")
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id") return None
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id (Peewee)")
return None return None

View File

@@ -1,9 +1,10 @@
from src.config.config import global_config from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv, MessageSending, Message from src.chat.message_receive.message import MessageRecv, MessageSending, Message
from common.database.database import db from src.common.database.database_model import Messages, ThinkingLog
import time import time
import traceback import traceback
from typing import List from typing import List
import json
class InfoCatcher: class InfoCatcher:
@@ -60,8 +61,6 @@ class InfoCatcher:
def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息 def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息
self.timing_results["sub_heartflow_observe_time"] = obs_duration self.timing_results["sub_heartflow_observe_time"] = obs_duration
# def catch_shf
def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str): def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str):
self.timing_results["sub_heartflow_step_time"] = step_duration self.timing_results["sub_heartflow_step_time"] = step_duration
if len(past_mind) > 1: if len(past_mind) > 1:
@@ -72,25 +71,10 @@ class InfoCatcher:
self.heartflow_data["sub_heartflow_now"] = current_mind self.heartflow_data["sub_heartflow_now"] = current_mind
def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""): def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""):
# if self.response_mode == "heart_flow": # 条件判断不需要了喵~
# self.heartflow_data["prompt"] = prompt
# self.heartflow_data["response"] = response
# self.heartflow_data["model"] = model_name
# elif self.response_mode == "reasoning": # 条件判断不需要了喵~
# self.reasoning_data["thinking_log"] = reasoning_content
# self.reasoning_data["prompt"] = prompt
# self.reasoning_data["response"] = response
# self.reasoning_data["model"] = model_name
# 直接记录信息喵~
self.reasoning_data["thinking_log"] = reasoning_content self.reasoning_data["thinking_log"] = reasoning_content
self.reasoning_data["prompt"] = prompt self.reasoning_data["prompt"] = prompt
self.reasoning_data["response"] = response self.reasoning_data["response"] = response
self.reasoning_data["model"] = model_name self.reasoning_data["model"] = model_name
# 如果 heartflow 数据也需要通用字段,可以取消下面的注释喵~
# self.heartflow_data["prompt"] = prompt
# self.heartflow_data["response"] = response
# self.heartflow_data["model"] = model_name
self.response_text = response self.response_text = response
@@ -102,6 +86,7 @@ class InfoCatcher:
): ):
self.timing_results["make_response_time"] = response_duration self.timing_results["make_response_time"] = response_duration
self.response_time = time.time() self.response_time = time.time()
self.response_messages = []
for msg in response_message: for msg in response_message:
self.response_messages.append(msg) self.response_messages.append(msg)
@@ -112,107 +97,110 @@ class InfoCatcher:
@staticmethod @staticmethod
def get_message_from_db_between_msgs(message_start: Message, message_end: Message): def get_message_from_db_between_msgs(message_start: Message, message_end: Message):
try: try:
# 从数据库中获取消息的时间戳
time_start = message_start.message_info.time time_start = message_start.message_info.time
time_end = message_end.message_info.time time_end = message_end.message_info.time
chat_id = message_start.chat_stream.stream_id chat_id = message_start.chat_stream.stream_id
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}") print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
# 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据 messages_between_query = Messages.select().where(
messages_between = db.messages.find( (Messages.chat_id == chat_id) &
{"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}} (Messages.time > time_start) &
).sort("time", -1) (Messages.time < time_end)
).order_by(Messages.time.desc())
result = list(messages_between) result = list(messages_between_query)
print(f"查询结果数量: {len(result)}") print(f"查询结果数量: {len(result)}")
if result: if result:
print(f"第一条消息时间: {result[0]['time']}") print(f"第一条消息时间: {result[0].time}")
print(f"最后一条消息时间: {result[-1]['time']}") print(f"最后一条消息时间: {result[-1].time}")
return result return result
except Exception as e: except Exception as e:
print(f"获取消息时出错: {str(e)}") print(f"获取消息时出错: {str(e)}")
print(traceback.format_exc())
return [] return []
def get_message_from_db_before_msg(self, message: MessageRecv): def get_message_from_db_before_msg(self, message: MessageRecv):
# 从数据库中获取消息 message_id_val = message.message_info.message_id
message_id = message.message_info.message_id chat_id_val = message.chat_stream.stream_id
chat_id = message.chat_stream.stream_id
# 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据 messages_before_query = Messages.select().where(
messages_before = ( (Messages.chat_id == chat_id_val) &
db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}}) (Messages.message_id < message_id_val)
.sort("time", -1) ).order_by(Messages.time.desc()).limit(self.context_length * 3)
.limit(self.context_length * 3)
) # 获取更多历史信息
return list(messages_before) return list(messages_before_query)
def message_list_to_dict(self, message_list): def message_list_to_dict(self, message_list):
# 存储简化的聊天记录
result = [] result = []
for message in message_list: for msg_item in message_list:
if not isinstance(message, dict): processed_msg_item = msg_item
message = self.message_to_dict(message) if not isinstance(msg_item, dict):
# print(message) processed_msg_item = self.message_to_dict(msg_item)
if not processed_msg_item:
continue
lite_message = { lite_message = {
"time": message["time"], "time": processed_msg_item.get("time"),
"user_nickname": message["user_info"]["user_nickname"], "user_nickname": processed_msg_item.get("user_nickname"),
"processed_plain_text": message["processed_plain_text"], "processed_plain_text": processed_msg_item.get("processed_plain_text"),
} }
result.append(lite_message) result.append(lite_message)
return result return result
@staticmethod @staticmethod
def message_to_dict(message): def message_to_dict(msg_obj):
if not message: if not msg_obj:
return None return None
if isinstance(message, dict): if isinstance(msg_obj, dict):
return message return msg_obj
if isinstance(msg_obj, Messages):
return { return {
# "message_id": message.message_info.message_id, "time": msg_obj.time,
"time": message.message_info.time, "user_id": msg_obj.user_id,
"user_id": message.message_info.user_info.user_id, "user_nickname": msg_obj.user_nickname,
"user_nickname": message.message_info.user_info.user_nickname, "processed_plain_text": msg_obj.processed_plain_text,
"processed_plain_text": message.processed_plain_text,
# "detailed_plain_text": message.detailed_plain_text
} }
if hasattr(msg_obj, 'message_info') and hasattr(msg_obj.message_info, 'user_info'):
return {
"time": msg_obj.message_info.time,
"user_id": msg_obj.message_info.user_info.user_id,
"user_nickname": msg_obj.message_info.user_info.user_nickname,
"processed_plain_text": msg_obj.processed_plain_text,
}
print(f"Warning: message_to_dict received an unhandled type: {type(msg_obj)}")
return {}
def done_catch(self): def done_catch(self):
"""将收集到的信息存储到数据库的 thinking_log 集合中喵~""" """将收集到的信息存储到数据库的 thinking_log 中喵~"""
try: try:
# 将消息对象转换为可序列化的字典喵~ trigger_info_dict = self.message_to_dict(self.trigger_response_message)
response_info_dict = {
thinking_log_data = {
"chat_id": self.chat_id,
"trigger_text": self.trigger_response_text,
"response_text": self.response_text,
"trigger_info": {
"time": self.trigger_response_time,
"message": self.message_to_dict(self.trigger_response_message),
},
"response_info": {
"time": self.response_time, "time": self.response_time,
"message": self.response_messages, "message": self.response_messages,
},
"timing_results": self.timing_results,
"chat_history": self.message_list_to_dict(self.chat_history),
"chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking),
"chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response),
"heartflow_data": self.heartflow_data,
"reasoning_data": self.reasoning_data,
} }
chat_history_list = self.message_list_to_dict(self.chat_history)
chat_history_in_thinking_list = self.message_list_to_dict(self.chat_history_in_thinking)
chat_history_after_response_list = self.message_list_to_dict(self.chat_history_after_response)
# 根据不同的响应模式添加相应的数据喵~ # 现在直接都加上去好了喵~ log_entry = ThinkingLog(
# if self.response_mode == "heart_flow": chat_id=self.chat_id,
# thinking_log_data["mode_specific_data"] = self.heartflow_data trigger_text=self.trigger_response_text,
# elif self.response_mode == "reasoning": response_text=self.response_text,
# thinking_log_data["mode_specific_data"] = self.reasoning_data trigger_info_json=json.dumps(trigger_info_dict) if trigger_info_dict else None,
response_info_json=json.dumps(response_info_dict),
# 将数据插入到 thinking_log 集合中喵~ timing_results_json=json.dumps(self.timing_results),
db.thinking_log.insert_one(thinking_log_data) chat_history_json=json.dumps(chat_history_list),
chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list),
chat_history_after_response_json=json.dumps(chat_history_after_response_list),
heartflow_data_json=json.dumps(self.heartflow_data),
reasoning_data_json=json.dumps(self.reasoning_data)
)
log_entry.save()
return True return True
except Exception as e: except Exception as e:

View File

@@ -5,7 +5,8 @@ from typing import Any, Dict, Tuple, List
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.manager.async_task_manager import AsyncTask from src.manager.async_task_manager import AsyncTask
from ...common.database.database import db from ...common.database.database import db # This db is the Peewee database instance
from ...common.database.database_model import OnlineTime # Import the Peewee model
from src.manager.local_store_manager import local_storage from src.manager.local_store_manager import local_storage
logger = get_module_logger("maibot_statistic") logger = get_module_logger("maibot_statistic")
@@ -39,7 +40,7 @@ class OnlineTimeRecordTask(AsyncTask):
def __init__(self): def __init__(self):
super().__init__(task_name="Online Time Record Task", run_interval=60) super().__init__(task_name="Online Time Record Task", run_interval=60)
self.record_id: str | None = None self.record_id: int | None = None # Changed to int for Peewee's default ID
"""记录ID""" """记录ID"""
self._init_database() # 初始化数据库 self._init_database() # 初始化数据库
@@ -47,53 +48,46 @@ class OnlineTimeRecordTask(AsyncTask):
@staticmethod @staticmethod
def _init_database(): def _init_database():
"""初始化数据库""" """初始化数据库"""
if "online_time" not in db.list_collection_names(): with db.atomic(): # Use atomic operations for schema changes
# 初始化数据库(在线时长) OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
db.create_collection("online_time")
# 创建索引
if ("end_timestamp", 1) not in db.online_time.list_indexes():
db.online_time.create_index([("end_timestamp", 1)])
async def run(self): async def run(self):
try: try:
current_time = datetime.now()
extended_end_time = current_time + timedelta(minutes=1)
if self.record_id: if self.record_id:
# 如果有记录,则更新结束时间 # 如果有记录,则更新结束时间
db.online_time.update_one( query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id)
{"_id": self.record_id}, updated_rows = query.execute()
{ if updated_rows == 0:
"$set": { # Record might have been deleted or ID is stale, try to find/create
"end_timestamp": datetime.now() + timedelta(minutes=1), self.record_id = None # Reset record_id to trigger find/create logic below
}
}, if not self.record_id: # Check again if record_id was reset or initially None
)
else:
# 如果没有记录,检查一分钟以内是否已有记录 # 如果没有记录,检查一分钟以内是否已有记录
current_time = datetime.now() # Look for a record whose end_timestamp is recent enough to be considered ongoing
if recent_record := db.online_time.find_one( recent_record = OnlineTime.select().where(
{"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}} OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))
): ).order_by(OnlineTime.end_timestamp.desc()).first()
if recent_record:
# 如果有记录,则更新结束时间 # 如果有记录,则更新结束时间
self.record_id = recent_record["_id"] self.record_id = recent_record.id
db.online_time.update_one( recent_record.end_timestamp = extended_end_time
{"_id": self.record_id}, recent_record.save()
{
"$set": {
"end_timestamp": current_time + timedelta(minutes=1),
}
},
)
else: else:
# 若没有记录,则插入新的在线时间记录 # 若没有记录,则插入新的在线时间记录
self.record_id = db.online_time.insert_one( new_record = OnlineTime.create(
{ start_timestamp=current_time,
"start_timestamp": current_time, end_timestamp=extended_end_time,
"end_timestamp": current_time + timedelta(minutes=1), )
} self.record_id = new_record.id
).inserted_id
except Exception as e: except Exception as e:
logger.error(f"在线时间记录失败,错误信息:{e}") logger.error(f"在线时间记录失败,错误信息:{e}")
def _format_online_time(online_seconds: int) -> str: def _format_online_time(online_seconds: int) -> str:
""" """
格式化在线时间 格式化在线时间

View File

@@ -9,6 +9,7 @@ import numpy as np
from ...common.database.database import db from ...common.database.database import db
from ...common.database.database_model import Images, ImageDescriptions
from ...config.config import global_config from ...config.config import global_config
from ..models.utils_model import LLMRequest from ..models.utils_model import LLMRequest
@@ -32,40 +33,21 @@ class ImageManager:
def __init__(self): def __init__(self):
if not self._initialized: if not self._initialized:
self._ensure_image_collection()
self._ensure_description_collection()
self._ensure_image_dir() self._ensure_image_dir()
self._initialized = True
self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image") self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image")
try:
db.connect(reuse_if_open=True)
db.create_tables([Images, ImageDescriptions], safe=True)
except Exception as e:
logger.error(f"数据库连接或表创建失败: {e}")
self._initialized = True
def _ensure_image_dir(self): def _ensure_image_dir(self):
"""确保图像存储目录存在""" """确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True) os.makedirs(self.IMAGE_DIR, exist_ok=True)
@staticmethod
def _ensure_image_collection():
"""确保images集合存在并创建索引"""
if "images" not in db.list_collection_names():
db.create_collection("images")
# 删除旧索引
db.images.drop_indexes()
# 创建新的复合索引
db.images.create_index([("hash", 1), ("type", 1)], unique=True)
db.images.create_index([("url", 1)])
db.images.create_index([("path", 1)])
@staticmethod
def _ensure_description_collection():
"""确保image_descriptions集合存在并创建索引"""
if "image_descriptions" not in db.list_collection_names():
db.create_collection("image_descriptions")
# 删除旧索引
db.image_descriptions.drop_indexes()
# 创建新的复合索引
db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
@staticmethod @staticmethod
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]: def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述 """从数据库获取图片描述
@@ -77,8 +59,15 @@ class ImageManager:
Returns: Returns:
Optional[str]: 描述文本如果不存在则返回None Optional[str]: 描述文本如果不存在则返回None
""" """
result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type}) try:
return result["description"] if result else None record = ImageDescriptions.get_or_none(
(ImageDescriptions.hash == image_hash) &
(ImageDescriptions.type == description_type)
)
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
return None
@staticmethod @staticmethod
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None: def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
@@ -90,20 +79,22 @@ class ImageManager:
description_type: 描述类型 ('emoji''image') description_type: 描述类型 ('emoji''image')
""" """
try: try:
db.image_descriptions.update_one( current_timestamp = time.time()
{"hash": image_hash, "type": description_type}, defaults = {
{ 'description': description,
"$set": { 'timestamp': current_timestamp
"description": description,
"timestamp": int(time.time()),
"hash": image_hash, # 确保hash字段存在
"type": description_type, # 确保type字段存在
} }
}, desc_obj, created = ImageDescriptions.get_or_create(
upsert=True, hash=image_hash,
type=description_type,
defaults=defaults
) )
if not created: # 如果记录已存在,则更新
desc_obj.description = description
desc_obj.timestamp = current_timestamp
desc_obj.save()
except Exception as e: except Exception as e:
logger.error(f"保存描述到数据库失败: {str(e)}") logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
async def get_emoji_description(self, image_base64: str) -> str: async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述,带查重和保存功能""" """获取表情包描述,带查重和保存功能"""
@@ -116,18 +107,25 @@ class ImageManager:
# 查询缓存的描述 # 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, "emoji") cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description: if cached_description:
# logger.debug(f"缓存表情包描述: {cached_description}")
return f"[表情包,含义看起来是:{cached_description}]" return f"[表情包,含义看起来是:{cached_description}]"
# 调用AI获取描述 # 调用AI获取描述
if image_format == "gif" or image_format == "GIF": if image_format == "gif" or image_format == "GIF":
image_base64 = self.transform_gif(image_base64) image_base64_processed = self.transform_gif(image_base64)
if image_base64_processed is None:
logger.warning("GIF转换失败无法获取描述")
return "[表情包(GIF处理失败)]"
prompt = "这是一个动态图表情包每一张图代表了动态图的某一帧黑色背景代表透明使用1-2个词描述一下表情包表达的情感和内容简短一些" prompt = "这是一个动态图表情包每一张图代表了动态图的某一帧黑色背景代表透明使用1-2个词描述一下表情包表达的情感和内容简短一些"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg") description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg")
else: else:
prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些" prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
if description is None:
logger.warning("AI未能生成表情包描述")
return "[表情包(描述生成失败)]"
# 再次检查缓存,防止并发写入时重复生成
cached_description = self._get_description_from_db(image_hash, "emoji") cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description: if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}") logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
@@ -136,31 +134,37 @@ class ImageManager:
# 根据配置决定是否保存图片 # 根据配置决定是否保存图片
if global_config.save_emoji: if global_config.save_emoji:
# 生成文件名和路径 # 生成文件名和路径
timestamp = int(time.time()) current_timestamp = time.time()
filename = f"{timestamp}_{image_hash[:8]}.{image_format}" filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")): emoji_dir = os.path.join(self.IMAGE_DIR, "emoji")
os.makedirs(os.path.join(self.IMAGE_DIR, "emoji")) os.makedirs(emoji_dir, exist_ok=True)
file_path = os.path.join(self.IMAGE_DIR, "emoji", filename) file_path = os.path.join(emoji_dir, filename)
try: try:
# 保存文件 # 保存文件
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(image_bytes) f.write(image_bytes)
# 保存到数据库 # 保存到数据库 (Images表)
image_doc = { try:
"hash": image_hash, img_obj = Images.get((Images.hash == image_hash) & (Images.type == "emoji"))
"path": file_path, img_obj.path = file_path
"type": "emoji", img_obj.description = description
"description": description, img_obj.timestamp = current_timestamp
"timestamp": timestamp, img_obj.save()
} except Images.DoesNotExist:
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True) Images.create(
logger.trace(f"保存表情包: {file_path}") hash=image_hash,
path=file_path,
type="emoji",
description=description,
timestamp=current_timestamp,
)
logger.trace(f"保存表情包元数据: {file_path}")
except Exception as e: except Exception as e:
logger.error(f"保存表情包文件失败: {str(e)}") logger.error(f"保存表情包文件或元数据失败: {str(e)}")
# 保存描述到数据库 # 保存描述到数据库 (ImageDescriptions表)
self._save_description_to_db(image_hash, description, "emoji") self._save_description_to_db(image_hash, description, "emoji")
return f"[表情包:{description}]" return f"[表情包:{description}]"
@@ -188,6 +192,11 @@ class ImageManager:
) )
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片(描述生成失败)]"
# 再次检查缓存
cached_description = self._get_description_from_db(image_hash, "image") cached_description = self._get_description_from_db(image_hash, "image")
if cached_description: if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}") logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}")
@@ -195,38 +204,40 @@ class ImageManager:
logger.debug(f"描述是{description}") logger.debug(f"描述是{description}")
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片]"
# 根据配置决定是否保存图片 # 根据配置决定是否保存图片
if global_config.save_pic: if global_config.save_pic:
# 生成文件名和路径 # 生成文件名和路径
timestamp = int(time.time()) current_timestamp = time.time()
filename = f"{timestamp}_{image_hash[:8]}.{image_format}" filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")): image_dir = os.path.join(self.IMAGE_DIR, "image")
os.makedirs(os.path.join(self.IMAGE_DIR, "image")) os.makedirs(image_dir, exist_ok=True)
file_path = os.path.join(self.IMAGE_DIR, "image", filename) file_path = os.path.join(image_dir, filename)
try: try:
# 保存文件 # 保存文件
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(image_bytes) f.write(image_bytes)
# 保存到数据库 # 保存到数据库 (Images表)
image_doc = { try:
"hash": image_hash, img_obj = Images.get((Images.hash == image_hash) & (Images.type == "image"))
"path": file_path, img_obj.path = file_path
"type": "image", img_obj.description = description
"description": description, img_obj.timestamp = current_timestamp
"timestamp": timestamp, img_obj.save()
} except Images.DoesNotExist:
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True) Images.create(
logger.trace(f"保存图片: {file_path}") hash=image_hash,
path=file_path,
type="image",
description=description,
timestamp=current_timestamp,
)
logger.trace(f"保存图片元数据: {file_path}")
except Exception as e: except Exception as e:
logger.error(f"保存图片文件失败: {str(e)}") logger.error(f"保存图片文件或元数据失败: {str(e)}")
# 保存描述到数据库 # 保存描述到数据库 (ImageDescriptions表)
self._save_description_to_db(image_hash, description, "image") self._save_description_to_db(image_hash, description, "image")
return f"[图片:{description}]" return f"[图片:{description}]"

View File

@@ -1,5 +1,6 @@
import os import os
from pymongo import MongoClient from pymongo import MongoClient
from peewee import SqliteDatabase
from pymongo.database import Database from pymongo.database import Database
from rich.traceback import install from rich.traceback import install
@@ -57,4 +58,15 @@ class DBWrapper:
# 全局数据库访问点 # 全局数据库访问点
db: Database = DBWrapper() memory_db: Database = DBWrapper()
# 定义数据库文件路径
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
_DB_DIR = os.path.join(ROOT_PATH, "data")
_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db")
# 确保数据库目录存在
os.makedirs(_DB_DIR, exist_ok=True)
# 全局 Peewee SQLite 数据库访问点
db = SqliteDatabase(_DB_FILE)

View File

@@ -1,9 +1,10 @@
from peewee import Model, DoubleField, IntegerField, SqliteDatabase, BooleanField, TextField, FloatField from peewee import Model, DoubleField, IntegerField, BooleanField, TextField, FloatField, DateTimeField
from .database import db
import datetime
# 请在此处定义您的数据库实例。 # 请在此处定义您的数据库实例。
# 您需要取消注释并配置适合您的数据库的部分。 # 您需要取消注释并配置适合您的数据库的部分。
# 例如,对于 SQLite: # 例如,对于 SQLite:
db = SqliteDatabase('my_application.db') # db = SqliteDatabase('MaiBot.db')
# #
# 对于 PostgreSQL: # 对于 PostgreSQL:
# db = PostgresqlDatabase('your_db_name', user='your_user', password='your_password', # db = PostgresqlDatabase('your_db_name', user='your_user', password='your_password',
@@ -69,17 +70,16 @@ class LLMUsage(BaseModel):
""" """
用于存储 API 使用日志数据的模型。 用于存储 API 使用日志数据的模型。
""" """
model_name = TextField() model_name = TextField(index=True) # 添加索引
user_id = TextField() user_id = TextField(index=True) # 添加索引
request_type = TextField() request_type = TextField(index=True) # 添加索引
endpoint = TextField() endpoint = TextField()
prompt_tokens = IntegerField() prompt_tokens = IntegerField()
completion_tokens = IntegerField() completion_tokens = IntegerField()
total_tokens = IntegerField() total_tokens = IntegerField()
cost = DoubleField() cost = DoubleField()
status = TextField() status = TextField()
# timestamp: "$date": "2025-05-01T18:52:50.870Z" (存储为字符串) timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
timestamp = TextField()
class Meta: class Meta:
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
@@ -177,6 +177,8 @@ class OnlineTime(BaseModel):
# timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串) # timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串)
timestamp = TextField() timestamp = TextField()
duration = IntegerField() # 时长,单位分钟 duration = IntegerField() # 时长,单位分钟
start_timestamp = DateTimeField(default=datetime.datetime.now)
end_timestamp = DateTimeField(index=True)
class Meta: class Meta:
# database = db # 继承自 BaseModel # database = db # 继承自 BaseModel
@@ -202,3 +204,39 @@ class PersonInfo(BaseModel):
# database = db # 继承自 BaseModel # database = db # 继承自 BaseModel
table_name = 'person_info' table_name = 'person_info'
class Knowledges(BaseModel):
"""
用于存储知识库条目的模型。
"""
content = TextField() # 知识内容的文本
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
# 可以添加其他元数据字段,如 source, create_time 等
class Meta:
# database = db # 继承自 BaseModel
table_name = 'knowledges'
class ThinkingLog(BaseModel):
chat_id = TextField(index=True)
trigger_text = TextField(null=True)
response_text = TextField(null=True)
# Store complex dicts/lists as JSON strings
trigger_info_json = TextField(null=True)
response_info_json = TextField(null=True)
timing_results_json = TextField(null=True)
chat_history_json = TextField(null=True)
chat_history_in_thinking_json = TextField(null=True)
chat_history_after_response_json = TextField(null=True)
heartflow_data_json = TextField(null=True)
reasoning_data_json = TextField(null=True)
# Add a timestamp for the log entry itself
# Ensure you have: from peewee import DateTimeField
# And: import datetime
created_at = DateTimeField(default=datetime.datetime.now)
class Meta:
table_name = 'thinking_logs'

View File

@@ -1,11 +1,19 @@
from common.database.database import db from src.common.database.database_model import Messages # 更改导入
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
import traceback import traceback
from typing import List, Any, Optional from typing import List, Any, Optional
from peewee import Model # 添加 Peewee Model 导入
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
def _model_to_dict(model_instance: Model) -> dict[str, Any]:
"""
将 Peewee 模型实例转换为字典。
"""
return model_instance.__data__
def find_messages( def find_messages(
message_filter: dict[str, Any], message_filter: dict[str, Any],
sort: Optional[List[tuple[str, int]]] = None, sort: Optional[List[tuple[str, int]]] = None,
@@ -16,39 +24,72 @@ def find_messages(
根据提供的过滤器、排序和限制条件查找消息。 根据提供的过滤器、排序和限制条件查找消息。
Args: Args:
message_filter: MongoDB 查询过滤器。 message_filter: 查询过滤器字典,键为模型字段名,值为期望值
sort: MongoDB 排序条件列表,例如 [('time', 1)]。仅在 limit 为 0 时生效。 sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。
limit: 返回的最大文档数0表示不限制。 limit: 返回的最大文档数0表示不限制。
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest' limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'
Returns: Returns:
消息文档列表,如果出错则返回空列表。 消息字典列表,如果出错则返回空列表。
""" """
try: try:
query = db.messages.find(message_filter) query = Messages.select()
# 应用过滤器
if message_filter:
conditions = []
for key, value in message_filter.items():
if hasattr(Messages, key):
conditions.append(getattr(Messages, key) == value)
else:
logger.warning(
f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。"
)
if conditions:
# 使用 *conditions 将所有条件以 AND 连接
query = query.where(*conditions)
if limit > 0: if limit > 0:
if limit_mode == "earliest": if limit_mode == "earliest":
# 获取时间最早的 limit 条记录,已经是正序 # 获取时间最早的 limit 条记录,已经是正序
query = query.sort([("time", 1)]).limit(limit) query = query.order_by(Messages.time.asc()).limit(limit)
results = list(query) peewee_results = list(query)
else: # 默认为 'latest' else: # 默认为 'latest'
# 获取时间最晚的 limit 条记录 # 获取时间最晚的 limit 条记录
query = query.sort([("time", -1)]).limit(limit) query = query.order_by(Messages.time.desc()).limit(limit)
latest_results = list(query) latest_results_peewee = list(query)
# 将结果按时间正序排列 # 将结果按时间正序排列
# 假设消息文档中总是有 'time' 字段且可排序 peewee_results = sorted(
results = sorted(latest_results, key=lambda msg: msg.get("time")) latest_results_peewee, key=lambda msg: msg.time
)
else: else:
# limit 为 0 时,应用传入的 sort 参数 # limit 为 0 时,应用传入的 sort 参数
if sort: if sort:
query = query.sort(sort) peewee_sort_terms = []
results = list(query) for field_name, direction in sort:
if hasattr(Messages, field_name):
field = getattr(Messages, field_name)
if direction == 1: # ASC
peewee_sort_terms.append(field.asc())
elif direction == -1: # DESC
peewee_sort_terms.append(field.desc())
else:
logger.warning(
f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。"
)
else:
logger.warning(
f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。"
)
if peewee_sort_terms:
query = query.order_by(*peewee_sort_terms)
peewee_results = list(query)
results = [_model_to_dict(msg) for msg in peewee_results]
return results return results
except Exception as e: except Exception as e:
log_message = ( log_message = (
f"查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
+ traceback.format_exc() + traceback.format_exc()
) )
logger.error(log_message) logger.error(log_message)
@@ -60,18 +101,35 @@ def count_messages(message_filter: dict[str, Any]) -> int:
根据提供的过滤器计算消息数量。 根据提供的过滤器计算消息数量。
Args: Args:
message_filter: MongoDB 查询过滤器。 message_filter: 查询过滤器字典,键为模型字段名,值为期望值
Returns: Returns:
符合条件的消息数量,如果出错则返回 0。 符合条件的消息数量,如果出错则返回 0。
""" """
try: try:
count = db.messages.count_documents(message_filter) query = Messages.select()
# 应用过滤器
if message_filter:
conditions = []
for key, value in message_filter.items():
if hasattr(Messages, key):
conditions.append(getattr(Messages, key) == value)
else:
logger.warning(
f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。"
)
if conditions:
query = query.where(*conditions)
count = query.count()
return count return count
except Exception as e: except Exception as e:
log_message = f"计数消息失败 (message_filter={message_filter}): {e}\n" + traceback.format_exc() log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n" + traceback.format_exc()
logger.error(log_message) logger.error(log_message)
return 0 return 0
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 # 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
# 注意:对于 Peewee插入操作通常是 Messages.create(...) 或 instance.save()。
# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。

View File

@@ -10,7 +10,7 @@ from src.experimental.PFC.chat_states import (
create_new_message_notification, create_new_message_notification,
create_cold_chat_notification, create_cold_chat_notification,
) )
from src.experimental.PFC.message_storage import MongoDBMessageStorage from src.experimental.PFC.message_storage import PeeweeMessageStorage
from rich.traceback import install from rich.traceback import install
install(extra_lines=3) install(extra_lines=3)
@@ -53,7 +53,7 @@ class ChatObserver:
self.stream_id = stream_id self.stream_id = stream_id
self.private_name = private_name self.private_name = private_name
self.message_storage = MongoDBMessageStorage() self.message_storage = PeeweeMessageStorage()
# self.last_user_speak_time: Optional[float] = None # 对方上次发言时间 # self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
# self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间 # self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间

View File

@@ -1,6 +1,8 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Dict, Any from typing import List, Dict, Any
from common.database.database import db # from src.common.database.database import db # Peewee db 导入
from src.common.database.database_model import Messages # Peewee Messages 模型导入
from playhouse.shortcuts import model_to_dict # 用于将模型实例转换为字典
class MessageStorage(ABC): class MessageStorage(ABC):
@@ -47,28 +49,35 @@ class MessageStorage(ABC):
pass pass
class MongoDBMessageStorage(MessageStorage): class PeeweeMessageStorage(MessageStorage):
"""MongoDB消息存储实现""" """Peewee消息存储实现"""
async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]: async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
query = {"chat_id": chat_id, "time": {"$gt": message_time}} query = Messages.select().where(
# print(f"storage_check_message: {message_time}") (Messages.chat_id == chat_id) &
(Messages.time > message_time)
).order_by(Messages.time.asc())
return list(db.messages.find(query).sort("time", 1)) # print(f"storage_check_message: {message_time}")
messages_models = list(query)
return [model_to_dict(msg) for msg in messages_models]
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
query = {"chat_id": chat_id, "time": {"$lt": time_point}} query = Messages.select().where(
(Messages.chat_id == chat_id) &
messages = list(db.messages.find(query).sort("time", -1).limit(limit)) (Messages.time < time_point)
).order_by(Messages.time.desc()).limit(limit)
messages_models = list(query)
# 将消息按时间正序排列 # 将消息按时间正序排列
messages.reverse() messages_models.reverse()
return messages return [model_to_dict(msg) for msg in messages_models]
async def has_new_messages(self, chat_id: str, after_time: float) -> bool: async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
query = {"chat_id": chat_id, "time": {"$gt": after_time}} return Messages.select().where(
(Messages.chat_id == chat_id) &
return db.messages.find_one(query) is not None (Messages.time > after_time)
).exists()
# # 创建一个内存消息存储实现,用于测试 # # 创建一个内存消息存储实现,用于测试

View File

@@ -1,8 +1,10 @@
from src.tools.tool_can_use.base_tool import BaseTool from src.tools.tool_can_use.base_tool import BaseTool
from src.chat.utils.utils import get_embedding from src.chat.utils.utils import get_embedding
from common.database.database import db from src.common.database.database_model import Knowledges # Updated import
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from typing import Any, Union from typing import Any, Union, List # Added List
import json # Added for parsing embedding
import math # Added for cosine similarity
logger = get_logger("get_knowledge_tool") logger = get_logger("get_knowledge_tool")
@@ -30,6 +32,7 @@ class SearchKnowledgeTool(BaseTool):
Returns: Returns:
dict: 工具执行结果 dict: 工具执行结果
""" """
query = "" # Initialize query to ensure it's defined in except block
try: try:
query = function_args.get("query") query = function_args.get("query")
threshold = function_args.get("threshold", 0.4) threshold = function_args.get("threshold", 0.4)
@@ -48,9 +51,19 @@ class SearchKnowledgeTool(BaseTool):
logger.error(f"知识库搜索工具执行失败: {str(e)}") logger.error(f"知识库搜索工具执行失败: {str(e)}")
return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"} return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"}
@staticmethod
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
"""计算两个向量之间的余弦相似度"""
dot_product = sum(p * q for p, q in zip(vec1, vec2))
magnitude1 = math.sqrt(sum(p * p for p in vec1))
magnitude2 = math.sqrt(sum(q * q for q in vec2))
if magnitude1 == 0 or magnitude2 == 0:
return 0.0
return dot_product / (magnitude1 * magnitude2)
@staticmethod @staticmethod
def get_info_from_db( def get_info_from_db(
query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]: ) -> Union[str, list]:
"""从数据库中获取相关信息 """从数据库中获取相关信息
@@ -66,66 +79,49 @@ class SearchKnowledgeTool(BaseTool):
if not query_embedding: if not query_embedding:
return "" if not return_raw else [] return "" if not return_raw else []
# 使用余弦相似度计算 similar_items = []
pipeline = [ try:
{ all_knowledges = Knowledges.select()
"$addFields": { for item in all_knowledges:
"dotProduct": { try:
"$reduce": { item_embedding_str = item.embedding
"input": {"$range": [0, {"$size": "$embedding"}]}, if not item_embedding_str:
"initialValue": 0, logger.warning(f"Knowledge item ID {item.id} has empty embedding string.")
"in": { continue
"$add": [ item_embedding = json.loads(item_embedding_str)
"$$value", if not isinstance(item_embedding, list) or not all(isinstance(x, (int, float)) for x in item_embedding):
{ logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.")
"$multiply": [ continue
{"$arrayElemAt": ["$embedding", "$$this"]}, except json.JSONDecodeError:
{"$arrayElemAt": [query_embedding, "$$this"]}, logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}")
] continue
}, except AttributeError:
] logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.")
}, continue
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
}
},
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
{
"$match": {
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1}},
]
results = list(db.knowledges.aggregate(pipeline)) similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding)
logger.debug(f"知识库查询结果数量: {len(results)}")
if similarity >= threshold:
similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item})
# 按相似度降序排序
similar_items.sort(key=lambda x: x["similarity"], reverse=True)
# 应用限制
results = similar_items[:limit]
logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}")
except Exception as e:
logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}")
return "" if not return_raw else []
if not results: if not results:
return "" if not return_raw else [] return "" if not return_raw else []
if return_raw: if return_raw:
return results # Peewee 模型实例不能直接序列化为 JSON如果需要原始模型调用者需要处理
# 这里返回包含内容和相似度的字典列表
return [{"content": r["content"], "similarity": r["similarity"]} for r in results]
else: else:
# 返回所有找到的内容,用换行分隔 # 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results) return "\n".join(str(result["content"]) for result in results)