Files
Mofox-Core/src/chat/utils/utils_image.py
明天好像没什么 30658afdb4 ruff归零
2025-11-01 21:32:41 +08:00

496 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import base64
import hashlib
import io
import os
import time
import uuid
from typing import Any
import aiofiles
import numpy as np
from PIL import Image
from rich.traceback import install
from sqlalchemy import and_, select
from src.common.database.core import get_db_session
from src.common.database.core.models import ImageDescriptions, Images
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
install(extra_lines=3)
logger = get_logger("chat_image")
def is_image_message(message: dict[str, Any]) -> bool:
"""
判断消息是否为图片消息
Args:
message: 消息字典
Returns:
bool: 是否为图片消息
"""
return message.get("type") == "image" or (
isinstance(message.get("content"), dict) and message["content"].get("type") == "image"
)
class ImageManager:
_instance = None
IMAGE_DIR = "data" # 图像存储根目录
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
self._ensure_image_dir()
self._initialized = True
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
# try:
# db.connect(reuse_if_open=True)
# # 使用SQLAlchemy创建表已在初始化时完成
# logger.debug("使用SQLAlchemy进行表管理")
# except Exception as e:
# logger.error(f"数据库连接失败: {e}")
self._initialized = True
def _ensure_image_dir(self):
"""确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True)
@staticmethod
async def _get_description_from_db(image_hash: str, description_type: str) -> str | None:
"""从数据库获取图片描述
Args:
image_hash: 图片哈希值
description_type: 描述类型 ('emoji''image')
Returns:
Optional[str]: 描述文本如果不存在则返回None
"""
try:
async with get_db_session() as session:
record = (
await session.execute(
select(ImageDescriptions).where(
and_(
ImageDescriptions.image_description_hash == image_hash,
ImageDescriptions.type == description_type,
)
)
)
).scalar()
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {e!s}")
return None
@staticmethod
async def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
"""保存图片描述到数据库
Args:
image_hash: 图片哈希值
description: 描述文本
description_type: 描述类型 ('emoji''image')
"""
try:
current_timestamp = time.time()
async with get_db_session() as session:
# 查找现有记录
existing = (
await session.execute(
select(ImageDescriptions).where(
and_(
ImageDescriptions.image_description_hash == image_hash,
ImageDescriptions.type == description_type,
)
)
)
).scalar()
if existing:
# 更新现有记录
existing.description = description
existing.timestamp = current_timestamp
else:
# 创建新记录
new_desc = ImageDescriptions(
image_description_hash=image_hash,
type=description_type,
description=description,
timestamp=current_timestamp,
)
session.add(new_desc)
await session.commit()
# 会在上下文管理器中自动调用
except Exception as e:
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {e!s}")
@staticmethod
async def get_emoji_tag(image_base64: str) -> str:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
emoji_manager = get_emoji_manager()
if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
emoji = await emoji_manager.get_emoji_from_manager(image_hash)
if not emoji:
return "[表情包:未知]"
emotion_list = emoji.emotion
tag_str = ",".join(emotion_list)
return f"[表情包:{tag_str}]"
async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述统一使用EmojiManager中的逻辑进行处理和缓存"""
try:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
emoji_manager = get_emoji_manager()
# 1. 计算图片哈希
if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 2. 优先查询已注册表情的缓存Emoji表
if full_description := await emoji_manager.get_emoji_description_by_hash(image_hash):
logger.info("[缓存命中] 使用已注册表情包(Emoji表)的完整描述")
refined_part = full_description.split(" Keywords:")[0]
return f"[表情包:{refined_part}]"
# 3. 查询通用图片描述缓存ImageDescriptions表
if cached_description := await self._get_description_from_db(image_hash, "emoji"):
logger.info("[缓存命中] 使用通用图片缓存(ImageDescriptions表)中的描述")
refined_part = cached_description.split(" Keywords:")[0]
return f"[表情包:{refined_part}]"
# 4. 如果都未命中,则调用新逻辑生成描述
logger.info(f"[新表情识别] 表情包未注册且无缓存 (Hash: {image_hash[:8]}...),调用新逻辑生成描述")
full_description, emotions = await emoji_manager.build_emoji_description(image_base64)
if not full_description:
logger.warning("未能通过新逻辑生成有效描述")
return "[表情包(描述生成失败)]"
# 4. (可选) 如果启用了“偷表情包”,则将图片和完整描述存入待注册区
if global_config.emoji.steal_emoji:
logger.debug(f"偷取表情包功能已开启,保存待注册表情包: {image_hash}")
try:
image_format = (Image.open(io.BytesIO(image_bytes)).format or "jpeg").lower()
current_timestamp = time.time()
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
emoji_dir = os.path.join(self.IMAGE_DIR, "emoji")
os.makedirs(emoji_dir, exist_ok=True)
file_path = os.path.join(emoji_dir, filename)
async with aiofiles.open(file_path, "wb") as f:
await f.write(image_bytes)
logger.info(f"新表情包已保存至待注册目录: {file_path}")
except Exception as e:
logger.error(f"保存待注册表情包文件失败: {e!s}")
# 5. 将新生成的完整描述存入通用缓存ImageDescriptions表
await self._save_description_to_db(image_hash, full_description, "emoji")
logger.info(f"新生成的表情包描述已存入通用缓存 (Hash: {image_hash[:8]}...)")
# 6. 返回新生成的描述中用于显示的“精炼描述”部分
refined_part = full_description.split(" Keywords:")[0]
return f"[表情包:{refined_part}]"
except Exception as e:
logger.error(f"获取表情包描述失败: {e!s}", exc_info=True)
return "[表情包(处理失败)]"
async def get_image_description(self, image_base64: str) -> str:
"""获取普通图片描述,采用同步识别+缓存策略"""
try:
# 1. 计算图片哈希
if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 2. 优先查询 Images 表缓存
async with get_db_session() as session:
result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
existing_image = result.scalar()
if existing_image and existing_image.description:
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
return f"[图片:{existing_image.description}]"
# 3. 其次查询 ImageDescriptions 表缓存
if cached_description := await self._get_description_from_db(image_hash, "image"):
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
return f"[图片:{cached_description}]"
# 4. 如果都未命中则同步调用VLM生成新描述
logger.info(f"[新图片识别] 无缓存 (Hash: {image_hash[:8]}...)调用VLM生成描述")
description = None
prompt = global_config.custom_prompt.image_prompt
logger.info(f"[识图VLM调用] Prompt: {prompt}")
for i in range(3): # 重试3次
try:
image_format = (Image.open(io.BytesIO(image_bytes)).format or "jpeg").lower()
logger.info(f"[VLM调用] 正在为图片生成描述 (第 {i+1}/3 次)...")
description, response_tuple = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
# response_tuple is (reasoning, model_name, tool_calls)
model_name_used = response_tuple[1]
logger.info(f"[VLM调用成功] 使用模型: {model_name_used}")
if description and description.strip():
break # 成功获取描述则跳出循环
except Exception as e:
logger.error(f"VLM调用失败 (第 {i+1}/3 次): {e}", exc_info=True)
if i < 2: # 如果不是最后一次则等待1秒
logger.warning("识图失败将在1秒后重试...")
await asyncio.sleep(1)
if not description or not description.strip():
logger.warning("VLM未能生成有效描述")
return "[图片(描述生成失败)]"
logger.info(f"[VLM完成] 图片描述生成: {description[:50]}...")
# 5. 将新描述存入两个缓存表
await self._save_description_to_db(image_hash, description, "image")
async with get_db_session() as session:
result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
existing_image_for_update = result.scalar()
if existing_image_for_update:
existing_image_for_update.description = description
existing_image_for_update.vlm_processed = True
logger.debug(f"[数据库] 为现有图片记录补充描述: {image_hash[:8]}...")
# 注意这里不创建新的Images记录因为process_image会负责创建
await session.commit()
logger.info(f"新生成的图片描述已存入缓存 (Hash: {image_hash[:8]}...)")
return f"[图片:{description}]"
except Exception as e:
logger.error(f"获取图片描述时发生严重错误: {e!s}", exc_info=True)
return "[图片(处理失败)]"
@staticmethod
def transform_gif(gif_base64: str) -> str | None:
# sourcery skip: use-contextlib-suppress
"""将GIF转换为水平拼接的静态图像, 均匀抽取4帧。
Args:
gif_base64: GIF的base64编码字符串
Returns:
Optional[str]: 拼接后的JPG图像的base64编码字符串, 或者在失败时返回None
"""
try:
# 确保base64字符串只包含ASCII字符
if isinstance(gif_base64, str):
gif_base64 = gif_base64.encode("ascii", errors="ignore").decode("ascii")
# 解码base64
gif_data = base64.b64decode(gif_base64)
gif = Image.open(io.BytesIO(gif_data))
# 收集所有帧
all_frames = []
try:
while True:
gif.seek(len(all_frames))
# 确保是RGB格式方便比较
frame = gif.convert("RGB")
all_frames.append(frame.copy())
except EOFError:
... # 读完啦
if not all_frames:
logger.warning("GIF中没有找到任何帧")
return None # 空的GIF直接返回None
# --- 新的帧选择逻辑均匀抽取4帧 ---
num_frames = len(all_frames)
if num_frames <= 4:
# 如果总帧数小于等于4则全部选中
selected_frames = all_frames
else:
# 使用linspace计算4个均匀分布的索引
indices = np.linspace(0, num_frames - 1, 4, dtype=int)
selected_frames = [all_frames[i] for i in indices]
logger.debug(f"GIF Frame Analysis: Total frames={num_frames}, Selected indices={indices if num_frames > 4 else list(range(num_frames))}")
# --- 帧选择逻辑结束 ---
# 如果选择后连一帧都没有比如GIF只有一帧且后续处理失败或者原始GIF就没帧也返回None
if not selected_frames:
logger.warning("处理后没有选中任何帧")
return None
# logger.debug(f"总帧数: {len(all_frames)}, 选中帧数: {len(selected_frames)}")
# 获取选中的第一帧的尺寸(假设所有帧尺寸一致)
frame_width, frame_height = selected_frames[0].size
# 计算目标尺寸,保持宽高比
target_height = 200 # 固定高度
# 防止除以零
if frame_height == 0:
logger.error("帧高度为0无法计算缩放尺寸")
return None
target_width = int((target_height / frame_height) * frame_width)
# 宽度也不能是0
if target_width == 0:
logger.warning(f"计算出的目标宽度为0 (原始尺寸 {frame_width}x{frame_height})调整为1")
target_width = 1
# 调整所有选中帧的大小
resized_frames = [
frame.resize((target_width, target_height), Image.Resampling.LANCZOS) for frame in selected_frames
]
# 创建拼接图像
total_width = target_width * len(resized_frames)
# 防止总宽度为0
if total_width == 0 and resized_frames:
logger.warning("计算出的总宽度为0但有选中帧可能目标宽度太小")
# 至少给点宽度吧
total_width = len(resized_frames)
elif total_width == 0:
logger.error("计算出的总宽度为0且无选中帧")
return None
combined_image = Image.new("RGB", (total_width, target_height))
# 水平拼接图像
for idx, frame in enumerate(resized_frames):
combined_image.paste(frame, (idx * target_width, 0))
# 转换为base64
buffer = io.BytesIO()
combined_image.save(buffer, format="JPEG", quality=85) # 保存为JPEG
return base64.b64encode(buffer.getvalue()).decode("utf-8")
except MemoryError:
logger.error("GIF转换失败: 内存不足可能是GIF太大或帧数太多")
return None # 内存不够啦
except Exception as e:
logger.error(f"GIF转换失败: {e!s}", exc_info=True) # 记录详细错误信息
return None # 其他错误也返回None
async def process_image(self, image_base64: str) -> tuple[str, str]:
"""处理图片并返回图片ID和描述采用同步识别流程"""
try:
if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
image_id = ""
description = ""
async with get_db_session() as session:
result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
existing_image = result.scalar()
if existing_image and existing_image.image_id:
image_id = existing_image.image_id
existing_image.count += 1
logger.debug(f"图片记录已存在 (ID: {image_id}),使用次数 +1")
if existing_image.description and existing_image.description.strip():
description = f"[图片:{existing_image.description}]"
logger.debug("缓存命中,直接返回数据库中已有的完整描述")
return image_id, description
else:
logger.warning(f"图片记录 (ID: {image_id}) 描述为空,将同步生成")
description = await self.get_image_description(image_base64)
existing_image.description = description.replace("[图片:", "").replace("]", "")
existing_image.vlm_processed = True
else:
logger.debug(f"新图片 (Hash: {image_hash[:8]}...),将同步生成描述并创建新记录")
image_id = str(uuid.uuid4())
description = await self.get_image_description(image_base64)
# 如果描述生成失败,则不存入数据库,直接返回失败信息
if "(处理失败)" in description or "(描述生成失败)" in description:
logger.warning("图片描述生成失败,不创建数据库记录,直接返回失败信息。")
return "", description
clean_description = description.replace("[图片:", "").replace("]", "")
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower()
filename = f"{image_id}.{image_format}"
image_dir = os.path.join(self.IMAGE_DIR, "images")
os.makedirs(image_dir, exist_ok=True)
file_path = os.path.join(image_dir, filename)
async with aiofiles.open(file_path, "wb") as f:
await f.write(image_bytes)
new_img = Images(
image_id=image_id,
emoji_hash=image_hash,
path=file_path,
type="image",
description=clean_description,
timestamp=time.time(),
vlm_processed=True,
count=1,
)
session.add(new_img)
logger.info(f"新图片记录已创建 (ID: {image_id})")
await session.commit()
# 无论是新图片还是旧图片,只要成功获取描述,就直接返回描述
return image_id, description
except Exception as e:
logger.error(f"处理图片时发生严重错误: {e!s}", exc_info=True)
return "", "[图片(处理失败)]"
# 创建全局单例
image_manager = None
def get_image_manager() -> ImageManager:
"""获取全局图片管理器单例"""
global image_manager
if image_manager is None:
image_manager = ImageManager()
return image_manager
def image_path_to_base64(image_path: str) -> str:
"""将图片路径转换为base64编码
Args:
image_path: 图片文件路径
Returns:
str: base64编码的图片数据
Raises:
FileNotFoundError: 当图片文件不存在时
IOError: 当读取图片文件失败时
"""
if not os.path.exists(image_path):
raise FileNotFoundError(f"图片文件不存在: {image_path}")
with open(image_path, "rb") as f:
if image_data := f.read():
return base64.b64encode(image_data).decode("utf-8")
else:
raise OSError(f"读取图片文件失败: {image_path}")