增加一些校验

This commit is contained in:
UnCLAS-Prommer
2025-07-31 22:32:02 +08:00
parent 303931e680
commit 9c818b78a2
2 changed files with 23 additions and 1 deletions

View File

@@ -31,6 +31,15 @@ class APIProvider(ConfigBase):
def get_api_key(self) -> str: def get_api_key(self) -> str:
return self.api_key return self.api_key
def __post_init__(self):
"""确保api_key在repr中不被显示"""
if not self.api_key:
raise ValueError("API密钥不能为空请在配置中设置有效的API密钥。")
if not self.base_url:
raise ValueError("API基础URL不能为空请在配置中设置有效的基础URL。")
if not self.name:
raise ValueError("API提供商名称不能为空请在配置中设置有效的名称。")
@dataclass @dataclass
class ModelInfo(ConfigBase): class ModelInfo(ConfigBase):
@@ -57,6 +66,14 @@ class ModelInfo(ConfigBase):
extra_params: dict = field(default_factory=dict) extra_params: dict = field(default_factory=dict)
"""额外参数用于API调用时的额外配置""" """额外参数用于API调用时的额外配置"""
def __post_init__(self):
if not self.model_identifier:
raise ValueError("模型标识符不能为空,请在配置中设置有效的模型标识符。")
if not self.name:
raise ValueError("模型名称不能为空,请在配置中设置有效的模型名称。")
if not self.api_provider:
raise ValueError("API提供商不能为空请在配置中设置有效的API提供商。")
@dataclass @dataclass
class TaskConfig(ConfigBase): class TaskConfig(ConfigBase):

View File

@@ -364,6 +364,11 @@ class APIAdapterConfig(ConfigBase):
"""API提供商列表""" """API提供商列表"""
def __post_init__(self): def __post_init__(self):
if not self.models:
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
if not self.api_providers:
raise ValueError("API提供商列表不能为空请在配置中设置有效的API提供商列表。")
# 检查API提供商名称是否重复 # 检查API提供商名称是否重复
provider_names = [provider.name for provider in self.api_providers] provider_names = [provider.name for provider in self.api_providers]
if len(provider_names) != len(set(provider_names)): if len(provider_names) != len(set(provider_names)):
@@ -376,7 +381,7 @@ class APIAdapterConfig(ConfigBase):
self.api_providers_dict = {provider.name: provider for provider in self.api_providers} self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
self.models_dict = {model.name: model for model in self.models} self.models_dict = {model.name: model for model in self.models}
for model in self.models: for model in self.models:
if not model.model_identifier: if not model.model_identifier:
raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空") raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")