typing
This commit is contained in:
@@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Tuple, Union, Dict, Any
|
||||
from typing import Tuple, Union, Dict, Any, Callable
|
||||
import aiohttp
|
||||
from aiohttp.client import ClientResponse
|
||||
from src.common.logger import get_logger
|
||||
@@ -300,7 +300,7 @@ class LLMRequest:
|
||||
file_format: str = None,
|
||||
payload: dict = None,
|
||||
retry_policy: dict = None,
|
||||
response_handler: callable = None,
|
||||
response_handler: Callable = None,
|
||||
user_id: str = "system",
|
||||
request_type: str = None,
|
||||
):
|
||||
@@ -336,19 +336,17 @@ class LLMRequest:
|
||||
headers["Accept"] = "text/event-stream"
|
||||
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
|
||||
post_kwargs = {"headers": headers}
|
||||
#form-data数据上传方式不同
|
||||
# form-data数据上传方式不同
|
||||
if file_bytes:
|
||||
post_kwargs["data"] = request_content["payload"]
|
||||
else:
|
||||
post_kwargs["json"] = request_content["payload"]
|
||||
|
||||
async with session.post(
|
||||
request_content["api_url"], **post_kwargs
|
||||
) as response:
|
||||
async with session.post(request_content["api_url"], **post_kwargs) as response:
|
||||
handled_result = await self._handle_response(
|
||||
response, request_content, retry, response_handler, user_id, request_type, endpoint
|
||||
)
|
||||
return handled_result
|
||||
return handled_result
|
||||
|
||||
except Exception as e:
|
||||
handled_payload, count_delta = await self._handle_exception(e, retry, request_content)
|
||||
@@ -366,11 +364,11 @@ class LLMRequest:
|
||||
response: ClientResponse,
|
||||
request_content: Dict[str, Any],
|
||||
retry_count: int,
|
||||
response_handler: callable,
|
||||
response_handler: Callable,
|
||||
user_id,
|
||||
request_type,
|
||||
endpoint,
|
||||
) -> Union[Dict[str, Any], None]:
|
||||
):
|
||||
policy = request_content["policy"]
|
||||
stream_mode = request_content["stream_mode"]
|
||||
if response.status in policy["retry_codes"] or response.status in policy["abort_codes"]:
|
||||
@@ -477,9 +475,7 @@ class LLMRequest:
|
||||
}
|
||||
return result
|
||||
|
||||
async def _handle_error_response(
|
||||
self, response: ClientResponse, retry_count: int, policy: Dict[str, Any]
|
||||
) -> Union[Dict[str, any]]:
|
||||
async def _handle_error_response(self, response: ClientResponse, retry_count: int, policy: Dict[str, Any]):
|
||||
if response.status in policy["retry_codes"]:
|
||||
wait_time = policy["base_wait"] * (2**retry_count)
|
||||
logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试")
|
||||
@@ -629,7 +625,9 @@ class LLMRequest:
|
||||
)
|
||||
# 安全地检查和记录请求详情
|
||||
handled_payload = await _safely_record(request_content, payload)
|
||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}")
|
||||
logger.critical(
|
||||
f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
|
||||
)
|
||||
@@ -643,7 +641,9 @@ class LLMRequest:
|
||||
logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}")
|
||||
# 安全地检查和记录请求详情
|
||||
handled_payload = await _safely_record(request_content, payload)
|
||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}")
|
||||
logger.critical(
|
||||
f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}"
|
||||
)
|
||||
raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}")
|
||||
|
||||
async def _transform_parameters(self, params: dict) -> dict:
|
||||
@@ -682,15 +682,14 @@ class LLMRequest:
|
||||
logger.warning(f"暂不支持的文件类型: {file_format}")
|
||||
|
||||
data.add_field(
|
||||
"file",io.BytesIO(file_bytes),
|
||||
"file",
|
||||
io.BytesIO(file_bytes),
|
||||
filename=f"file.{file_format}",
|
||||
content_type=f'{content_type}' # 根据实际文件类型设置
|
||||
)
|
||||
data.add_field(
|
||||
"model", self.model_name
|
||||
content_type=f"{content_type}", # 根据实际文件类型设置
|
||||
)
|
||||
data.add_field("model", self.model_name)
|
||||
return data
|
||||
|
||||
|
||||
async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
|
||||
"""构建请求体"""
|
||||
# 复制一份参数,避免直接修改 self.params
|
||||
@@ -819,9 +818,11 @@ class LLMRequest:
|
||||
|
||||
async def generate_response_for_voice(self, voice_bytes: bytes) -> Tuple:
|
||||
"""根据输入的语音文件生成模型的异步响应"""
|
||||
response = await self._execute_request(endpoint="/audio/transcriptions",file_bytes=voice_bytes, file_format='wav')
|
||||
response = await self._execute_request(
|
||||
endpoint="/audio/transcriptions", file_bytes=voice_bytes, file_format="wav"
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
|
||||
"""异步方式根据输入的提示生成模型的响应"""
|
||||
# 构建请求体,不硬编码max_tokens
|
||||
|
||||
Reference in New Issue
Block a user