Use Literal types for config field validation

Replaced manual string validation with Python's Literal type for 'client_type' in APIProvider and 'search_strategy' in WebSearchConfig. This simplifies validation and improves type safety by restricting allowed values at the type level.
This commit is contained in:
雅诺狐
2025-08-20 19:38:37 +08:00
parent 921d07e30a
commit f959ca6bb2
2 changed files with 3 additions and 12 deletions

View File

@@ -1,4 +1,4 @@
from typing import List, Dict, Any from typing import List, Dict, Any, Literal
from pydantic import Field, field_validator from pydantic import Field, field_validator
from src.config.config_base import ValidatedConfigBase from src.config.config_base import ValidatedConfigBase
@@ -10,7 +10,7 @@ class APIProvider(ValidatedConfigBase):
name: str = Field(..., min_length=1, description="API提供商名称") name: str = Field(..., min_length=1, description="API提供商名称")
base_url: str = Field(..., description="API基础URL") base_url: str = Field(..., description="API基础URL")
api_key: str = Field(..., min_length=1, description="API密钥") api_key: str = Field(..., min_length=1, description="API密钥")
client_type: str = Field(default="openai", description="客户端类型如openai/google等默认为openai") client_type: Literal["openai", "gemini", "aiohttp_gemini"] = Field(default="openai", description="客户端类型如openai/google等默认为openai")
max_retry: int = Field(default=2, ge=0, description="最大重试次数单个模型API调用失败最多重试的次数") max_retry: int = Field(default=2, ge=0, description="最大重试次数单个模型API调用失败最多重试的次数")
timeout: int = Field(default=10, ge=1, description="API调用的超时时长超过这个时长本次请求将被视为'请求超时',单位:秒)") timeout: int = Field(default=10, ge=1, description="API调用的超时时长超过这个时长本次请求将被视为'请求超时',单位:秒)")
retry_interval: int = Field(default=10, ge=0, description="重试间隔如果API调用失败重试的间隔时间单位") retry_interval: int = Field(default=10, ge=0, description="重试间隔如果API调用失败重试的间隔时间单位")
@@ -33,15 +33,6 @@ class APIProvider(ValidatedConfigBase):
raise ValueError("API密钥不能为空") raise ValueError("API密钥不能为空")
return v return v
@field_validator('client_type')
@classmethod
def validate_client_type(cls, v):
"""验证客户端类型"""
allowed_types = ["openai", "gemini","aiohttp_gemini"]
if v not in allowed_types:
raise ValueError(f"客户端类型必须是以下之一: {allowed_types}")
return v
def get_api_key(self) -> str: def get_api_key(self) -> str:
return self.api_key return self.api_key

View File

@@ -618,7 +618,7 @@ class WebSearchConfig(ValidatedConfigBase):
enable_web_search_tool: bool = Field(default=True, description="启用网络搜索工具") enable_web_search_tool: bool = Field(default=True, description="启用网络搜索工具")
enable_url_tool: bool = Field(default=True, description="启用URL工具") enable_url_tool: bool = Field(default=True, description="启用URL工具")
enabled_engines: list[str] = Field(default_factory=lambda: ["ddg"], description="启用的搜索引擎") enabled_engines: list[str] = Field(default_factory=lambda: ["ddg"], description="启用的搜索引擎")
search_strategy: str = Field(default="single", description="搜索策略") search_strategy: Literal["fallback","single","parallel"] = Field(default="single", description="搜索策略")
class AntiPromptInjectionConfig(ValidatedConfigBase): class AntiPromptInjectionConfig(ValidatedConfigBase):