137 lines
5.5 KiB
Python
137 lines
5.5 KiB
Python
from dataclasses import dataclass, fields, MISSING
|
||
from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Dict, Union
|
||
|
||
T = TypeVar("T", bound="ConfigBase")
|
||
|
||
TOML_DICT_TYPE = {
|
||
int,
|
||
float,
|
||
str,
|
||
bool,
|
||
list,
|
||
dict,
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class ConfigBase:
|
||
"""配置类的基类"""
|
||
|
||
@classmethod
|
||
def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
|
||
"""从字典加载配置字段"""
|
||
if not isinstance(data, dict):
|
||
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
|
||
|
||
init_args: Dict[str, Any] = {}
|
||
|
||
for f in fields(cls):
|
||
field_name = f.name
|
||
field_type = f.type
|
||
if field_name.startswith("_"):
|
||
# 跳过以 _ 开头的字段
|
||
continue
|
||
|
||
if field_name not in data:
|
||
if f.default is not MISSING or f.default_factory is not MISSING:
|
||
# 跳过未提供且有默认值/默认构造方法的字段
|
||
continue
|
||
else:
|
||
raise ValueError(f"Missing required field: '{field_name}'")
|
||
|
||
value = data[field_name]
|
||
try:
|
||
init_args[field_name] = cls._convert_field(value, field_type)
|
||
except TypeError as e:
|
||
raise TypeError(f"字段 '{field_name}' 出现类型错误: {e}") from e
|
||
except Exception as e:
|
||
raise RuntimeError(f"无法将字段 '{field_name}' 转换为目标类型,出现错误: {e}") from e
|
||
|
||
return cls(**init_args)
|
||
|
||
@classmethod
|
||
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
|
||
"""
|
||
转换字段值为指定类型
|
||
|
||
1. 对于嵌套的 dataclass,递归调用相应的 from_dict 方法
|
||
2. 对于泛型集合类型(list, set, tuple),递归转换每个元素
|
||
3. 对于基础类型(int, str, float, bool),直接转换
|
||
4. 对于其他类型,尝试直接转换,如果失败则抛出异常
|
||
"""
|
||
# 如果是嵌套的 dataclass,递归调用 from_dict 方法
|
||
if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
|
||
return field_type.from_dict(value)
|
||
|
||
field_origin_type = get_origin(field_type)
|
||
field_args_type = get_args(field_type)
|
||
|
||
# 处理泛型集合类型(list, set, tuple)
|
||
if field_origin_type in {list, set, tuple}:
|
||
# 检查提供的value是否为list
|
||
if not isinstance(value, list):
|
||
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
|
||
|
||
if field_origin_type is list:
|
||
return [cls._convert_field(item, field_args_type[0]) for item in value]
|
||
if field_origin_type is set:
|
||
return {cls._convert_field(item, field_args_type[0]) for item in value}
|
||
if field_origin_type is tuple:
|
||
# 检查提供的value长度是否与类型参数一致
|
||
if len(value) != len(field_args_type):
|
||
raise TypeError(
|
||
f"Expected {len(field_args_type)} items for {field_type.__name__}, got {len(value)}"
|
||
)
|
||
return tuple(cls._convert_field(item, arg_type) for item, arg_type in zip(value, field_args_type))
|
||
|
||
if field_origin_type is dict:
|
||
# 检查提供的value是否为dict
|
||
if not isinstance(value, dict):
|
||
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||
|
||
# 检查字典的键值类型
|
||
if len(field_args_type) != 2:
|
||
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
|
||
key_type, value_type = field_args_type
|
||
|
||
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
|
||
|
||
# 处理Optional类型
|
||
if field_origin_type is Union: # assert get_origin(Optional[Any]) is Union
|
||
if value is None:
|
||
return None
|
||
# 如果有数据,检查实际类型
|
||
if type(value) not in field_args_type:
|
||
raise TypeError(f"Expected {field_args_type} for {field_type.__name__}, got {type(value).__name__}")
|
||
return cls._convert_field(value, field_args_type[0])
|
||
|
||
# 处理int, str, float, bool等基础类型
|
||
if field_origin_type is None:
|
||
if isinstance(value, field_type):
|
||
return field_type(value)
|
||
else:
|
||
raise TypeError(f"Expected {field_type.__name__}, got {type(value).__name__}")
|
||
|
||
# 处理Literal类型
|
||
if field_origin_type is Literal:
|
||
# 获取Literal的允许值
|
||
allowed_values = get_args(field_type)
|
||
if value in allowed_values:
|
||
return value
|
||
else:
|
||
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
|
||
|
||
# 处理其他类型
|
||
if field_type is Any:
|
||
return value
|
||
|
||
# 其他类型直接转换
|
||
try:
|
||
return field_type(value)
|
||
except (ValueError, TypeError) as e:
|
||
raise TypeError(f"无法将 {type(value).__name__} 转换为 {field_type.__name__}") from e
|
||
|
||
def __str__(self):
|
||
"""返回配置类的字符串表示"""
|
||
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"
|