Refactor config system to use Pydantic validation

Refactored configuration classes to inherit from a new ValidatedConfigBase using Pydantic for robust validation and error reporting. Updated api_ada_configs.py, config.py, config_base.py, and official_configs.py to replace dataclasses with Pydantic models, add field validation, and improve error messages. This change enhances configuration reliability and developer feedback for misconfigurations. Also includes minor code cleanups and removal of unused variables in other modules.
This commit is contained in:
雅诺狐
2025-08-19 15:33:43 +08:00
committed by Windpicker-owo
parent 97ece6524c
commit bb4592846c
19 changed files with 717 additions and 1288 deletions

View File

@@ -6,12 +6,12 @@ import sys
from datetime import datetime
from tomlkit import TOMLDocument
from tomlkit.items import Table, KeyType
from dataclasses import field, dataclass
from rich.traceback import install
from typing import List, Optional
from pydantic import Field, field_validator
from src.common.logger import get_logger
from src.config.config_base import ConfigBase
from src.config.config_base import ValidatedConfigBase
from src.config.official_configs import (
DatabaseConfig,
BotConfig,
@@ -342,82 +342,90 @@ def update_model_config():
_update_config_generic("model_config", "model_config_template")
@dataclass
class Config(ConfigBase):
class Config(ValidatedConfigBase):
"""总配置类"""
MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息
MMC_VERSION: str = Field(default=MMC_VERSION, description="MaiCore版本号")
database: DatabaseConfig
bot: BotConfig
personality: PersonalityConfig
relationship: RelationshipConfig
chat: ChatConfig
message_receive: MessageReceiveConfig
emoji: EmojiConfig
expression: ExpressionConfig
memory: MemoryConfig
mood: MoodConfig
keyword_reaction: KeywordReactionConfig
chinese_typo: ChineseTypoConfig
response_post_process: ResponsePostProcessConfig
response_splitter: ResponseSplitterConfig
telemetry: TelemetryConfig
experimental: ExperimentalConfig
maim_message: MaimMessageConfig
lpmm_knowledge: LPMMKnowledgeConfig
tool: ToolConfig
debug: DebugConfig
custom_prompt: CustomPromptConfig
voice: VoiceConfig
schedule: ScheduleConfig
database: DatabaseConfig = Field(..., description="数据库配置")
bot: BotConfig = Field(..., description="机器人基本配置")
personality: PersonalityConfig = Field(..., description="个性配置")
relationship: RelationshipConfig = Field(..., description="关系配置")
chat: ChatConfig = Field(..., description="聊天配置")
message_receive: MessageReceiveConfig = Field(..., description="消息接收配置")
normal_chat: NormalChatConfig = Field(..., description="普通聊天配置")
emoji: EmojiConfig = Field(..., description="表情配置")
expression: ExpressionConfig = Field(..., description="表达配置")
memory: MemoryConfig = Field(..., description="记忆配置")
mood: MoodConfig = Field(..., description="情绪配置")
keyword_reaction: KeywordReactionConfig = Field(..., description="关键词反应配置")
chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置")
response_post_process: ResponsePostProcessConfig = Field(..., description="响应后处理配置")
response_splitter: ResponseSplitterConfig = Field(..., description="响应分割配置")
telemetry: TelemetryConfig = Field(..., description="遥测配置")
experimental: ExperimentalConfig = Field(..., description="实验性功能配置")
maim_message: MaimMessageConfig = Field(..., description="Maim消息配置")
lpmm_knowledge: LPMMKnowledgeConfig = Field(..., description="LPMM知识配置")
tool: ToolConfig = Field(..., description="工具配置")
debug: DebugConfig = Field(..., description="调试配置")
custom_prompt: CustomPromptConfig = Field(..., description="自定义提示配置")
voice: VoiceConfig = Field(..., description="语音配置")
schedule: ScheduleConfig = Field(..., description="调度配置")
# 有默认值的字段放在后面
anti_prompt_injection: AntiPromptInjectionConfig = field(default_factory=lambda: AntiPromptInjectionConfig())
video_analysis: VideoAnalysisConfig = field(default_factory=lambda: VideoAnalysisConfig())
dependency_management: DependencyManagementConfig = field(default_factory=lambda: DependencyManagementConfig())
exa: ExaConfig = field(default_factory=lambda: ExaConfig())
web_search: WebSearchConfig = field(default_factory=lambda: WebSearchConfig())
tavily: TavilyConfig = field(default_factory=lambda: TavilyConfig())
plugins: PluginsConfig = field(default_factory=lambda: PluginsConfig())
anti_prompt_injection: AntiPromptInjectionConfig = Field(default_factory=lambda: AntiPromptInjectionConfig(), description="反提示注入配置")
video_analysis: VideoAnalysisConfig = Field(default_factory=lambda: VideoAnalysisConfig(), description="视频分析配置")
dependency_management: DependencyManagementConfig = Field(default_factory=lambda: DependencyManagementConfig(), description="依赖管理配置")
exa: ExaConfig = Field(default_factory=lambda: ExaConfig(), description="Exa配置")
web_search: WebSearchConfig = Field(default_factory=lambda: WebSearchConfig(), description="网络搜索配置")
tavily: TavilyConfig = Field(default_factory=lambda: TavilyConfig(), description="Tavily配置")
plugins: PluginsConfig = Field(default_factory=lambda: PluginsConfig(), description="插件配置")
@dataclass
class APIAdapterConfig(ConfigBase):
class APIAdapterConfig(ValidatedConfigBase):
"""API Adapter配置类"""
models: List[ModelInfo]
"""模型列表"""
model_task_config: ModelTaskConfig
"""模型任务配置"""
api_providers: List[APIProvider] = field(default_factory=list)
"""API提供商列表"""
def __post_init__(self):
if not self.models:
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
if not self.api_providers:
raise ValueError("API提供商列表不能为空请在配置中设置有效的API提供商列表。")
# 检查API提供商名称是否重复
provider_names = [provider.name for provider in self.api_providers]
if len(provider_names) != len(set(provider_names)):
raise ValueError("API提供商名称存在重复请检查配置文件。")
# 检查模型名称是否重复
model_names = [model.name for model in self.models]
if len(model_names) != len(set(model_names)):
raise ValueError("模型名称存在重复,请检查配置文件。")
models: List[ModelInfo] = Field(..., min_items=1, description="模型列表")
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
api_providers: List[APIProvider] = Field(..., min_items=1, description="API提供商列表")
def __init__(self, **data):
super().__init__(**data)
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
self.models_dict = {model.name: model for model in self.models}
for model in self.models:
@field_validator('models')
@classmethod
def validate_models_list(cls, v):
"""验证模型列表"""
if not v:
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
# 检查模型名称是否重复
model_names = [model.name for model in v]
if len(model_names) != len(set(model_names)):
raise ValueError("模型名称存在重复,请检查配置文件。")
# 检查模型标识符是否有效
for model in v:
if not model.model_identifier:
raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")
if not model.api_provider or model.api_provider not in self.api_providers_dict:
raise ValueError(f"模型 '{model.name}' 的 api_provider '{model.api_provider}' 不存在")
return v
@field_validator('api_providers')
@classmethod
def validate_api_providers_list(cls, v):
"""验证API提供商列表"""
if not v:
raise ValueError("API提供商列表不能为空请在配置中设置有效的API提供商列表。")
# 检查API提供商名称是否重复
provider_names = [provider.name for provider in v]
if len(provider_names) != len(set(provider_names)):
raise ValueError("API提供商名称存在重复请检查配置文件。")
return v
def get_model_info(self, model_name: str) -> ModelInfo:
"""根据模型名称获取模型信息"""
@@ -448,11 +456,14 @@ def load_config(config_path: str) -> Config:
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建Config对象
# 创建Config对象(各个配置类会自动进行 Pydantic 验证)
try:
return Config.from_dict(config_data)
logger.info("正在解析和验证配置文件...")
config = Config.from_dict(config_data)
logger.info("配置文件解析和验证完成")
return config
except Exception as e:
logger.critical("配置文件解析失败")
logger.critical(f"配置文件解析失败: {e}")
raise e
@@ -468,11 +479,14 @@ def api_ada_load_config(config_path: str) -> APIAdapterConfig:
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建APIAdapterConfig对象
# 创建APIAdapterConfig对象(各个配置类会自动进行 Pydantic 验证)
try:
return APIAdapterConfig.from_dict(config_data)
logger.info("正在解析和验证API适配器配置文件...")
config = APIAdapterConfig.from_dict(config_data)
logger.info("API适配器配置文件解析和验证完成")
return config
except Exception as e:
logger.critical("API适配器配置文件解析失败")
logger.critical(f"API适配器配置文件解析失败: {e}")
raise e