优化配置类,添加元信息和日志配置,调整验证策略以禁止额外字段

This commit is contained in:
Windpicker-owo
2025-12-13 22:35:34 +08:00
parent 7fbe90de95
commit 2f38d220c3
4 changed files with 179 additions and 27 deletions

View File

@@ -1,9 +1,10 @@
from threading import Lock from threading import Lock
from typing import Any, Literal from typing import Any, Literal
from pydantic import Field from pydantic import Field, PrivateAttr
from src.config.config_base import ValidatedConfigBase from src.config.config_base import ValidatedConfigBase
from src.config.official_configs import InnerConfig
class APIProvider(ValidatedConfigBase): class APIProvider(ValidatedConfigBase):
@@ -21,6 +22,9 @@ class APIProvider(ValidatedConfigBase):
) )
retry_interval: int = Field(default=10, ge=0, description="重试间隔如果API调用失败重试的间隔时间单位") retry_interval: int = Field(default=10, ge=0, description="重试间隔如果API调用失败重试的间隔时间单位")
_api_key_lock: Lock = PrivateAttr(default_factory=Lock)
_api_key_index: int = PrivateAttr(default=0)
@classmethod @classmethod
def validate_base_url(cls, v): def validate_base_url(cls, v):
"""验证base_url确保URL格式正确""" """验证base_url确保URL格式正确"""
@@ -44,11 +48,6 @@ class APIProvider(ValidatedConfigBase):
raise ValueError("API密钥必须是字符串或字符串列表") raise ValueError("API密钥必须是字符串或字符串列表")
return v return v
def __init__(self, **data):
super().__init__(**data)
self._api_key_lock = Lock()
self._api_key_index = 0
def get_api_key(self) -> str: def get_api_key(self) -> str:
with self._api_key_lock: with self._api_key_lock:
if isinstance(self.api_key, str): if isinstance(self.api_key, str):
@@ -134,6 +133,7 @@ class ModelTaskConfig(ValidatedConfigBase):
replyer_private: TaskConfig = Field(..., description="normal_chat首要回复模型模型配置私聊使用") replyer_private: TaskConfig = Field(..., description="normal_chat首要回复模型模型配置私聊使用")
maizone: TaskConfig = Field(..., description="maizone专用模型") maizone: TaskConfig = Field(..., description="maizone专用模型")
emotion: TaskConfig = Field(..., description="情绪模型配置") emotion: TaskConfig = Field(..., description="情绪模型配置")
mood: TaskConfig = Field(..., description="心情模型配置")
vlm: TaskConfig = Field(..., description="视觉语言模型配置") vlm: TaskConfig = Field(..., description="视觉语言模型配置")
voice: TaskConfig = Field(..., description="语音识别模型配置") voice: TaskConfig = Field(..., description="语音识别模型配置")
tool_use: TaskConfig = Field(..., description="专注工具使用模型配置") tool_use: TaskConfig = Field(..., description="专注工具使用模型配置")
@@ -178,14 +178,26 @@ class ModelTaskConfig(ValidatedConfigBase):
class APIAdapterConfig(ValidatedConfigBase): class APIAdapterConfig(ValidatedConfigBase):
"""API Adapter配置类""" """API Adapter配置类"""
inner: InnerConfig = Field(..., description="配置元信息")
models: list[ModelInfo] = Field(..., min_length=1, description="模型列表") models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置") model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表") api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
_api_providers_dict: dict[str, APIProvider] = PrivateAttr(default_factory=dict)
_models_dict: dict[str, ModelInfo] = PrivateAttr(default_factory=dict)
def __init__(self, **data): def __init__(self, **data):
super().__init__(**data) super().__init__(**data)
self.api_providers_dict = {provider.name: provider for provider in self.api_providers} self._api_providers_dict = {provider.name: provider for provider in self.api_providers}
self.models_dict = {model.name: model for model in self.models} self._models_dict = {model.name: model for model in self.models}
@property
def api_providers_dict(self) -> dict[str, APIProvider]:
return self._api_providers_dict
@property
def models_dict(self) -> dict[str, ModelInfo]:
return self._models_dict
@classmethod @classmethod
def validate_models_list(cls, v): def validate_models_list(cls, v):

View File

@@ -1,10 +1,14 @@
import os import os
import shutil import shutil
import sys import sys
import typing
import types
from datetime import datetime from datetime import datetime
from pathlib import Path
from typing import Any, get_args, get_origin
import tomlkit import tomlkit
from pydantic import Field from pydantic import BaseModel, Field, PrivateAttr
from rich.traceback import install from rich.traceback import install
from tomlkit import TOMLDocument from tomlkit import TOMLDocument
from tomlkit.items import KeyType, Table from tomlkit.items import KeyType, Table
@@ -25,6 +29,8 @@ from src.config.official_configs import (
EmojiConfig, EmojiConfig,
ExperimentalConfig, ExperimentalConfig,
ExpressionConfig, ExpressionConfig,
InnerConfig,
LogConfig,
KokoroFlowChatterConfig, KokoroFlowChatterConfig,
LPMMKnowledgeConfig, LPMMKnowledgeConfig,
MemoryConfig, MemoryConfig,
@@ -180,6 +186,76 @@ def _remove_obsolete_keys(target: TOMLDocument | dict | Table, reference: TOMLDo
_remove_obsolete_keys(target[key], reference[key]) # type: ignore _remove_obsolete_keys(target[key], reference[key]) # type: ignore
def _prune_unknown_keys_by_schema(target: TOMLDocument | Table, schema_model: type[BaseModel]):
"""
基于 Pydantic Schema 递归移除未知配置键(含可重复的 AoT 表)。
说明:
- 只移除 schema 中不存在的键,避免跨版本遗留废弃配置项。
- 对于 list[BaseModel] 字段TOML 的 [[...]]),会遍历每个元素并递归清理。
- 对于 dict[str, Any] 等自由结构字段,不做键级裁剪。
"""
def _strip_optional(annotation: Any) -> Any:
origin = get_origin(annotation)
if origin is None:
return annotation
# 兼容 | None 与 Union[..., None]
union_type = getattr(types, "UnionType", None)
if origin is union_type or origin is typing.Union:
args = [a for a in get_args(annotation) if a is not type(None)]
if len(args) == 1:
return args[0]
return annotation
def _is_model_type(annotation: Any) -> bool:
return isinstance(annotation, type) and issubclass(annotation, BaseModel)
def _prune_table(table: TOMLDocument | Table, model: type[BaseModel]):
name_by_key: dict[str, str] = {}
allowed_keys: set[str] = set()
for field_name, field_info in model.model_fields.items():
allowed_keys.add(field_name)
name_by_key[field_name] = field_name
alias = getattr(field_info, "alias", None)
if isinstance(alias, str) and alias:
allowed_keys.add(alias)
name_by_key[alias] = field_name
for key in list(table.keys()):
if key not in allowed_keys:
del table[key]
continue
field_name = name_by_key[key]
field_info = model.model_fields[field_name]
annotation = _strip_optional(getattr(field_info, "annotation", Any))
value = table.get(key)
if value is None:
continue
if _is_model_type(annotation) and isinstance(value, (TOMLDocument, Table)):
_prune_table(value, annotation)
continue
origin = get_origin(annotation)
if origin is list:
args = get_args(annotation)
elem_ann = _strip_optional(args[0]) if args else Any
# list[BaseModel] 对应 TOML 的 AoT[[...]]
if _is_model_type(elem_ann) and hasattr(value, "__iter__"):
for item in value:
if isinstance(item, (TOMLDocument, Table)):
_prune_table(item, elem_ann)
_prune_table(target, schema_model)
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict): def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
""" """
将source字典的值更新到target字典中 将source字典的值更新到target字典中
@@ -232,13 +308,14 @@ def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dic
target[key] = value target[key] = value
def _update_config_generic(config_name: str, template_name: str): def _update_config_generic(config_name: str, template_name: str, schema_model: type[BaseModel] | None = None):
""" """
通用的配置文件更新函数 通用的配置文件更新函数
Args: Args:
config_name: 配置文件名(不含扩展名),如 'bot_config''model_config' config_name: 配置文件名(不含扩展名),如 'bot_config''model_config'
template_name: 模板文件名(不含扩展名),如 'bot_config_template''model_config_template' template_name: 模板文件名(不含扩展名),如 'bot_config_template''model_config_template'
schema_model: 用于裁剪未知键的 Pydantic 模型(避免跨版本遗留废弃配置项)
""" """
# 获取根目录路径 # 获取根目录路径
old_config_dir = os.path.join(CONFIG_DIR, "old") old_config_dir = os.path.join(CONFIG_DIR, "old")
@@ -355,8 +432,11 @@ def _update_config_generic(config_name: str, template_name: str):
logger.info(f"开始合并{config_name}新旧配置...") logger.info(f"开始合并{config_name}新旧配置...")
_update_dict(new_config, old_config) _update_dict(new_config, old_config)
# 移除在新模板中已不存在的旧配置项 # 移除未知/废弃的旧配置项(尤其是可重复的 [[...]] 段落)
logger.info(f"开始移除{config_name}中已废弃的配置项...") logger.info(f"开始移除{config_name}中已废弃的配置项...")
if schema_model is not None:
_prune_unknown_keys_by_schema(new_config, schema_model)
else:
with open(template_path, encoding="utf-8") as f: with open(template_path, encoding="utf-8") as f:
template_doc = tomlkit.load(f) template_doc = tomlkit.load(f)
_remove_obsolete_keys(new_config, template_doc) _remove_obsolete_keys(new_config, template_doc)
@@ -370,18 +450,18 @@ def _update_config_generic(config_name: str, template_name: str):
def update_config(): def update_config():
"""更新bot_config.toml配置文件""" """更新bot_config.toml配置文件"""
_update_config_generic("bot_config", "bot_config_template") _update_config_generic("bot_config", "bot_config_template", schema_model=Config)
def update_model_config(): def update_model_config():
"""更新model_config.toml配置文件""" """更新model_config.toml配置文件"""
_update_config_generic("model_config", "model_config_template") _update_config_generic("model_config", "model_config_template", schema_model=APIAdapterConfig)
class Config(ValidatedConfigBase): class Config(ValidatedConfigBase):
"""总配置类""" """总配置类"""
MMC_VERSION: str = Field(default=MMC_VERSION, description="MaiCore版本号") inner: InnerConfig = Field(..., description="配置元信息")
database: DatabaseConfig = Field(..., description="数据库配置") database: DatabaseConfig = Field(..., description="数据库配置")
bot: BotConfig = Field(..., description="机器人基本配置") bot: BotConfig = Field(..., description="机器人基本配置")
@@ -397,6 +477,7 @@ class Config(ValidatedConfigBase):
chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置") chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置")
response_post_process: ResponsePostProcessConfig = Field(..., description="响应后处理配置") response_post_process: ResponsePostProcessConfig = Field(..., description="响应后处理配置")
response_splitter: ResponseSplitterConfig = Field(..., description="响应分割配置") response_splitter: ResponseSplitterConfig = Field(..., description="响应分割配置")
log: LogConfig = Field(..., description="日志配置")
experimental: ExperimentalConfig = Field(default_factory=lambda: ExperimentalConfig(), description="实验性功能配置") experimental: ExperimentalConfig = Field(default_factory=lambda: ExperimentalConfig(), description="实验性功能配置")
message_bus: MessageBusConfig = Field(..., description="消息总线配置") message_bus: MessageBusConfig = Field(..., description="消息总线配置")
lpmm_knowledge: LPMMKnowledgeConfig = Field(..., description="LPMM知识配置") lpmm_knowledge: LPMMKnowledgeConfig = Field(..., description="LPMM知识配置")
@@ -433,18 +514,34 @@ class Config(ValidatedConfigBase):
default_factory=lambda: PluginHttpSystemConfig(), description="插件HTTP端点系统配置" default_factory=lambda: PluginHttpSystemConfig(), description="插件HTTP端点系统配置"
) )
@property
def MMC_VERSION(self) -> str: # noqa: N802
return MMC_VERSION
class APIAdapterConfig(ValidatedConfigBase): class APIAdapterConfig(ValidatedConfigBase):
"""API Adapter配置类""" """API Adapter配置类"""
inner: InnerConfig = Field(..., description="配置元信息")
models: list[ModelInfo] = Field(..., min_length=1, description="模型列表") models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置") model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表") api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
_api_providers_dict: dict[str, APIProvider] = PrivateAttr(default_factory=dict)
_models_dict: dict[str, ModelInfo] = PrivateAttr(default_factory=dict)
def __init__(self, **data): def __init__(self, **data):
super().__init__(**data) super().__init__(**data)
self.api_providers_dict = {provider.name: provider for provider in self.api_providers} self._api_providers_dict = {provider.name: provider for provider in self.api_providers}
self.models_dict = {model.name: model for model in self.models} self._models_dict = {model.name: model for model in self.models}
@property
def api_providers_dict(self) -> dict[str, APIProvider]:
return self._api_providers_dict
@property
def models_dict(self) -> dict[str, ModelInfo]:
return self._models_dict
@classmethod @classmethod
def validate_models_list(cls, v): def validate_models_list(cls, v):
@@ -502,9 +599,14 @@ def load_config(config_path: str) -> Config:
Returns: Returns:
Config对象 Config对象
""" """
# 读取配置文件 # 读取配置文件(会自动删除未知/废弃配置项)
with open(config_path, encoding="utf-8") as f: original_text = Path(config_path).read_text(encoding="utf-8")
config_data = tomlkit.load(f) config_data = tomlkit.parse(original_text)
_prune_unknown_keys_by_schema(config_data, Config)
new_text = tomlkit.dumps(config_data)
if new_text != original_text:
Path(config_path).write_text(new_text, encoding="utf-8")
logger.warning(f"已自动移除 {config_path} 中未知/废弃配置项")
# 将 tomlkit 对象转换为纯 Python 字典,避免 Pydantic 严格模式下的类型验证问题 # 将 tomlkit 对象转换为纯 Python 字典,避免 Pydantic 严格模式下的类型验证问题
# tomlkit 返回的是特殊类型(如 Array、String 等),虽然继承自 Python 标准类型, # tomlkit 返回的是特殊类型(如 Array、String 等),虽然继承自 Python 标准类型,
@@ -530,11 +632,16 @@ def api_ada_load_config(config_path: str) -> APIAdapterConfig:
Returns: Returns:
APIAdapterConfig对象 APIAdapterConfig对象
""" """
# 读取配置文件 # 读取配置文件(会自动删除未知/废弃配置项)
with open(config_path, encoding="utf-8") as f: original_text = Path(config_path).read_text(encoding="utf-8")
config_data = tomlkit.load(f) config_data = tomlkit.parse(original_text)
_prune_unknown_keys_by_schema(config_data, APIAdapterConfig)
new_text = tomlkit.dumps(config_data)
if new_text != original_text:
Path(config_path).write_text(new_text, encoding="utf-8")
logger.warning(f"已自动移除 {config_path} 中未知/废弃配置项")
config_dict = dict(config_data) config_dict = config_data.unwrap()
try: try:
logger.debug("正在解析和验证API适配器配置文件...") logger.debug("正在解析和验证API适配器配置文件...")

View File

@@ -142,7 +142,7 @@ class ValidatedConfigBase(BaseModel):
"""带验证的配置基类继承自Pydantic BaseModel""" """带验证的配置基类继承自Pydantic BaseModel"""
model_config = { model_config = {
"extra": "allow", # 允许额外字段 "extra": "forbid", # 禁止额外字段(防止跨版本遗留废弃配置项)
"validate_assignment": True, # 验证赋值 "validate_assignment": True, # 验证赋值
"arbitrary_types_allowed": True, # 允许任意类型 "arbitrary_types_allowed": True, # 允许任意类型
"strict": True, # 如果设为 True 会完全禁用类型转换 "strict": True, # 如果设为 True 会完全禁用类型转换

View File

@@ -13,6 +13,12 @@ from src.config.config_base import ValidatedConfigBase
""" """
class InnerConfig(ValidatedConfigBase):
"""配置文件元信息"""
version: str = Field(..., description="配置文件版本号(用于配置文件升级与兼容性检查)")
class DatabaseConfig(ValidatedConfigBase): class DatabaseConfig(ValidatedConfigBase):
"""数据库配置类""" """数据库配置类"""
@@ -588,6 +594,20 @@ class ResponseSplitterConfig(ValidatedConfigBase):
enable_kaomoji_protection: bool = Field(default=False, description="启用颜文字保护") enable_kaomoji_protection: bool = Field(default=False, description="启用颜文字保护")
class LogConfig(ValidatedConfigBase):
"""日志配置类"""
date_style: str = Field(default="m-d H:i:s", description="日期格式")
log_level_style: str = Field(default="lite", description="日志级别样式")
color_text: str = Field(default="full", description="日志文本颜色")
log_level: str = Field(default="INFO", description="全局日志级别(向下兼容,优先级低于分别设置)")
file_retention_days: int = Field(default=7, description="文件日志保留天数0=禁用文件日志,-1=永不删除")
console_log_level: str = Field(default="INFO", description="控制台日志级别")
file_log_level: str = Field(default="DEBUG", description="文件日志级别")
suppress_libraries: list[str] = Field(default_factory=list, description="完全屏蔽日志的第三方库列表")
library_log_levels: dict[str, str] = Field(default_factory=dict, description="设置特定库的日志级别")
class DebugConfig(ValidatedConfigBase): class DebugConfig(ValidatedConfigBase):
"""调试配置类""" """调试配置类"""
@@ -703,6 +723,7 @@ class WebSearchConfig(ValidatedConfigBase):
enable_url_tool: bool = Field(default=True, description="启用URL工具") enable_url_tool: bool = Field(default=True, description="启用URL工具")
tavily_api_keys: list[str] = Field(default_factory=lambda: [], description="Tavily API密钥列表支持轮询机制") tavily_api_keys: list[str] = Field(default_factory=lambda: [], description="Tavily API密钥列表支持轮询机制")
exa_api_keys: list[str] = Field(default_factory=lambda: [], description="exa API密钥列表支持轮询机制") exa_api_keys: list[str] = Field(default_factory=lambda: [], description="exa API密钥列表支持轮询机制")
metaso_api_keys: list[str] = Field(default_factory=lambda: [], description="Metaso API密钥列表支持轮询机制")
searxng_instances: list[str] = Field(default_factory=list, description="SearXNG 实例 URL 列表") searxng_instances: list[str] = Field(default_factory=list, description="SearXNG 实例 URL 列表")
searxng_api_keys: list[str] = Field(default_factory=list, description="SearXNG 实例 API 密钥列表") searxng_api_keys: list[str] = Field(default_factory=list, description="SearXNG 实例 API 密钥列表")
serper_api_keys: list[str] = Field(default_factory=list, description="serper API 密钥列表") serper_api_keys: list[str] = Field(default_factory=list, description="serper API 密钥列表")
@@ -988,6 +1009,12 @@ class KokoroFlowChatterConfig(ValidatedConfigBase):
description="开启后KFC将接管所有私聊消息关闭后私聊消息将由AFC处理" description="开启后KFC将接管所有私聊消息关闭后私聊消息将由AFC处理"
) )
# --- 工作模式 ---
mode: Literal["unified", "split"] = Field(
default="split",
description='工作模式: "unified"(单次调用) 或 "split"(planner+replyer两次调用)',
)
# --- 核心行为配置 --- # --- 核心行为配置 ---
max_wait_seconds_default: int = Field( max_wait_seconds_default: int = Field(
default=300, ge=30, le=3600, default=300, ge=30, le=3600,
@@ -998,6 +1025,12 @@ class KokoroFlowChatterConfig(ValidatedConfigBase):
description="是否在等待期间启用心理活动更新" description="是否在等待期间启用心理活动更新"
) )
# --- 自定义决策提示词 ---
custom_decision_prompt: str = Field(
default="",
description="自定义KFC决策行为指导提示词unified影响整体split仅影响planner",
)
waiting: KokoroFlowChatterWaitingConfig = Field( waiting: KokoroFlowChatterWaitingConfig = Field(
default_factory=KokoroFlowChatterWaitingConfig, default_factory=KokoroFlowChatterWaitingConfig,
description="等待策略配置(默认等待时间、倍率等)", description="等待策略配置(默认等待时间、倍率等)",