Refactor config system to use Pydantic validation
Refactored configuration classes to inherit from a new ValidatedConfigBase using Pydantic for robust validation and error reporting. Updated api_ada_configs.py, config.py, config_base.py, and official_configs.py to replace dataclasses with Pydantic models, add field validation, and improve error messages. This change enhances configuration reliability and developer feedback for misconfigurations. Also includes minor code cleanups and removal of unused variables in other modules.
This commit is contained in:
@@ -6,12 +6,12 @@ import sys
|
||||
from datetime import datetime
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import Table, KeyType
|
||||
from dataclasses import field, dataclass
|
||||
from rich.traceback import install
|
||||
from typing import List, Optional
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config_base import ConfigBase
|
||||
from src.config.config_base import ValidatedConfigBase
|
||||
from src.config.official_configs import (
|
||||
DatabaseConfig,
|
||||
BotConfig,
|
||||
@@ -329,83 +329,90 @@ def update_model_config():
|
||||
_update_config_generic("model_config", "model_config_template")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config(ConfigBase):
|
||||
class Config(ValidatedConfigBase):
|
||||
"""总配置类"""
|
||||
|
||||
MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息
|
||||
MMC_VERSION: str = Field(default=MMC_VERSION, description="MaiCore版本号")
|
||||
|
||||
database: DatabaseConfig
|
||||
bot: BotConfig
|
||||
personality: PersonalityConfig
|
||||
relationship: RelationshipConfig
|
||||
chat: ChatConfig
|
||||
message_receive: MessageReceiveConfig
|
||||
normal_chat: NormalChatConfig
|
||||
emoji: EmojiConfig
|
||||
expression: ExpressionConfig
|
||||
memory: MemoryConfig
|
||||
mood: MoodConfig
|
||||
keyword_reaction: KeywordReactionConfig
|
||||
chinese_typo: ChineseTypoConfig
|
||||
response_post_process: ResponsePostProcessConfig
|
||||
response_splitter: ResponseSplitterConfig
|
||||
telemetry: TelemetryConfig
|
||||
experimental: ExperimentalConfig
|
||||
maim_message: MaimMessageConfig
|
||||
lpmm_knowledge: LPMMKnowledgeConfig
|
||||
tool: ToolConfig
|
||||
debug: DebugConfig
|
||||
custom_prompt: CustomPromptConfig
|
||||
voice: VoiceConfig
|
||||
schedule: ScheduleConfig
|
||||
database: DatabaseConfig = Field(..., description="数据库配置")
|
||||
bot: BotConfig = Field(..., description="机器人基本配置")
|
||||
personality: PersonalityConfig = Field(..., description="个性配置")
|
||||
relationship: RelationshipConfig = Field(..., description="关系配置")
|
||||
chat: ChatConfig = Field(..., description="聊天配置")
|
||||
message_receive: MessageReceiveConfig = Field(..., description="消息接收配置")
|
||||
normal_chat: NormalChatConfig = Field(..., description="普通聊天配置")
|
||||
emoji: EmojiConfig = Field(..., description="表情配置")
|
||||
expression: ExpressionConfig = Field(..., description="表达配置")
|
||||
memory: MemoryConfig = Field(..., description="记忆配置")
|
||||
mood: MoodConfig = Field(..., description="情绪配置")
|
||||
keyword_reaction: KeywordReactionConfig = Field(..., description="关键词反应配置")
|
||||
chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置")
|
||||
response_post_process: ResponsePostProcessConfig = Field(..., description="响应后处理配置")
|
||||
response_splitter: ResponseSplitterConfig = Field(..., description="响应分割配置")
|
||||
telemetry: TelemetryConfig = Field(..., description="遥测配置")
|
||||
experimental: ExperimentalConfig = Field(..., description="实验性功能配置")
|
||||
maim_message: MaimMessageConfig = Field(..., description="Maim消息配置")
|
||||
lpmm_knowledge: LPMMKnowledgeConfig = Field(..., description="LPMM知识配置")
|
||||
tool: ToolConfig = Field(..., description="工具配置")
|
||||
debug: DebugConfig = Field(..., description="调试配置")
|
||||
custom_prompt: CustomPromptConfig = Field(..., description="自定义提示配置")
|
||||
voice: VoiceConfig = Field(..., description="语音配置")
|
||||
schedule: ScheduleConfig = Field(..., description="调度配置")
|
||||
|
||||
# 有默认值的字段放在后面
|
||||
anti_prompt_injection: AntiPromptInjectionConfig = field(default_factory=lambda: AntiPromptInjectionConfig())
|
||||
video_analysis: VideoAnalysisConfig = field(default_factory=lambda: VideoAnalysisConfig())
|
||||
dependency_management: DependencyManagementConfig = field(default_factory=lambda: DependencyManagementConfig())
|
||||
exa: ExaConfig = field(default_factory=lambda: ExaConfig())
|
||||
web_search: WebSearchConfig = field(default_factory=lambda: WebSearchConfig())
|
||||
tavily: TavilyConfig = field(default_factory=lambda: TavilyConfig())
|
||||
plugins: PluginsConfig = field(default_factory=lambda: PluginsConfig())
|
||||
anti_prompt_injection: AntiPromptInjectionConfig = Field(default_factory=lambda: AntiPromptInjectionConfig(), description="反提示注入配置")
|
||||
video_analysis: VideoAnalysisConfig = Field(default_factory=lambda: VideoAnalysisConfig(), description="视频分析配置")
|
||||
dependency_management: DependencyManagementConfig = Field(default_factory=lambda: DependencyManagementConfig(), description="依赖管理配置")
|
||||
exa: ExaConfig = Field(default_factory=lambda: ExaConfig(), description="Exa配置")
|
||||
web_search: WebSearchConfig = Field(default_factory=lambda: WebSearchConfig(), description="网络搜索配置")
|
||||
tavily: TavilyConfig = Field(default_factory=lambda: TavilyConfig(), description="Tavily配置")
|
||||
plugins: PluginsConfig = Field(default_factory=lambda: PluginsConfig(), description="插件配置")
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIAdapterConfig(ConfigBase):
|
||||
class APIAdapterConfig(ValidatedConfigBase):
|
||||
"""API Adapter配置类"""
|
||||
|
||||
models: List[ModelInfo]
|
||||
"""模型列表"""
|
||||
|
||||
model_task_config: ModelTaskConfig
|
||||
"""模型任务配置"""
|
||||
|
||||
api_providers: List[APIProvider] = field(default_factory=list)
|
||||
"""API提供商列表"""
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.models:
|
||||
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
|
||||
if not self.api_providers:
|
||||
raise ValueError("API提供商列表不能为空,请在配置中设置有效的API提供商列表。")
|
||||
|
||||
# 检查API提供商名称是否重复
|
||||
provider_names = [provider.name for provider in self.api_providers]
|
||||
if len(provider_names) != len(set(provider_names)):
|
||||
raise ValueError("API提供商名称存在重复,请检查配置文件。")
|
||||
|
||||
# 检查模型名称是否重复
|
||||
model_names = [model.name for model in self.models]
|
||||
if len(model_names) != len(set(model_names)):
|
||||
raise ValueError("模型名称存在重复,请检查配置文件。")
|
||||
models: List[ModelInfo] = Field(..., min_items=1, description="模型列表")
|
||||
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
|
||||
api_providers: List[APIProvider] = Field(..., min_items=1, description="API提供商列表")
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
|
||||
self.models_dict = {model.name: model for model in self.models}
|
||||
|
||||
for model in self.models:
|
||||
@field_validator('models')
|
||||
@classmethod
|
||||
def validate_models_list(cls, v):
|
||||
"""验证模型列表"""
|
||||
if not v:
|
||||
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
|
||||
|
||||
# 检查模型名称是否重复
|
||||
model_names = [model.name for model in v]
|
||||
if len(model_names) != len(set(model_names)):
|
||||
raise ValueError("模型名称存在重复,请检查配置文件。")
|
||||
|
||||
# 检查模型标识符是否有效
|
||||
for model in v:
|
||||
if not model.model_identifier:
|
||||
raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")
|
||||
if not model.api_provider or model.api_provider not in self.api_providers_dict:
|
||||
raise ValueError(f"模型 '{model.name}' 的 api_provider '{model.api_provider}' 不存在")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator('api_providers')
|
||||
@classmethod
|
||||
def validate_api_providers_list(cls, v):
|
||||
"""验证API提供商列表"""
|
||||
if not v:
|
||||
raise ValueError("API提供商列表不能为空,请在配置中设置有效的API提供商列表。")
|
||||
|
||||
# 检查API提供商名称是否重复
|
||||
provider_names = [provider.name for provider in v]
|
||||
if len(provider_names) != len(set(provider_names)):
|
||||
raise ValueError("API提供商名称存在重复,请检查配置文件。")
|
||||
|
||||
return v
|
||||
|
||||
def get_model_info(self, model_name: str) -> ModelInfo:
|
||||
"""根据模型名称获取模型信息"""
|
||||
@@ -436,11 +443,14 @@ def load_config(config_path: str) -> Config:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 创建Config对象
|
||||
# 创建Config对象(各个配置类会自动进行 Pydantic 验证)
|
||||
try:
|
||||
return Config.from_dict(config_data)
|
||||
logger.info("正在解析和验证配置文件...")
|
||||
config = Config.from_dict(config_data)
|
||||
logger.info("配置文件解析和验证完成")
|
||||
return config
|
||||
except Exception as e:
|
||||
logger.critical("配置文件解析失败")
|
||||
logger.critical(f"配置文件解析失败: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
@@ -456,11 +466,14 @@ def api_ada_load_config(config_path: str) -> APIAdapterConfig:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 创建APIAdapterConfig对象
|
||||
# 创建APIAdapterConfig对象(各个配置类会自动进行 Pydantic 验证)
|
||||
try:
|
||||
return APIAdapterConfig.from_dict(config_data)
|
||||
logger.info("正在解析和验证API适配器配置文件...")
|
||||
config = APIAdapterConfig.from_dict(config_data)
|
||||
logger.info("API适配器配置文件解析和验证完成")
|
||||
return config
|
||||
except Exception as e:
|
||||
logger.critical("API适配器配置文件解析失败")
|
||||
logger.critical(f"API适配器配置文件解析失败: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user