优化了代码规范
This commit is contained in:
@@ -155,11 +155,11 @@ class MessageRecv(Message):
|
||||
self.has_emoji = True
|
||||
self.is_emoji = True
|
||||
self.is_picid = False
|
||||
self.is_voice = False
|
||||
if isinstance(segment.data, str):
|
||||
return await get_image_manager().get_emoji_description(segment.data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
elif segment.type == "voice":
|
||||
self.has_picid = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = True
|
||||
@@ -169,11 +169,13 @@ class MessageRecv(Message):
|
||||
elif segment.type == "mention_bot":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
self.is_mentioned = float(segment.data) # type: ignore
|
||||
return ""
|
||||
elif segment.type == "priority_info":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = False
|
||||
if isinstance(segment.data, dict):
|
||||
# 处理优先级信息
|
||||
self.priority_mode = "priority"
|
||||
@@ -222,10 +224,12 @@ class MessageRecvS4U(MessageRecv):
|
||||
"""
|
||||
try:
|
||||
if segment.type == "text":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
return segment.data # type: ignore
|
||||
elif segment.type == "image":
|
||||
self.is_voice = False
|
||||
# 如果是base64图片数据
|
||||
if isinstance(segment.data, str):
|
||||
self.has_picid = True
|
||||
@@ -252,11 +256,13 @@ class MessageRecvS4U(MessageRecv):
|
||||
return await get_voice_text(segment.data)
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
elif segment.type == "mention_bot":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_mentioned = float(segment.data) # type: ignore
|
||||
return ""
|
||||
elif segment.type == "priority_info":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
if isinstance(segment.data, dict):
|
||||
@@ -271,6 +277,7 @@ class MessageRecvS4U(MessageRecv):
|
||||
"""
|
||||
return ""
|
||||
elif segment.type == "gift":
|
||||
self.is_voice = False
|
||||
self.is_gift = True
|
||||
# 解析gift_info,格式为"名称:数量"
|
||||
name, count = segment.data.split(":", 1) # type: ignore
|
||||
|
||||
@@ -1,17 +1,5 @@
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
import hashlib
|
||||
import uuid
|
||||
from typing import Optional, Tuple
|
||||
from PIL import Image
|
||||
import io
|
||||
import numpy as np
|
||||
import asyncio
|
||||
|
||||
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import Images, ImageDescriptions
|
||||
from src.config.config import global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
|
||||
@@ -310,7 +310,7 @@ class LLMRequest:
|
||||
prompt: prompt文本
|
||||
image_base64: 图片的base64编码
|
||||
image_format: 图片格式
|
||||
file_base64: 文件的二进制数据
|
||||
file_bytes: 文件的二进制数据
|
||||
file_format: 文件格式
|
||||
payload: 请求体数据
|
||||
retry_policy: 自定义重试策略
|
||||
@@ -335,23 +335,21 @@ class LLMRequest:
|
||||
if request_content["stream_mode"]:
|
||||
headers["Accept"] = "text/event-stream"
|
||||
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
|
||||
post_kwargs = {"headers": headers}
|
||||
#form-data数据上传方式不同
|
||||
if file_bytes:
|
||||
#form-data数据上传方式不同
|
||||
async with session.post(
|
||||
request_content["api_url"], headers=headers, data=request_content["payload"]
|
||||
) as response:
|
||||
handled_result = await self._handle_response(
|
||||
response, request_content, retry, response_handler, user_id, request_type, endpoint
|
||||
)
|
||||
return handled_result
|
||||
post_kwargs["data"] = request_content["payload"]
|
||||
else:
|
||||
async with session.post(
|
||||
request_content["api_url"], headers=headers, json=request_content["payload"]
|
||||
) as response:
|
||||
handled_result = await self._handle_response(
|
||||
response, request_content, retry, response_handler, user_id, request_type, endpoint
|
||||
)
|
||||
return handled_result
|
||||
post_kwargs["json"] = request_content["payload"]
|
||||
|
||||
async with session.post(
|
||||
request_content["api_url"], **post_kwargs
|
||||
) as response:
|
||||
handled_result = await self._handle_response(
|
||||
response, request_content, retry, response_handler, user_id, request_type, endpoint
|
||||
)
|
||||
return handled_result
|
||||
|
||||
except Exception as e:
|
||||
handled_payload, count_delta = await self._handle_exception(e, retry, request_content)
|
||||
retry += count_delta # 降级不计入重试次数
|
||||
@@ -666,7 +664,7 @@ class LLMRequest:
|
||||
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
|
||||
return new_params
|
||||
|
||||
async def _build_formdata_payload(self, file_bytes: str, file_format: str):
|
||||
async def _build_formdata_payload(self, file_bytes: str, file_format: str) -> aiohttp.FormData:
|
||||
"""构建form-data请求体"""
|
||||
# 目前只适配了音频文件
|
||||
# 如果后续要支持其他类型的文件,可以在这里添加更多的处理逻辑
|
||||
@@ -678,11 +676,15 @@ class LLMRequest:
|
||||
"flac": "audio/flac",
|
||||
"aac": "audio/aac",
|
||||
}
|
||||
|
||||
|
||||
content_type = content_type_list.get(file_format)
|
||||
if not content_type:
|
||||
logger.warning(f"暂不支持的文件类型: {file_format}")
|
||||
|
||||
data.add_field(
|
||||
"file",io.BytesIO(file_bytes),
|
||||
filename=f"file.{file_format}",
|
||||
content_type=f'audio/{content_type_list[file_format]}' # 根据实际文件类型设置
|
||||
content_type=f'{content_type_list[file_format]}' # 根据实际文件类型设置
|
||||
)
|
||||
data.add_field(
|
||||
"model", self.model_name
|
||||
|
||||
Reference in New Issue
Block a user