重构数据库交互以使用 Peewee ORM
- 更新数据库连接和模型定义,以便使用 Peewee for SQLite。 - 在消息存储和检索功能中,用 Peewee ORM 查询替换 MongoDB 查询。 - 为 Messages、ThinkingLog 和 OnlineTime 引入了新的模型,以方便结构化数据存储。 - 增强了数据库操作的错误处理和日志记录。 - 删除了过时的 MongoDB 集合管理代码。 - 通过利用 Peewee 内置的查询和数据操作方法来提升性能。
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.message import MessageRecv, MessageSending, Message
|
||||
from common.database.database import db
|
||||
from src.common.database.database_model import Messages, ThinkingLog
|
||||
import time
|
||||
import traceback
|
||||
from typing import List
|
||||
import json
|
||||
|
||||
|
||||
class InfoCatcher:
|
||||
@@ -60,8 +61,6 @@ class InfoCatcher:
|
||||
def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息
|
||||
self.timing_results["sub_heartflow_observe_time"] = obs_duration
|
||||
|
||||
# def catch_shf
|
||||
|
||||
def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str):
|
||||
self.timing_results["sub_heartflow_step_time"] = step_duration
|
||||
if len(past_mind) > 1:
|
||||
@@ -72,25 +71,10 @@ class InfoCatcher:
|
||||
self.heartflow_data["sub_heartflow_now"] = current_mind
|
||||
|
||||
def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""):
|
||||
# if self.response_mode == "heart_flow": # 条件判断不需要了喵~
|
||||
# self.heartflow_data["prompt"] = prompt
|
||||
# self.heartflow_data["response"] = response
|
||||
# self.heartflow_data["model"] = model_name
|
||||
# elif self.response_mode == "reasoning": # 条件判断不需要了喵~
|
||||
# self.reasoning_data["thinking_log"] = reasoning_content
|
||||
# self.reasoning_data["prompt"] = prompt
|
||||
# self.reasoning_data["response"] = response
|
||||
# self.reasoning_data["model"] = model_name
|
||||
|
||||
# 直接记录信息喵~
|
||||
self.reasoning_data["thinking_log"] = reasoning_content
|
||||
self.reasoning_data["prompt"] = prompt
|
||||
self.reasoning_data["response"] = response
|
||||
self.reasoning_data["model"] = model_name
|
||||
# 如果 heartflow 数据也需要通用字段,可以取消下面的注释喵~
|
||||
# self.heartflow_data["prompt"] = prompt
|
||||
# self.heartflow_data["response"] = response
|
||||
# self.heartflow_data["model"] = model_name
|
||||
|
||||
self.response_text = response
|
||||
|
||||
@@ -102,6 +86,7 @@ class InfoCatcher:
|
||||
):
|
||||
self.timing_results["make_response_time"] = response_duration
|
||||
self.response_time = time.time()
|
||||
self.response_messages = []
|
||||
for msg in response_message:
|
||||
self.response_messages.append(msg)
|
||||
|
||||
@@ -112,107 +97,110 @@ class InfoCatcher:
|
||||
@staticmethod
|
||||
def get_message_from_db_between_msgs(message_start: Message, message_end: Message):
|
||||
try:
|
||||
# 从数据库中获取消息的时间戳
|
||||
time_start = message_start.message_info.time
|
||||
time_end = message_end.message_info.time
|
||||
chat_id = message_start.chat_stream.stream_id
|
||||
|
||||
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
|
||||
|
||||
# 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据
|
||||
messages_between = db.messages.find(
|
||||
{"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}}
|
||||
).sort("time", -1)
|
||||
messages_between_query = Messages.select().where(
|
||||
(Messages.chat_id == chat_id) &
|
||||
(Messages.time > time_start) &
|
||||
(Messages.time < time_end)
|
||||
).order_by(Messages.time.desc())
|
||||
|
||||
result = list(messages_between)
|
||||
result = list(messages_between_query)
|
||||
print(f"查询结果数量: {len(result)}")
|
||||
if result:
|
||||
print(f"第一条消息时间: {result[0]['time']}")
|
||||
print(f"最后一条消息时间: {result[-1]['time']}")
|
||||
print(f"第一条消息时间: {result[0].time}")
|
||||
print(f"最后一条消息时间: {result[-1].time}")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"获取消息时出错: {str(e)}")
|
||||
print(traceback.format_exc())
|
||||
return []
|
||||
|
||||
def get_message_from_db_before_msg(self, message: MessageRecv):
|
||||
# 从数据库中获取消息
|
||||
message_id = message.message_info.message_id
|
||||
chat_id = message.chat_stream.stream_id
|
||||
message_id_val = message.message_info.message_id
|
||||
chat_id_val = message.chat_stream.stream_id
|
||||
|
||||
# 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据
|
||||
messages_before = (
|
||||
db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}})
|
||||
.sort("time", -1)
|
||||
.limit(self.context_length * 3)
|
||||
) # 获取更多历史信息
|
||||
messages_before_query = Messages.select().where(
|
||||
(Messages.chat_id == chat_id_val) &
|
||||
(Messages.message_id < message_id_val)
|
||||
).order_by(Messages.time.desc()).limit(self.context_length * 3)
|
||||
|
||||
return list(messages_before)
|
||||
return list(messages_before_query)
|
||||
|
||||
def message_list_to_dict(self, message_list):
|
||||
# 存储简化的聊天记录
|
||||
result = []
|
||||
for message in message_list:
|
||||
if not isinstance(message, dict):
|
||||
message = self.message_to_dict(message)
|
||||
# print(message)
|
||||
for msg_item in message_list:
|
||||
processed_msg_item = msg_item
|
||||
if not isinstance(msg_item, dict):
|
||||
processed_msg_item = self.message_to_dict(msg_item)
|
||||
|
||||
if not processed_msg_item:
|
||||
continue
|
||||
|
||||
lite_message = {
|
||||
"time": message["time"],
|
||||
"user_nickname": message["user_info"]["user_nickname"],
|
||||
"processed_plain_text": message["processed_plain_text"],
|
||||
"time": processed_msg_item.get("time"),
|
||||
"user_nickname": processed_msg_item.get("user_nickname"),
|
||||
"processed_plain_text": processed_msg_item.get("processed_plain_text"),
|
||||
}
|
||||
result.append(lite_message)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def message_to_dict(message):
|
||||
if not message:
|
||||
def message_to_dict(msg_obj):
|
||||
if not msg_obj:
|
||||
return None
|
||||
if isinstance(message, dict):
|
||||
return message
|
||||
return {
|
||||
# "message_id": message.message_info.message_id,
|
||||
"time": message.message_info.time,
|
||||
"user_id": message.message_info.user_info.user_id,
|
||||
"user_nickname": message.message_info.user_info.user_nickname,
|
||||
"processed_plain_text": message.processed_plain_text,
|
||||
# "detailed_plain_text": message.detailed_plain_text
|
||||
}
|
||||
if isinstance(msg_obj, dict):
|
||||
return msg_obj
|
||||
|
||||
if isinstance(msg_obj, Messages):
|
||||
return {
|
||||
"time": msg_obj.time,
|
||||
"user_id": msg_obj.user_id,
|
||||
"user_nickname": msg_obj.user_nickname,
|
||||
"processed_plain_text": msg_obj.processed_plain_text,
|
||||
}
|
||||
|
||||
if hasattr(msg_obj, 'message_info') and hasattr(msg_obj.message_info, 'user_info'):
|
||||
return {
|
||||
"time": msg_obj.message_info.time,
|
||||
"user_id": msg_obj.message_info.user_info.user_id,
|
||||
"user_nickname": msg_obj.message_info.user_info.user_nickname,
|
||||
"processed_plain_text": msg_obj.processed_plain_text,
|
||||
}
|
||||
|
||||
print(f"Warning: message_to_dict received an unhandled type: {type(msg_obj)}")
|
||||
return {}
|
||||
|
||||
def done_catch(self):
|
||||
"""将收集到的信息存储到数据库的 thinking_log 集合中喵~"""
|
||||
"""将收集到的信息存储到数据库的 thinking_log 表中喵~"""
|
||||
try:
|
||||
# 将消息对象转换为可序列化的字典喵~
|
||||
|
||||
thinking_log_data = {
|
||||
"chat_id": self.chat_id,
|
||||
"trigger_text": self.trigger_response_text,
|
||||
"response_text": self.response_text,
|
||||
"trigger_info": {
|
||||
"time": self.trigger_response_time,
|
||||
"message": self.message_to_dict(self.trigger_response_message),
|
||||
},
|
||||
"response_info": {
|
||||
"time": self.response_time,
|
||||
"message": self.response_messages,
|
||||
},
|
||||
"timing_results": self.timing_results,
|
||||
"chat_history": self.message_list_to_dict(self.chat_history),
|
||||
"chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking),
|
||||
"chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response),
|
||||
"heartflow_data": self.heartflow_data,
|
||||
"reasoning_data": self.reasoning_data,
|
||||
trigger_info_dict = self.message_to_dict(self.trigger_response_message)
|
||||
response_info_dict = {
|
||||
"time": self.response_time,
|
||||
"message": self.response_messages,
|
||||
}
|
||||
chat_history_list = self.message_list_to_dict(self.chat_history)
|
||||
chat_history_in_thinking_list = self.message_list_to_dict(self.chat_history_in_thinking)
|
||||
chat_history_after_response_list = self.message_list_to_dict(self.chat_history_after_response)
|
||||
|
||||
# 根据不同的响应模式添加相应的数据喵~ # 现在直接都加上去好了喵~
|
||||
# if self.response_mode == "heart_flow":
|
||||
# thinking_log_data["mode_specific_data"] = self.heartflow_data
|
||||
# elif self.response_mode == "reasoning":
|
||||
# thinking_log_data["mode_specific_data"] = self.reasoning_data
|
||||
|
||||
# 将数据插入到 thinking_log 集合中喵~
|
||||
db.thinking_log.insert_one(thinking_log_data)
|
||||
log_entry = ThinkingLog(
|
||||
chat_id=self.chat_id,
|
||||
trigger_text=self.trigger_response_text,
|
||||
response_text=self.response_text,
|
||||
trigger_info_json=json.dumps(trigger_info_dict) if trigger_info_dict else None,
|
||||
response_info_json=json.dumps(response_info_dict),
|
||||
timing_results_json=json.dumps(self.timing_results),
|
||||
chat_history_json=json.dumps(chat_history_list),
|
||||
chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list),
|
||||
chat_history_after_response_json=json.dumps(chat_history_after_response_list),
|
||||
heartflow_data_json=json.dumps(self.heartflow_data),
|
||||
reasoning_data_json=json.dumps(self.reasoning_data)
|
||||
)
|
||||
log_entry.save()
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
|
||||
@@ -5,7 +5,8 @@ from typing import Any, Dict, Tuple, List
|
||||
from src.common.logger import get_module_logger
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
|
||||
from ...common.database.database import db
|
||||
from ...common.database.database import db # This db is the Peewee database instance
|
||||
from ...common.database.database_model import OnlineTime # Import the Peewee model
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_module_logger("maibot_statistic")
|
||||
@@ -39,7 +40,7 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
def __init__(self):
|
||||
super().__init__(task_name="Online Time Record Task", run_interval=60)
|
||||
|
||||
self.record_id: str | None = None
|
||||
self.record_id: int | None = None # Changed to int for Peewee's default ID
|
||||
"""记录ID"""
|
||||
|
||||
self._init_database() # 初始化数据库
|
||||
@@ -47,53 +48,46 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
"""初始化数据库"""
|
||||
if "online_time" not in db.list_collection_names():
|
||||
# 初始化数据库(在线时长)
|
||||
db.create_collection("online_time")
|
||||
# 创建索引
|
||||
if ("end_timestamp", 1) not in db.online_time.list_indexes():
|
||||
db.online_time.create_index([("end_timestamp", 1)])
|
||||
with db.atomic(): # Use atomic operations for schema changes
|
||||
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
extended_end_time = current_time + timedelta(minutes=1)
|
||||
|
||||
if self.record_id:
|
||||
# 如果有记录,则更新结束时间
|
||||
db.online_time.update_one(
|
||||
{"_id": self.record_id},
|
||||
{
|
||||
"$set": {
|
||||
"end_timestamp": datetime.now() + timedelta(minutes=1),
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id)
|
||||
updated_rows = query.execute()
|
||||
if updated_rows == 0:
|
||||
# Record might have been deleted or ID is stale, try to find/create
|
||||
self.record_id = None # Reset record_id to trigger find/create logic below
|
||||
|
||||
if not self.record_id: # Check again if record_id was reset or initially None
|
||||
# 如果没有记录,检查一分钟以内是否已有记录
|
||||
current_time = datetime.now()
|
||||
if recent_record := db.online_time.find_one(
|
||||
{"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}}
|
||||
):
|
||||
# Look for a record whose end_timestamp is recent enough to be considered ongoing
|
||||
recent_record = OnlineTime.select().where(
|
||||
OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))
|
||||
).order_by(OnlineTime.end_timestamp.desc()).first()
|
||||
|
||||
if recent_record:
|
||||
# 如果有记录,则更新结束时间
|
||||
self.record_id = recent_record["_id"]
|
||||
db.online_time.update_one(
|
||||
{"_id": self.record_id},
|
||||
{
|
||||
"$set": {
|
||||
"end_timestamp": current_time + timedelta(minutes=1),
|
||||
}
|
||||
},
|
||||
)
|
||||
self.record_id = recent_record.id
|
||||
recent_record.end_timestamp = extended_end_time
|
||||
recent_record.save()
|
||||
else:
|
||||
# 若没有记录,则插入新的在线时间记录
|
||||
self.record_id = db.online_time.insert_one(
|
||||
{
|
||||
"start_timestamp": current_time,
|
||||
"end_timestamp": current_time + timedelta(minutes=1),
|
||||
}
|
||||
).inserted_id
|
||||
new_record = OnlineTime.create(
|
||||
start_timestamp=current_time,
|
||||
end_timestamp=extended_end_time,
|
||||
)
|
||||
self.record_id = new_record.id
|
||||
except Exception as e:
|
||||
logger.error(f"在线时间记录失败,错误信息:{e}")
|
||||
|
||||
|
||||
|
||||
def _format_online_time(online_seconds: int) -> str:
|
||||
"""
|
||||
格式化在线时间
|
||||
|
||||
@@ -9,6 +9,7 @@ import numpy as np
|
||||
|
||||
|
||||
from ...common.database.database import db
|
||||
from ...common.database.database_model import Images, ImageDescriptions
|
||||
from ...config.config import global_config
|
||||
from ..models.utils_model import LLMRequest
|
||||
|
||||
@@ -32,40 +33,21 @@ class ImageManager:
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
self._ensure_image_collection()
|
||||
self._ensure_description_collection()
|
||||
self._ensure_image_dir()
|
||||
self._initialized = True
|
||||
self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image")
|
||||
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
db.create_tables([Images, ImageDescriptions], safe=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或表创建失败: {e}")
|
||||
|
||||
self._initialized = True
|
||||
|
||||
def _ensure_image_dir(self):
|
||||
"""确保图像存储目录存在"""
|
||||
os.makedirs(self.IMAGE_DIR, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_image_collection():
|
||||
"""确保images集合存在并创建索引"""
|
||||
if "images" not in db.list_collection_names():
|
||||
db.create_collection("images")
|
||||
|
||||
# 删除旧索引
|
||||
db.images.drop_indexes()
|
||||
# 创建新的复合索引
|
||||
db.images.create_index([("hash", 1), ("type", 1)], unique=True)
|
||||
db.images.create_index([("url", 1)])
|
||||
db.images.create_index([("path", 1)])
|
||||
|
||||
@staticmethod
|
||||
def _ensure_description_collection():
|
||||
"""确保image_descriptions集合存在并创建索引"""
|
||||
if "image_descriptions" not in db.list_collection_names():
|
||||
db.create_collection("image_descriptions")
|
||||
|
||||
# 删除旧索引
|
||||
db.image_descriptions.drop_indexes()
|
||||
# 创建新的复合索引
|
||||
db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
|
||||
|
||||
@staticmethod
|
||||
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
|
||||
"""从数据库获取图片描述
|
||||
@@ -77,8 +59,15 @@ class ImageManager:
|
||||
Returns:
|
||||
Optional[str]: 描述文本,如果不存在则返回None
|
||||
"""
|
||||
result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
|
||||
return result["description"] if result else None
|
||||
try:
|
||||
record = ImageDescriptions.get_or_none(
|
||||
(ImageDescriptions.hash == image_hash) &
|
||||
(ImageDescriptions.type == description_type)
|
||||
)
|
||||
return record.description if record else None
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
|
||||
@@ -90,20 +79,22 @@ class ImageManager:
|
||||
description_type: 描述类型 ('emoji' 或 'image')
|
||||
"""
|
||||
try:
|
||||
db.image_descriptions.update_one(
|
||||
{"hash": image_hash, "type": description_type},
|
||||
{
|
||||
"$set": {
|
||||
"description": description,
|
||||
"timestamp": int(time.time()),
|
||||
"hash": image_hash, # 确保hash字段存在
|
||||
"type": description_type, # 确保type字段存在
|
||||
}
|
||||
},
|
||||
upsert=True,
|
||||
current_timestamp = time.time()
|
||||
defaults = {
|
||||
'description': description,
|
||||
'timestamp': current_timestamp
|
||||
}
|
||||
desc_obj, created = ImageDescriptions.get_or_create(
|
||||
hash=image_hash,
|
||||
type=description_type,
|
||||
defaults=defaults
|
||||
)
|
||||
if not created: # 如果记录已存在,则更新
|
||||
desc_obj.description = description
|
||||
desc_obj.timestamp = current_timestamp
|
||||
desc_obj.save()
|
||||
except Exception as e:
|
||||
logger.error(f"保存描述到数据库失败: {str(e)}")
|
||||
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
||||
|
||||
async def get_emoji_description(self, image_base64: str) -> str:
|
||||
"""获取表情包描述,带查重和保存功能"""
|
||||
@@ -116,18 +107,25 @@ class ImageManager:
|
||||
# 查询缓存的描述
|
||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||
if cached_description:
|
||||
# logger.debug(f"缓存表情包描述: {cached_description}")
|
||||
return f"[表情包,含义看起来是:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
if image_format == "gif" or image_format == "GIF":
|
||||
image_base64 = self.transform_gif(image_base64)
|
||||
image_base64_processed = self.transform_gif(image_base64)
|
||||
if image_base64_processed is None:
|
||||
logger.warning("GIF转换失败,无法获取描述")
|
||||
return "[表情包(GIF处理失败)]"
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,使用1-2个词描述一下表情包表达的情感和内容,简短一些"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg")
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg")
|
||||
else:
|
||||
prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
if description is None:
|
||||
logger.warning("AI未能生成表情包描述")
|
||||
return "[表情包(描述生成失败)]"
|
||||
|
||||
# 再次检查缓存,防止并发写入时重复生成
|
||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||
if cached_description:
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
|
||||
@@ -136,31 +134,37 @@ class ImageManager:
|
||||
# 根据配置决定是否保存图片
|
||||
if global_config.save_emoji:
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
|
||||
if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")):
|
||||
os.makedirs(os.path.join(self.IMAGE_DIR, "emoji"))
|
||||
file_path = os.path.join(self.IMAGE_DIR, "emoji", filename)
|
||||
current_timestamp = time.time()
|
||||
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
||||
emoji_dir = os.path.join(self.IMAGE_DIR, "emoji")
|
||||
os.makedirs(emoji_dir, exist_ok=True)
|
||||
file_path = os.path.join(emoji_dir, filename)
|
||||
|
||||
try:
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库
|
||||
image_doc = {
|
||||
"hash": image_hash,
|
||||
"path": file_path,
|
||||
"type": "emoji",
|
||||
"description": description,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
|
||||
logger.trace(f"保存表情包: {file_path}")
|
||||
# 保存到数据库 (Images表)
|
||||
try:
|
||||
img_obj = Images.get((Images.hash == image_hash) & (Images.type == "emoji"))
|
||||
img_obj.path = file_path
|
||||
img_obj.description = description
|
||||
img_obj.timestamp = current_timestamp
|
||||
img_obj.save()
|
||||
except Images.DoesNotExist:
|
||||
Images.create(
|
||||
hash=image_hash,
|
||||
path=file_path,
|
||||
type="emoji",
|
||||
description=description,
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
logger.trace(f"保存表情包元数据: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存表情包文件失败: {str(e)}")
|
||||
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
|
||||
|
||||
# 保存描述到数据库
|
||||
# 保存描述到数据库 (ImageDescriptions表)
|
||||
self._save_description_to_db(image_hash, description, "emoji")
|
||||
|
||||
return f"[表情包:{description}]"
|
||||
@@ -187,7 +191,12 @@ class ImageManager:
|
||||
"请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多100个字。"
|
||||
)
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
if description is None:
|
||||
logger.warning("AI未能生成图片描述")
|
||||
return "[图片(描述生成失败)]"
|
||||
|
||||
# 再次检查缓存
|
||||
cached_description = self._get_description_from_db(image_hash, "image")
|
||||
if cached_description:
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}")
|
||||
@@ -195,38 +204,40 @@ class ImageManager:
|
||||
|
||||
logger.debug(f"描述是{description}")
|
||||
|
||||
if description is None:
|
||||
logger.warning("AI未能生成图片描述")
|
||||
return "[图片]"
|
||||
|
||||
# 根据配置决定是否保存图片
|
||||
if global_config.save_pic:
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
|
||||
if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")):
|
||||
os.makedirs(os.path.join(self.IMAGE_DIR, "image"))
|
||||
file_path = os.path.join(self.IMAGE_DIR, "image", filename)
|
||||
current_timestamp = time.time()
|
||||
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
||||
image_dir = os.path.join(self.IMAGE_DIR, "image")
|
||||
os.makedirs(image_dir, exist_ok=True)
|
||||
file_path = os.path.join(image_dir, filename)
|
||||
|
||||
try:
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库
|
||||
image_doc = {
|
||||
"hash": image_hash,
|
||||
"path": file_path,
|
||||
"type": "image",
|
||||
"description": description,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
|
||||
logger.trace(f"保存图片: {file_path}")
|
||||
# 保存到数据库 (Images表)
|
||||
try:
|
||||
img_obj = Images.get((Images.hash == image_hash) & (Images.type == "image"))
|
||||
img_obj.path = file_path
|
||||
img_obj.description = description
|
||||
img_obj.timestamp = current_timestamp
|
||||
img_obj.save()
|
||||
except Images.DoesNotExist:
|
||||
Images.create(
|
||||
hash=image_hash,
|
||||
path=file_path,
|
||||
type="image",
|
||||
description=description,
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
logger.trace(f"保存图片元数据: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片文件失败: {str(e)}")
|
||||
logger.error(f"保存图片文件或元数据失败: {str(e)}")
|
||||
|
||||
# 保存描述到数据库
|
||||
# 保存描述到数据库 (ImageDescriptions表)
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
return f"[图片:{description}]"
|
||||
|
||||
Reference in New Issue
Block a user