diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py index 5e53eec4b..cc25d0646 100644 --- a/src/config/api_ada_configs.py +++ b/src/config/api_ada_configs.py @@ -1,5 +1,6 @@ -from typing import List, Dict, Any, Literal +from typing import List, Dict, Any, Literal, Union from pydantic import Field, field_validator +from threading import Lock from src.config.config_base import ValidatedConfigBase @@ -9,7 +10,7 @@ class APIProvider(ValidatedConfigBase): name: str = Field(..., min_length=1, description="API提供商名称") base_url: str = Field(..., description="API基础URL") - api_key: str = Field(..., min_length=1, description="API密钥") + api_key: Union[str, List[str]] = Field(..., min_length=1, description="API密钥,支持单个密钥或密钥列表轮询") client_type: Literal["openai", "gemini", "aiohttp_gemini"] = Field( default="openai", description="客户端类型(如openai/google等,默认为openai)" ) @@ -33,12 +34,33 @@ class APIProvider(ValidatedConfigBase): @classmethod def validate_api_key(cls, v): """验证API密钥不能为空""" - if not v or not v.strip(): - raise ValueError("API密钥不能为空") + if isinstance(v, str): + if not v.strip(): + raise ValueError("API密钥不能为空") + elif isinstance(v, list): + if not v: + raise ValueError("API密钥列表不能为空") + for key in v: + if not isinstance(key, str) or not key.strip(): + raise ValueError("API密钥列表中的密钥不能为空") + else: + raise ValueError("API密钥必须是字符串或字符串列表") return v + def __init__(self, **data): + super().__init__(**data) + self._api_key_lock = Lock() + self._api_key_index = 0 + def get_api_key(self) -> str: - return self.api_key + with self._api_key_lock: + if isinstance(self.api_key, str): + return self.api_key + if not self.api_key: + raise ValueError("API密钥列表为空") + key = self.api_key[self._api_key_index] + self._api_key_index = (self._api_key_index + 1) % len(self.api_key) + return key class ModelInfo(ValidatedConfigBase): diff --git a/template/model_config_template.toml b/template/model_config_template.toml index fab3ee509..c5f2a2947 100644 --- a/template/model_config_template.toml +++ b/template/model_config_template.toml @@ -6,7 +6,7 @@ version = "1.3.1" [[api_providers]] # API服务提供商(可以配置多个) name = "DeepSeek" # API服务商名称(可随意命名,在models的api-provider中需使用这个命名) base_url = "https://api.deepseek.com/v1" # API服务商的BaseURL -api_key = "your-api-key-here" # API密钥(请替换为实际的API密钥) +api_key = ["your-api-key-here-1", "your-api-key-here-2"] # API密钥(支持单个密钥或密钥列表轮询) client_type = "openai" # 请求客户端(可选,默认值为"openai",使用gimini等Google系模型时请配置为"gemini") max_retry = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数) timeout = 30 # API请求超时时间(单位:秒) @@ -24,7 +24,7 @@ retry_interval = 10 [[api_providers]] # 特殊:Google的Gimini使用特殊API,与OpenAI格式不兼容,需要配置client为"aiohttp_gemini" name = "Google" base_url = "https://api.google.com/v1" -api_key = "your-google-api-key-1" +api_key = ["your-google-api-key-1", "your-google-api-key-2"] client_type = "aiohttp_gemini" # 官方的gemini客户端现在已经死了 max_retry = 2 timeout = 30