4次修改

This commit is contained in:
tt-P607
2025-09-20 11:57:22 +08:00
parent 898208f425
commit a8992cdd51
6 changed files with 69 additions and 60 deletions

View File

@@ -69,7 +69,7 @@ class ImageManager:
os.makedirs(self.IMAGE_DIR, exist_ok=True)
@staticmethod
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
async def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
Args:
@@ -80,22 +80,22 @@ class ImageManager:
Optional[str]: 描述文本如果不存在则返回None
"""
try:
with get_db_session() as session:
record = session.execute(
async with get_db_session() as session:
record = (await session.execute(
select(ImageDescriptions).where(
and_(
ImageDescriptions.image_description_hash == image_hash,
ImageDescriptions.type == description_type,
)
)
).scalar()
)).scalar()
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
return None
@staticmethod
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
async def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
"""保存图片描述到数据库
Args:
@@ -105,16 +105,16 @@ class ImageManager:
"""
try:
current_timestamp = time.time()
with get_db_session() as session:
async with get_db_session() as session:
# 查找现有记录
existing = session.execute(
existing = (await session.execute(
select(ImageDescriptions).where(
and_(
ImageDescriptions.image_description_hash == image_hash,
ImageDescriptions.type == description_type,
)
)
).scalar()
)).scalar()
if existing:
# 更新现有记录
@@ -129,7 +129,7 @@ class ImageManager:
timestamp=current_timestamp,
)
session.add(new_desc)
session.commit()
await session.commit()
# 会在上下文管理器中自动调用
except Exception as e:
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
@@ -175,7 +175,7 @@ class ImageManager:
logger.debug(f"查询EmojiManager时出错: {e}")
# 查询ImageDescriptions表的缓存描述
if cached_description := self._get_description_from_db(image_hash, "emoji"):
if cached_description := await self._get_description_from_db(image_hash, "emoji"):
logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...")
return f"[表情包:{cached_description}]"
@@ -239,7 +239,7 @@ class ImageManager:
logger.info(f"[emoji识别] 详细描述: {detailed_description}... -> 情感标签: {final_emotion}")
if cached_description := self._get_description_from_db(image_hash, "emoji"):
if cached_description := await self._get_description_from_db(image_hash, "emoji"):
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
return f"[表情包:{cached_description}]"
@@ -261,10 +261,10 @@ class ImageManager:
try:
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
existing_img = session.execute(
async with get_db_session() as session:
existing_img = (await session.execute(
select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji"))
).scalar()
)).scalar()
if existing_img:
existing_img.path = file_path
@@ -279,7 +279,7 @@ class ImageManager:
timestamp=current_timestamp,
)
session.add(new_img)
session.commit()
await session.commit()
except Exception as e:
logger.error(f"保存到Images表失败: {str(e)}")
@@ -289,7 +289,7 @@ class ImageManager:
logger.debug("偷取表情包功能已关闭,跳过保存。")
# 保存最终的情感标签到缓存 (ImageDescriptions表)
self._save_description_to_db(image_hash, final_emotion, "emoji")
await self._save_description_to_db(image_hash, final_emotion, "emoji")
return f"[表情包:{final_emotion}]"
@@ -306,9 +306,9 @@ class ImageManager:
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 优先检查Images表中是否已有完整的描述
with get_db_session() as session:
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
async with get_db_session() as session:
# 优先检查Images表中是否已有完整的描述
existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar()
if existing_image:
# 更新计数
if hasattr(existing_image, "count") and existing_image.count is not None:
@@ -318,34 +318,34 @@ class ImageManager:
# 如果已有描述,直接返回
if existing_image.description:
await session.commit()
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description}...")
return f"[图片:{existing_image.description}]"
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...")
return f"[图片:{cached_description}]"
# 如果没有描述,继续在当前会话中操作
if cached_description := await self._get_description_from_db(image_hash, "image"):
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...")
return f"[图片:{cached_description}]"
# 调用AI获取描述
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
prompt = global_config.custom_prompt.image_prompt
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
# 调用AI获取描述
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
prompt = global_config.custom_prompt.image_prompt
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片(描述生成失败)]"
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片(描述生成失败)]"
# 保存图片和描述
current_timestamp = time.time()
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
image_dir = os.path.join(self.IMAGE_DIR, "image")
os.makedirs(image_dir, exist_ok=True)
file_path = os.path.join(image_dir, filename)
# 保存图片和描述
current_timestamp = time.time()
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
image_dir = os.path.join(self.IMAGE_DIR, "image")
os.makedirs(image_dir, exist_ok=True)
file_path = os.path.join(image_dir, filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
@@ -358,7 +358,6 @@ class ImageManager:
existing_image.image_id = str(uuid.uuid4())
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
existing_image.vlm_processed = True
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
else:
new_img = Images(
@@ -372,13 +371,15 @@ class ImageManager:
count=1,
)
session.add(new_img)
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
except Exception as e:
logger.error(f"保存图片文件或元数据失败: {str(e)}")
# 保存描述到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, description, "image")
await session.commit()
# 保存描述到ImageDescriptions表作为备用缓存
await self._save_description_to_db(image_hash, description, "image")
logger.info(f"[VLM完成] 图片描述生成: {description}...")
return f"[图片:{description}]"
logger.info(f"[VLM完成] 图片描述生成: {description}...")
return f"[图片:{description}]"
@@ -525,8 +526,8 @@ class ImageManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
with get_db_session() as session:
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
async with get_db_session() as session:
existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar()
if existing_image:
# 检查是否缺少必要字段,如果缺少则创建新记录
if (
@@ -546,6 +547,7 @@ class ImageManager:
existing_image.vlm_processed = False
existing_image.count += 1
await session.commit()
# 如果已有描述,直接返回
if existing_image.description and existing_image.description.strip():
@@ -556,6 +558,7 @@ class ImageManager:
# 更新数据库中的描述
existing_image.description = description.replace("[图片:", "").replace("]", "")
existing_image.vlm_processed = True
await session.commit()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
# print(f"图片不存在: {image_hash}")
@@ -588,7 +591,7 @@ class ImageManager:
count=1,
)
session.add(new_img)
session.commit()
await session.commit()
return image_id, f"[picid:{image_id}]"