修复代码格式和文件名大小写问题
This commit is contained in:
@@ -10,22 +10,26 @@ class APIProvider(ValidatedConfigBase):
|
||||
name: str = Field(..., min_length=1, description="API提供商名称")
|
||||
base_url: str = Field(..., description="API基础URL")
|
||||
api_key: str = Field(..., min_length=1, description="API密钥")
|
||||
client_type: Literal["openai", "gemini", "aiohttp_gemini"] = Field(default="openai", description="客户端类型(如openai/google等,默认为openai)")
|
||||
client_type: Literal["openai", "gemini", "aiohttp_gemini"] = Field(
|
||||
default="openai", description="客户端类型(如openai/google等,默认为openai)"
|
||||
)
|
||||
max_retry: int = Field(default=2, ge=0, description="最大重试次数(单个模型API调用失败,最多重试的次数)")
|
||||
timeout: int = Field(default=10, ge=1, description="API调用的超时时长(超过这个时长,本次请求将被视为'请求超时',单位:秒)")
|
||||
timeout: int = Field(
|
||||
default=10, ge=1, description="API调用的超时时长(超过这个时长,本次请求将被视为'请求超时',单位:秒)"
|
||||
)
|
||||
retry_interval: int = Field(default=10, ge=0, description="重试间隔(如果API调用失败,重试的间隔时间,单位:秒)")
|
||||
enable_content_obfuscation: bool = Field(default=False, description="是否启用内容混淆(用于特定场景下的内容处理)")
|
||||
obfuscation_intensity: int = Field(default=1, ge=1, le=3, description="混淆强度(1-3级,数值越高混淆程度越强)")
|
||||
|
||||
@field_validator('base_url')
|
||||
@field_validator("base_url")
|
||||
@classmethod
|
||||
def validate_base_url(cls, v):
|
||||
"""验证base_url,确保URL格式正确"""
|
||||
if v and not (v.startswith('http://') or v.startswith('https://')):
|
||||
if v and not (v.startswith("http://") or v.startswith("https://")):
|
||||
raise ValueError("base_url必须以http://或https://开头")
|
||||
return v
|
||||
|
||||
@field_validator('api_key')
|
||||
@field_validator("api_key")
|
||||
@classmethod
|
||||
def validate_api_key(cls, v):
|
||||
"""验证API密钥不能为空"""
|
||||
@@ -49,7 +53,7 @@ class ModelInfo(ValidatedConfigBase):
|
||||
extra_params: Dict[str, Any] = Field(default_factory=dict, description="额外参数(用于API调用时的额外配置)")
|
||||
anti_truncation: bool = Field(default=False, description="是否启用反截断功能,防止模型输出被截断")
|
||||
|
||||
@field_validator('price_in', 'price_out')
|
||||
@field_validator("price_in", "price_out")
|
||||
@classmethod
|
||||
def validate_prices(cls, v):
|
||||
"""验证价格必须为非负数"""
|
||||
@@ -57,18 +61,18 @@ class ModelInfo(ValidatedConfigBase):
|
||||
raise ValueError("价格不能为负数")
|
||||
return v
|
||||
|
||||
@field_validator('model_identifier')
|
||||
@field_validator("model_identifier")
|
||||
@classmethod
|
||||
def validate_model_identifier(cls, v):
|
||||
"""验证模型标识符不能为空且不能包含特殊字符"""
|
||||
if not v or not v.strip():
|
||||
raise ValueError("模型标识符不能为空")
|
||||
# 检查是否包含危险字符
|
||||
if any(char in v for char in [' ', '\n', '\t', '\r']):
|
||||
if any(char in v for char in [" ", "\n", "\t", "\r"]):
|
||||
raise ValueError("模型标识符不能包含空格或换行符")
|
||||
return v
|
||||
|
||||
@field_validator('name')
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v):
|
||||
"""验证模型名称不能为空"""
|
||||
@@ -85,7 +89,7 @@ class TaskConfig(ValidatedConfigBase):
|
||||
temperature: float = Field(default=0.7, description="模型温度")
|
||||
concurrency_count: int = Field(default=1, description="并发请求数量")
|
||||
|
||||
@field_validator('model_list')
|
||||
@field_validator("model_list")
|
||||
@classmethod
|
||||
def validate_model_list(cls, v):
|
||||
"""验证模型列表不能为空"""
|
||||
@@ -118,7 +122,7 @@ class ModelTaskConfig(ValidatedConfigBase):
|
||||
monthly_plan_generator: TaskConfig = Field(..., description="月层计划生成模型配置")
|
||||
emoji_vlm: TaskConfig = Field(..., description="表情包识别模型配置")
|
||||
anti_injection: TaskConfig = Field(..., description="反注入检测专用模型配置")
|
||||
|
||||
|
||||
# 处理配置文件中命名不一致的问题
|
||||
utils_video: TaskConfig = Field(..., description="视频分析模型配置(兼容配置文件中的命名)")
|
||||
|
||||
@@ -132,7 +136,7 @@ class ModelTaskConfig(ValidatedConfigBase):
|
||||
# 处理向后兼容性:如果请求video_analysis,返回utils_video
|
||||
if task_name == "video_analysis":
|
||||
task_name = "utils_video"
|
||||
|
||||
|
||||
if hasattr(self, task_name):
|
||||
config = getattr(self, task_name)
|
||||
if config is None:
|
||||
@@ -153,37 +157,37 @@ class APIAdapterConfig(ValidatedConfigBase):
|
||||
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
|
||||
self.models_dict = {model.name: model for model in self.models}
|
||||
|
||||
@field_validator('models')
|
||||
@field_validator("models")
|
||||
@classmethod
|
||||
def validate_models_list(cls, v):
|
||||
"""验证模型列表"""
|
||||
if not v:
|
||||
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
|
||||
|
||||
|
||||
# 检查模型名称是否重复
|
||||
model_names = [model.name for model in v]
|
||||
if len(model_names) != len(set(model_names)):
|
||||
raise ValueError("模型名称存在重复,请检查配置文件。")
|
||||
|
||||
|
||||
# 检查模型标识符是否有效
|
||||
for model in v:
|
||||
if not model.model_identifier:
|
||||
raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")
|
||||
|
||||
|
||||
return v
|
||||
|
||||
@field_validator('api_providers')
|
||||
@field_validator("api_providers")
|
||||
@classmethod
|
||||
def validate_api_providers_list(cls, v):
|
||||
"""验证API提供商列表"""
|
||||
if not v:
|
||||
raise ValueError("API提供商列表不能为空,请在配置中设置有效的API提供商列表。")
|
||||
|
||||
|
||||
# 检查API提供商名称是否重复
|
||||
provider_names = [provider.name for provider in v]
|
||||
if len(provider_names) != len(set(provider_names)):
|
||||
raise ValueError("API提供商名称存在重复,请检查配置文件。")
|
||||
|
||||
|
||||
return v
|
||||
|
||||
def get_model_info(self, model_name: str) -> ModelInfo:
|
||||
|
||||
Reference in New Issue
Block a user