合并openai兼容,过ruff

This commit is contained in:
UnCLAS-Prommer
2025-04-21 15:25:29 +08:00
parent 6cb317123d
commit 7d2f5b51a7
2 changed files with 28 additions and 35 deletions

View File

@@ -1,10 +1,10 @@
from typing import Dict, List, Optional from typing import Dict, List, Optional
import strawberry import strawberry
from packaging.version import Version, InvalidVersion # from packaging.version import Version, InvalidVersion
from packaging.specifiers import SpecifierSet, InvalidSpecifier # from packaging.specifiers import SpecifierSet, InvalidSpecifier
from ..config.config import global_config # from ..config.config import global_config
import os # import os
from packaging.version import Version
@strawberry.type @strawberry.type
class BotConfig: class BotConfig:

View File

@@ -79,8 +79,7 @@ class LLMRequest:
"o3", "o3",
"o3-2025-04-16", "o3-2025-04-16",
"o3-mini", "o3-mini",
"o3-mini-2025-01-31" "o3-mini-2025-01-31o4-mini",
"o4-mini",
"o4-mini-2025-04-16", "o4-mini-2025-04-16",
] ]
@@ -806,10 +805,8 @@ class LLMRequest:
) -> Union[Tuple[Dict[str, Any], int], Tuple[None, int]]: ) -> Union[Tuple[Dict[str, Any], int], Tuple[None, int]]:
policy = request_content["policy"] policy = request_content["policy"]
payload = request_content["payload"] payload = request_content["payload"]
keep_request = False wait_time = policy["base_wait"] * (2**retry_count)
wait_time = 0.1
if retry_count < policy["max_retries"] - 1: if retry_count < policy["max_retries"] - 1:
wait_time = policy["base_wait"] * (2**retry_count)
keep_request = True keep_request = True
if isinstance(exception, RequestAbortException): if isinstance(exception, RequestAbortException):
response = exception.response response = exception.response
@@ -989,30 +986,27 @@ class LLMRequest:
# 复制一份参数,避免直接修改 self.params # 复制一份参数,避免直接修改 self.params
params_copy = await self._transform_parameters(self.params) params_copy = await self._transform_parameters(self.params)
if image_base64: if image_base64:
payload = { messages = [
"model": self.model_name, {
"messages": [ "role": "user",
{ "content": [
"role": "user", {"type": "text", "text": prompt},
"content": [ {
{"type": "text", "text": prompt}, "type": "image_url",
{ "image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"},
"type": "image_url", },
"image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"}, ],
}, }
], ]
}
],
"max_tokens": global_config.max_response_length,
**params_copy,
}
else: else:
payload = { messages = [{"role": "user", "content": prompt}]
"model": self.model_name, payload = {
"messages": [{"role": "user", "content": prompt}], "model": self.model_name,
"max_tokens": global_config.max_response_length, "messages": messages,
**params_copy, **params_copy,
} }
if "max_tokens" not in payload and "max_completion_tokens" not in payload:
payload["max_tokens"] = global_config.max_response_length
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens") payload["max_completion_tokens"] = payload.pop("max_tokens")
@@ -1105,11 +1099,10 @@ class LLMRequest:
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]: async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
"""异步方式根据输入的提示生成模型的响应""" """异步方式根据输入的提示生成模型的响应"""
# 构建请求体 # 构建请求体不硬编码max_tokens
data = { data = {
"model": self.model_name, "model": self.model_name,
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
"max_tokens": global_config.max_response_length,
**self.params, **self.params,
**kwargs, **kwargs,
} }