diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py index ff8359738..5f3398e0e 100644 --- a/src/config/api_ada_configs.py +++ b/src/config/api_ada_configs.py @@ -31,6 +31,15 @@ class APIProvider(ConfigBase): def get_api_key(self) -> str: return self.api_key + def __post_init__(self): + """确保api_key在repr中不被显示""" + if not self.api_key: + raise ValueError("API密钥不能为空,请在配置中设置有效的API密钥。") + if not self.base_url: + raise ValueError("API基础URL不能为空,请在配置中设置有效的基础URL。") + if not self.name: + raise ValueError("API提供商名称不能为空,请在配置中设置有效的名称。") + @dataclass class ModelInfo(ConfigBase): @@ -57,6 +66,14 @@ class ModelInfo(ConfigBase): extra_params: dict = field(default_factory=dict) """额外参数(用于API调用时的额外配置)""" + def __post_init__(self): + if not self.model_identifier: + raise ValueError("模型标识符不能为空,请在配置中设置有效的模型标识符。") + if not self.name: + raise ValueError("模型名称不能为空,请在配置中设置有效的模型名称。") + if not self.api_provider: + raise ValueError("API提供商不能为空,请在配置中设置有效的API提供商。") + @dataclass class TaskConfig(ConfigBase): diff --git a/src/config/config.py b/src/config/config.py index 868739436..1fee71a1e 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -364,6 +364,11 @@ class APIAdapterConfig(ConfigBase): """API提供商列表""" def __post_init__(self): + if not self.models: + raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。") + if not self.api_providers: + raise ValueError("API提供商列表不能为空,请在配置中设置有效的API提供商列表。") + # 检查API提供商名称是否重复 provider_names = [provider.name for provider in self.api_providers] if len(provider_names) != len(set(provider_names)): @@ -376,7 +381,7 @@ class APIAdapterConfig(ConfigBase): self.api_providers_dict = {provider.name: provider for provider in self.api_providers} self.models_dict = {model.name: model for model in self.models} - + for model in self.models: if not model.model_identifier: raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")