合并openai兼容,过ruff
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
from typing import Dict, List, Optional
|
||||
import strawberry
|
||||
from packaging.version import Version, InvalidVersion
|
||||
from packaging.specifiers import SpecifierSet, InvalidSpecifier
|
||||
from ..config.config import global_config
|
||||
import os
|
||||
|
||||
# from packaging.version import Version, InvalidVersion
|
||||
# from packaging.specifiers import SpecifierSet, InvalidSpecifier
|
||||
# from ..config.config import global_config
|
||||
# import os
|
||||
from packaging.version import Version
|
||||
|
||||
@strawberry.type
|
||||
class BotConfig:
|
||||
|
||||
@@ -79,8 +79,7 @@ class LLMRequest:
|
||||
"o3",
|
||||
"o3-2025-04-16",
|
||||
"o3-mini",
|
||||
"o3-mini-2025-01-31"
|
||||
"o4-mini",
|
||||
"o3-mini-2025-01-31o4-mini",
|
||||
"o4-mini-2025-04-16",
|
||||
]
|
||||
|
||||
@@ -806,10 +805,8 @@ class LLMRequest:
|
||||
) -> Union[Tuple[Dict[str, Any], int], Tuple[None, int]]:
|
||||
policy = request_content["policy"]
|
||||
payload = request_content["payload"]
|
||||
keep_request = False
|
||||
wait_time = 0.1
|
||||
if retry_count < policy["max_retries"] - 1:
|
||||
wait_time = policy["base_wait"] * (2**retry_count)
|
||||
if retry_count < policy["max_retries"] - 1:
|
||||
keep_request = True
|
||||
if isinstance(exception, RequestAbortException):
|
||||
response = exception.response
|
||||
@@ -989,9 +986,7 @@ class LLMRequest:
|
||||
# 复制一份参数,避免直接修改 self.params
|
||||
params_copy = await self._transform_parameters(self.params)
|
||||
if image_base64:
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": [
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
@@ -1002,17 +997,16 @@ class LLMRequest:
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": global_config.max_response_length,
|
||||
**params_copy,
|
||||
}
|
||||
]
|
||||
else:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": global_config.max_response_length,
|
||||
"messages": messages,
|
||||
**params_copy,
|
||||
}
|
||||
if "max_tokens" not in payload and "max_completion_tokens" not in payload:
|
||||
payload["max_tokens"] = global_config.max_response_length
|
||||
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
||||
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
|
||||
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
||||
@@ -1105,11 +1099,10 @@ class LLMRequest:
|
||||
|
||||
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
|
||||
"""异步方式根据输入的提示生成模型的响应"""
|
||||
# 构建请求体
|
||||
# 构建请求体,不硬编码max_tokens
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": global_config.max_response_length,
|
||||
**self.params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user