优化配置类,添加元信息和日志配置,调整验证策略以禁止额外字段

This commit is contained in:
Windpicker-owo
2025-12-13 22:35:34 +08:00
parent 7fbe90de95
commit 2f38d220c3
4 changed files with 179 additions and 27 deletions

View File

@@ -1,9 +1,10 @@
from threading import Lock
from typing import Any, Literal
from pydantic import Field
from pydantic import Field, PrivateAttr
from src.config.config_base import ValidatedConfigBase
from src.config.official_configs import InnerConfig
class APIProvider(ValidatedConfigBase):
@@ -21,6 +22,9 @@ class APIProvider(ValidatedConfigBase):
)
retry_interval: int = Field(default=10, ge=0, description="重试间隔如果API调用失败重试的间隔时间单位")
_api_key_lock: Lock = PrivateAttr(default_factory=Lock)
_api_key_index: int = PrivateAttr(default=0)
@classmethod
def validate_base_url(cls, v):
"""验证base_url确保URL格式正确"""
@@ -44,11 +48,6 @@ class APIProvider(ValidatedConfigBase):
raise ValueError("API密钥必须是字符串或字符串列表")
return v
def __init__(self, **data):
super().__init__(**data)
self._api_key_lock = Lock()
self._api_key_index = 0
def get_api_key(self) -> str:
with self._api_key_lock:
if isinstance(self.api_key, str):
@@ -134,6 +133,7 @@ class ModelTaskConfig(ValidatedConfigBase):
replyer_private: TaskConfig = Field(..., description="normal_chat首要回复模型模型配置私聊使用")
maizone: TaskConfig = Field(..., description="maizone专用模型")
emotion: TaskConfig = Field(..., description="情绪模型配置")
mood: TaskConfig = Field(..., description="心情模型配置")
vlm: TaskConfig = Field(..., description="视觉语言模型配置")
voice: TaskConfig = Field(..., description="语音识别模型配置")
tool_use: TaskConfig = Field(..., description="专注工具使用模型配置")
@@ -178,14 +178,26 @@ class ModelTaskConfig(ValidatedConfigBase):
class APIAdapterConfig(ValidatedConfigBase):
"""API Adapter配置类"""
inner: InnerConfig = Field(..., description="配置元信息")
models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
_api_providers_dict: dict[str, APIProvider] = PrivateAttr(default_factory=dict)
_models_dict: dict[str, ModelInfo] = PrivateAttr(default_factory=dict)
def __init__(self, **data):
super().__init__(**data)
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
self.models_dict = {model.name: model for model in self.models}
self._api_providers_dict = {provider.name: provider for provider in self.api_providers}
self._models_dict = {model.name: model for model in self.models}
@property
def api_providers_dict(self) -> dict[str, APIProvider]:
return self._api_providers_dict
@property
def models_dict(self) -> dict[str, ModelInfo]:
return self._models_dict
@classmethod
def validate_models_list(cls, v):

View File

@@ -1,10 +1,14 @@
import os
import shutil
import sys
import typing
import types
from datetime import datetime
from pathlib import Path
from typing import Any, get_args, get_origin
import tomlkit
from pydantic import Field
from pydantic import BaseModel, Field, PrivateAttr
from rich.traceback import install
from tomlkit import TOMLDocument
from tomlkit.items import KeyType, Table
@@ -25,6 +29,8 @@ from src.config.official_configs import (
EmojiConfig,
ExperimentalConfig,
ExpressionConfig,
InnerConfig,
LogConfig,
KokoroFlowChatterConfig,
LPMMKnowledgeConfig,
MemoryConfig,
@@ -180,6 +186,76 @@ def _remove_obsolete_keys(target: TOMLDocument | dict | Table, reference: TOMLDo
_remove_obsolete_keys(target[key], reference[key]) # type: ignore
def _prune_unknown_keys_by_schema(target: TOMLDocument | Table, schema_model: type[BaseModel]):
"""
基于 Pydantic Schema 递归移除未知配置键(含可重复的 AoT 表)。
说明:
- 只移除 schema 中不存在的键,避免跨版本遗留废弃配置项。
- 对于 list[BaseModel] 字段TOML 的 [[...]]),会遍历每个元素并递归清理。
- 对于 dict[str, Any] 等自由结构字段,不做键级裁剪。
"""
def _strip_optional(annotation: Any) -> Any:
origin = get_origin(annotation)
if origin is None:
return annotation
# 兼容 | None 与 Union[..., None]
union_type = getattr(types, "UnionType", None)
if origin is union_type or origin is typing.Union:
args = [a for a in get_args(annotation) if a is not type(None)]
if len(args) == 1:
return args[0]
return annotation
def _is_model_type(annotation: Any) -> bool:
return isinstance(annotation, type) and issubclass(annotation, BaseModel)
def _prune_table(table: TOMLDocument | Table, model: type[BaseModel]):
name_by_key: dict[str, str] = {}
allowed_keys: set[str] = set()
for field_name, field_info in model.model_fields.items():
allowed_keys.add(field_name)
name_by_key[field_name] = field_name
alias = getattr(field_info, "alias", None)
if isinstance(alias, str) and alias:
allowed_keys.add(alias)
name_by_key[alias] = field_name
for key in list(table.keys()):
if key not in allowed_keys:
del table[key]
continue
field_name = name_by_key[key]
field_info = model.model_fields[field_name]
annotation = _strip_optional(getattr(field_info, "annotation", Any))
value = table.get(key)
if value is None:
continue
if _is_model_type(annotation) and isinstance(value, (TOMLDocument, Table)):
_prune_table(value, annotation)
continue
origin = get_origin(annotation)
if origin is list:
args = get_args(annotation)
elem_ann = _strip_optional(args[0]) if args else Any
# list[BaseModel] 对应 TOML 的 AoT[[...]]
if _is_model_type(elem_ann) and hasattr(value, "__iter__"):
for item in value:
if isinstance(item, (TOMLDocument, Table)):
_prune_table(item, elem_ann)
_prune_table(target, schema_model)
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中
@@ -232,13 +308,14 @@ def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dic
target[key] = value
def _update_config_generic(config_name: str, template_name: str):
def _update_config_generic(config_name: str, template_name: str, schema_model: type[BaseModel] | None = None):
"""
通用的配置文件更新函数
Args:
config_name: 配置文件名(不含扩展名),如 'bot_config''model_config'
template_name: 模板文件名(不含扩展名),如 'bot_config_template''model_config_template'
schema_model: 用于裁剪未知键的 Pydantic 模型(避免跨版本遗留废弃配置项)
"""
# 获取根目录路径
old_config_dir = os.path.join(CONFIG_DIR, "old")
@@ -355,11 +432,14 @@ def _update_config_generic(config_name: str, template_name: str):
logger.info(f"开始合并{config_name}新旧配置...")
_update_dict(new_config, old_config)
# 移除在新模板中已不存在的旧配置项
# 移除未知/废弃的旧配置项(尤其是可重复的 [[...]] 段落)
logger.info(f"开始移除{config_name}中已废弃的配置项...")
with open(template_path, encoding="utf-8") as f:
template_doc = tomlkit.load(f)
_remove_obsolete_keys(new_config, template_doc)
if schema_model is not None:
_prune_unknown_keys_by_schema(new_config, schema_model)
else:
with open(template_path, encoding="utf-8") as f:
template_doc = tomlkit.load(f)
_remove_obsolete_keys(new_config, template_doc)
logger.info(f"已移除{config_name}中已废弃的配置项")
# 保存更新后的配置(保留注释和格式)
@@ -370,18 +450,18 @@ def _update_config_generic(config_name: str, template_name: str):
def update_config():
"""更新bot_config.toml配置文件"""
_update_config_generic("bot_config", "bot_config_template")
_update_config_generic("bot_config", "bot_config_template", schema_model=Config)
def update_model_config():
"""更新model_config.toml配置文件"""
_update_config_generic("model_config", "model_config_template")
_update_config_generic("model_config", "model_config_template", schema_model=APIAdapterConfig)
class Config(ValidatedConfigBase):
"""总配置类"""
MMC_VERSION: str = Field(default=MMC_VERSION, description="MaiCore版本号")
inner: InnerConfig = Field(..., description="配置元信息")
database: DatabaseConfig = Field(..., description="数据库配置")
bot: BotConfig = Field(..., description="机器人基本配置")
@@ -397,6 +477,7 @@ class Config(ValidatedConfigBase):
chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置")
response_post_process: ResponsePostProcessConfig = Field(..., description="响应后处理配置")
response_splitter: ResponseSplitterConfig = Field(..., description="响应分割配置")
log: LogConfig = Field(..., description="日志配置")
experimental: ExperimentalConfig = Field(default_factory=lambda: ExperimentalConfig(), description="实验性功能配置")
message_bus: MessageBusConfig = Field(..., description="消息总线配置")
lpmm_knowledge: LPMMKnowledgeConfig = Field(..., description="LPMM知识配置")
@@ -433,18 +514,34 @@ class Config(ValidatedConfigBase):
default_factory=lambda: PluginHttpSystemConfig(), description="插件HTTP端点系统配置"
)
@property
def MMC_VERSION(self) -> str: # noqa: N802
return MMC_VERSION
class APIAdapterConfig(ValidatedConfigBase):
"""API Adapter配置类"""
inner: InnerConfig = Field(..., description="配置元信息")
models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
_api_providers_dict: dict[str, APIProvider] = PrivateAttr(default_factory=dict)
_models_dict: dict[str, ModelInfo] = PrivateAttr(default_factory=dict)
def __init__(self, **data):
super().__init__(**data)
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
self.models_dict = {model.name: model for model in self.models}
self._api_providers_dict = {provider.name: provider for provider in self.api_providers}
self._models_dict = {model.name: model for model in self.models}
@property
def api_providers_dict(self) -> dict[str, APIProvider]:
return self._api_providers_dict
@property
def models_dict(self) -> dict[str, ModelInfo]:
return self._models_dict
@classmethod
def validate_models_list(cls, v):
@@ -502,9 +599,14 @@ def load_config(config_path: str) -> Config:
Returns:
Config对象
"""
# 读取配置文件
with open(config_path, encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 读取配置文件(会自动删除未知/废弃配置项)
original_text = Path(config_path).read_text(encoding="utf-8")
config_data = tomlkit.parse(original_text)
_prune_unknown_keys_by_schema(config_data, Config)
new_text = tomlkit.dumps(config_data)
if new_text != original_text:
Path(config_path).write_text(new_text, encoding="utf-8")
logger.warning(f"已自动移除 {config_path} 中未知/废弃配置项")
# 将 tomlkit 对象转换为纯 Python 字典,避免 Pydantic 严格模式下的类型验证问题
# tomlkit 返回的是特殊类型(如 Array、String 等),虽然继承自 Python 标准类型,
@@ -530,11 +632,16 @@ def api_ada_load_config(config_path: str) -> APIAdapterConfig:
Returns:
APIAdapterConfig对象
"""
# 读取配置文件
with open(config_path, encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 读取配置文件(会自动删除未知/废弃配置项)
original_text = Path(config_path).read_text(encoding="utf-8")
config_data = tomlkit.parse(original_text)
_prune_unknown_keys_by_schema(config_data, APIAdapterConfig)
new_text = tomlkit.dumps(config_data)
if new_text != original_text:
Path(config_path).write_text(new_text, encoding="utf-8")
logger.warning(f"已自动移除 {config_path} 中未知/废弃配置项")
config_dict = dict(config_data)
config_dict = config_data.unwrap()
try:
logger.debug("正在解析和验证API适配器配置文件...")

View File

@@ -142,7 +142,7 @@ class ValidatedConfigBase(BaseModel):
"""带验证的配置基类继承自Pydantic BaseModel"""
model_config = {
"extra": "allow", # 允许额外字段
"extra": "forbid", # 禁止额外字段(防止跨版本遗留废弃配置项)
"validate_assignment": True, # 验证赋值
"arbitrary_types_allowed": True, # 允许任意类型
"strict": True, # 如果设为 True 会完全禁用类型转换

View File

@@ -13,6 +13,12 @@ from src.config.config_base import ValidatedConfigBase
"""
class InnerConfig(ValidatedConfigBase):
"""配置文件元信息"""
version: str = Field(..., description="配置文件版本号(用于配置文件升级与兼容性检查)")
class DatabaseConfig(ValidatedConfigBase):
"""数据库配置类"""
@@ -588,6 +594,20 @@ class ResponseSplitterConfig(ValidatedConfigBase):
enable_kaomoji_protection: bool = Field(default=False, description="启用颜文字保护")
class LogConfig(ValidatedConfigBase):
"""日志配置类"""
date_style: str = Field(default="m-d H:i:s", description="日期格式")
log_level_style: str = Field(default="lite", description="日志级别样式")
color_text: str = Field(default="full", description="日志文本颜色")
log_level: str = Field(default="INFO", description="全局日志级别(向下兼容,优先级低于分别设置)")
file_retention_days: int = Field(default=7, description="文件日志保留天数0=禁用文件日志,-1=永不删除")
console_log_level: str = Field(default="INFO", description="控制台日志级别")
file_log_level: str = Field(default="DEBUG", description="文件日志级别")
suppress_libraries: list[str] = Field(default_factory=list, description="完全屏蔽日志的第三方库列表")
library_log_levels: dict[str, str] = Field(default_factory=dict, description="设置特定库的日志级别")
class DebugConfig(ValidatedConfigBase):
"""调试配置类"""
@@ -703,6 +723,7 @@ class WebSearchConfig(ValidatedConfigBase):
enable_url_tool: bool = Field(default=True, description="启用URL工具")
tavily_api_keys: list[str] = Field(default_factory=lambda: [], description="Tavily API密钥列表支持轮询机制")
exa_api_keys: list[str] = Field(default_factory=lambda: [], description="exa API密钥列表支持轮询机制")
metaso_api_keys: list[str] = Field(default_factory=lambda: [], description="Metaso API密钥列表支持轮询机制")
searxng_instances: list[str] = Field(default_factory=list, description="SearXNG 实例 URL 列表")
searxng_api_keys: list[str] = Field(default_factory=list, description="SearXNG 实例 API 密钥列表")
serper_api_keys: list[str] = Field(default_factory=list, description="serper API 密钥列表")
@@ -988,6 +1009,12 @@ class KokoroFlowChatterConfig(ValidatedConfigBase):
description="开启后KFC将接管所有私聊消息关闭后私聊消息将由AFC处理"
)
# --- 工作模式 ---
mode: Literal["unified", "split"] = Field(
default="split",
description='工作模式: "unified"(单次调用) 或 "split"(planner+replyer两次调用)',
)
# --- 核心行为配置 ---
max_wait_seconds_default: int = Field(
default=300, ge=30, le=3600,
@@ -998,6 +1025,12 @@ class KokoroFlowChatterConfig(ValidatedConfigBase):
description="是否在等待期间启用心理活动更新"
)
# --- 自定义决策提示词 ---
custom_decision_prompt: str = Field(
default="",
description="自定义KFC决策行为指导提示词unified影响整体split仅影响planner",
)
waiting: KokoroFlowChatterWaitingConfig = Field(
default_factory=KokoroFlowChatterWaitingConfig,
description="等待策略配置(默认等待时间、倍率等)",