优化配置类,添加元信息和日志配置,调整验证策略以禁止额外字段

This commit is contained in:
Windpicker-owo
2025-12-13 22:35:34 +08:00
parent 7fbe90de95
commit 2f38d220c3
4 changed files with 179 additions and 27 deletions

View File

@@ -1,10 +1,14 @@
import os
import shutil
import sys
import typing
import types
from datetime import datetime
from pathlib import Path
from typing import Any, get_args, get_origin
import tomlkit
from pydantic import Field
from pydantic import BaseModel, Field, PrivateAttr
from rich.traceback import install
from tomlkit import TOMLDocument
from tomlkit.items import KeyType, Table
@@ -25,6 +29,8 @@ from src.config.official_configs import (
EmojiConfig,
ExperimentalConfig,
ExpressionConfig,
InnerConfig,
LogConfig,
KokoroFlowChatterConfig,
LPMMKnowledgeConfig,
MemoryConfig,
@@ -180,6 +186,76 @@ def _remove_obsolete_keys(target: TOMLDocument | dict | Table, reference: TOMLDo
_remove_obsolete_keys(target[key], reference[key]) # type: ignore
def _prune_unknown_keys_by_schema(target: TOMLDocument | Table, schema_model: type[BaseModel]):
"""
基于 Pydantic Schema 递归移除未知配置键(含可重复的 AoT 表)。
说明:
- 只移除 schema 中不存在的键,避免跨版本遗留废弃配置项。
- 对于 list[BaseModel] 字段TOML 的 [[...]]),会遍历每个元素并递归清理。
- 对于 dict[str, Any] 等自由结构字段,不做键级裁剪。
"""
def _strip_optional(annotation: Any) -> Any:
origin = get_origin(annotation)
if origin is None:
return annotation
# 兼容 | None 与 Union[..., None]
union_type = getattr(types, "UnionType", None)
if origin is union_type or origin is typing.Union:
args = [a for a in get_args(annotation) if a is not type(None)]
if len(args) == 1:
return args[0]
return annotation
def _is_model_type(annotation: Any) -> bool:
return isinstance(annotation, type) and issubclass(annotation, BaseModel)
def _prune_table(table: TOMLDocument | Table, model: type[BaseModel]):
name_by_key: dict[str, str] = {}
allowed_keys: set[str] = set()
for field_name, field_info in model.model_fields.items():
allowed_keys.add(field_name)
name_by_key[field_name] = field_name
alias = getattr(field_info, "alias", None)
if isinstance(alias, str) and alias:
allowed_keys.add(alias)
name_by_key[alias] = field_name
for key in list(table.keys()):
if key not in allowed_keys:
del table[key]
continue
field_name = name_by_key[key]
field_info = model.model_fields[field_name]
annotation = _strip_optional(getattr(field_info, "annotation", Any))
value = table.get(key)
if value is None:
continue
if _is_model_type(annotation) and isinstance(value, (TOMLDocument, Table)):
_prune_table(value, annotation)
continue
origin = get_origin(annotation)
if origin is list:
args = get_args(annotation)
elem_ann = _strip_optional(args[0]) if args else Any
# list[BaseModel] 对应 TOML 的 AoT[[...]]
if _is_model_type(elem_ann) and hasattr(value, "__iter__"):
for item in value:
if isinstance(item, (TOMLDocument, Table)):
_prune_table(item, elem_ann)
_prune_table(target, schema_model)
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中
@@ -232,13 +308,14 @@ def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dic
target[key] = value
def _update_config_generic(config_name: str, template_name: str):
def _update_config_generic(config_name: str, template_name: str, schema_model: type[BaseModel] | None = None):
"""
通用的配置文件更新函数
Args:
config_name: 配置文件名(不含扩展名),如 'bot_config''model_config'
template_name: 模板文件名(不含扩展名),如 'bot_config_template''model_config_template'
schema_model: 用于裁剪未知键的 Pydantic 模型(避免跨版本遗留废弃配置项)
"""
# 获取根目录路径
old_config_dir = os.path.join(CONFIG_DIR, "old")
@@ -355,11 +432,14 @@ def _update_config_generic(config_name: str, template_name: str):
logger.info(f"开始合并{config_name}新旧配置...")
_update_dict(new_config, old_config)
# 移除在新模板中已不存在的旧配置项
# 移除未知/废弃的旧配置项(尤其是可重复的 [[...]] 段落)
logger.info(f"开始移除{config_name}中已废弃的配置项...")
with open(template_path, encoding="utf-8") as f:
template_doc = tomlkit.load(f)
_remove_obsolete_keys(new_config, template_doc)
if schema_model is not None:
_prune_unknown_keys_by_schema(new_config, schema_model)
else:
with open(template_path, encoding="utf-8") as f:
template_doc = tomlkit.load(f)
_remove_obsolete_keys(new_config, template_doc)
logger.info(f"已移除{config_name}中已废弃的配置项")
# 保存更新后的配置(保留注释和格式)
@@ -370,18 +450,18 @@ def _update_config_generic(config_name: str, template_name: str):
def update_config():
"""更新bot_config.toml配置文件"""
_update_config_generic("bot_config", "bot_config_template")
_update_config_generic("bot_config", "bot_config_template", schema_model=Config)
def update_model_config():
"""更新model_config.toml配置文件"""
_update_config_generic("model_config", "model_config_template")
_update_config_generic("model_config", "model_config_template", schema_model=APIAdapterConfig)
class Config(ValidatedConfigBase):
"""总配置类"""
MMC_VERSION: str = Field(default=MMC_VERSION, description="MaiCore版本号")
inner: InnerConfig = Field(..., description="配置元信息")
database: DatabaseConfig = Field(..., description="数据库配置")
bot: BotConfig = Field(..., description="机器人基本配置")
@@ -397,6 +477,7 @@ class Config(ValidatedConfigBase):
chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置")
response_post_process: ResponsePostProcessConfig = Field(..., description="响应后处理配置")
response_splitter: ResponseSplitterConfig = Field(..., description="响应分割配置")
log: LogConfig = Field(..., description="日志配置")
experimental: ExperimentalConfig = Field(default_factory=lambda: ExperimentalConfig(), description="实验性功能配置")
message_bus: MessageBusConfig = Field(..., description="消息总线配置")
lpmm_knowledge: LPMMKnowledgeConfig = Field(..., description="LPMM知识配置")
@@ -433,18 +514,34 @@ class Config(ValidatedConfigBase):
default_factory=lambda: PluginHttpSystemConfig(), description="插件HTTP端点系统配置"
)
@property
def MMC_VERSION(self) -> str: # noqa: N802
return MMC_VERSION
class APIAdapterConfig(ValidatedConfigBase):
"""API Adapter配置类"""
inner: InnerConfig = Field(..., description="配置元信息")
models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
_api_providers_dict: dict[str, APIProvider] = PrivateAttr(default_factory=dict)
_models_dict: dict[str, ModelInfo] = PrivateAttr(default_factory=dict)
def __init__(self, **data):
super().__init__(**data)
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
self.models_dict = {model.name: model for model in self.models}
self._api_providers_dict = {provider.name: provider for provider in self.api_providers}
self._models_dict = {model.name: model for model in self.models}
@property
def api_providers_dict(self) -> dict[str, APIProvider]:
return self._api_providers_dict
@property
def models_dict(self) -> dict[str, ModelInfo]:
return self._models_dict
@classmethod
def validate_models_list(cls, v):
@@ -502,9 +599,14 @@ def load_config(config_path: str) -> Config:
Returns:
Config对象
"""
# 读取配置文件
with open(config_path, encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 读取配置文件(会自动删除未知/废弃配置项)
original_text = Path(config_path).read_text(encoding="utf-8")
config_data = tomlkit.parse(original_text)
_prune_unknown_keys_by_schema(config_data, Config)
new_text = tomlkit.dumps(config_data)
if new_text != original_text:
Path(config_path).write_text(new_text, encoding="utf-8")
logger.warning(f"已自动移除 {config_path} 中未知/废弃配置项")
# 将 tomlkit 对象转换为纯 Python 字典,避免 Pydantic 严格模式下的类型验证问题
# tomlkit 返回的是特殊类型(如 Array、String 等),虽然继承自 Python 标准类型,
@@ -530,11 +632,16 @@ def api_ada_load_config(config_path: str) -> APIAdapterConfig:
Returns:
APIAdapterConfig对象
"""
# 读取配置文件
with open(config_path, encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 读取配置文件(会自动删除未知/废弃配置项)
original_text = Path(config_path).read_text(encoding="utf-8")
config_data = tomlkit.parse(original_text)
_prune_unknown_keys_by_schema(config_data, APIAdapterConfig)
new_text = tomlkit.dumps(config_data)
if new_text != original_text:
Path(config_path).write_text(new_text, encoding="utf-8")
logger.warning(f"已自动移除 {config_path} 中未知/废弃配置项")
config_dict = dict(config_data)
config_dict = config_data.unwrap()
try:
logger.debug("正在解析和验证API适配器配置文件...")