Merge pull request #251 from HYY1116/debug

fix: 尝试修复所有图片都被保存为jpg的问题,并以正确的格式请求识图api
This commit is contained in:
HYY
2025-03-12 17:05:43 +08:00
committed by GitHub
3 changed files with 27 additions and 17 deletions

View File

@@ -6,6 +6,8 @@ import random
import time import time
import traceback import traceback
from typing import Optional, Tuple from typing import Optional, Tuple
from PIL import Image
import io
from loguru import logger from loguru import logger
from nonebot import get_driver from nonebot import get_driver
@@ -192,11 +194,11 @@ class EmojiManager:
logger.error(f"获取标签失败: {str(e)}") logger.error(f"获取标签失败: {str(e)}")
return None return None
async def _check_emoji(self, image_base64: str) -> str: async def _check_emoji(self, image_base64: str, image_format: str) -> str:
try: try:
prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容' prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容'
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64) content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
logger.debug(f"输出描述: {content}") logger.debug(f"输出描述: {content}")
return content return content
@@ -237,7 +239,7 @@ class EmojiManager:
image_bytes = base64.b64decode(image_base64) image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 检查是否已经注册过 # 检查是否已经注册过
existing_emoji = self.db['emoji'].find_one({'filename': filename}) existing_emoji = self.db['emoji'].find_one({'filename': filename})
description = None description = None
@@ -278,7 +280,7 @@ class EmojiManager:
if global_config.EMOJI_CHECK: if global_config.EMOJI_CHECK:
check = await self._check_emoji(image_base64) check = await self._check_emoji(image_base64, image_format)
if '' not in check: if '' not in check:
os.remove(image_path) os.remove(image_path)
logger.info(f"描述: {description}") logger.info(f"描述: {description}")

View File

@@ -4,6 +4,8 @@ import time
import aiohttp import aiohttp
import hashlib import hashlib
from typing import Optional, Union from typing import Optional, Union
from PIL import Image
import io
from loguru import logger from loguru import logger
from nonebot import get_driver from nonebot import get_driver
@@ -119,6 +121,7 @@ class ImageManager:
# 计算哈希值 # 计算哈希值
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 查重 # 查重
existing = self.db.images.find_one({'hash': image_hash}) existing = self.db.images.find_one({'hash': image_hash})
@@ -127,7 +130,7 @@ class ImageManager:
# 生成文件名和路径 # 生成文件名和路径
timestamp = int(time.time()) timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg" filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
file_path = os.path.join(self.IMAGE_DIR, filename) file_path = os.path.join(self.IMAGE_DIR, filename)
# 保存文件 # 保存文件
@@ -238,6 +241,7 @@ class ImageManager:
# 计算图片哈希 # 计算图片哈希
image_bytes = base64.b64decode(image_base64) image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 查询缓存的描述 # 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, 'emoji') cached_description = self._get_description_from_db(image_hash, 'emoji')
@@ -247,13 +251,13 @@ class ImageManager:
# 调用AI获取描述 # 调用AI获取描述
prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感" prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
# 根据配置决定是否保存图片 # 根据配置决定是否保存图片
if global_config.EMOJI_SAVE: if global_config.EMOJI_SAVE:
# 生成文件名和路径 # 生成文件名和路径
timestamp = int(time.time()) timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg" filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
file_path = os.path.join(self.IMAGE_DIR, 'emoji',filename) file_path = os.path.join(self.IMAGE_DIR, 'emoji',filename)
try: try:
@@ -293,6 +297,7 @@ class ImageManager:
# 计算图片哈希 # 计算图片哈希
image_bytes = base64.b64decode(image_base64) image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 查询缓存的描述 # 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, 'image') cached_description = self._get_description_from_db(image_hash, 'image')
@@ -302,7 +307,7 @@ class ImageManager:
# 调用AI获取描述 # 调用AI获取描述
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。" prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
print(f"描述是{description}") print(f"描述是{description}")
@@ -314,7 +319,7 @@ class ImageManager:
if global_config.EMOJI_SAVE: if global_config.EMOJI_SAVE:
# 生成文件名和路径 # 生成文件名和路径
timestamp = int(time.time()) timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg" filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
file_path = os.path.join(self.IMAGE_DIR,'image', filename) file_path = os.path.join(self.IMAGE_DIR,'image', filename)
try: try:

View File

@@ -104,6 +104,7 @@ class LLM_request:
endpoint: str, endpoint: str,
prompt: str = None, prompt: str = None,
image_base64: str = None, image_base64: str = None,
image_format: str = None,
payload: dict = None, payload: dict = None,
retry_policy: dict = None, retry_policy: dict = None,
response_handler: callable = None, response_handler: callable = None,
@@ -115,6 +116,7 @@ class LLM_request:
endpoint: API端点路径 (如 "chat/completions") endpoint: API端点路径 (如 "chat/completions")
prompt: prompt文本 prompt: prompt文本
image_base64: 图片的base64编码 image_base64: 图片的base64编码
image_format: 图片格式
payload: 请求体数据 payload: 请求体数据
retry_policy: 自定义重试策略 retry_policy: 自定义重试策略
response_handler: 自定义响应处理器 response_handler: 自定义响应处理器
@@ -151,7 +153,7 @@ class LLM_request:
# 构建请求体 # 构建请求体
if image_base64: if image_base64:
payload = await self._build_payload(prompt, image_base64) payload = await self._build_payload(prompt, image_base64, image_format)
elif payload is None: elif payload is None:
payload = await self._build_payload(prompt) payload = await self._build_payload(prompt)
@@ -172,7 +174,7 @@ class LLM_request:
if response.status == 413: if response.status == 413:
logger.warning("请求体过大,尝试压缩...") logger.warning("请求体过大,尝试压缩...")
image_base64 = compress_base64_image_by_scale(image_base64) image_base64 = compress_base64_image_by_scale(image_base64)
payload = await self._build_payload(prompt, image_base64) payload = await self._build_payload(prompt, image_base64, image_format)
elif response.status in [500, 503]: elif response.status in [500, 503]:
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
raise RuntimeError("服务器负载过高模型恢复失败QAQ") raise RuntimeError("服务器负载过高模型恢复失败QAQ")
@@ -294,7 +296,7 @@ class LLM_request:
new_params["max_completion_tokens"] = new_params.pop("max_tokens") new_params["max_completion_tokens"] = new_params.pop("max_tokens")
return new_params return new_params
async def _build_payload(self, prompt: str, image_base64: str = None) -> dict: async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
"""构建请求体""" """构建请求体"""
# 复制一份参数,避免直接修改 self.params # 复制一份参数,避免直接修改 self.params
params_copy = await self._transform_parameters(self.params) params_copy = await self._transform_parameters(self.params)
@@ -306,7 +308,7 @@ class LLM_request:
"role": "user", "role": "user",
"content": [ "content": [
{"type": "text", "text": prompt}, {"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}} {"type": "image_url", "image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"}}
] ]
} }
], ],
@@ -391,13 +393,14 @@ class LLM_request:
) )
return content, reasoning_content return content, reasoning_content
async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]: async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple[str, str]:
"""根据输入的提示和图片生成模型的异步响应""" """根据输入的提示和图片生成模型的异步响应"""
content, reasoning_content = await self._execute_request( content, reasoning_content = await self._execute_request(
endpoint="/chat/completions", endpoint="/chat/completions",
prompt=prompt, prompt=prompt,
image_base64=image_base64 image_base64=image_base64,
image_format=image_format
) )
return content, reasoning_content return content, reasoning_content