Files
Mofox-Core/src/config/config_base.py
雅诺狐 921d07e30a Enforce strict type validation and update config types
Enabled strict type checking in ValidatedConfigBase to fully disable type coercion. Updated MessageReceiveConfig and MemoryConfig fields from set/tuple to list types for compatibility with strict validation.
2025-08-20 19:27:47 +08:00

234 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dataclasses import dataclass, fields, MISSING
from typing import TypeVar, Type, Any, get_origin, get_args, Literal
from pydantic import BaseModel, ValidationError
T = TypeVar("T", bound="ConfigBase")
TOML_DICT_TYPE = {
int,
float,
str,
bool,
list,
dict,
}
@dataclass
class ConfigBase:
"""配置类的基类"""
@classmethod
def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
"""从字典加载配置字段"""
if not isinstance(data, dict):
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
init_args: dict[str, Any] = {}
for f in fields(cls):
field_name = f.name
if field_name.startswith("_"):
# 跳过以 _ 开头的字段
continue
if field_name not in data:
if f.default is not MISSING or f.default_factory is not MISSING:
# 跳过未提供且有默认值/默认构造方法的字段
continue
else:
raise ValueError(f"Missing required field: '{field_name}'")
value = data[field_name]
field_type = f.type
try:
init_args[field_name] = cls._convert_field(value, field_type) # type: ignore
except TypeError as e:
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
except Exception as e:
raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e
return cls(**init_args)
@classmethod
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
"""
转换字段值为指定类型
1. 对于嵌套的 dataclass递归调用相应的 from_dict 方法
2. 对于泛型集合类型list, set, tuple递归转换每个元素
3. 对于基础类型int, str, float, bool直接转换
4. 对于其他类型,尝试直接转换,如果失败则抛出异常
"""
# 如果是嵌套的 dataclass递归调用 from_dict 方法
if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
return field_type.from_dict(value)
# 处理泛型集合类型list, set, tuple
field_origin_type = get_origin(field_type)
field_type_args = get_args(field_type)
if field_origin_type in {list, set, tuple}:
# 检查提供的value是否为list
if not isinstance(value, list):
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
if field_origin_type is list:
# 如果列表元素类型是ConfigBase的子类则对每个元素调用from_dict
if (
field_type_args
and isinstance(field_type_args[0], type)
and issubclass(field_type_args[0], ConfigBase)
):
return [field_type_args[0].from_dict(item) for item in value]
return [cls._convert_field(item, field_type_args[0]) for item in value]
elif field_origin_type is set:
return {cls._convert_field(item, field_type_args[0]) for item in value}
elif field_origin_type is tuple:
# 检查提供的value长度是否与类型参数一致
if len(value) != len(field_type_args):
raise TypeError(
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
)
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args, strict=False))
if field_origin_type is dict:
# 检查提供的value是否为dict
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
# 检查字典的键值类型
if len(field_type_args) != 2:
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
key_type, value_type = field_type_args
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
# 处理基础类型,例如 int, str 等
if field_origin_type is type(None) and value is None: # 处理Optional类型
return None
# 处理Literal类型
if field_origin_type is Literal or get_origin(field_type) is Literal:
# 获取Literal的允许值
allowed_values = get_args(field_type)
if value in allowed_values:
return value
else:
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
if field_type is Any or isinstance(value, field_type):
return value
# 其他类型,尝试直接转换
try:
return field_type(value)
except (ValueError, TypeError) as e:
raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e
def __str__(self):
"""返回配置类的字符串表示"""
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"
class ValidatedConfigBase(BaseModel):
"""带验证的配置基类继承自Pydantic BaseModel"""
model_config = {
"extra": "allow", # 允许额外字段
"validate_assignment": True, # 验证赋值
"arbitrary_types_allowed": True, # 允许任意类型
"strict": True, # 如果设为 True 会完全禁用类型转换
}
@classmethod
def from_dict(cls, data: dict):
"""兼容原有的from_dict方法增强错误信息"""
try:
return cls.model_validate(data)
except ValidationError as e:
enhanced_message = cls._create_enhanced_error_message(e, data)
raise ValueError(enhanced_message) from e
@classmethod
def _create_enhanced_error_message(cls, e: ValidationError, data: dict) -> str:
"""创建增强的错误信息"""
enhanced_messages = []
for error in e.errors():
error_type = error.get('type', '')
field_path = error.get('loc', ())
input_value = error.get('input')
# 构建字段路径字符串
field_path_str = '.'.join(str(p) for p in field_path)
# 处理字符串类型错误
if error_type == 'string_type' and len(field_path) >= 2:
parent_field = field_path[0]
element_index = field_path[1]
# 尝试获取父字段的类型信息
parent_field_info = cls.model_fields.get(parent_field)
if parent_field_info and hasattr(parent_field_info, 'annotation'):
expected_type = parent_field_info.annotation
# 获取实际的父字段值
actual_parent_value = data.get(parent_field)
# 检查是否是列表类型错误
if get_origin(expected_type) is list and isinstance(actual_parent_value, list):
list_element_type = get_args(expected_type)[0] if get_args(expected_type) else str
actual_item_type = type(input_value).__name__
expected_element_name = getattr(list_element_type, '__name__', str(list_element_type))
enhanced_messages.append(
f"字段 '{field_path_str}' 类型错误: "
f"期待类型 List[{expected_element_name}]"
f"但列表中第 {element_index} 个元素类型为 {actual_item_type} (值: {input_value})"
)
else:
# 其他嵌套字段错误
actual_name = type(input_value).__name__
enhanced_messages.append(
f"字段 '{field_path_str}' 类型错误: "
f"期待字符串类型,实际类型 {actual_name} (值: {input_value})"
)
else:
# 回退到原始错误信息
enhanced_messages.append(f"字段 '{field_path_str}': {error.get('msg', str(error))}")
# 处理缺失字段错误
elif error_type == 'missing':
enhanced_messages.append(f"缺少必需字段: '{field_path_str}'")
# 处理模型类型错误
elif error_type in ['model_type', 'dict_type', 'is_instance_of']:
field_name = field_path[0] if field_path else 'unknown'
field_info = cls.model_fields.get(field_name)
if field_info and hasattr(field_info, 'annotation'):
expected_type = field_info.annotation
expected_name = getattr(expected_type, '__name__', str(expected_type))
actual_name = type(input_value).__name__
enhanced_messages.append(
f"字段 '{field_name}' 类型错误: "
f"期待类型 {expected_name},实际类型 {actual_name} (值: {input_value})"
)
else:
enhanced_messages.append(f"字段 '{field_path_str}': {error.get('msg', str(error))}")
# 处理其他类型错误
else:
enhanced_messages.append(f"字段 '{field_path_str}': {error.get('msg', str(error))}")
return "配置验证失败:\n" + "\n".join(f" - {msg}" for msg in enhanced_messages)