4次修改
This commit is contained in:
@@ -69,7 +69,7 @@ class ImageManager:
|
||||
os.makedirs(self.IMAGE_DIR, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
|
||||
async def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
|
||||
"""从数据库获取图片描述
|
||||
|
||||
Args:
|
||||
@@ -80,22 +80,22 @@ class ImageManager:
|
||||
Optional[str]: 描述文本,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
record = session.execute(
|
||||
async with get_db_session() as session:
|
||||
record = (await session.execute(
|
||||
select(ImageDescriptions).where(
|
||||
and_(
|
||||
ImageDescriptions.image_description_hash == image_hash,
|
||||
ImageDescriptions.type == description_type,
|
||||
)
|
||||
)
|
||||
).scalar()
|
||||
)).scalar()
|
||||
return record.description if record else None
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
|
||||
async def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
|
||||
"""保存图片描述到数据库
|
||||
|
||||
Args:
|
||||
@@ -105,16 +105,16 @@ class ImageManager:
|
||||
"""
|
||||
try:
|
||||
current_timestamp = time.time()
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 查找现有记录
|
||||
existing = session.execute(
|
||||
existing = (await session.execute(
|
||||
select(ImageDescriptions).where(
|
||||
and_(
|
||||
ImageDescriptions.image_description_hash == image_hash,
|
||||
ImageDescriptions.type == description_type,
|
||||
)
|
||||
)
|
||||
).scalar()
|
||||
)).scalar()
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
@@ -129,7 +129,7 @@ class ImageManager:
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
session.add(new_desc)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
# 会在上下文管理器中自动调用
|
||||
except Exception as e:
|
||||
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
|
||||
@@ -175,7 +175,7 @@ class ImageManager:
|
||||
logger.debug(f"查询EmojiManager时出错: {e}")
|
||||
|
||||
# 查询ImageDescriptions表的缓存描述
|
||||
if cached_description := self._get_description_from_db(image_hash, "emoji"):
|
||||
if cached_description := await self._get_description_from_db(image_hash, "emoji"):
|
||||
logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...")
|
||||
return f"[表情包:{cached_description}]"
|
||||
|
||||
@@ -239,7 +239,7 @@ class ImageManager:
|
||||
|
||||
logger.info(f"[emoji识别] 详细描述: {detailed_description}... -> 情感标签: {final_emotion}")
|
||||
|
||||
if cached_description := self._get_description_from_db(image_hash, "emoji"):
|
||||
if cached_description := await self._get_description_from_db(image_hash, "emoji"):
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
|
||||
return f"[表情包:{cached_description}]"
|
||||
|
||||
@@ -261,10 +261,10 @@ class ImageManager:
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
with get_db_session() as session:
|
||||
existing_img = session.execute(
|
||||
async with get_db_session() as session:
|
||||
existing_img = (await session.execute(
|
||||
select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji"))
|
||||
).scalar()
|
||||
)).scalar()
|
||||
|
||||
if existing_img:
|
||||
existing_img.path = file_path
|
||||
@@ -279,7 +279,7 @@ class ImageManager:
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
session.add(new_img)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"保存到Images表失败: {str(e)}")
|
||||
|
||||
@@ -289,7 +289,7 @@ class ImageManager:
|
||||
logger.debug("偷取表情包功能已关闭,跳过保存。")
|
||||
|
||||
# 保存最终的情感标签到缓存 (ImageDescriptions表)
|
||||
self._save_description_to_db(image_hash, final_emotion, "emoji")
|
||||
await self._save_description_to_db(image_hash, final_emotion, "emoji")
|
||||
|
||||
return f"[表情包:{final_emotion}]"
|
||||
|
||||
@@ -306,9 +306,9 @@ class ImageManager:
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 优先检查Images表中是否已有完整的描述
|
||||
with get_db_session() as session:
|
||||
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
|
||||
async with get_db_session() as session:
|
||||
# 优先检查Images表中是否已有完整的描述
|
||||
existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar()
|
||||
if existing_image:
|
||||
# 更新计数
|
||||
if hasattr(existing_image, "count") and existing_image.count is not None:
|
||||
@@ -318,34 +318,34 @@ class ImageManager:
|
||||
|
||||
# 如果已有描述,直接返回
|
||||
if existing_image.description:
|
||||
await session.commit()
|
||||
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description}...")
|
||||
return f"[图片:{existing_image.description}]"
|
||||
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...")
|
||||
return f"[图片:{cached_description}]"
|
||||
# 如果没有描述,继续在当前会话中操作
|
||||
if cached_description := await self._get_description_from_db(image_hash, "image"):
|
||||
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...")
|
||||
return f"[图片:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
prompt = global_config.custom_prompt.image_prompt
|
||||
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||
)
|
||||
# 调用AI获取描述
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
prompt = global_config.custom_prompt.image_prompt
|
||||
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||
)
|
||||
|
||||
if description is None:
|
||||
logger.warning("AI未能生成图片描述")
|
||||
return "[图片(描述生成失败)]"
|
||||
if description is None:
|
||||
logger.warning("AI未能生成图片描述")
|
||||
return "[图片(描述生成失败)]"
|
||||
|
||||
# 保存图片和描述
|
||||
current_timestamp = time.time()
|
||||
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
||||
image_dir = os.path.join(self.IMAGE_DIR, "image")
|
||||
os.makedirs(image_dir, exist_ok=True)
|
||||
file_path = os.path.join(image_dir, filename)
|
||||
# 保存图片和描述
|
||||
current_timestamp = time.time()
|
||||
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
||||
image_dir = os.path.join(self.IMAGE_DIR, "image")
|
||||
os.makedirs(image_dir, exist_ok=True)
|
||||
file_path = os.path.join(image_dir, filename)
|
||||
|
||||
try:
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
@@ -358,7 +358,6 @@ class ImageManager:
|
||||
existing_image.image_id = str(uuid.uuid4())
|
||||
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = True
|
||||
|
||||
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
|
||||
else:
|
||||
new_img = Images(
|
||||
@@ -372,13 +371,15 @@ class ImageManager:
|
||||
count=1,
|
||||
)
|
||||
session.add(new_img)
|
||||
|
||||
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片文件或元数据失败: {str(e)}")
|
||||
|
||||
# 保存描述到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
await session.commit()
|
||||
|
||||
# 保存描述到ImageDescriptions表作为备用缓存
|
||||
await self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
logger.info(f"[VLM完成] 图片描述生成: {description}...")
|
||||
return f"[图片:{description}]"
|
||||
|
||||
logger.info(f"[VLM完成] 图片描述生成: {description}...")
|
||||
return f"[图片:{description}]"
|
||||
@@ -525,8 +526,8 @@ class ImageManager:
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
with get_db_session() as session:
|
||||
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
|
||||
async with get_db_session() as session:
|
||||
existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar()
|
||||
if existing_image:
|
||||
# 检查是否缺少必要字段,如果缺少则创建新记录
|
||||
if (
|
||||
@@ -546,6 +547,7 @@ class ImageManager:
|
||||
existing_image.vlm_processed = False
|
||||
|
||||
existing_image.count += 1
|
||||
await session.commit()
|
||||
|
||||
# 如果已有描述,直接返回
|
||||
if existing_image.description and existing_image.description.strip():
|
||||
@@ -556,6 +558,7 @@ class ImageManager:
|
||||
# 更新数据库中的描述
|
||||
existing_image.description = description.replace("[图片:", "").replace("]", "")
|
||||
existing_image.vlm_processed = True
|
||||
await session.commit()
|
||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
||||
|
||||
# print(f"图片不存在: {image_hash}")
|
||||
@@ -588,7 +591,7 @@ class ImageManager:
|
||||
count=1,
|
||||
)
|
||||
session.add(new_img)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
return image_id, f"[picid:{image_id}]"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user