重构数据库交互以使用 Peewee ORM
- 更新数据库连接和模型定义,以便使用 Peewee for SQLite。 - 在消息存储和检索功能中,用 Peewee ORM 查询替换 MongoDB 查询。 - 为 Messages、ThinkingLog 和 OnlineTime 引入了新的模型,以方便结构化数据存储。 - 增强了数据库操作的错误处理和日志记录。 - 删除了过时的 MongoDB 集合管理代码。 - 通过利用 Peewee 内置的查询和数据操作方法来提升性能。
This commit is contained in:
@@ -10,7 +10,10 @@ from PIL import Image
|
||||
import io
|
||||
import re
|
||||
|
||||
from ...common.database.database import db
|
||||
# from gradio_client import file
|
||||
|
||||
from ...common.database.database_model import Emoji
|
||||
from ...common.database.database import db as peewee_db
|
||||
from ...config.config import global_config
|
||||
from ..utils.utils_image import image_path_to_base64, image_manager
|
||||
from ..models.utils_model import LLMRequest
|
||||
@@ -143,22 +146,21 @@ class MaiEmoji:
|
||||
# --- 数据库操作 ---
|
||||
try:
|
||||
# 准备数据库记录 for emoji collection
|
||||
emoji_record = {
|
||||
"filename": self.filename,
|
||||
"path": self.path, # 存储目录路径
|
||||
"full_path": self.full_path, # 存储完整文件路径
|
||||
"embedding": self.embedding,
|
||||
"description": self.description,
|
||||
"emotion": self.emotion,
|
||||
"hash": self.hash,
|
||||
"format": self.format,
|
||||
"timestamp": int(self.register_time),
|
||||
"usage_count": self.usage_count,
|
||||
"last_used_time": self.last_used_time,
|
||||
}
|
||||
emotion_str = ",".join(self.emotion) if self.emotion else ""
|
||||
|
||||
# 使用upsert确保记录存在或被更新
|
||||
db["emoji"].update_one({"hash": self.hash}, {"$set": emoji_record}, upsert=True)
|
||||
Emoji.create(hash=self.hash,
|
||||
full_path=self.full_path,
|
||||
format=self.format,
|
||||
description=self.description,
|
||||
emotion=emotion_str, # Store as comma-separated string
|
||||
query_count=0, # Default value
|
||||
is_registered=True,
|
||||
is_banned=False, # Default value
|
||||
record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time
|
||||
register_time=self.register_time,
|
||||
usage_count=self.usage_count,
|
||||
last_used_time=self.last_used_time,
|
||||
)
|
||||
|
||||
logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
|
||||
|
||||
@@ -166,14 +168,6 @@ class MaiEmoji:
|
||||
|
||||
except Exception as db_error:
|
||||
logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}")
|
||||
# 数据库保存失败,是否需要将文件移回?为了简化,暂时只记录错误
|
||||
# 可以考虑在这里尝试删除已移动的文件,避免残留
|
||||
try:
|
||||
if os.path.exists(self.full_path): # full_path 此时是目标路径
|
||||
os.remove(self.full_path)
|
||||
logger.warning(f"[回滚] 已删除移动失败后残留的文件: {self.full_path}")
|
||||
except Exception as remove_error:
|
||||
logger.error(f"[错误] 回滚删除文件失败: {remove_error}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
@@ -201,10 +195,14 @@ class MaiEmoji:
|
||||
# 文件删除失败,但仍然尝试删除数据库记录
|
||||
|
||||
# 2. 删除数据库记录
|
||||
result = db.emoji.delete_one({"hash": self.hash})
|
||||
deleted_in_db = result.deleted_count > 0
|
||||
try:
|
||||
will_delete_emoji = Emoji.get(Emoji.hash == self.hash)
|
||||
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
|
||||
except Emoji.DoesNotExist:
|
||||
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
|
||||
result = 0 # Indicate no DB record was deleted
|
||||
|
||||
if deleted_in_db:
|
||||
if result > 0:
|
||||
logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})")
|
||||
# 3. 标记对象已被删除
|
||||
self.is_deleted = True
|
||||
@@ -246,44 +244,43 @@ def _emoji_objects_to_readable_list(emoji_objects):
|
||||
def _to_emoji_objects(data):
|
||||
emoji_objects = []
|
||||
load_errors = 0
|
||||
# data is now an iterable of Peewee Emoji model instances
|
||||
emoji_data_list = list(data)
|
||||
|
||||
for emoji_data in emoji_data_list:
|
||||
full_path = emoji_data.get("full_path")
|
||||
for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance
|
||||
full_path = emoji_data.full_path
|
||||
if not full_path:
|
||||
logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: {emoji_data.get('_id')}")
|
||||
logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}")
|
||||
load_errors += 1
|
||||
continue # 跳过缺少 full_path 的记录
|
||||
continue
|
||||
|
||||
try:
|
||||
# 使用 full_path 初始化 MaiEmoji 对象
|
||||
emoji = MaiEmoji(full_path=full_path)
|
||||
|
||||
# 设置从数据库加载的属性
|
||||
emoji.hash = emoji_data.get("hash", "")
|
||||
# 如果 hash 为空,也跳过?取决于业务逻辑
|
||||
emoji.hash = emoji_data.hash
|
||||
if not emoji.hash:
|
||||
logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}")
|
||||
load_errors += 1
|
||||
continue
|
||||
|
||||
emoji.description = emoji_data.get("description", "")
|
||||
emoji.emotion = emoji_data.get("emotion", [])
|
||||
emoji.usage_count = emoji_data.get("usage_count", 0)
|
||||
# 优先使用 last_used_time,否则用 timestamp,最后用当前时间
|
||||
last_used = emoji_data.get("last_used_time")
|
||||
timestamp = emoji_data.get("timestamp")
|
||||
emoji.last_used_time = (
|
||||
last_used if last_used is not None else (timestamp if timestamp is not None else time.time())
|
||||
)
|
||||
emoji.register_time = timestamp if timestamp is not None else time.time()
|
||||
emoji.format = emoji_data.get("format", "") # 加载格式
|
||||
emoji.description = emoji_data.description
|
||||
# Deserialize emotion string from DB to list
|
||||
emoji.emotion = emoji_data.emotion.split(',') if emoji_data.emotion else []
|
||||
emoji.usage_count = emoji_data.usage_count
|
||||
|
||||
# 不需要再手动设置 path 和 filename,__init__ 会自动处理
|
||||
db_last_used_time = emoji_data.last_used_time
|
||||
db_register_time = emoji_data.register_time
|
||||
|
||||
# If last_used_time from DB is None, use MaiEmoji's initialized register_time or current time
|
||||
emoji.last_used_time = db_last_used_time if db_last_used_time is not None else emoji.register_time
|
||||
# If register_time from DB is None, use MaiEmoji's initialized register_time (which is time.time())
|
||||
emoji.register_time = db_register_time if db_register_time is not None else emoji.register_time
|
||||
|
||||
emoji.format = emoji_data.format
|
||||
|
||||
emoji_objects.append(emoji)
|
||||
|
||||
except ValueError as ve: # 捕获 __init__ 可能的错误
|
||||
except ValueError as ve:
|
||||
logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}")
|
||||
load_errors += 1
|
||||
except Exception as e:
|
||||
@@ -385,12 +382,13 @@ class EmojiManager:
|
||||
"""初始化数据库连接和表情目录"""
|
||||
if not self._initialized:
|
||||
try:
|
||||
self._ensure_emoji_collection()
|
||||
# Ensure Peewee database connection is up and tables are created
|
||||
if not peewee_db.is_closed():
|
||||
peewee_db.connect(reuse_if_open=True)
|
||||
Emoji.create_table(safe=True) # Ensures table exists
|
||||
|
||||
_ensure_emoji_dir()
|
||||
self._initialized = True
|
||||
# 更新表情包数量
|
||||
# 启动时执行一次完整性检查
|
||||
# await self.check_emoji_file_integrity()
|
||||
except Exception as e:
|
||||
logger.exception(f"初始化表情管理器失败: {e}")
|
||||
|
||||
@@ -401,33 +399,15 @@ class EmojiManager:
|
||||
if not self._initialized:
|
||||
raise RuntimeError("EmojiManager not initialized")
|
||||
|
||||
@staticmethod
|
||||
def _ensure_emoji_collection():
|
||||
"""确保emoji集合存在并创建索引
|
||||
|
||||
这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。
|
||||
|
||||
索引的作用是加快数据库查询速度:
|
||||
- embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包
|
||||
- tags字段的普通索引: 加快按标签搜索表情包的速度
|
||||
- filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度
|
||||
|
||||
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
|
||||
"""
|
||||
if "emoji" not in db.list_collection_names():
|
||||
db.create_collection("emoji")
|
||||
db.emoji.create_index([("embedding", "2dsphere")])
|
||||
db.emoji.create_index([("filename", 1)], unique=True)
|
||||
|
||||
def record_usage(self, emoji_hash: str):
|
||||
"""记录表情使用次数"""
|
||||
try:
|
||||
db.emoji.update_one({"hash": emoji_hash}, {"$inc": {"usage_count": 1}})
|
||||
for emoji in self.emoji_objects:
|
||||
if emoji.hash == emoji_hash:
|
||||
emoji.usage_count += 1
|
||||
break
|
||||
|
||||
emoji_update = Emoji.get(Emoji.hash == emoji_hash)
|
||||
emoji_update.usage_count += 1
|
||||
emoji_update.last_used_time = time.time() # Update last used time
|
||||
emoji_update.save() # Persist changes to DB
|
||||
except Emoji.DoesNotExist:
|
||||
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
|
||||
except Exception as e:
|
||||
logger.error(f"记录表情使用失败: {str(e)}")
|
||||
|
||||
@@ -657,9 +637,10 @@ class EmojiManager:
|
||||
"""获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects"""
|
||||
try:
|
||||
self._ensure_db()
|
||||
logger.info("[数据库] 开始加载所有表情包记录...")
|
||||
logger.info("[数据库] 开始加载所有表情包记录 (Peewee)...")
|
||||
|
||||
emoji_objects, load_errors = _to_emoji_objects(db.emoji.find())
|
||||
emoji_peewee_instances = Emoji.select()
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
|
||||
|
||||
# 更新内存中的列表和数量
|
||||
self.emoji_objects = emoji_objects
|
||||
@@ -686,15 +667,16 @@ class EmojiManager:
|
||||
try:
|
||||
self._ensure_db()
|
||||
|
||||
query = {}
|
||||
if emoji_hash:
|
||||
query = {"hash": emoji_hash}
|
||||
query = Emoji.select().where(Emoji.hash == emoji_hash)
|
||||
else:
|
||||
logger.warning(
|
||||
"[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。"
|
||||
)
|
||||
query = Emoji.select()
|
||||
|
||||
emoji_objects, load_errors = _to_emoji_objects(db.emoji.find(query))
|
||||
emoji_peewee_instances = query
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
|
||||
|
||||
if load_errors > 0:
|
||||
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
|
||||
@@ -908,6 +890,44 @@ class EmojiManager:
|
||||
logger.error(f"获取表情包描述失败: {str(e)}")
|
||||
return "", []
|
||||
|
||||
# async def register_emoji_by_filename(self, filename: str) -> bool:
|
||||
# if global_config.EMOJI_CHECK:
|
||||
# prompt = f'''
|
||||
# 这是一个表情包,请对这个表情包进行审核,标准如下:
|
||||
# 1. 必须符合"{global_config.EMOJI_CHECK_PROMPT}"的要求
|
||||
# 2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗
|
||||
# 3. 不能是任何形式的截图,聊天记录或视频截图
|
||||
# 4. 不要出现5个以上文字
|
||||
# 请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容
|
||||
# '''
|
||||
# content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
# if content == "否":
|
||||
# return "", []
|
||||
|
||||
# # 分析情感含义
|
||||
# emotion_prompt = f"""
|
||||
# 请你识别这个表情包的含义和适用场景,给我简短的描述,每个描述不要超过15个字
|
||||
# 这是一个基于这个表情包的描述:'{description}'
|
||||
# 你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析
|
||||
# 请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔
|
||||
# """
|
||||
# emotions_text, _ = await self.llm_emotion_judge.generate_response_async(emotion_prompt, temperature=0.7)
|
||||
|
||||
# # 处理情感列表
|
||||
# emotions = [e.strip() for e in emotions_text.split(",") if e.strip()]
|
||||
|
||||
# # 根据情感标签数量随机选择喵~超过5个选3个,超过2个选2个
|
||||
# if len(emotions) > 5:
|
||||
# emotions = random.sample(emotions, 3)
|
||||
# elif len(emotions) > 2:
|
||||
# emotions = random.sample(emotions, 2)
|
||||
|
||||
# return f"[表情包:{description}]", emotions
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"获取表情包描述失败: {str(e)}")
|
||||
# return "", []
|
||||
|
||||
async def register_emoji_by_filename(self, filename: str) -> bool:
|
||||
"""读取指定文件名的表情包图片,分析并注册到数据库
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from src.chat.person_info.relationship_manager import relationship_manager
|
||||
from src.chat.utils.utils import get_embedding
|
||||
import time
|
||||
from typing import Union, Optional
|
||||
from common.database.database import db
|
||||
# from common.database.database import db
|
||||
from src.chat.utils.utils import get_recent_group_speaker
|
||||
from src.manager.mood_manager import mood_manager
|
||||
from src.chat.memory_system.Hippocampus import HippocampusManager
|
||||
@@ -15,6 +15,9 @@ from src.chat.knowledge.knowledge_lib import qa_manager
|
||||
from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
|
||||
# import traceback
|
||||
import random
|
||||
import json
|
||||
import math
|
||||
from src.common.database.database_model import Knowledges
|
||||
|
||||
|
||||
logger = get_logger("prompt")
|
||||
@@ -69,7 +72,7 @@ def init_prompt():
|
||||
你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt},{reply_style1},
|
||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}。{prompt_ger}
|
||||
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容。
|
||||
{moderation_prompt}
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容""",
|
||||
"reasoning_prompt_main",
|
||||
@@ -439,30 +442,6 @@ class PromptBuilder:
|
||||
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||
# 1. 先从LLM获取主题,类似于记忆系统的做法
|
||||
topics = []
|
||||
# try:
|
||||
# # 先尝试使用记忆系统的方法获取主题
|
||||
# hippocampus = HippocampusManager.get_instance()._hippocampus
|
||||
# topic_num = min(5, max(1, int(len(message) * 0.1)))
|
||||
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
|
||||
|
||||
# # 提取关键词
|
||||
# topics = re.findall(r"<([^>]+)>", topics_response[0])
|
||||
# if not topics:
|
||||
# topics = []
|
||||
# else:
|
||||
# topics = [
|
||||
# topic.strip()
|
||||
# for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||
# if topic.strip()
|
||||
# ]
|
||||
|
||||
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"从LLM提取主题失败: {str(e)}")
|
||||
# # 如果LLM提取失败,使用jieba分词提取关键词作为备选
|
||||
# words = jieba.cut(message)
|
||||
# topics = [word for word in words if len(word) > 1][:5]
|
||||
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
|
||||
|
||||
# 如果无法提取到主题,直接使用整个消息
|
||||
if not topics:
|
||||
@@ -572,8 +551,6 @@ class PromptBuilder:
|
||||
for _i, result in enumerate(results, 1):
|
||||
_similarity = result["similarity"]
|
||||
content = result["content"].strip()
|
||||
# 调试:为内容添加序号和相似度信息
|
||||
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
|
||||
related_info += f"{content}\n"
|
||||
related_info += "\n"
|
||||
|
||||
@@ -602,14 +579,14 @@ class PromptBuilder:
|
||||
return related_info
|
||||
else:
|
||||
logger.debug("从LPMM知识库获取知识失败,使用旧版数据库进行检索")
|
||||
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38)
|
||||
knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
|
||||
related_info += knowledge_from_old
|
||||
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||
return related_info
|
||||
except Exception as e:
|
||||
logger.error(f"获取知识库内容时发生异常: {str(e)}")
|
||||
try:
|
||||
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38)
|
||||
knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
|
||||
related_info += knowledge_from_old
|
||||
logger.debug(
|
||||
f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}"
|
||||
@@ -625,70 +602,70 @@ class PromptBuilder:
|
||||
) -> Union[str, list]:
|
||||
if not query_embedding:
|
||||
return "" if not return_raw else []
|
||||
# 使用余弦相似度计算
|
||||
pipeline = [
|
||||
{
|
||||
"$addFields": {
|
||||
"dotProduct": {
|
||||
"$reduce": {
|
||||
"input": {"$range": [0, {"$size": "$embedding"}]},
|
||||
"initialValue": 0,
|
||||
"in": {
|
||||
"$add": [
|
||||
"$$value",
|
||||
{
|
||||
"$multiply": [
|
||||
{"$arrayElemAt": ["$embedding", "$$this"]},
|
||||
{"$arrayElemAt": [query_embedding, "$$this"]},
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
},
|
||||
"magnitude1": {
|
||||
"$sqrt": {
|
||||
"$reduce": {
|
||||
"input": "$embedding",
|
||||
"initialValue": 0,
|
||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
||||
}
|
||||
}
|
||||
},
|
||||
"magnitude2": {
|
||||
"$sqrt": {
|
||||
"$reduce": {
|
||||
"input": query_embedding,
|
||||
"initialValue": 0,
|
||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
|
||||
{
|
||||
"$match": {
|
||||
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
|
||||
}
|
||||
},
|
||||
{"$sort": {"similarity": -1}},
|
||||
{"$limit": limit},
|
||||
{"$project": {"content": 1, "similarity": 1}},
|
||||
]
|
||||
|
||||
results = list(db.knowledges.aggregate(pipeline))
|
||||
logger.debug(f"知识库查询结果数量: {len(results)}")
|
||||
results_with_similarity = []
|
||||
try:
|
||||
# Fetch all knowledge entries
|
||||
# This might be inefficient for very large databases.
|
||||
# Consider strategies like FAISS or other vector search libraries if performance becomes an issue.
|
||||
all_knowledges = Knowledges.select()
|
||||
|
||||
if not results:
|
||||
if not all_knowledges:
|
||||
return "" if not return_raw else []
|
||||
|
||||
query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding))
|
||||
if query_embedding_magnitude == 0: # Avoid division by zero
|
||||
return "" if not return_raw else []
|
||||
|
||||
for knowledge_item in all_knowledges:
|
||||
try:
|
||||
db_embedding_str = knowledge_item.embedding
|
||||
db_embedding = json.loads(db_embedding_str)
|
||||
|
||||
if len(db_embedding) != len(query_embedding):
|
||||
logger.warning(f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping.")
|
||||
continue
|
||||
|
||||
# Calculate Cosine Similarity
|
||||
dot_product = sum(q * d for q, d in zip(query_embedding, db_embedding))
|
||||
db_embedding_magnitude = math.sqrt(sum(x * x for x in db_embedding))
|
||||
|
||||
if db_embedding_magnitude == 0: # Avoid division by zero
|
||||
similarity = 0.0
|
||||
else:
|
||||
similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude)
|
||||
|
||||
if similarity >= threshold:
|
||||
results_with_similarity.append({
|
||||
"content": knowledge_item.content,
|
||||
"similarity": similarity
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing knowledge item: {e}")
|
||||
|
||||
|
||||
# Sort by similarity in descending order
|
||||
results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True)
|
||||
|
||||
# Limit results
|
||||
limited_results = results_with_similarity[:limit]
|
||||
|
||||
logger.debug(f"知识库查询结果数量 (after Peewee processing): {len(limited_results)}")
|
||||
|
||||
if not limited_results:
|
||||
return "" if not return_raw else []
|
||||
|
||||
if return_raw:
|
||||
return limited_results
|
||||
else:
|
||||
return "\n".join(str(result["content"]) for result in limited_results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error querying Knowledges with Peewee: {e}")
|
||||
return "" if not return_raw else []
|
||||
|
||||
if return_raw:
|
||||
return results
|
||||
else:
|
||||
# 返回所有找到的内容,用换行分隔
|
||||
return "\n".join(str(result["content"]) for result in results)
|
||||
|
||||
|
||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||
"""
|
||||
|
||||
@@ -10,7 +10,7 @@ import jieba
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from collections import Counter
|
||||
from ...common.database.database import db
|
||||
from ...common.database.database import memory_db as db
|
||||
from ...chat.models.utils_model import LLMRequest
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Dict, Optional
|
||||
|
||||
|
||||
from ...common.database.database import db
|
||||
from ...common.database.database_model import ChatStreams # 新增导入
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
|
||||
from src.common.logger_manager import get_logger
|
||||
@@ -82,7 +83,13 @@ class ChatManager:
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self._ensure_collection()
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
# 确保 ChatStreams 表存在
|
||||
db.create_tables([ChatStreams], safe=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
|
||||
|
||||
self._initialized = True
|
||||
# 在事件循环中启动初始化
|
||||
# asyncio.create_task(self._initialize())
|
||||
@@ -107,15 +114,6 @@ class ChatManager:
|
||||
except Exception as e:
|
||||
logger.error(f"聊天流自动保存失败: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _ensure_collection():
|
||||
"""确保数据库集合存在并创建索引"""
|
||||
if "chat_streams" not in db.list_collection_names():
|
||||
db.create_collection("chat_streams")
|
||||
# 创建索引
|
||||
db.chat_streams.create_index([("stream_id", 1)], unique=True)
|
||||
db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)])
|
||||
|
||||
@staticmethod
|
||||
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
|
||||
"""生成聊天流唯一ID"""
|
||||
@@ -151,16 +149,43 @@ class ChatManager:
|
||||
stream = self.streams[stream_id]
|
||||
# 更新用户信息和群组信息
|
||||
stream.update_active_time()
|
||||
stream = copy.deepcopy(stream)
|
||||
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
|
||||
stream.user_info = user_info
|
||||
if group_info:
|
||||
stream.group_info = group_info
|
||||
return stream
|
||||
|
||||
# 检查数据库中是否存在
|
||||
data = db.chat_streams.find_one({"stream_id": stream_id})
|
||||
if data:
|
||||
stream = ChatStream.from_dict(data)
|
||||
def _db_find_stream_sync(s_id: str):
|
||||
return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
|
||||
|
||||
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
|
||||
|
||||
if model_instance:
|
||||
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"user_id": model_instance.user_id,
|
||||
"user_nickname": model_instance.user_nickname,
|
||||
"user_cardname": model_instance.user_cardname or "",
|
||||
}
|
||||
group_info_data = None
|
||||
if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息
|
||||
group_info_data = {
|
||||
"platform": model_instance.group_platform,
|
||||
"group_id": model_instance.group_id,
|
||||
"group_name": model_instance.group_name,
|
||||
}
|
||||
|
||||
data_for_from_dict = {
|
||||
"stream_id": model_instance.stream_id,
|
||||
"platform": model_instance.platform,
|
||||
"user_info": user_info_data,
|
||||
"group_info": group_info_data,
|
||||
"create_time": model_instance.create_time,
|
||||
"last_active_time": model_instance.last_active_time,
|
||||
}
|
||||
stream = ChatStream.from_dict(data_for_from_dict)
|
||||
# 更新用户信息和群组信息
|
||||
stream.user_info = user_info
|
||||
if group_info:
|
||||
@@ -175,7 +200,7 @@ class ChatManager:
|
||||
group_info=group_info,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建聊天流失败: {e}")
|
||||
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
# 保存到内存和数据库
|
||||
@@ -205,15 +230,38 @@ class ChatManager:
|
||||
elif stream.user_info and stream.user_info.user_nickname:
|
||||
return f"{stream.user_info.user_nickname}的私聊"
|
||||
else:
|
||||
# 如果没有群名或用户昵称,返回 None 或其他默认值
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _save_stream(stream: ChatStream):
|
||||
"""保存聊天流到数据库"""
|
||||
if not stream.saved:
|
||||
db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True)
|
||||
stream.saved = True
|
||||
stream_data_dict = stream.to_dict()
|
||||
|
||||
def _db_save_stream_sync(s_data_dict: dict):
|
||||
user_info_d = s_data_dict.get("user_info")
|
||||
group_info_d = s_data_dict.get("group_info")
|
||||
|
||||
fields_to_save = {
|
||||
"platform": s_data_dict["platform"],
|
||||
"create_time": s_data_dict["create_time"],
|
||||
"last_active_time": s_data_dict["last_active_time"],
|
||||
"user_platform": user_info_d["platform"] if user_info_d else "",
|
||||
"user_id": user_info_d["user_id"] if user_info_d else "",
|
||||
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
|
||||
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
|
||||
"group_platform": group_info_d["platform"] if group_info_d else "",
|
||||
"group_id": group_info_d["group_id"] if group_info_d else "",
|
||||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||||
}
|
||||
|
||||
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
|
||||
stream.saved = True
|
||||
except Exception as e:
|
||||
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
|
||||
|
||||
async def _save_all_streams(self):
|
||||
"""保存所有聊天流"""
|
||||
@@ -222,10 +270,44 @@ class ChatManager:
|
||||
|
||||
async def load_all_streams(self):
|
||||
"""从数据库加载所有聊天流"""
|
||||
all_streams = db.chat_streams.find({})
|
||||
for data in all_streams:
|
||||
stream = ChatStream.from_dict(data)
|
||||
self.streams[stream.stream_id] = stream
|
||||
|
||||
def _db_load_all_streams_sync():
|
||||
loaded_streams_data = []
|
||||
for model_instance in ChatStreams.select():
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"user_id": model_instance.user_id,
|
||||
"user_nickname": model_instance.user_nickname,
|
||||
"user_cardname": model_instance.user_cardname or "",
|
||||
}
|
||||
group_info_data = None
|
||||
if model_instance.group_id:
|
||||
group_info_data = {
|
||||
"platform": model_instance.group_platform,
|
||||
"group_id": model_instance.group_id,
|
||||
"group_name": model_instance.group_name,
|
||||
}
|
||||
|
||||
data_for_from_dict = {
|
||||
"stream_id": model_instance.stream_id,
|
||||
"platform": model_instance.platform,
|
||||
"user_info": user_info_data,
|
||||
"group_info": group_info_data,
|
||||
"create_time": model_instance.create_time,
|
||||
"last_active_time": model_instance.last_active_time,
|
||||
}
|
||||
loaded_streams_data.append(data_for_from_dict)
|
||||
return loaded_streams_data
|
||||
|
||||
try:
|
||||
all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync)
|
||||
self.streams.clear()
|
||||
for data in all_streams_data_list:
|
||||
stream = ChatStream.from_dict(data)
|
||||
stream.saved = True
|
||||
self.streams[stream.stream_id] = stream
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
|
||||
@@ -12,7 +12,8 @@ import base64
|
||||
from PIL import Image
|
||||
import io
|
||||
import os
|
||||
from ...common.database.database import db
|
||||
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
||||
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
|
||||
from ...config.config import global_config
|
||||
from rich.traceback import install
|
||||
|
||||
@@ -85,8 +86,6 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any
|
||||
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
|
||||
f"{image_base64[:10]}...{image_base64[-10:]}"
|
||||
)
|
||||
# if isinstance(content, str) and len(content) > 100:
|
||||
# payload["messages"][0]["content"] = content[:100]
|
||||
return payload
|
||||
|
||||
|
||||
@@ -134,13 +133,11 @@ class LLMRequest:
|
||||
def _init_database():
|
||||
"""初始化数据库集合"""
|
||||
try:
|
||||
# 创建llm_usage集合的索引
|
||||
db.llm_usage.create_index([("timestamp", 1)])
|
||||
db.llm_usage.create_index([("model_name", 1)])
|
||||
db.llm_usage.create_index([("user_id", 1)])
|
||||
db.llm_usage.create_index([("request_type", 1)])
|
||||
# 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误
|
||||
db.create_tables([LLMUsage], safe=True)
|
||||
logger.info("LLMUsage 表已初始化/确保存在。")
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库索引失败: {str(e)}")
|
||||
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
|
||||
|
||||
def _record_usage(
|
||||
self,
|
||||
@@ -165,19 +162,19 @@ class LLMRequest:
|
||||
request_type = self.request_type
|
||||
|
||||
try:
|
||||
usage_data = {
|
||||
"model_name": self.model_name,
|
||||
"user_id": user_id,
|
||||
"request_type": request_type,
|
||||
"endpoint": endpoint,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"cost": self._calculate_cost(prompt_tokens, completion_tokens),
|
||||
"status": "success",
|
||||
"timestamp": datetime.now(),
|
||||
}
|
||||
db.llm_usage.insert_one(usage_data)
|
||||
# 使用 Peewee 模型创建记录
|
||||
LLMUsage.create(
|
||||
model_name=self.model_name,
|
||||
user_id=user_id,
|
||||
request_type=request_type,
|
||||
endpoint=endpoint,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost=self._calculate_cost(prompt_tokens, completion_tokens),
|
||||
status="success",
|
||||
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
|
||||
)
|
||||
logger.trace(
|
||||
f"Token使用情况 - 模型: {self.model_name}, "
|
||||
f"用户: {user_id}, 类型: {request_type}, "
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from src.common.logger_manager import get_logger
|
||||
from ...common.database.database import db
|
||||
from ...common.database.database_model import PersonInfo # 新增导入
|
||||
import copy
|
||||
import hashlib
|
||||
from typing import Any, Callable, Dict
|
||||
@@ -16,7 +17,7 @@ matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import json
|
||||
import json # 新增导入
|
||||
import re
|
||||
|
||||
|
||||
@@ -43,17 +44,13 @@ person_info_default = {
|
||||
"platform": None,
|
||||
"user_id": None,
|
||||
"nickname": None,
|
||||
# "age" : 0,
|
||||
"relationship_value": 0,
|
||||
# "saved" : True,
|
||||
# "impression" : None,
|
||||
# "gender" : Unkown,
|
||||
"konw_time": 0,
|
||||
"msg_interval": 2000,
|
||||
"msg_interval_list": [],
|
||||
"user_cardname": None, # 添加群名片
|
||||
"user_avatar": None, # 添加头像信息(例如URL或标识符)
|
||||
} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项
|
||||
"msg_interval_list": [], # 将作为 JSON 字符串存储在 Peewee 的 TextField
|
||||
"user_cardname": None, # 注意:此字段不在 PersonInfo Peewee 模型中
|
||||
"user_avatar": None, # 注意:此字段不在 PersonInfo Peewee 模型中
|
||||
}
|
||||
|
||||
|
||||
class PersonInfoManager:
|
||||
@@ -64,21 +61,26 @@ class PersonInfoManager:
|
||||
max_tokens=256,
|
||||
request_type="qv_name",
|
||||
)
|
||||
if "person_info" not in db.list_collection_names():
|
||||
db.create_collection("person_info")
|
||||
db.person_info.create_index("person_id", unique=True)
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
db.create_tables([PersonInfo], safe=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
|
||||
|
||||
# 初始化时读取所有person_name
|
||||
cursor = db.person_info.find({"person_name": {"$exists": True}}, {"person_id": 1, "person_name": 1, "_id": 0})
|
||||
for doc in cursor:
|
||||
if doc.get("person_name"):
|
||||
self.person_name_list[doc["person_id"]] = doc["person_name"]
|
||||
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称")
|
||||
try:
|
||||
for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where(
|
||||
PersonInfo.person_name.is_null(False)
|
||||
):
|
||||
if record.person_name:
|
||||
self.person_name_list[record.person_id] = record.person_name
|
||||
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
|
||||
except Exception as e:
|
||||
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def get_person_id(platform: str, user_id: int):
|
||||
"""获取唯一id"""
|
||||
# 如果platform中存在-,就截取-后面的部分
|
||||
if "-" in platform:
|
||||
platform = platform.split("-")[1]
|
||||
|
||||
@@ -86,13 +88,17 @@ class PersonInfoManager:
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
def is_person_known(self, platform: str, user_id: int):
|
||||
async def is_person_known(self, platform: str, user_id: int):
|
||||
"""判断是否认识某人"""
|
||||
person_id = self.get_person_id(platform, user_id)
|
||||
document = db.person_info.find_one({"person_id": person_id})
|
||||
if document:
|
||||
return True
|
||||
else:
|
||||
|
||||
def _db_check_known_sync(p_id: str):
|
||||
return PersonInfo.get_or_none(PersonInfo.person_id == p_id) is not None
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_db_check_known_sync, person_id)
|
||||
except Exception as e:
|
||||
logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@@ -103,73 +109,111 @@ class PersonInfoManager:
|
||||
return
|
||||
|
||||
_person_info_default = copy.deepcopy(person_info_default)
|
||||
_person_info_default["person_id"] = person_id
|
||||
model_fields = PersonInfo._meta.fields.keys()
|
||||
|
||||
final_data = {"person_id": person_id}
|
||||
|
||||
if data:
|
||||
for key in _person_info_default:
|
||||
if key != "person_id" and key in data:
|
||||
_person_info_default[key] = data[key]
|
||||
for key, value in data.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = value
|
||||
|
||||
db.person_info.insert_one(_person_info_default)
|
||||
for key, default_value in _person_info_default.items():
|
||||
if key in model_fields and key not in final_data:
|
||||
final_data[key] = default_value
|
||||
|
||||
if "msg_interval_list" in final_data and isinstance(final_data["msg_interval_list"], list):
|
||||
final_data["msg_interval_list"] = json.dumps(final_data["msg_interval_list"])
|
||||
elif "msg_interval_list" not in final_data and "msg_interval_list" in model_fields:
|
||||
final_data["msg_interval_list"] = json.dumps([])
|
||||
|
||||
def _db_create_sync(p_data: dict):
|
||||
try:
|
||||
PersonInfo.create(**p_data)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (Peewee): {e}")
|
||||
return False
|
||||
|
||||
await asyncio.to_thread(_db_create_sync, final_data)
|
||||
|
||||
async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None):
|
||||
"""更新某一个字段,会补全"""
|
||||
if field_name not in person_info_default.keys():
|
||||
logger.debug(f"更新'{field_name}'失败,未定义的字段")
|
||||
if field_name not in PersonInfo._meta.fields:
|
||||
if field_name in person_info_default:
|
||||
logger.debug(f"更新'{field_name}'跳过,字段存在于默认配置但不在 PersonInfo Peewee 模型中。")
|
||||
return
|
||||
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。")
|
||||
return
|
||||
|
||||
document = db.person_info.find_one({"person_id": person_id})
|
||||
def _db_update_sync(p_id: str, f_name: str, val):
|
||||
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||
if record:
|
||||
if f_name == "msg_interval_list" and isinstance(val, list):
|
||||
setattr(record, f_name, json.dumps(val))
|
||||
else:
|
||||
setattr(record, f_name, val)
|
||||
record.save()
|
||||
return True, False
|
||||
return False, True
|
||||
|
||||
if document:
|
||||
db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}})
|
||||
else:
|
||||
data[field_name] = value
|
||||
logger.debug(f"更新时{person_id}不存在,已新建")
|
||||
await self.create_person_info(person_id, data)
|
||||
found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, value)
|
||||
|
||||
if needs_creation:
|
||||
logger.debug(f"更新时 {person_id} 不存在,将新建。")
|
||||
creation_data = data if data is not None else {}
|
||||
creation_data[field_name] = value
|
||||
if "platform" not in creation_data or "user_id" not in creation_data:
|
||||
logger.warning(f"为 {person_id} 创建记录时,platform/user_id 可能缺失。")
|
||||
|
||||
await self.create_person_info(person_id, creation_data)
|
||||
|
||||
@staticmethod
|
||||
async def has_one_field(person_id: str, field_name: str):
|
||||
"""判断是否存在某一个字段"""
|
||||
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1})
|
||||
if document:
|
||||
return True
|
||||
else:
|
||||
if field_name not in PersonInfo._meta.fields:
|
||||
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。")
|
||||
return False
|
||||
|
||||
def _db_has_field_sync(p_id: str, f_name: str):
|
||||
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||
if record:
|
||||
return True
|
||||
return False
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
|
||||
except Exception as e:
|
||||
logger.error(f"检查字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_from_text(text: str) -> dict:
|
||||
"""从文本中提取JSON数据的高容错方法"""
|
||||
try:
|
||||
# 尝试直接解析
|
||||
parsed_json = json.loads(text)
|
||||
# 如果解析结果是列表,尝试取第一个元素
|
||||
if isinstance(parsed_json, list):
|
||||
if parsed_json: # 检查列表是否为空
|
||||
if parsed_json:
|
||||
parsed_json = parsed_json[0]
|
||||
else: # 如果列表为空,重置为 None,走后续逻辑
|
||||
else:
|
||||
parsed_json = None
|
||||
# 确保解析结果是字典
|
||||
if isinstance(parsed_json, dict):
|
||||
return parsed_json
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# 解析失败,继续尝试其他方法
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"尝试直接解析JSON时发生意外错误: {e}")
|
||||
pass # 继续尝试其他方法
|
||||
pass
|
||||
|
||||
# 如果直接解析失败或结果不是字典
|
||||
try:
|
||||
# 尝试找到JSON对象格式的部分
|
||||
json_pattern = r"\{[^{}]*\}"
|
||||
matches = re.findall(json_pattern, text)
|
||||
if matches:
|
||||
parsed_obj = json.loads(matches[0])
|
||||
if isinstance(parsed_obj, dict): # 确保是字典
|
||||
if isinstance(parsed_obj, dict):
|
||||
return parsed_obj
|
||||
|
||||
# 如果上面都失败了,尝试提取键值对
|
||||
nickname_pattern = r'"nickname"[:\s]+"([^"]+)"'
|
||||
reason_pattern = r'"reason"[:\s]+"([^"]+)"'
|
||||
|
||||
@@ -184,7 +228,6 @@ class PersonInfoManager:
|
||||
except Exception as e:
|
||||
logger.error(f"后备JSON提取失败: {str(e)}")
|
||||
|
||||
# 如果所有方法都失败了,返回默认字典
|
||||
logger.warning(f"无法从文本中提取有效的JSON字典: {text}")
|
||||
return {"nickname": "", "reason": ""}
|
||||
|
||||
@@ -199,9 +242,11 @@ class PersonInfoManager:
|
||||
old_name = await self.get_value(person_id, "person_name")
|
||||
old_reason = await self.get_value(person_id, "name_reason")
|
||||
|
||||
max_retries = 5 # 最大重试次数
|
||||
max_retries = 5
|
||||
current_try = 0
|
||||
existing_names = ""
|
||||
existing_names_str = ""
|
||||
current_name_set = set(self.person_name_list.values())
|
||||
|
||||
while current_try < max_retries:
|
||||
individuality = Individuality.get_instance()
|
||||
prompt_personality = individuality.get_prompt(x_person=2, level=1)
|
||||
@@ -216,45 +261,55 @@ class PersonInfoManager:
|
||||
qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason},"
|
||||
|
||||
qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸"
|
||||
qv_name_prompt += "\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的qq昵称,可以稍作修改"
|
||||
|
||||
if existing_names_str:
|
||||
qv_name_prompt += f"\n请注意,以下名称已被你尝试过或已知存在,请避免:{existing_names_str}。\n"
|
||||
|
||||
if len(current_name_set) < 50 and current_name_set:
|
||||
qv_name_prompt += f"已知的其他昵称有: {', '.join(list(current_name_set)[:10])}等。\n"
|
||||
|
||||
qv_name_prompt += (
|
||||
"\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的qq昵称,可以稍作修改"
|
||||
)
|
||||
if existing_names:
|
||||
qv_name_prompt += f"\n请注意,以下名称已被使用,不要使用以下昵称:{existing_names}。\n"
|
||||
qv_name_prompt += "请用json给出你的想法,并给出理由,示例如下:"
|
||||
qv_name_prompt += """{
|
||||
"nickname": "昵称",
|
||||
"reason": "理由"
|
||||
}"""
|
||||
# logger.debug(f"取名提示词:{qv_name_prompt}")
|
||||
response = await self.qv_name_llm.generate_response(qv_name_prompt)
|
||||
logger.trace(f"取名提示词:{qv_name_prompt}\n取名回复:{response}")
|
||||
result = self._extract_json_from_text(response[0])
|
||||
|
||||
if not result["nickname"]:
|
||||
logger.error("生成的昵称为空,重试中...")
|
||||
if not result or not result.get("nickname"):
|
||||
logger.error("生成的昵称为空或结果格式不正确,重试中...")
|
||||
current_try += 1
|
||||
continue
|
||||
|
||||
# 检查生成的昵称是否已存在
|
||||
if result["nickname"] not in self.person_name_list.values():
|
||||
# 更新数据库和内存中的列表
|
||||
await self.update_one_field(person_id, "person_name", result["nickname"])
|
||||
# await self.update_one_field(person_id, "nickname", user_nickname)
|
||||
# await self.update_one_field(person_id, "avatar", user_avatar)
|
||||
await self.update_one_field(person_id, "name_reason", result["reason"])
|
||||
generated_nickname = result["nickname"]
|
||||
|
||||
self.person_name_list[person_id] = result["nickname"]
|
||||
# logger.debug(f"用户 {person_id} 的名称已更新为 {result['nickname']},原因:{result['reason']}")
|
||||
is_duplicate = False
|
||||
if generated_nickname in current_name_set:
|
||||
is_duplicate = True
|
||||
else:
|
||||
def _db_check_name_exists_sync(name_to_check):
|
||||
return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
|
||||
|
||||
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
|
||||
is_duplicate = True
|
||||
current_name_set.add(generated_nickname)
|
||||
|
||||
if not is_duplicate:
|
||||
await self.update_one_field(person_id, "person_name", generated_nickname)
|
||||
await self.update_one_field(person_id, "name_reason", result.get("reason", "未提供理由"))
|
||||
|
||||
self.person_name_list[person_id] = generated_nickname
|
||||
return result
|
||||
else:
|
||||
existing_names += f"{result['nickname']}、"
|
||||
if existing_names_str:
|
||||
existing_names_str += "、"
|
||||
existing_names_str += generated_nickname
|
||||
logger.debug(f"生成的昵称 {generated_nickname} 已存在,重试中...")
|
||||
current_try += 1
|
||||
|
||||
logger.debug(f"生成的昵称 {result['nickname']} 已存在,重试中...")
|
||||
current_try += 1
|
||||
|
||||
logger.error(f"在{max_retries}次尝试后仍未能生成唯一昵称")
|
||||
logger.error(f"在{max_retries}次尝试后仍未能生成唯一昵称 for {person_id}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@@ -264,30 +319,56 @@ class PersonInfoManager:
|
||||
logger.debug("删除失败:person_id 不能为空")
|
||||
return
|
||||
|
||||
result = db.person_info.delete_one({"person_id": person_id})
|
||||
if result.deleted_count > 0:
|
||||
logger.debug(f"删除成功:person_id={person_id}")
|
||||
def _db_delete_sync(p_id: str):
|
||||
try:
|
||||
query = PersonInfo.delete().where(PersonInfo.person_id == p_id)
|
||||
deleted_count = query.execute()
|
||||
return deleted_count
|
||||
except Exception as e:
|
||||
logger.error(f"删除 PersonInfo {p_id} 失败 (Peewee): {e}")
|
||||
return 0
|
||||
|
||||
deleted_count = await asyncio.to_thread(_db_delete_sync, person_id)
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"删除成功:person_id={person_id} (Peewee)")
|
||||
else:
|
||||
logger.debug(f"删除失败:未找到 person_id={person_id}")
|
||||
logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)")
|
||||
|
||||
@staticmethod
|
||||
async def get_value(person_id: str, field_name: str):
|
||||
"""获取指定person_id文档的字段值,若不存在该字段,则返回该字段的全局默认值"""
|
||||
if not person_id:
|
||||
logger.debug("get_value获取失败:person_id不能为空")
|
||||
return person_info_default.get(field_name)
|
||||
|
||||
if field_name not in PersonInfo._meta.fields:
|
||||
if field_name in person_info_default:
|
||||
logger.trace(f"字段'{field_name}'不在Peewee模型中,但存在于默认配置中。返回配置默认值。")
|
||||
return copy.deepcopy(person_info_default[field_name])
|
||||
logger.debug(f"get_value获取失败:字段'{field_name}'未在Peewee模型和默认配置中定义。")
|
||||
return None
|
||||
|
||||
if field_name not in person_info_default:
|
||||
logger.debug(f"get_value获取失败:字段'{field_name}'未定义")
|
||||
def _db_get_value_sync(p_id: str, f_name: str):
|
||||
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||
if record:
|
||||
val = getattr(record, f_name)
|
||||
if f_name == "msg_interval_list" and isinstance(val, str):
|
||||
try:
|
||||
return json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"无法解析 {p_id} 的 msg_interval_list JSON: {val}")
|
||||
return copy.deepcopy(person_info_default.get(f_name, []))
|
||||
return val
|
||||
return None
|
||||
|
||||
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1})
|
||||
value = await asyncio.to_thread(_db_get_value_sync, person_id, field_name)
|
||||
|
||||
if document and field_name in document:
|
||||
return document[field_name]
|
||||
if value is not None:
|
||||
return value
|
||||
else:
|
||||
default_value = copy.deepcopy(person_info_default[field_name])
|
||||
logger.trace(f"获取{person_id}的{field_name}失败,已返回默认值{default_value}")
|
||||
default_value = copy.deepcopy(person_info_default.get(field_name))
|
||||
logger.trace(f"获取{person_id}的{field_name}失败或值为None,已返回默认值{default_value} (Peewee)")
|
||||
return default_value
|
||||
|
||||
@staticmethod
|
||||
@@ -297,93 +378,82 @@ class PersonInfoManager:
|
||||
logger.debug("get_values获取失败:person_id不能为空")
|
||||
return {}
|
||||
|
||||
# 检查所有字段是否有效
|
||||
for field in field_names:
|
||||
if field not in person_info_default:
|
||||
logger.debug(f"get_values获取失败:字段'{field}'未定义")
|
||||
return {}
|
||||
|
||||
# 构建查询投影(所有字段都有效才会执行到这里)
|
||||
projection = {field: 1 for field in field_names}
|
||||
|
||||
document = db.person_info.find_one({"person_id": person_id}, projection)
|
||||
|
||||
result = {}
|
||||
for field in field_names:
|
||||
result[field] = copy.deepcopy(
|
||||
document.get(field, person_info_default[field]) if document else person_info_default[field]
|
||||
)
|
||||
|
||||
def _db_get_record_sync(p_id: str):
|
||||
return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||
|
||||
record = await asyncio.to_thread(_db_get_record_sync, person_id)
|
||||
|
||||
for field_name in field_names:
|
||||
if field_name not in PersonInfo._meta.fields:
|
||||
if field_name in person_info_default:
|
||||
result[field_name] = copy.deepcopy(person_info_default[field_name])
|
||||
logger.trace(f"字段'{field_name}'不在Peewee模型中,使用默认配置值。")
|
||||
else:
|
||||
logger.debug(f"get_values查询失败:字段'{field_name}'未在Peewee模型和默认配置中定义。")
|
||||
result[field_name] = None
|
||||
continue
|
||||
|
||||
if record:
|
||||
value = getattr(record, field_name)
|
||||
if field_name == "msg_interval_list" and isinstance(value, str):
|
||||
try:
|
||||
result[field_name] = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"无法解析 {person_id} 的 msg_interval_list JSON: {value}")
|
||||
result[field_name] = copy.deepcopy(person_info_default.get(field_name, []))
|
||||
elif value is not None:
|
||||
result[field_name] = value
|
||||
else:
|
||||
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
|
||||
else:
|
||||
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def del_all_undefined_field():
|
||||
"""删除所有项里的未定义字段"""
|
||||
# 获取所有已定义的字段名
|
||||
defined_fields = set(person_info_default.keys())
|
||||
|
||||
try:
|
||||
# 遍历集合中的所有文档
|
||||
for document in db.person_info.find({}):
|
||||
# 找出文档中未定义的字段
|
||||
undefined_fields = set(document.keys()) - defined_fields - {"_id"}
|
||||
|
||||
if undefined_fields:
|
||||
# 构建更新操作,使用$unset删除未定义字段
|
||||
update_result = db.person_info.update_one(
|
||||
{"_id": document["_id"]}, {"$unset": {field: 1 for field in undefined_fields}}
|
||||
)
|
||||
|
||||
if update_result.modified_count > 0:
|
||||
logger.debug(f"已清理文档 {document['_id']} 的未定义字段: {undefined_fields}")
|
||||
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理未定义字段时出错: {e}")
|
||||
return
|
||||
"""删除所有项里的未定义字段 - 对于Peewee (SQL),此操作通常不适用,因为模式是固定的。"""
|
||||
logger.info("del_all_undefined_field: 对于使用Peewee的SQL数据库,此操作通常不适用或不需要,因为表结构是预定义的。")
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
async def get_specific_value_list(
|
||||
field_name: str,
|
||||
way: Callable[[Any], bool], # 接受任意类型值
|
||||
way: Callable[[Any], bool],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取满足条件的字段值字典
|
||||
|
||||
Args:
|
||||
field_name: 目标字段名
|
||||
way: 判断函数 (value: Any) -> bool
|
||||
|
||||
Returns:
|
||||
{person_id: value} | {}
|
||||
|
||||
Example:
|
||||
# 查找所有nickname包含"admin"的用户
|
||||
result = manager.specific_value_list(
|
||||
"nickname",
|
||||
lambda x: "admin" in x.lower()
|
||||
)
|
||||
"""
|
||||
if field_name not in person_info_default:
|
||||
logger.error(f"字段检查失败:'{field_name}'未定义")
|
||||
if field_name not in PersonInfo._meta.fields:
|
||||
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo Peewee 模型中定义")
|
||||
return {}
|
||||
|
||||
def _db_get_specific_sync(f_name: str):
|
||||
found_results = {}
|
||||
try:
|
||||
for record in PersonInfo.select(PersonInfo.person_id, getattr(PersonInfo, f_name)):
|
||||
value = getattr(record, f_name)
|
||||
if f_name == "msg_interval_list" and isinstance(value, str):
|
||||
try:
|
||||
processed_value = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"跳过记录 {record.person_id},无法解析 msg_interval_list: {value}")
|
||||
continue
|
||||
else:
|
||||
processed_value = value
|
||||
|
||||
if way(processed_value):
|
||||
found_results[record.person_id] = processed_value
|
||||
except Exception as e_query:
|
||||
logger.error(f"数据库查询失败 (Peewee specific_value_list for {f_name}): {str(e_query)}", exc_info=True)
|
||||
return found_results
|
||||
|
||||
try:
|
||||
result = {}
|
||||
for doc in db.person_info.find({field_name: {"$exists": True}}, {"person_id": 1, field_name: 1, "_id": 0}):
|
||||
try:
|
||||
value = doc[field_name]
|
||||
if way(value):
|
||||
result[doc["person_id"]] = value
|
||||
except (KeyError, TypeError, ValueError) as e:
|
||||
logger.debug(f"记录{doc.get('person_id')}处理失败: {str(e)}")
|
||||
continue
|
||||
|
||||
return result
|
||||
|
||||
return await asyncio.to_thread(_db_get_specific_sync, field_name)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库查询失败: {str(e)}", exc_info=True)
|
||||
logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True)
|
||||
return {}
|
||||
|
||||
async def personal_habit_deduction(self):
|
||||
@@ -391,35 +461,31 @@ class PersonInfoManager:
|
||||
try:
|
||||
while 1:
|
||||
await asyncio.sleep(600)
|
||||
current_time = datetime.datetime.now()
|
||||
logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
current_time_dt = datetime.datetime.now()
|
||||
logger.info(f"个人信息推断启动: {current_time_dt.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
# "msg_interval"推断
|
||||
msg_interval_map = False
|
||||
msg_interval_lists = await self.get_specific_value_list(
|
||||
msg_interval_map_generated = False
|
||||
msg_interval_lists_map = await self.get_specific_value_list(
|
||||
"msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100
|
||||
)
|
||||
for person_id, msg_interval_list_ in msg_interval_lists.items():
|
||||
|
||||
for person_id, actual_msg_interval_list in msg_interval_lists_map.items():
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
time_interval = []
|
||||
for t1, t2 in zip(msg_interval_list_, msg_interval_list_[1:]):
|
||||
for t1, t2 in zip(actual_msg_interval_list, actual_msg_interval_list[1:]):
|
||||
delta = t2 - t1
|
||||
if delta > 0:
|
||||
time_interval.append(delta)
|
||||
|
||||
time_interval = [t for t in time_interval if 200 <= t <= 8000]
|
||||
# --- 修改后的逻辑 ---
|
||||
# 数据量检查 (至少需要 30 条有效间隔,并且足够进行头尾截断)
|
||||
if len(time_interval) >= 30 + 10: # 至少30条有效+头尾各5条
|
||||
time_interval.sort()
|
||||
|
||||
# 画图(log) - 这部分保留
|
||||
msg_interval_map = True
|
||||
if len(time_interval) >= 30 + 10:
|
||||
time_interval.sort()
|
||||
msg_interval_map_generated = True
|
||||
log_dir = Path("logs/person_info")
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
plt.figure(figsize=(10, 6))
|
||||
# 使用截断前的数据画图,更能反映原始分布
|
||||
time_series_original = pd.Series(time_interval)
|
||||
plt.hist(
|
||||
time_series_original,
|
||||
@@ -441,34 +507,27 @@ class PersonInfoManager:
|
||||
img_path = log_dir / f"interval_distribution_{person_id[:8]}.png"
|
||||
plt.savefig(img_path)
|
||||
plt.close()
|
||||
# 画图结束
|
||||
|
||||
# 去掉头尾各 5 个数据点
|
||||
trimmed_interval = time_interval[5:-5]
|
||||
|
||||
# 计算截断后数据的 37% 分位数
|
||||
if trimmed_interval: # 确保截断后列表不为空
|
||||
msg_interval = int(round(np.percentile(trimmed_interval, 37)))
|
||||
# 更新数据库
|
||||
await self.update_one_field(person_id, "msg_interval", msg_interval)
|
||||
logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval}")
|
||||
if trimmed_interval:
|
||||
msg_interval_val = int(round(np.percentile(trimmed_interval, 37)))
|
||||
await self.update_one_field(person_id, "msg_interval", msg_interval_val)
|
||||
logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}")
|
||||
else:
|
||||
logger.trace(f"用户{person_id}截断后数据为空,无法计算msg_interval")
|
||||
else:
|
||||
logger.trace(
|
||||
f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)"
|
||||
)
|
||||
# --- 修改结束 ---
|
||||
except Exception as e:
|
||||
logger.trace(f"用户{person_id}消息间隔计算失败: {type(e).__name__}: {str(e)}")
|
||||
except Exception as e_inner:
|
||||
logger.trace(f"用户{person_id}消息间隔计算失败: {type(e_inner).__name__}: {str(e_inner)}")
|
||||
continue
|
||||
|
||||
# 其他...
|
||||
|
||||
if msg_interval_map:
|
||||
if msg_interval_map_generated:
|
||||
logger.trace("已保存分布图到: logs/person_info")
|
||||
current_time = datetime.datetime.now()
|
||||
logger.trace(f"个人信息推断结束: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
current_time_dt_end = datetime.datetime.now()
|
||||
logger.trace(f"个人信息推断结束: {current_time_dt_end.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
await asyncio.sleep(86400)
|
||||
|
||||
except Exception as e:
|
||||
@@ -481,41 +540,27 @@ class PersonInfoManager:
|
||||
"""
|
||||
根据 platform 和 user_id 获取 person_id。
|
||||
如果对应的用户不存在,则使用提供的可选信息创建新用户。
|
||||
|
||||
Args:
|
||||
platform: 平台标识
|
||||
user_id: 用户在该平台上的ID
|
||||
nickname: 用户的昵称 (可选,用于创建新用户)
|
||||
user_cardname: 用户的群名片 (可选,用于创建新用户)
|
||||
user_avatar: 用户的头像信息 (可选,用于创建新用户)
|
||||
|
||||
Returns:
|
||||
对应的 person_id。
|
||||
"""
|
||||
person_id = self.get_person_id(platform, user_id)
|
||||
|
||||
# 检查用户是否已存在
|
||||
# 使用静态方法 get_person_id,因此可以直接调用 db
|
||||
document = db.person_info.find_one({"person_id": person_id})
|
||||
def _db_check_exists_sync(p_id: str):
|
||||
return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||
|
||||
if document is None:
|
||||
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。")
|
||||
record = await asyncio.to_thread(_db_check_exists_sync, person_id)
|
||||
|
||||
if record is None:
|
||||
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
|
||||
initial_data = {
|
||||
"platform": platform,
|
||||
"user_id": user_id,
|
||||
"user_id": str(user_id),
|
||||
"nickname": nickname,
|
||||
"konw_time": int(datetime.datetime.now().timestamp()), # 添加初次认识时间
|
||||
# 注意:这里没有添加 user_cardname 和 user_avatar,因为它们不在 person_info_default 中
|
||||
# 如果需要存储它们,需要先在 person_info_default 中定义
|
||||
"konw_time": int(datetime.datetime.now().timestamp()),
|
||||
}
|
||||
# 过滤掉值为 None 的初始数据
|
||||
initial_data = {k: v for k, v in initial_data.items() if v is not None}
|
||||
model_fields = PersonInfo._meta.fields.keys()
|
||||
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
|
||||
|
||||
# 注意:create_person_info 是静态方法
|
||||
await PersonInfoManager.create_person_info(person_id, data=initial_data)
|
||||
# 创建后,可以考虑立即为其取名,但这可能会增加延迟
|
||||
# await self.qv_person_name(person_id, nickname, user_cardname, user_avatar)
|
||||
logger.debug(f"已为 {person_id} 创建新记录,初始数据: {initial_data}")
|
||||
await self.create_person_info(person_id, data=filtered_initial_data)
|
||||
logger.debug(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
|
||||
|
||||
return person_id
|
||||
|
||||
@@ -525,35 +570,49 @@ class PersonInfoManager:
|
||||
logger.debug("get_person_info_by_name 获取失败:person_name 不能为空")
|
||||
return None
|
||||
|
||||
# 优先从内存缓存查找 person_id
|
||||
found_person_id = None
|
||||
for pid, name in self.person_name_list.items():
|
||||
if name == person_name:
|
||||
for pid, name_in_cache in self.person_name_list.items():
|
||||
if name_in_cache == person_name:
|
||||
found_person_id = pid
|
||||
break # 找到第一个匹配就停止
|
||||
break
|
||||
|
||||
if not found_person_id:
|
||||
# 如果内存没有,尝试数据库查询(可能内存未及时更新或启动时未加载)
|
||||
document = db.person_info.find_one({"person_name": person_name})
|
||||
if document:
|
||||
found_person_id = document.get("person_id")
|
||||
else:
|
||||
logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户")
|
||||
return None # 数据库也找不到
|
||||
def _db_find_by_name_sync(p_name_to_find: str):
|
||||
return PersonInfo.get_or_none(PersonInfo.person_name == p_name_to_find)
|
||||
|
||||
# 根据找到的 person_id 获取所需信息
|
||||
if found_person_id:
|
||||
required_fields = ["person_id", "platform", "user_id", "nickname", "user_cardname", "user_avatar"]
|
||||
person_data = await self.get_values(found_person_id, required_fields)
|
||||
if person_data: # 确保 get_values 成功返回
|
||||
return person_data
|
||||
record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
|
||||
if record:
|
||||
found_person_id = record.person_id
|
||||
if found_person_id not in self.person_name_list or self.person_name_list[found_person_id] != person_name:
|
||||
self.person_name_list[found_person_id] = person_name
|
||||
else:
|
||||
logger.warning(f"找到了 person_id '{found_person_id}' 但获取详细信息失败")
|
||||
logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)")
|
||||
return None
|
||||
else:
|
||||
# 这理论上不应该发生,因为上面已经处理了找不到的情况
|
||||
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id")
|
||||
return None
|
||||
|
||||
if found_person_id:
|
||||
required_fields = [
|
||||
"person_id",
|
||||
"platform",
|
||||
"user_id",
|
||||
"nickname",
|
||||
"user_cardname",
|
||||
"user_avatar",
|
||||
"person_name",
|
||||
"name_reason",
|
||||
]
|
||||
valid_fields_to_get = [f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default]
|
||||
|
||||
person_data = await self.get_values(found_person_id, valid_fields_to_get)
|
||||
|
||||
if person_data:
|
||||
final_result = {key: person_data.get(key) for key in required_fields}
|
||||
return final_result
|
||||
else:
|
||||
logger.warning(f"找到了 person_id '{found_person_id}' 但 get_values 返回空 (Peewee)")
|
||||
return None
|
||||
|
||||
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id (Peewee)")
|
||||
return None
|
||||
|
||||
|
||||
person_info_manager = PersonInfoManager()
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.message import MessageRecv, MessageSending, Message
|
||||
from common.database.database import db
|
||||
from src.common.database.database_model import Messages, ThinkingLog
|
||||
import time
|
||||
import traceback
|
||||
from typing import List
|
||||
import json
|
||||
|
||||
|
||||
class InfoCatcher:
|
||||
@@ -60,8 +61,6 @@ class InfoCatcher:
|
||||
def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息
|
||||
self.timing_results["sub_heartflow_observe_time"] = obs_duration
|
||||
|
||||
# def catch_shf
|
||||
|
||||
def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str):
|
||||
self.timing_results["sub_heartflow_step_time"] = step_duration
|
||||
if len(past_mind) > 1:
|
||||
@@ -72,25 +71,10 @@ class InfoCatcher:
|
||||
self.heartflow_data["sub_heartflow_now"] = current_mind
|
||||
|
||||
def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""):
|
||||
# if self.response_mode == "heart_flow": # 条件判断不需要了喵~
|
||||
# self.heartflow_data["prompt"] = prompt
|
||||
# self.heartflow_data["response"] = response
|
||||
# self.heartflow_data["model"] = model_name
|
||||
# elif self.response_mode == "reasoning": # 条件判断不需要了喵~
|
||||
# self.reasoning_data["thinking_log"] = reasoning_content
|
||||
# self.reasoning_data["prompt"] = prompt
|
||||
# self.reasoning_data["response"] = response
|
||||
# self.reasoning_data["model"] = model_name
|
||||
|
||||
# 直接记录信息喵~
|
||||
self.reasoning_data["thinking_log"] = reasoning_content
|
||||
self.reasoning_data["prompt"] = prompt
|
||||
self.reasoning_data["response"] = response
|
||||
self.reasoning_data["model"] = model_name
|
||||
# 如果 heartflow 数据也需要通用字段,可以取消下面的注释喵~
|
||||
# self.heartflow_data["prompt"] = prompt
|
||||
# self.heartflow_data["response"] = response
|
||||
# self.heartflow_data["model"] = model_name
|
||||
|
||||
self.response_text = response
|
||||
|
||||
@@ -102,6 +86,7 @@ class InfoCatcher:
|
||||
):
|
||||
self.timing_results["make_response_time"] = response_duration
|
||||
self.response_time = time.time()
|
||||
self.response_messages = []
|
||||
for msg in response_message:
|
||||
self.response_messages.append(msg)
|
||||
|
||||
@@ -112,107 +97,110 @@ class InfoCatcher:
|
||||
@staticmethod
|
||||
def get_message_from_db_between_msgs(message_start: Message, message_end: Message):
|
||||
try:
|
||||
# 从数据库中获取消息的时间戳
|
||||
time_start = message_start.message_info.time
|
||||
time_end = message_end.message_info.time
|
||||
chat_id = message_start.chat_stream.stream_id
|
||||
|
||||
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
|
||||
|
||||
# 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据
|
||||
messages_between = db.messages.find(
|
||||
{"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}}
|
||||
).sort("time", -1)
|
||||
messages_between_query = Messages.select().where(
|
||||
(Messages.chat_id == chat_id) &
|
||||
(Messages.time > time_start) &
|
||||
(Messages.time < time_end)
|
||||
).order_by(Messages.time.desc())
|
||||
|
||||
result = list(messages_between)
|
||||
result = list(messages_between_query)
|
||||
print(f"查询结果数量: {len(result)}")
|
||||
if result:
|
||||
print(f"第一条消息时间: {result[0]['time']}")
|
||||
print(f"最后一条消息时间: {result[-1]['time']}")
|
||||
print(f"第一条消息时间: {result[0].time}")
|
||||
print(f"最后一条消息时间: {result[-1].time}")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"获取消息时出错: {str(e)}")
|
||||
print(traceback.format_exc())
|
||||
return []
|
||||
|
||||
def get_message_from_db_before_msg(self, message: MessageRecv):
|
||||
# 从数据库中获取消息
|
||||
message_id = message.message_info.message_id
|
||||
chat_id = message.chat_stream.stream_id
|
||||
message_id_val = message.message_info.message_id
|
||||
chat_id_val = message.chat_stream.stream_id
|
||||
|
||||
# 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据
|
||||
messages_before = (
|
||||
db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}})
|
||||
.sort("time", -1)
|
||||
.limit(self.context_length * 3)
|
||||
) # 获取更多历史信息
|
||||
messages_before_query = Messages.select().where(
|
||||
(Messages.chat_id == chat_id_val) &
|
||||
(Messages.message_id < message_id_val)
|
||||
).order_by(Messages.time.desc()).limit(self.context_length * 3)
|
||||
|
||||
return list(messages_before)
|
||||
return list(messages_before_query)
|
||||
|
||||
def message_list_to_dict(self, message_list):
|
||||
# 存储简化的聊天记录
|
||||
result = []
|
||||
for message in message_list:
|
||||
if not isinstance(message, dict):
|
||||
message = self.message_to_dict(message)
|
||||
# print(message)
|
||||
for msg_item in message_list:
|
||||
processed_msg_item = msg_item
|
||||
if not isinstance(msg_item, dict):
|
||||
processed_msg_item = self.message_to_dict(msg_item)
|
||||
|
||||
if not processed_msg_item:
|
||||
continue
|
||||
|
||||
lite_message = {
|
||||
"time": message["time"],
|
||||
"user_nickname": message["user_info"]["user_nickname"],
|
||||
"processed_plain_text": message["processed_plain_text"],
|
||||
"time": processed_msg_item.get("time"),
|
||||
"user_nickname": processed_msg_item.get("user_nickname"),
|
||||
"processed_plain_text": processed_msg_item.get("processed_plain_text"),
|
||||
}
|
||||
result.append(lite_message)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def message_to_dict(message):
|
||||
if not message:
|
||||
def message_to_dict(msg_obj):
|
||||
if not msg_obj:
|
||||
return None
|
||||
if isinstance(message, dict):
|
||||
return message
|
||||
return {
|
||||
# "message_id": message.message_info.message_id,
|
||||
"time": message.message_info.time,
|
||||
"user_id": message.message_info.user_info.user_id,
|
||||
"user_nickname": message.message_info.user_info.user_nickname,
|
||||
"processed_plain_text": message.processed_plain_text,
|
||||
# "detailed_plain_text": message.detailed_plain_text
|
||||
}
|
||||
if isinstance(msg_obj, dict):
|
||||
return msg_obj
|
||||
|
||||
def done_catch(self):
|
||||
"""将收集到的信息存储到数据库的 thinking_log 集合中喵~"""
|
||||
try:
|
||||
# 将消息对象转换为可序列化的字典喵~
|
||||
|
||||
thinking_log_data = {
|
||||
"chat_id": self.chat_id,
|
||||
"trigger_text": self.trigger_response_text,
|
||||
"response_text": self.response_text,
|
||||
"trigger_info": {
|
||||
"time": self.trigger_response_time,
|
||||
"message": self.message_to_dict(self.trigger_response_message),
|
||||
},
|
||||
"response_info": {
|
||||
"time": self.response_time,
|
||||
"message": self.response_messages,
|
||||
},
|
||||
"timing_results": self.timing_results,
|
||||
"chat_history": self.message_list_to_dict(self.chat_history),
|
||||
"chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking),
|
||||
"chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response),
|
||||
"heartflow_data": self.heartflow_data,
|
||||
"reasoning_data": self.reasoning_data,
|
||||
if isinstance(msg_obj, Messages):
|
||||
return {
|
||||
"time": msg_obj.time,
|
||||
"user_id": msg_obj.user_id,
|
||||
"user_nickname": msg_obj.user_nickname,
|
||||
"processed_plain_text": msg_obj.processed_plain_text,
|
||||
}
|
||||
|
||||
# 根据不同的响应模式添加相应的数据喵~ # 现在直接都加上去好了喵~
|
||||
# if self.response_mode == "heart_flow":
|
||||
# thinking_log_data["mode_specific_data"] = self.heartflow_data
|
||||
# elif self.response_mode == "reasoning":
|
||||
# thinking_log_data["mode_specific_data"] = self.reasoning_data
|
||||
if hasattr(msg_obj, 'message_info') and hasattr(msg_obj.message_info, 'user_info'):
|
||||
return {
|
||||
"time": msg_obj.message_info.time,
|
||||
"user_id": msg_obj.message_info.user_info.user_id,
|
||||
"user_nickname": msg_obj.message_info.user_info.user_nickname,
|
||||
"processed_plain_text": msg_obj.processed_plain_text,
|
||||
}
|
||||
|
||||
# 将数据插入到 thinking_log 集合中喵~
|
||||
db.thinking_log.insert_one(thinking_log_data)
|
||||
print(f"Warning: message_to_dict received an unhandled type: {type(msg_obj)}")
|
||||
return {}
|
||||
|
||||
def done_catch(self):
|
||||
"""将收集到的信息存储到数据库的 thinking_log 表中喵~"""
|
||||
try:
|
||||
trigger_info_dict = self.message_to_dict(self.trigger_response_message)
|
||||
response_info_dict = {
|
||||
"time": self.response_time,
|
||||
"message": self.response_messages,
|
||||
}
|
||||
chat_history_list = self.message_list_to_dict(self.chat_history)
|
||||
chat_history_in_thinking_list = self.message_list_to_dict(self.chat_history_in_thinking)
|
||||
chat_history_after_response_list = self.message_list_to_dict(self.chat_history_after_response)
|
||||
|
||||
log_entry = ThinkingLog(
|
||||
chat_id=self.chat_id,
|
||||
trigger_text=self.trigger_response_text,
|
||||
response_text=self.response_text,
|
||||
trigger_info_json=json.dumps(trigger_info_dict) if trigger_info_dict else None,
|
||||
response_info_json=json.dumps(response_info_dict),
|
||||
timing_results_json=json.dumps(self.timing_results),
|
||||
chat_history_json=json.dumps(chat_history_list),
|
||||
chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list),
|
||||
chat_history_after_response_json=json.dumps(chat_history_after_response_list),
|
||||
heartflow_data_json=json.dumps(self.heartflow_data),
|
||||
reasoning_data_json=json.dumps(self.reasoning_data)
|
||||
)
|
||||
log_entry.save()
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
|
||||
@@ -5,7 +5,8 @@ from typing import Any, Dict, Tuple, List
|
||||
from src.common.logger import get_module_logger
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
|
||||
from ...common.database.database import db
|
||||
from ...common.database.database import db # This db is the Peewee database instance
|
||||
from ...common.database.database_model import OnlineTime # Import the Peewee model
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_module_logger("maibot_statistic")
|
||||
@@ -39,7 +40,7 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
def __init__(self):
|
||||
super().__init__(task_name="Online Time Record Task", run_interval=60)
|
||||
|
||||
self.record_id: str | None = None
|
||||
self.record_id: int | None = None # Changed to int for Peewee's default ID
|
||||
"""记录ID"""
|
||||
|
||||
self._init_database() # 初始化数据库
|
||||
@@ -47,53 +48,46 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
"""初始化数据库"""
|
||||
if "online_time" not in db.list_collection_names():
|
||||
# 初始化数据库(在线时长)
|
||||
db.create_collection("online_time")
|
||||
# 创建索引
|
||||
if ("end_timestamp", 1) not in db.online_time.list_indexes():
|
||||
db.online_time.create_index([("end_timestamp", 1)])
|
||||
with db.atomic(): # Use atomic operations for schema changes
|
||||
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
extended_end_time = current_time + timedelta(minutes=1)
|
||||
|
||||
if self.record_id:
|
||||
# 如果有记录,则更新结束时间
|
||||
db.online_time.update_one(
|
||||
{"_id": self.record_id},
|
||||
{
|
||||
"$set": {
|
||||
"end_timestamp": datetime.now() + timedelta(minutes=1),
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id)
|
||||
updated_rows = query.execute()
|
||||
if updated_rows == 0:
|
||||
# Record might have been deleted or ID is stale, try to find/create
|
||||
self.record_id = None # Reset record_id to trigger find/create logic below
|
||||
|
||||
if not self.record_id: # Check again if record_id was reset or initially None
|
||||
# 如果没有记录,检查一分钟以内是否已有记录
|
||||
current_time = datetime.now()
|
||||
if recent_record := db.online_time.find_one(
|
||||
{"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}}
|
||||
):
|
||||
# Look for a record whose end_timestamp is recent enough to be considered ongoing
|
||||
recent_record = OnlineTime.select().where(
|
||||
OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))
|
||||
).order_by(OnlineTime.end_timestamp.desc()).first()
|
||||
|
||||
if recent_record:
|
||||
# 如果有记录,则更新结束时间
|
||||
self.record_id = recent_record["_id"]
|
||||
db.online_time.update_one(
|
||||
{"_id": self.record_id},
|
||||
{
|
||||
"$set": {
|
||||
"end_timestamp": current_time + timedelta(minutes=1),
|
||||
}
|
||||
},
|
||||
)
|
||||
self.record_id = recent_record.id
|
||||
recent_record.end_timestamp = extended_end_time
|
||||
recent_record.save()
|
||||
else:
|
||||
# 若没有记录,则插入新的在线时间记录
|
||||
self.record_id = db.online_time.insert_one(
|
||||
{
|
||||
"start_timestamp": current_time,
|
||||
"end_timestamp": current_time + timedelta(minutes=1),
|
||||
}
|
||||
).inserted_id
|
||||
new_record = OnlineTime.create(
|
||||
start_timestamp=current_time,
|
||||
end_timestamp=extended_end_time,
|
||||
)
|
||||
self.record_id = new_record.id
|
||||
except Exception as e:
|
||||
logger.error(f"在线时间记录失败,错误信息:{e}")
|
||||
|
||||
|
||||
|
||||
def _format_online_time(online_seconds: int) -> str:
|
||||
"""
|
||||
格式化在线时间
|
||||
|
||||
@@ -9,6 +9,7 @@ import numpy as np
|
||||
|
||||
|
||||
from ...common.database.database import db
|
||||
from ...common.database.database_model import Images, ImageDescriptions
|
||||
from ...config.config import global_config
|
||||
from ..models.utils_model import LLMRequest
|
||||
|
||||
@@ -32,40 +33,21 @@ class ImageManager:
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
self._ensure_image_collection()
|
||||
self._ensure_description_collection()
|
||||
self._ensure_image_dir()
|
||||
self._initialized = True
|
||||
self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image")
|
||||
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
db.create_tables([Images, ImageDescriptions], safe=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或表创建失败: {e}")
|
||||
|
||||
self._initialized = True
|
||||
|
||||
def _ensure_image_dir(self):
|
||||
"""确保图像存储目录存在"""
|
||||
os.makedirs(self.IMAGE_DIR, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_image_collection():
|
||||
"""确保images集合存在并创建索引"""
|
||||
if "images" not in db.list_collection_names():
|
||||
db.create_collection("images")
|
||||
|
||||
# 删除旧索引
|
||||
db.images.drop_indexes()
|
||||
# 创建新的复合索引
|
||||
db.images.create_index([("hash", 1), ("type", 1)], unique=True)
|
||||
db.images.create_index([("url", 1)])
|
||||
db.images.create_index([("path", 1)])
|
||||
|
||||
@staticmethod
|
||||
def _ensure_description_collection():
|
||||
"""确保image_descriptions集合存在并创建索引"""
|
||||
if "image_descriptions" not in db.list_collection_names():
|
||||
db.create_collection("image_descriptions")
|
||||
|
||||
# 删除旧索引
|
||||
db.image_descriptions.drop_indexes()
|
||||
# 创建新的复合索引
|
||||
db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
|
||||
|
||||
@staticmethod
|
||||
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
|
||||
"""从数据库获取图片描述
|
||||
@@ -77,8 +59,15 @@ class ImageManager:
|
||||
Returns:
|
||||
Optional[str]: 描述文本,如果不存在则返回None
|
||||
"""
|
||||
result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
|
||||
return result["description"] if result else None
|
||||
try:
|
||||
record = ImageDescriptions.get_or_none(
|
||||
(ImageDescriptions.hash == image_hash) &
|
||||
(ImageDescriptions.type == description_type)
|
||||
)
|
||||
return record.description if record else None
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
|
||||
@@ -90,20 +79,22 @@ class ImageManager:
|
||||
description_type: 描述类型 ('emoji' 或 'image')
|
||||
"""
|
||||
try:
|
||||
db.image_descriptions.update_one(
|
||||
{"hash": image_hash, "type": description_type},
|
||||
{
|
||||
"$set": {
|
||||
"description": description,
|
||||
"timestamp": int(time.time()),
|
||||
"hash": image_hash, # 确保hash字段存在
|
||||
"type": description_type, # 确保type字段存在
|
||||
}
|
||||
},
|
||||
upsert=True,
|
||||
current_timestamp = time.time()
|
||||
defaults = {
|
||||
'description': description,
|
||||
'timestamp': current_timestamp
|
||||
}
|
||||
desc_obj, created = ImageDescriptions.get_or_create(
|
||||
hash=image_hash,
|
||||
type=description_type,
|
||||
defaults=defaults
|
||||
)
|
||||
if not created: # 如果记录已存在,则更新
|
||||
desc_obj.description = description
|
||||
desc_obj.timestamp = current_timestamp
|
||||
desc_obj.save()
|
||||
except Exception as e:
|
||||
logger.error(f"保存描述到数据库失败: {str(e)}")
|
||||
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
||||
|
||||
async def get_emoji_description(self, image_base64: str) -> str:
|
||||
"""获取表情包描述,带查重和保存功能"""
|
||||
@@ -116,18 +107,25 @@ class ImageManager:
|
||||
# 查询缓存的描述
|
||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||
if cached_description:
|
||||
# logger.debug(f"缓存表情包描述: {cached_description}")
|
||||
return f"[表情包,含义看起来是:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
if image_format == "gif" or image_format == "GIF":
|
||||
image_base64 = self.transform_gif(image_base64)
|
||||
image_base64_processed = self.transform_gif(image_base64)
|
||||
if image_base64_processed is None:
|
||||
logger.warning("GIF转换失败,无法获取描述")
|
||||
return "[表情包(GIF处理失败)]"
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,使用1-2个词描述一下表情包表达的情感和内容,简短一些"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg")
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg")
|
||||
else:
|
||||
prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
if description is None:
|
||||
logger.warning("AI未能生成表情包描述")
|
||||
return "[表情包(描述生成失败)]"
|
||||
|
||||
# 再次检查缓存,防止并发写入时重复生成
|
||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||
if cached_description:
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
|
||||
@@ -136,31 +134,37 @@ class ImageManager:
|
||||
# 根据配置决定是否保存图片
|
||||
if global_config.save_emoji:
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
|
||||
if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")):
|
||||
os.makedirs(os.path.join(self.IMAGE_DIR, "emoji"))
|
||||
file_path = os.path.join(self.IMAGE_DIR, "emoji", filename)
|
||||
current_timestamp = time.time()
|
||||
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
||||
emoji_dir = os.path.join(self.IMAGE_DIR, "emoji")
|
||||
os.makedirs(emoji_dir, exist_ok=True)
|
||||
file_path = os.path.join(emoji_dir, filename)
|
||||
|
||||
try:
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库
|
||||
image_doc = {
|
||||
"hash": image_hash,
|
||||
"path": file_path,
|
||||
"type": "emoji",
|
||||
"description": description,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
|
||||
logger.trace(f"保存表情包: {file_path}")
|
||||
# 保存到数据库 (Images表)
|
||||
try:
|
||||
img_obj = Images.get((Images.hash == image_hash) & (Images.type == "emoji"))
|
||||
img_obj.path = file_path
|
||||
img_obj.description = description
|
||||
img_obj.timestamp = current_timestamp
|
||||
img_obj.save()
|
||||
except Images.DoesNotExist:
|
||||
Images.create(
|
||||
hash=image_hash,
|
||||
path=file_path,
|
||||
type="emoji",
|
||||
description=description,
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
logger.trace(f"保存表情包元数据: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存表情包文件失败: {str(e)}")
|
||||
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
|
||||
|
||||
# 保存描述到数据库
|
||||
# 保存描述到数据库 (ImageDescriptions表)
|
||||
self._save_description_to_db(image_hash, description, "emoji")
|
||||
|
||||
return f"[表情包:{description}]"
|
||||
@@ -188,6 +192,11 @@ class ImageManager:
|
||||
)
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
if description is None:
|
||||
logger.warning("AI未能生成图片描述")
|
||||
return "[图片(描述生成失败)]"
|
||||
|
||||
# 再次检查缓存
|
||||
cached_description = self._get_description_from_db(image_hash, "image")
|
||||
if cached_description:
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}")
|
||||
@@ -195,38 +204,40 @@ class ImageManager:
|
||||
|
||||
logger.debug(f"描述是{description}")
|
||||
|
||||
if description is None:
|
||||
logger.warning("AI未能生成图片描述")
|
||||
return "[图片]"
|
||||
|
||||
# 根据配置决定是否保存图片
|
||||
if global_config.save_pic:
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
|
||||
if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")):
|
||||
os.makedirs(os.path.join(self.IMAGE_DIR, "image"))
|
||||
file_path = os.path.join(self.IMAGE_DIR, "image", filename)
|
||||
current_timestamp = time.time()
|
||||
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
||||
image_dir = os.path.join(self.IMAGE_DIR, "image")
|
||||
os.makedirs(image_dir, exist_ok=True)
|
||||
file_path = os.path.join(image_dir, filename)
|
||||
|
||||
try:
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库
|
||||
image_doc = {
|
||||
"hash": image_hash,
|
||||
"path": file_path,
|
||||
"type": "image",
|
||||
"description": description,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
|
||||
logger.trace(f"保存图片: {file_path}")
|
||||
# 保存到数据库 (Images表)
|
||||
try:
|
||||
img_obj = Images.get((Images.hash == image_hash) & (Images.type == "image"))
|
||||
img_obj.path = file_path
|
||||
img_obj.description = description
|
||||
img_obj.timestamp = current_timestamp
|
||||
img_obj.save()
|
||||
except Images.DoesNotExist:
|
||||
Images.create(
|
||||
hash=image_hash,
|
||||
path=file_path,
|
||||
type="image",
|
||||
description=description,
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
logger.trace(f"保存图片元数据: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片文件失败: {str(e)}")
|
||||
logger.error(f"保存图片文件或元数据失败: {str(e)}")
|
||||
|
||||
# 保存描述到数据库
|
||||
# 保存描述到数据库 (ImageDescriptions表)
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
return f"[图片:{description}]"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
from pymongo import MongoClient
|
||||
from peewee import SqliteDatabase
|
||||
from pymongo.database import Database
|
||||
from rich.traceback import install
|
||||
|
||||
@@ -57,4 +58,15 @@ class DBWrapper:
|
||||
|
||||
|
||||
# 全局数据库访问点
|
||||
db: Database = DBWrapper()
|
||||
memory_db: Database = DBWrapper()
|
||||
|
||||
# 定义数据库文件路径
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
_DB_DIR = os.path.join(ROOT_PATH, "data")
|
||||
_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db")
|
||||
|
||||
# 确保数据库目录存在
|
||||
os.makedirs(_DB_DIR, exist_ok=True)
|
||||
|
||||
# 全局 Peewee SQLite 数据库访问点
|
||||
db = SqliteDatabase(_DB_FILE)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from peewee import Model, DoubleField, IntegerField, SqliteDatabase, BooleanField, TextField, FloatField
|
||||
|
||||
from peewee import Model, DoubleField, IntegerField, BooleanField, TextField, FloatField, DateTimeField
|
||||
from .database import db
|
||||
import datetime
|
||||
# 请在此处定义您的数据库实例。
|
||||
# 您需要取消注释并配置适合您的数据库的部分。
|
||||
# 例如,对于 SQLite:
|
||||
db = SqliteDatabase('my_application.db')
|
||||
# db = SqliteDatabase('MaiBot.db')
|
||||
#
|
||||
# 对于 PostgreSQL:
|
||||
# db = PostgresqlDatabase('your_db_name', user='your_user', password='your_password',
|
||||
@@ -69,17 +70,16 @@ class LLMUsage(BaseModel):
|
||||
"""
|
||||
用于存储 API 使用日志数据的模型。
|
||||
"""
|
||||
model_name = TextField()
|
||||
user_id = TextField()
|
||||
request_type = TextField()
|
||||
model_name = TextField(index=True) # 添加索引
|
||||
user_id = TextField(index=True) # 添加索引
|
||||
request_type = TextField(index=True) # 添加索引
|
||||
endpoint = TextField()
|
||||
prompt_tokens = IntegerField()
|
||||
completion_tokens = IntegerField()
|
||||
total_tokens = IntegerField()
|
||||
cost = DoubleField()
|
||||
status = TextField()
|
||||
# timestamp: "$date": "2025-05-01T18:52:50.870Z" (存储为字符串)
|
||||
timestamp = TextField()
|
||||
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
|
||||
|
||||
class Meta:
|
||||
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
|
||||
@@ -177,6 +177,8 @@ class OnlineTime(BaseModel):
|
||||
# timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串)
|
||||
timestamp = TextField()
|
||||
duration = IntegerField() # 时长,单位分钟
|
||||
start_timestamp = DateTimeField(default=datetime.datetime.now)
|
||||
end_timestamp = DateTimeField(index=True)
|
||||
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
@@ -202,3 +204,39 @@ class PersonInfo(BaseModel):
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = 'person_info'
|
||||
|
||||
class Knowledges(BaseModel):
|
||||
"""
|
||||
用于存储知识库条目的模型。
|
||||
"""
|
||||
content = TextField() # 知识内容的文本
|
||||
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
|
||||
# 可以添加其他元数据字段,如 source, create_time 等
|
||||
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = 'knowledges'
|
||||
|
||||
|
||||
class ThinkingLog(BaseModel):
|
||||
chat_id = TextField(index=True)
|
||||
trigger_text = TextField(null=True)
|
||||
response_text = TextField(null=True)
|
||||
|
||||
# Store complex dicts/lists as JSON strings
|
||||
trigger_info_json = TextField(null=True)
|
||||
response_info_json = TextField(null=True)
|
||||
timing_results_json = TextField(null=True)
|
||||
chat_history_json = TextField(null=True)
|
||||
chat_history_in_thinking_json = TextField(null=True)
|
||||
chat_history_after_response_json = TextField(null=True)
|
||||
heartflow_data_json = TextField(null=True)
|
||||
reasoning_data_json = TextField(null=True)
|
||||
|
||||
# Add a timestamp for the log entry itself
|
||||
# Ensure you have: from peewee import DateTimeField
|
||||
# And: import datetime
|
||||
created_at = DateTimeField(default=datetime.datetime.now)
|
||||
|
||||
class Meta:
|
||||
table_name = 'thinking_logs'
|
||||
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
from common.database.database import db
|
||||
from src.common.database.database_model import Messages # 更改导入
|
||||
from src.common.logger import get_module_logger
|
||||
import traceback
|
||||
from typing import List, Any, Optional
|
||||
from peewee import Model # 添加 Peewee Model 导入
|
||||
|
||||
logger = get_module_logger(__name__)
|
||||
|
||||
|
||||
def _model_to_dict(model_instance: Model) -> dict[str, Any]:
|
||||
"""
|
||||
将 Peewee 模型实例转换为字典。
|
||||
"""
|
||||
return model_instance.__data__
|
||||
|
||||
|
||||
def find_messages(
|
||||
message_filter: dict[str, Any],
|
||||
sort: Optional[List[tuple[str, int]]] = None,
|
||||
@@ -16,39 +24,72 @@ def find_messages(
|
||||
根据提供的过滤器、排序和限制条件查找消息。
|
||||
|
||||
Args:
|
||||
message_filter: MongoDB 查询过滤器。
|
||||
sort: MongoDB 排序条件列表,例如 [('time', 1)]。仅在 limit 为 0 时生效。
|
||||
message_filter: 查询过滤器字典,键为模型字段名,值为期望值。
|
||||
sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。
|
||||
limit: 返回的最大文档数,0表示不限制。
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'。
|
||||
|
||||
Returns:
|
||||
消息文档列表,如果出错则返回空列表。
|
||||
消息字典列表,如果出错则返回空列表。
|
||||
"""
|
||||
try:
|
||||
query = db.messages.find(message_filter)
|
||||
query = Messages.select()
|
||||
|
||||
# 应用过滤器
|
||||
if message_filter:
|
||||
conditions = []
|
||||
for key, value in message_filter.items():
|
||||
if hasattr(Messages, key):
|
||||
conditions.append(getattr(Messages, key) == value)
|
||||
else:
|
||||
logger.warning(
|
||||
f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。"
|
||||
)
|
||||
if conditions:
|
||||
# 使用 *conditions 将所有条件以 AND 连接
|
||||
query = query.where(*conditions)
|
||||
|
||||
if limit > 0:
|
||||
if limit_mode == "earliest":
|
||||
# 获取时间最早的 limit 条记录,已经是正序
|
||||
query = query.sort([("time", 1)]).limit(limit)
|
||||
results = list(query)
|
||||
query = query.order_by(Messages.time.asc()).limit(limit)
|
||||
peewee_results = list(query)
|
||||
else: # 默认为 'latest'
|
||||
# 获取时间最晚的 limit 条记录
|
||||
query = query.sort([("time", -1)]).limit(limit)
|
||||
latest_results = list(query)
|
||||
query = query.order_by(Messages.time.desc()).limit(limit)
|
||||
latest_results_peewee = list(query)
|
||||
# 将结果按时间正序排列
|
||||
# 假设消息文档中总是有 'time' 字段且可排序
|
||||
results = sorted(latest_results, key=lambda msg: msg.get("time"))
|
||||
peewee_results = sorted(
|
||||
latest_results_peewee, key=lambda msg: msg.time
|
||||
)
|
||||
else:
|
||||
# limit 为 0 时,应用传入的 sort 参数
|
||||
if sort:
|
||||
query = query.sort(sort)
|
||||
results = list(query)
|
||||
peewee_sort_terms = []
|
||||
for field_name, direction in sort:
|
||||
if hasattr(Messages, field_name):
|
||||
field = getattr(Messages, field_name)
|
||||
if direction == 1: # ASC
|
||||
peewee_sort_terms.append(field.asc())
|
||||
elif direction == -1: # DESC
|
||||
peewee_sort_terms.append(field.desc())
|
||||
else:
|
||||
logger.warning(
|
||||
f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。"
|
||||
)
|
||||
if peewee_sort_terms:
|
||||
query = query.order_by(*peewee_sort_terms)
|
||||
peewee_results = list(query)
|
||||
|
||||
results = [_model_to_dict(msg) for msg in peewee_results]
|
||||
return results
|
||||
except Exception as e:
|
||||
log_message = (
|
||||
f"查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||||
f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||||
+ traceback.format_exc()
|
||||
)
|
||||
logger.error(log_message)
|
||||
@@ -60,18 +101,35 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
根据提供的过滤器计算消息数量。
|
||||
|
||||
Args:
|
||||
message_filter: MongoDB 查询过滤器。
|
||||
message_filter: 查询过滤器字典,键为模型字段名,值为期望值。
|
||||
|
||||
Returns:
|
||||
符合条件的消息数量,如果出错则返回 0。
|
||||
"""
|
||||
try:
|
||||
count = db.messages.count_documents(message_filter)
|
||||
query = Messages.select()
|
||||
|
||||
# 应用过滤器
|
||||
if message_filter:
|
||||
conditions = []
|
||||
for key, value in message_filter.items():
|
||||
if hasattr(Messages, key):
|
||||
conditions.append(getattr(Messages, key) == value)
|
||||
else:
|
||||
logger.warning(
|
||||
f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。"
|
||||
)
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
|
||||
count = query.count()
|
||||
return count
|
||||
except Exception as e:
|
||||
log_message = f"计数消息失败 (message_filter={message_filter}): {e}\n" + traceback.format_exc()
|
||||
log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n" + traceback.format_exc()
|
||||
logger.error(log_message)
|
||||
return 0
|
||||
|
||||
|
||||
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
|
||||
# 注意:对于 Peewee,插入操作通常是 Messages.create(...) 或 instance.save()。
|
||||
# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。
|
||||
|
||||
@@ -10,7 +10,7 @@ from src.experimental.PFC.chat_states import (
|
||||
create_new_message_notification,
|
||||
create_cold_chat_notification,
|
||||
)
|
||||
from src.experimental.PFC.message_storage import MongoDBMessageStorage
|
||||
from src.experimental.PFC.message_storage import PeeweeMessageStorage
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -53,7 +53,7 @@ class ChatObserver:
|
||||
|
||||
self.stream_id = stream_id
|
||||
self.private_name = private_name
|
||||
self.message_storage = MongoDBMessageStorage()
|
||||
self.message_storage = PeeweeMessageStorage()
|
||||
|
||||
# self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
|
||||
# self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any
|
||||
from common.database.database import db
|
||||
# from src.common.database.database import db # Peewee db 导入
|
||||
from src.common.database.database_model import Messages # Peewee Messages 模型导入
|
||||
from playhouse.shortcuts import model_to_dict # 用于将模型实例转换为字典
|
||||
|
||||
|
||||
class MessageStorage(ABC):
|
||||
@@ -47,28 +49,35 @@ class MessageStorage(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class MongoDBMessageStorage(MessageStorage):
|
||||
"""MongoDB消息存储实现"""
|
||||
class PeeweeMessageStorage(MessageStorage):
|
||||
"""Peewee消息存储实现"""
|
||||
|
||||
async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
|
||||
query = {"chat_id": chat_id, "time": {"$gt": message_time}}
|
||||
# print(f"storage_check_message: {message_time}")
|
||||
query = Messages.select().where(
|
||||
(Messages.chat_id == chat_id) &
|
||||
(Messages.time > message_time)
|
||||
).order_by(Messages.time.asc())
|
||||
|
||||
return list(db.messages.find(query).sort("time", 1))
|
||||
# print(f"storage_check_message: {message_time}")
|
||||
messages_models = list(query)
|
||||
return [model_to_dict(msg) for msg in messages_models]
|
||||
|
||||
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
query = {"chat_id": chat_id, "time": {"$lt": time_point}}
|
||||
|
||||
messages = list(db.messages.find(query).sort("time", -1).limit(limit))
|
||||
query = Messages.select().where(
|
||||
(Messages.chat_id == chat_id) &
|
||||
(Messages.time < time_point)
|
||||
).order_by(Messages.time.desc()).limit(limit)
|
||||
|
||||
messages_models = list(query)
|
||||
# 将消息按时间正序排列
|
||||
messages.reverse()
|
||||
return messages
|
||||
messages_models.reverse()
|
||||
return [model_to_dict(msg) for msg in messages_models]
|
||||
|
||||
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
||||
query = {"chat_id": chat_id, "time": {"$gt": after_time}}
|
||||
|
||||
return db.messages.find_one(query) is not None
|
||||
return Messages.select().where(
|
||||
(Messages.chat_id == chat_id) &
|
||||
(Messages.time > after_time)
|
||||
).exists()
|
||||
|
||||
|
||||
# # 创建一个内存消息存储实现,用于测试
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from common.database.database import db
|
||||
from src.common.database.database_model import Knowledges # Updated import
|
||||
from src.common.logger_manager import get_logger
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, List # Added List
|
||||
import json # Added for parsing embedding
|
||||
import math # Added for cosine similarity
|
||||
|
||||
logger = get_logger("get_knowledge_tool")
|
||||
|
||||
@@ -30,6 +32,7 @@ class SearchKnowledgeTool(BaseTool):
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
query = "" # Initialize query to ensure it's defined in except block
|
||||
try:
|
||||
query = function_args.get("query")
|
||||
threshold = function_args.get("threshold", 0.4)
|
||||
@@ -48,9 +51,19 @@ class SearchKnowledgeTool(BaseTool):
|
||||
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
||||
return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"}
|
||||
|
||||
@staticmethod
|
||||
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
|
||||
"""计算两个向量之间的余弦相似度"""
|
||||
dot_product = sum(p * q for p, q in zip(vec1, vec2))
|
||||
magnitude1 = math.sqrt(sum(p * p for p in vec1))
|
||||
magnitude2 = math.sqrt(sum(q * q for q in vec2))
|
||||
if magnitude1 == 0 or magnitude2 == 0:
|
||||
return 0.0
|
||||
return dot_product / (magnitude1 * magnitude2)
|
||||
|
||||
@staticmethod
|
||||
def get_info_from_db(
|
||||
query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
||||
query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
||||
) -> Union[str, list]:
|
||||
"""从数据库中获取相关信息
|
||||
|
||||
@@ -66,66 +79,49 @@ class SearchKnowledgeTool(BaseTool):
|
||||
if not query_embedding:
|
||||
return "" if not return_raw else []
|
||||
|
||||
# 使用余弦相似度计算
|
||||
pipeline = [
|
||||
{
|
||||
"$addFields": {
|
||||
"dotProduct": {
|
||||
"$reduce": {
|
||||
"input": {"$range": [0, {"$size": "$embedding"}]},
|
||||
"initialValue": 0,
|
||||
"in": {
|
||||
"$add": [
|
||||
"$$value",
|
||||
{
|
||||
"$multiply": [
|
||||
{"$arrayElemAt": ["$embedding", "$$this"]},
|
||||
{"$arrayElemAt": [query_embedding, "$$this"]},
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
},
|
||||
"magnitude1": {
|
||||
"$sqrt": {
|
||||
"$reduce": {
|
||||
"input": "$embedding",
|
||||
"initialValue": 0,
|
||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
||||
}
|
||||
}
|
||||
},
|
||||
"magnitude2": {
|
||||
"$sqrt": {
|
||||
"$reduce": {
|
||||
"input": query_embedding,
|
||||
"initialValue": 0,
|
||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
|
||||
{
|
||||
"$match": {
|
||||
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
|
||||
}
|
||||
},
|
||||
{"$sort": {"similarity": -1}},
|
||||
{"$limit": limit},
|
||||
{"$project": {"content": 1, "similarity": 1}},
|
||||
]
|
||||
similar_items = []
|
||||
try:
|
||||
all_knowledges = Knowledges.select()
|
||||
for item in all_knowledges:
|
||||
try:
|
||||
item_embedding_str = item.embedding
|
||||
if not item_embedding_str:
|
||||
logger.warning(f"Knowledge item ID {item.id} has empty embedding string.")
|
||||
continue
|
||||
item_embedding = json.loads(item_embedding_str)
|
||||
if not isinstance(item_embedding, list) or not all(isinstance(x, (int, float)) for x in item_embedding):
|
||||
logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.")
|
||||
continue
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}")
|
||||
continue
|
||||
except AttributeError:
|
||||
logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.")
|
||||
continue
|
||||
|
||||
results = list(db.knowledges.aggregate(pipeline))
|
||||
logger.debug(f"知识库查询结果数量: {len(results)}")
|
||||
similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding)
|
||||
|
||||
if similarity >= threshold:
|
||||
similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item})
|
||||
|
||||
# 按相似度降序排序
|
||||
similar_items.sort(key=lambda x: x["similarity"], reverse=True)
|
||||
|
||||
# 应用限制
|
||||
results = similar_items[:limit]
|
||||
logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}")
|
||||
return "" if not return_raw else []
|
||||
|
||||
if not results:
|
||||
return "" if not return_raw else []
|
||||
|
||||
if return_raw:
|
||||
return results
|
||||
# Peewee 模型实例不能直接序列化为 JSON,如果需要原始模型,调用者需要处理
|
||||
# 这里返回包含内容和相似度的字典列表
|
||||
return [{"content": r["content"], "similarity": r["similarity"]} for r in results]
|
||||
else:
|
||||
# 返回所有找到的内容,用换行分隔
|
||||
return "\n".join(str(result["content"]) for result in results)
|
||||
|
||||
Reference in New Issue
Block a user