Merge pull request #251 from HYY1116/debug

fix: 尝试修复所有图片都被保存为jpg的问题,并以正确的格式请求识图api
This commit is contained in:
HYY
2025-03-12 17:05:43 +08:00
committed by GitHub
3 changed files with 27 additions and 17 deletions

View File

@@ -6,6 +6,8 @@ import random
import time
import traceback
from typing import Optional, Tuple
from PIL import Image
import io
from loguru import logger
from nonebot import get_driver
@@ -192,11 +194,11 @@ class EmojiManager:
logger.error(f"获取标签失败: {str(e)}")
return None
async def _check_emoji(self, image_base64: str) -> str:
async def _check_emoji(self, image_base64: str, image_format: str) -> str:
try:
prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容'
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64)
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
logger.debug(f"输出描述: {content}")
return content
@@ -237,7 +239,7 @@ class EmojiManager:
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 检查是否已经注册过
existing_emoji = self.db['emoji'].find_one({'filename': filename})
description = None
@@ -278,7 +280,7 @@ class EmojiManager:
if global_config.EMOJI_CHECK:
check = await self._check_emoji(image_base64)
check = await self._check_emoji(image_base64, image_format)
if '' not in check:
os.remove(image_path)
logger.info(f"描述: {description}")

View File

@@ -4,6 +4,8 @@ import time
import aiohttp
import hashlib
from typing import Optional, Union
from PIL import Image
import io
from loguru import logger
from nonebot import get_driver
@@ -119,6 +121,7 @@ class ImageManager:
# 计算哈希值
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 查重
existing = self.db.images.find_one({'hash': image_hash})
@@ -127,7 +130,7 @@ class ImageManager:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg"
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
file_path = os.path.join(self.IMAGE_DIR, filename)
# 保存文件
@@ -238,6 +241,7 @@ class ImageManager:
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, 'emoji')
@@ -247,13 +251,13 @@ class ImageManager:
# 调用AI获取描述
prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
# 根据配置决定是否保存图片
if global_config.EMOJI_SAVE:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg"
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
file_path = os.path.join(self.IMAGE_DIR, 'emoji',filename)
try:
@@ -293,6 +297,7 @@ class ImageManager:
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, 'image')
@@ -302,7 +307,7 @@ class ImageManager:
# 调用AI获取描述
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
print(f"描述是{description}")
@@ -314,7 +319,7 @@ class ImageManager:
if global_config.EMOJI_SAVE:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg"
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
file_path = os.path.join(self.IMAGE_DIR,'image', filename)
try:

View File

@@ -104,6 +104,7 @@ class LLM_request:
endpoint: str,
prompt: str = None,
image_base64: str = None,
image_format: str = None,
payload: dict = None,
retry_policy: dict = None,
response_handler: callable = None,
@@ -115,6 +116,7 @@ class LLM_request:
endpoint: API端点路径 (如 "chat/completions")
prompt: prompt文本
image_base64: 图片的base64编码
image_format: 图片格式
payload: 请求体数据
retry_policy: 自定义重试策略
response_handler: 自定义响应处理器
@@ -151,7 +153,7 @@ class LLM_request:
# 构建请求体
if image_base64:
payload = await self._build_payload(prompt, image_base64)
payload = await self._build_payload(prompt, image_base64, image_format)
elif payload is None:
payload = await self._build_payload(prompt)
@@ -172,7 +174,7 @@ class LLM_request:
if response.status == 413:
logger.warning("请求体过大,尝试压缩...")
image_base64 = compress_base64_image_by_scale(image_base64)
payload = await self._build_payload(prompt, image_base64)
payload = await self._build_payload(prompt, image_base64, image_format)
elif response.status in [500, 503]:
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
raise RuntimeError("服务器负载过高模型恢复失败QAQ")
@@ -294,7 +296,7 @@ class LLM_request:
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
return new_params
async def _build_payload(self, prompt: str, image_base64: str = None) -> dict:
async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
"""构建请求体"""
# 复制一份参数,避免直接修改 self.params
params_copy = await self._transform_parameters(self.params)
@@ -306,7 +308,7 @@ class LLM_request:
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
{"type": "image_url", "image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"}}
]
}
],
@@ -391,13 +393,14 @@ class LLM_request:
)
return content, reasoning_content
async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]:
async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple[str, str]:
"""根据输入的提示和图片生成模型的异步响应"""
content, reasoning_content = await self._execute_request(
endpoint="/chat/completions",
prompt=prompt,
image_base64=image_base64
image_base64=image_base64,
image_format=image_format
)
return content, reasoning_content