Refactor config system to use Pydantic validation
Refactored configuration classes to inherit from a new ValidatedConfigBase using Pydantic for robust validation and error reporting. Updated api_ada_configs.py, config.py, config_base.py, and official_configs.py to replace dataclasses with Pydantic models, add field validation, and improve error messages. This change enhances configuration reliability and developer feedback for misconfigurations. Also includes minor code cleanups and removal of unused variables in other modules.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from dataclasses import dataclass, fields, MISSING
|
||||
from typing import TypeVar, Type, Any, get_origin, get_args, Literal
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
T = TypeVar("T", bound="ConfigBase")
|
||||
|
||||
@@ -133,3 +134,99 @@ class ConfigBase:
|
||||
def __str__(self):
|
||||
"""返回配置类的字符串表示"""
|
||||
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"
|
||||
|
||||
class ValidatedConfigBase(BaseModel):
|
||||
"""带验证的配置基类,继承自Pydantic BaseModel"""
|
||||
|
||||
model_config = {
|
||||
"extra": "allow", # 允许额外字段
|
||||
"validate_assignment": True, # 验证赋值
|
||||
"arbitrary_types_allowed": True, # 允许任意类型
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
"""兼容原有的from_dict方法,增强错误信息"""
|
||||
try:
|
||||
return cls.model_validate(data)
|
||||
except ValidationError as e:
|
||||
enhanced_message = cls._create_enhanced_error_message(e, data)
|
||||
|
||||
raise ValueError(enhanced_message) from e
|
||||
|
||||
@classmethod
|
||||
def _create_enhanced_error_message(cls, e: ValidationError, data: dict) -> str:
|
||||
"""创建增强的错误信息"""
|
||||
enhanced_messages = []
|
||||
|
||||
for error in e.errors():
|
||||
error_type = error.get('type', '')
|
||||
field_path = error.get('loc', ())
|
||||
input_value = error.get('input')
|
||||
|
||||
# 构建字段路径字符串
|
||||
field_path_str = '.'.join(str(p) for p in field_path)
|
||||
|
||||
# 处理字符串类型错误
|
||||
if error_type == 'string_type' and len(field_path) >= 2:
|
||||
parent_field = field_path[0]
|
||||
element_index = field_path[1]
|
||||
|
||||
# 尝试获取父字段的类型信息
|
||||
parent_field_info = cls.model_fields.get(parent_field)
|
||||
|
||||
if parent_field_info and hasattr(parent_field_info, 'annotation'):
|
||||
expected_type = parent_field_info.annotation
|
||||
|
||||
# 获取实际的父字段值
|
||||
actual_parent_value = data.get(parent_field)
|
||||
|
||||
# 检查是否是列表类型错误
|
||||
if get_origin(expected_type) is list and isinstance(actual_parent_value, list):
|
||||
list_element_type = get_args(expected_type)[0] if get_args(expected_type) else str
|
||||
actual_item_type = type(input_value).__name__
|
||||
expected_element_name = getattr(list_element_type, '__name__', str(list_element_type))
|
||||
|
||||
enhanced_messages.append(
|
||||
f"字段 '{field_path_str}' 类型错误: "
|
||||
f"期待类型 List[{expected_element_name}],"
|
||||
f"但列表中第 {element_index} 个元素类型为 {actual_item_type} (值: {input_value})"
|
||||
)
|
||||
else:
|
||||
# 其他嵌套字段错误
|
||||
actual_name = type(input_value).__name__
|
||||
enhanced_messages.append(
|
||||
f"字段 '{field_path_str}' 类型错误: "
|
||||
f"期待字符串类型,实际类型 {actual_name} (值: {input_value})"
|
||||
)
|
||||
else:
|
||||
# 回退到原始错误信息
|
||||
enhanced_messages.append(f"字段 '{field_path_str}': {error.get('msg', str(error))}")
|
||||
|
||||
# 处理缺失字段错误
|
||||
elif error_type == 'missing':
|
||||
enhanced_messages.append(f"缺少必需字段: '{field_path_str}'")
|
||||
|
||||
# 处理模型类型错误
|
||||
elif error_type in ['model_type', 'dict_type', 'is_instance_of']:
|
||||
field_name = field_path[0] if field_path else 'unknown'
|
||||
field_info = cls.model_fields.get(field_name)
|
||||
|
||||
if field_info and hasattr(field_info, 'annotation'):
|
||||
expected_type = field_info.annotation
|
||||
expected_name = getattr(expected_type, '__name__', str(expected_type))
|
||||
actual_name = type(input_value).__name__
|
||||
|
||||
enhanced_messages.append(
|
||||
f"字段 '{field_name}' 类型错误: "
|
||||
f"期待类型 {expected_name},实际类型 {actual_name} (值: {input_value})"
|
||||
)
|
||||
else:
|
||||
enhanced_messages.append(f"字段 '{field_path_str}': {error.get('msg', str(error))}")
|
||||
|
||||
# 处理其他类型错误
|
||||
else:
|
||||
enhanced_messages.append(f"字段 '{field_path_str}': {error.get('msg', str(error))}")
|
||||
|
||||
return "配置验证失败:\n" + "\n".join(f" - {msg}" for msg in enhanced_messages)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user