fix: 尝试修复所有图片都被保存为jpg的问题,并以正确的格式请求识图api

This commit is contained in:
HYY1116
2025-03-12 09:53:01 +08:00
parent b934d473ab
commit 1840599156
2 changed files with 21 additions and 13 deletions

View File

@@ -4,6 +4,8 @@ import time
import aiohttp import aiohttp
import hashlib import hashlib
from typing import Optional, Union from typing import Optional, Union
from PIL import Image
import io
from loguru import logger from loguru import logger
from nonebot import get_driver from nonebot import get_driver
@@ -119,6 +121,7 @@ class ImageManager:
# 计算哈希值 # 计算哈希值
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
img_format = Image.open(io.BytesIO(image_bytes)).format()
# 查重 # 查重
existing = self.db.images.find_one({'hash': image_hash}) existing = self.db.images.find_one({'hash': image_hash})
@@ -127,7 +130,7 @@ class ImageManager:
# 生成文件名和路径 # 生成文件名和路径
timestamp = int(time.time()) timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg" filename = f"{timestamp}_{image_hash[:8]}.{img_format}"
file_path = os.path.join(self.IMAGE_DIR, filename) file_path = os.path.join(self.IMAGE_DIR, filename)
# 保存文件 # 保存文件
@@ -238,7 +241,8 @@ class ImageManager:
# 计算图片哈希 # 计算图片哈希
image_bytes = base64.b64decode(image_base64) image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format
# 查询缓存的描述 # 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, 'emoji') cached_description = self._get_description_from_db(image_hash, 'emoji')
if cached_description: if cached_description:
@@ -247,13 +251,13 @@ class ImageManager:
# 调用AI获取描述 # 调用AI获取描述
prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感" prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
# 根据配置决定是否保存图片 # 根据配置决定是否保存图片
if global_config.EMOJI_SAVE: if global_config.EMOJI_SAVE:
# 生成文件名和路径 # 生成文件名和路径
timestamp = int(time.time()) timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg" filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
file_path = os.path.join(self.IMAGE_DIR, 'emoji',filename) file_path = os.path.join(self.IMAGE_DIR, 'emoji',filename)
try: try:
@@ -292,7 +296,8 @@ class ImageManager:
# 计算图片哈希 # 计算图片哈希
image_bytes = base64.b64decode(image_base64) image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format
# 查询缓存的描述 # 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, 'image') cached_description = self._get_description_from_db(image_hash, 'image')
if cached_description: if cached_description:
@@ -300,7 +305,7 @@ class ImageManager:
# 调用AI获取描述 # 调用AI获取描述
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。" prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
if description is None: if description is None:
logger.warning("AI未能生成图片描述") logger.warning("AI未能生成图片描述")
@@ -310,7 +315,7 @@ class ImageManager:
if global_config.EMOJI_SAVE: if global_config.EMOJI_SAVE:
# 生成文件名和路径 # 生成文件名和路径
timestamp = int(time.time()) timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg" filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
file_path = os.path.join(self.IMAGE_DIR,'image', filename) file_path = os.path.join(self.IMAGE_DIR,'image', filename)
try: try:

View File

@@ -104,6 +104,7 @@ class LLM_request:
endpoint: str, endpoint: str,
prompt: str = None, prompt: str = None,
image_base64: str = None, image_base64: str = None,
image_format: str = None,
payload: dict = None, payload: dict = None,
retry_policy: dict = None, retry_policy: dict = None,
response_handler: callable = None, response_handler: callable = None,
@@ -115,6 +116,7 @@ class LLM_request:
endpoint: API端点路径 (如 "chat/completions") endpoint: API端点路径 (如 "chat/completions")
prompt: prompt文本 prompt: prompt文本
image_base64: 图片的base64编码 image_base64: 图片的base64编码
image_format: 图片格式
payload: 请求体数据 payload: 请求体数据
retry_policy: 自定义重试策略 retry_policy: 自定义重试策略
response_handler: 自定义响应处理器 response_handler: 自定义响应处理器
@@ -151,7 +153,7 @@ class LLM_request:
# 构建请求体 # 构建请求体
if image_base64: if image_base64:
payload = await self._build_payload(prompt, image_base64) payload = await self._build_payload(prompt, image_base64, image_format)
elif payload is None: elif payload is None:
payload = await self._build_payload(prompt) payload = await self._build_payload(prompt)
@@ -172,7 +174,7 @@ class LLM_request:
if response.status == 413: if response.status == 413:
logger.warning("请求体过大,尝试压缩...") logger.warning("请求体过大,尝试压缩...")
image_base64 = compress_base64_image_by_scale(image_base64) image_base64 = compress_base64_image_by_scale(image_base64)
payload = await self._build_payload(prompt, image_base64) payload = await self._build_payload(prompt, image_base64, image_format)
elif response.status in [500, 503]: elif response.status in [500, 503]:
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
raise RuntimeError("服务器负载过高模型恢复失败QAQ") raise RuntimeError("服务器负载过高模型恢复失败QAQ")
@@ -294,7 +296,7 @@ class LLM_request:
new_params["max_completion_tokens"] = new_params.pop("max_tokens") new_params["max_completion_tokens"] = new_params.pop("max_tokens")
return new_params return new_params
async def _build_payload(self, prompt: str, image_base64: str = None) -> dict: async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
"""构建请求体""" """构建请求体"""
# 复制一份参数,避免直接修改 self.params # 复制一份参数,避免直接修改 self.params
params_copy = await self._transform_parameters(self.params) params_copy = await self._transform_parameters(self.params)
@@ -306,7 +308,7 @@ class LLM_request:
"role": "user", "role": "user",
"content": [ "content": [
{"type": "text", "text": prompt}, {"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}} {"type": "image_url", "image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"}}
] ]
} }
], ],
@@ -391,13 +393,14 @@ class LLM_request:
) )
return content, reasoning_content return content, reasoning_content
async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]: async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple[str, str, str]:
"""根据输入的提示和图片生成模型的异步响应""" """根据输入的提示和图片生成模型的异步响应"""
content, reasoning_content = await self._execute_request( content, reasoning_content = await self._execute_request(
endpoint="/chat/completions", endpoint="/chat/completions",
prompt=prompt, prompt=prompt,
image_base64=image_base64 image_base64=image_base64,
image_format=image_format
) )
return content, reasoning_content return content, reasoning_content