Refactor config system to use Pydantic validation

Refactored configuration classes to inherit from a new ValidatedConfigBase using Pydantic for robust validation and error reporting. Updated api_ada_configs.py, config.py, config_base.py, and official_configs.py to replace dataclasses with Pydantic models, add field validation, and improve error messages. This change enhances configuration reliability and developer feedback for misconfigurations. Also includes minor code cleanups and removal of unused variables in other modules.
This commit is contained in:
雅诺狐
2025-08-19 15:33:43 +08:00
committed by Windpicker-owo
parent 97ece6524c
commit bb4592846c
19 changed files with 717 additions and 1288 deletions

View File

@@ -1,174 +1,268 @@
from dataclasses import dataclass, field
from typing import List, Dict, Any
from pydantic import Field, field_validator
from .config_base import ConfigBase
from src.config.config_base import ValidatedConfigBase
@dataclass
class APIProvider(ConfigBase):
class APIProvider(ValidatedConfigBase):
"""API提供商配置类"""
name: str
"""API提供商名称"""
name: str = Field(..., min_length=1, description="API提供商名称")
base_url: str = Field(..., description="API基础URL")
api_key: str = Field(..., min_length=1, description="API密钥")
client_type: str = Field(default="openai", description="客户端类型如openai/google等默认为openai")
max_retry: int = Field(default=2, ge=0, description="最大重试次数单个模型API调用失败最多重试的次数")
timeout: int = Field(default=10, ge=1, description="API调用的超时时长超过这个时长本次请求将被视为'请求超时',单位:秒)")
retry_interval: int = Field(default=10, ge=0, description="重试间隔如果API调用失败重试的间隔时间单位")
enable_content_obfuscation: bool = Field(default=False, description="是否启用内容混淆(用于特定场景下的内容处理)")
obfuscation_intensity: int = Field(default=1, ge=1, le=3, description="混淆强度1-3级数值越高混淆程度越强")
base_url: str
"""API基础URL"""
@field_validator('base_url')
@classmethod
def validate_base_url(cls, v):
"""验证base_url确保URL格式正确"""
if v and not (v.startswith('http://') or v.startswith('https://')):
raise ValueError("base_url必须以http://或https://开头")
return v
api_key: str = field(default_factory=str, repr=False)
"""API密钥列表"""
@field_validator('api_key')
@classmethod
def validate_api_key(cls, v):
"""验证API密钥不能为空"""
if not v or not v.strip():
raise ValueError("API密钥不能为空")
return v
client_type: str = field(default="openai")
"""客户端类型如openai/google等默认为openai"""
max_retry: int = 2
"""最大重试次数单个模型API调用失败最多重试的次数"""
timeout: int = 10
"""API调用的超时时长超过这个时长本次请求将被视为"请求超时",单位:秒)"""
retry_interval: int = 10
"""重试间隔如果API调用失败重试的间隔时间单位"""
enable_content_obfuscation: bool = field(default=False)
"""是否启用内容混淆(用于特定场景下的内容处理)"""
obfuscation_intensity: int = field(default=1)
"""混淆强度1-3级数值越高混淆程度越强"""
@field_validator('client_type')
@classmethod
def validate_client_type(cls, v):
"""验证客户端类型"""
allowed_types = ["openai", "gemini"]
if v not in allowed_types:
raise ValueError(f"客户端类型必须是以下之一: {allowed_types}")
return v
def get_api_key(self) -> str:
return self.api_key
def __post_init__(self):
"""确保api_key在repr中不被显示"""
if not self.api_key:
raise ValueError("API密钥不能为空请在配置中设置有效的API密钥。")
if not self.base_url and self.client_type != "gemini":
raise ValueError("API基础URL不能为空请在配置中设置有效的基础URL。")
if not self.name:
raise ValueError("API提供商名称不能为空请在配置中设置有效的名称。")
@dataclass
class ModelInfo(ConfigBase):
class ModelInfo(ValidatedConfigBase):
"""单个模型信息配置类"""
model_identifier: str
"""模型标识符用于URL调用)"""
model_identifier: str = Field(..., min_length=1, description="模型标识符用于URL调用")
name: str = Field(..., min_length=1, description="模型名称(用于模块调用)")
api_provider: str = Field(..., min_length=1, description="API提供商如OpenAI、Azure等")
price_in: float = Field(default=0.0, ge=0, description="每M token输入价格")
price_out: float = Field(default=0.0, ge=0, description="每M token输出价格")
force_stream_mode: bool = Field(default=False, description="是否强制使用流式输出模式")
extra_params: Dict[str, Any] = Field(default_factory=dict, description="额外参数用于API调用时的额外配置")
name: str
"""模型名称(用于模块调用)"""
@field_validator('price_in', 'price_out')
@classmethod
def validate_prices(cls, v):
"""验证价格必须为非负数"""
if v < 0:
raise ValueError("价格不能为负数")
return v
api_provider: str
"""API提供商如OpenAI、Azure等"""
@field_validator('model_identifier')
@classmethod
def validate_model_identifier(cls, v):
"""验证模型标识符不能为空且不能包含特殊字符"""
if not v or not v.strip():
raise ValueError("模型标识符不能为空")
# 检查是否包含危险字符
if any(char in v for char in [' ', '\n', '\t', '\r']):
raise ValueError("模型标识符不能包含空格或换行符")
return v
price_in: float = field(default=0.0)
"""每M token输入价格"""
price_out: float = field(default=0.0)
"""每M token输出价格"""
force_stream_mode: bool = field(default=False)
"""是否强制使用流式输出模式"""
extra_params: dict = field(default_factory=dict)
"""额外参数用于API调用时的额外配置"""
def __post_init__(self):
if not self.model_identifier:
raise ValueError("模型标识符不能为空,请在配置中设置有效的模型标识符。")
if not self.name:
raise ValueError("模型名称不能为空,请在配置中设置有效的模型名称。")
if not self.api_provider:
raise ValueError("API提供商不能为空请在配置中设置有效的API提供商。")
@field_validator('name')
@classmethod
def validate_name(cls, v):
"""验证模型名称不能为空"""
if not v or not v.strip():
raise ValueError("模型名称不能为空")
return v
@dataclass
class TaskConfig(ConfigBase):
class TaskConfig(ValidatedConfigBase):
"""任务配置类"""
model_list: list[str] = field(default_factory=list)
"""任务使用的模型列表"""
model_list: List[str] = Field(default_factory=list, description="任务使用的模型列表")
max_tokens: int = Field(default=1024, ge=1, le=100000, description="任务最大输出token数")
temperature: float = Field(default=0.3, ge=0.0, le=2.0, description="模型温度")
concurrency_count: int = Field(default=1, ge=1, le=10, description="并发请求数量默认为1不并发")
max_tokens: int = 1024
"""任务最大输出token数"""
@field_validator('model_list')
@classmethod
def validate_model_list(cls, v):
"""验证模型列表不能为空"""
if not v:
raise ValueError("模型列表不能为空")
if len(v) != len(set(v)):
raise ValueError("模型列表中不能有重复的模型")
return v
temperature: float = 0.3
"""模型温度"""
concurrency_count: int = 1
"""并发请求数量默认为1不并发"""
@field_validator('max_tokens')
@classmethod
def validate_max_tokens(cls, v):
"""验证最大token数"""
if v <= 0:
raise ValueError("最大token数必须大于0")
if v > 100000:
raise ValueError("最大token数不能超过100000")
return v
@dataclass
class ModelTaskConfig(ConfigBase):
class ModelTaskConfig(ValidatedConfigBase):
"""模型配置类"""
utils: TaskConfig
"""组件模型配置"""
utils: TaskConfig = Field(..., description="组件模型配置")
# 可选配置项(有默认值)
utils_small: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["qwen3-8b"],
max_tokens=800,
temperature=0.7
),
description="组件小模型配置"
)
replyer_1: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["siliconflow-deepseek-v3"],
max_tokens=800,
temperature=0.2
),
description="normal_chat首要回复模型模型配置"
)
replyer_2: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["siliconflow-deepseek-v3"],
max_tokens=800,
temperature=0.7
),
description="normal_chat次要回复模型配置"
)
maizone: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["siliconflow-deepseek-v3"],
max_tokens=800,
temperature=0.3
),
description="maizone专用模型"
)
emotion: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["siliconflow-deepseek-v3"],
max_tokens=800,
temperature=0.7
),
description="情绪模型配置"
)
vlm: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["qwen2.5-vl-72b"],
max_tokens=1500,
temperature=0.3
),
description="视觉语言模型配置"
)
voice: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["siliconflow-deepseek-v3"],
max_tokens=800,
temperature=0.3
),
description="语音识别模型配置"
)
tool_use: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["siliconflow-deepseek-v3"],
max_tokens=800,
temperature=0.1
),
description="专注工具使用模型配置"
)
planner: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["siliconflow-deepseek-v3"],
max_tokens=800,
temperature=0.3
),
description="规划模型配置"
)
embedding: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["text-embedding-3-large"],
max_tokens=1024,
temperature=0.0
),
description="嵌入模型配置"
)
lpmm_entity_extract: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["siliconflow-deepseek-v3"],
max_tokens=2000,
temperature=0.1
),
description="LPMM实体提取模型配置"
)
lpmm_rdf_build: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["siliconflow-deepseek-v3"],
max_tokens=2000,
temperature=0.1
),
description="LPMM RDF构建模型配置"
)
lpmm_qa: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["siliconflow-deepseek-v3"],
max_tokens=2000,
temperature=0.3
),
description="LPMM问答模型配置"
)
schedule_generator: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["siliconflow-deepseek-v3"],
max_tokens=1500,
temperature=0.3
),
description="日程生成模型配置"
)
utils_small: TaskConfig
"""组件小模型配置"""
replyer: TaskConfig
"""normal_chat首要回复模型模型配置"""
replyer_2: TaskConfig
"""normal_chat次要回复模型配置"""
maizone : TaskConfig
"""maizone专用模型"""
emotion: TaskConfig
"""情绪模型配置"""
vlm: TaskConfig
"""视觉语言模型配置"""
voice: TaskConfig
"""语音识别模型配置"""
tool_use: TaskConfig
"""专注工具使用模型配置"""
planner: TaskConfig
"""规划模型配置"""
embedding: TaskConfig
"""嵌入模型配置"""
lpmm_entity_extract: TaskConfig
"""LPMM实体提取模型配置"""
lpmm_rdf_build: TaskConfig
"""LPMM RDF构建模型配置"""
lpmm_qa: TaskConfig
"""LPMM问答模型配置"""
schedule_generator: TaskConfig
"""日程生成模型配置"""
video_analysis: TaskConfig = field(default_factory=lambda: TaskConfig(
model_list=["qwen2.5-vl-72b"],
max_tokens=1500,
temperature=0.3
))
"""视频分析模型配置"""
emoji_vlm: TaskConfig = field(default_factory=lambda: TaskConfig(
model_list=["qwen2.5-vl-72b"],
max_tokens=800
))
"""表情包识别模型配置"""
anti_injection: TaskConfig = field(default_factory=lambda: TaskConfig(
model_list=["qwen2.5-vl-72b"],
max_tokens=200,
temperature=0.1
))
"""反注入检测专用模型配置"""
# 可选配置项(有默认值)
video_analysis: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["qwen2.5-vl-72b"],
max_tokens=1500,
temperature=0.3
),
description="视频分析模型配置"
)
emoji_vlm: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["qwen2.5-vl-72b"],
max_tokens=800
),
description="表情包识别模型配置"
)
anti_injection: TaskConfig = Field(
default_factory=lambda: TaskConfig(
model_list=["qwen2.5-vl-72b"],
max_tokens=200,
temperature=0.1
),
description="反注入检测专用模型配置"
)
def get_task(self, task_name: str) -> TaskConfig:
"""获取指定任务的配置"""
if hasattr(self, task_name):
return getattr(self, task_name)
config = getattr(self, task_name)
if config is None:
raise ValueError(f"任务 '{task_name}' 未配置")
return config
raise ValueError(f"任务 '{task_name}' 未找到对应的配置")