优化配置类,添加元信息和日志配置,调整验证策略以禁止额外字段
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
from threading import Lock
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
from src.config.config_base import ValidatedConfigBase
|
||||
from src.config.official_configs import InnerConfig
|
||||
|
||||
|
||||
class APIProvider(ValidatedConfigBase):
|
||||
@@ -21,6 +22,9 @@ class APIProvider(ValidatedConfigBase):
|
||||
)
|
||||
retry_interval: int = Field(default=10, ge=0, description="重试间隔(如果API调用失败,重试的间隔时间,单位:秒)")
|
||||
|
||||
_api_key_lock: Lock = PrivateAttr(default_factory=Lock)
|
||||
_api_key_index: int = PrivateAttr(default=0)
|
||||
|
||||
@classmethod
|
||||
def validate_base_url(cls, v):
|
||||
"""验证base_url,确保URL格式正确"""
|
||||
@@ -44,11 +48,6 @@ class APIProvider(ValidatedConfigBase):
|
||||
raise ValueError("API密钥必须是字符串或字符串列表")
|
||||
return v
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._api_key_lock = Lock()
|
||||
self._api_key_index = 0
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
with self._api_key_lock:
|
||||
if isinstance(self.api_key, str):
|
||||
@@ -134,6 +133,7 @@ class ModelTaskConfig(ValidatedConfigBase):
|
||||
replyer_private: TaskConfig = Field(..., description="normal_chat首要回复模型模型配置(私聊使用)")
|
||||
maizone: TaskConfig = Field(..., description="maizone专用模型")
|
||||
emotion: TaskConfig = Field(..., description="情绪模型配置")
|
||||
mood: TaskConfig = Field(..., description="心情模型配置")
|
||||
vlm: TaskConfig = Field(..., description="视觉语言模型配置")
|
||||
voice: TaskConfig = Field(..., description="语音识别模型配置")
|
||||
tool_use: TaskConfig = Field(..., description="专注工具使用模型配置")
|
||||
@@ -178,14 +178,26 @@ class ModelTaskConfig(ValidatedConfigBase):
|
||||
class APIAdapterConfig(ValidatedConfigBase):
|
||||
"""API Adapter配置类"""
|
||||
|
||||
inner: InnerConfig = Field(..., description="配置元信息")
|
||||
models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
|
||||
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
|
||||
api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
|
||||
|
||||
_api_providers_dict: dict[str, APIProvider] = PrivateAttr(default_factory=dict)
|
||||
_models_dict: dict[str, ModelInfo] = PrivateAttr(default_factory=dict)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
|
||||
self.models_dict = {model.name: model for model in self.models}
|
||||
self._api_providers_dict = {provider.name: provider for provider in self.api_providers}
|
||||
self._models_dict = {model.name: model for model in self.models}
|
||||
|
||||
@property
|
||||
def api_providers_dict(self) -> dict[str, APIProvider]:
|
||||
return self._api_providers_dict
|
||||
|
||||
@property
|
||||
def models_dict(self) -> dict[str, ModelInfo]:
|
||||
return self._models_dict
|
||||
|
||||
@classmethod
|
||||
def validate_models_list(cls, v):
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import typing
|
||||
import types
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, get_args, get_origin
|
||||
|
||||
import tomlkit
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from rich.traceback import install
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import KeyType, Table
|
||||
@@ -25,6 +29,8 @@ from src.config.official_configs import (
|
||||
EmojiConfig,
|
||||
ExperimentalConfig,
|
||||
ExpressionConfig,
|
||||
InnerConfig,
|
||||
LogConfig,
|
||||
KokoroFlowChatterConfig,
|
||||
LPMMKnowledgeConfig,
|
||||
MemoryConfig,
|
||||
@@ -180,6 +186,76 @@ def _remove_obsolete_keys(target: TOMLDocument | dict | Table, reference: TOMLDo
|
||||
_remove_obsolete_keys(target[key], reference[key]) # type: ignore
|
||||
|
||||
|
||||
def _prune_unknown_keys_by_schema(target: TOMLDocument | Table, schema_model: type[BaseModel]):
|
||||
"""
|
||||
基于 Pydantic Schema 递归移除未知配置键(含可重复的 AoT 表)。
|
||||
|
||||
说明:
|
||||
- 只移除 schema 中不存在的键,避免跨版本遗留废弃配置项。
|
||||
- 对于 list[BaseModel] 字段(TOML 的 [[...]]),会遍历每个元素并递归清理。
|
||||
- 对于 dict[str, Any] 等自由结构字段,不做键级裁剪。
|
||||
"""
|
||||
|
||||
def _strip_optional(annotation: Any) -> Any:
|
||||
origin = get_origin(annotation)
|
||||
if origin is None:
|
||||
return annotation
|
||||
|
||||
# 兼容 | None 与 Union[..., None]
|
||||
union_type = getattr(types, "UnionType", None)
|
||||
if origin is union_type or origin is typing.Union:
|
||||
args = [a for a in get_args(annotation) if a is not type(None)]
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return annotation
|
||||
|
||||
def _is_model_type(annotation: Any) -> bool:
|
||||
return isinstance(annotation, type) and issubclass(annotation, BaseModel)
|
||||
|
||||
def _prune_table(table: TOMLDocument | Table, model: type[BaseModel]):
|
||||
name_by_key: dict[str, str] = {}
|
||||
allowed_keys: set[str] = set()
|
||||
|
||||
for field_name, field_info in model.model_fields.items():
|
||||
allowed_keys.add(field_name)
|
||||
name_by_key[field_name] = field_name
|
||||
|
||||
alias = getattr(field_info, "alias", None)
|
||||
if isinstance(alias, str) and alias:
|
||||
allowed_keys.add(alias)
|
||||
name_by_key[alias] = field_name
|
||||
|
||||
for key in list(table.keys()):
|
||||
if key not in allowed_keys:
|
||||
del table[key]
|
||||
continue
|
||||
|
||||
field_name = name_by_key[key]
|
||||
field_info = model.model_fields[field_name]
|
||||
annotation = _strip_optional(getattr(field_info, "annotation", Any))
|
||||
|
||||
value = table.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if _is_model_type(annotation) and isinstance(value, (TOMLDocument, Table)):
|
||||
_prune_table(value, annotation)
|
||||
continue
|
||||
|
||||
origin = get_origin(annotation)
|
||||
if origin is list:
|
||||
args = get_args(annotation)
|
||||
elem_ann = _strip_optional(args[0]) if args else Any
|
||||
|
||||
# list[BaseModel] 对应 TOML 的 AoT([[...]])
|
||||
if _is_model_type(elem_ann) and hasattr(value, "__iter__"):
|
||||
for item in value:
|
||||
if isinstance(item, (TOMLDocument, Table)):
|
||||
_prune_table(item, elem_ann)
|
||||
|
||||
_prune_table(target, schema_model)
|
||||
|
||||
|
||||
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
|
||||
"""
|
||||
将source字典的值更新到target字典中
|
||||
@@ -232,13 +308,14 @@ def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dic
|
||||
target[key] = value
|
||||
|
||||
|
||||
def _update_config_generic(config_name: str, template_name: str):
|
||||
def _update_config_generic(config_name: str, template_name: str, schema_model: type[BaseModel] | None = None):
|
||||
"""
|
||||
通用的配置文件更新函数
|
||||
|
||||
Args:
|
||||
config_name: 配置文件名(不含扩展名),如 'bot_config' 或 'model_config'
|
||||
template_name: 模板文件名(不含扩展名),如 'bot_config_template' 或 'model_config_template'
|
||||
schema_model: 用于裁剪未知键的 Pydantic 模型(避免跨版本遗留废弃配置项)
|
||||
"""
|
||||
# 获取根目录路径
|
||||
old_config_dir = os.path.join(CONFIG_DIR, "old")
|
||||
@@ -355,11 +432,14 @@ def _update_config_generic(config_name: str, template_name: str):
|
||||
logger.info(f"开始合并{config_name}新旧配置...")
|
||||
_update_dict(new_config, old_config)
|
||||
|
||||
# 移除在新模板中已不存在的旧配置项
|
||||
# 移除未知/废弃的旧配置项(尤其是可重复的 [[...]] 段落)
|
||||
logger.info(f"开始移除{config_name}中已废弃的配置项...")
|
||||
with open(template_path, encoding="utf-8") as f:
|
||||
template_doc = tomlkit.load(f)
|
||||
_remove_obsolete_keys(new_config, template_doc)
|
||||
if schema_model is not None:
|
||||
_prune_unknown_keys_by_schema(new_config, schema_model)
|
||||
else:
|
||||
with open(template_path, encoding="utf-8") as f:
|
||||
template_doc = tomlkit.load(f)
|
||||
_remove_obsolete_keys(new_config, template_doc)
|
||||
logger.info(f"已移除{config_name}中已废弃的配置项")
|
||||
|
||||
# 保存更新后的配置(保留注释和格式)
|
||||
@@ -370,18 +450,18 @@ def _update_config_generic(config_name: str, template_name: str):
|
||||
|
||||
def update_config():
|
||||
"""更新bot_config.toml配置文件"""
|
||||
_update_config_generic("bot_config", "bot_config_template")
|
||||
_update_config_generic("bot_config", "bot_config_template", schema_model=Config)
|
||||
|
||||
|
||||
def update_model_config():
|
||||
"""更新model_config.toml配置文件"""
|
||||
_update_config_generic("model_config", "model_config_template")
|
||||
_update_config_generic("model_config", "model_config_template", schema_model=APIAdapterConfig)
|
||||
|
||||
|
||||
class Config(ValidatedConfigBase):
|
||||
"""总配置类"""
|
||||
|
||||
MMC_VERSION: str = Field(default=MMC_VERSION, description="MaiCore版本号")
|
||||
inner: InnerConfig = Field(..., description="配置元信息")
|
||||
|
||||
database: DatabaseConfig = Field(..., description="数据库配置")
|
||||
bot: BotConfig = Field(..., description="机器人基本配置")
|
||||
@@ -397,6 +477,7 @@ class Config(ValidatedConfigBase):
|
||||
chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置")
|
||||
response_post_process: ResponsePostProcessConfig = Field(..., description="响应后处理配置")
|
||||
response_splitter: ResponseSplitterConfig = Field(..., description="响应分割配置")
|
||||
log: LogConfig = Field(..., description="日志配置")
|
||||
experimental: ExperimentalConfig = Field(default_factory=lambda: ExperimentalConfig(), description="实验性功能配置")
|
||||
message_bus: MessageBusConfig = Field(..., description="消息总线配置")
|
||||
lpmm_knowledge: LPMMKnowledgeConfig = Field(..., description="LPMM知识配置")
|
||||
@@ -433,18 +514,34 @@ class Config(ValidatedConfigBase):
|
||||
default_factory=lambda: PluginHttpSystemConfig(), description="插件HTTP端点系统配置"
|
||||
)
|
||||
|
||||
@property
|
||||
def MMC_VERSION(self) -> str: # noqa: N802
|
||||
return MMC_VERSION
|
||||
|
||||
|
||||
class APIAdapterConfig(ValidatedConfigBase):
|
||||
"""API Adapter配置类"""
|
||||
|
||||
inner: InnerConfig = Field(..., description="配置元信息")
|
||||
models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
|
||||
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
|
||||
api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
|
||||
|
||||
_api_providers_dict: dict[str, APIProvider] = PrivateAttr(default_factory=dict)
|
||||
_models_dict: dict[str, ModelInfo] = PrivateAttr(default_factory=dict)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
|
||||
self.models_dict = {model.name: model for model in self.models}
|
||||
self._api_providers_dict = {provider.name: provider for provider in self.api_providers}
|
||||
self._models_dict = {model.name: model for model in self.models}
|
||||
|
||||
@property
|
||||
def api_providers_dict(self) -> dict[str, APIProvider]:
|
||||
return self._api_providers_dict
|
||||
|
||||
@property
|
||||
def models_dict(self) -> dict[str, ModelInfo]:
|
||||
return self._models_dict
|
||||
|
||||
@classmethod
|
||||
def validate_models_list(cls, v):
|
||||
@@ -502,9 +599,14 @@ def load_config(config_path: str) -> Config:
|
||||
Returns:
|
||||
Config对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
# 读取配置文件(会自动删除未知/废弃配置项)
|
||||
original_text = Path(config_path).read_text(encoding="utf-8")
|
||||
config_data = tomlkit.parse(original_text)
|
||||
_prune_unknown_keys_by_schema(config_data, Config)
|
||||
new_text = tomlkit.dumps(config_data)
|
||||
if new_text != original_text:
|
||||
Path(config_path).write_text(new_text, encoding="utf-8")
|
||||
logger.warning(f"已自动移除 {config_path} 中未知/废弃配置项")
|
||||
|
||||
# 将 tomlkit 对象转换为纯 Python 字典,避免 Pydantic 严格模式下的类型验证问题
|
||||
# tomlkit 返回的是特殊类型(如 Array、String 等),虽然继承自 Python 标准类型,
|
||||
@@ -530,11 +632,16 @@ def api_ada_load_config(config_path: str) -> APIAdapterConfig:
|
||||
Returns:
|
||||
APIAdapterConfig对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
# 读取配置文件(会自动删除未知/废弃配置项)
|
||||
original_text = Path(config_path).read_text(encoding="utf-8")
|
||||
config_data = tomlkit.parse(original_text)
|
||||
_prune_unknown_keys_by_schema(config_data, APIAdapterConfig)
|
||||
new_text = tomlkit.dumps(config_data)
|
||||
if new_text != original_text:
|
||||
Path(config_path).write_text(new_text, encoding="utf-8")
|
||||
logger.warning(f"已自动移除 {config_path} 中未知/废弃配置项")
|
||||
|
||||
config_dict = dict(config_data)
|
||||
config_dict = config_data.unwrap()
|
||||
|
||||
try:
|
||||
logger.debug("正在解析和验证API适配器配置文件...")
|
||||
|
||||
@@ -142,7 +142,7 @@ class ValidatedConfigBase(BaseModel):
|
||||
"""带验证的配置基类,继承自Pydantic BaseModel"""
|
||||
|
||||
model_config = {
|
||||
"extra": "allow", # 允许额外字段
|
||||
"extra": "forbid", # 禁止额外字段(防止跨版本遗留废弃配置项)
|
||||
"validate_assignment": True, # 验证赋值
|
||||
"arbitrary_types_allowed": True, # 允许任意类型
|
||||
"strict": True, # 如果设为 True 会完全禁用类型转换
|
||||
|
||||
@@ -13,6 +13,12 @@ from src.config.config_base import ValidatedConfigBase
|
||||
"""
|
||||
|
||||
|
||||
class InnerConfig(ValidatedConfigBase):
|
||||
"""配置文件元信息"""
|
||||
|
||||
version: str = Field(..., description="配置文件版本号(用于配置文件升级与兼容性检查)")
|
||||
|
||||
|
||||
class DatabaseConfig(ValidatedConfigBase):
|
||||
"""数据库配置类"""
|
||||
|
||||
@@ -588,6 +594,20 @@ class ResponseSplitterConfig(ValidatedConfigBase):
|
||||
enable_kaomoji_protection: bool = Field(default=False, description="启用颜文字保护")
|
||||
|
||||
|
||||
class LogConfig(ValidatedConfigBase):
|
||||
"""日志配置类"""
|
||||
|
||||
date_style: str = Field(default="m-d H:i:s", description="日期格式")
|
||||
log_level_style: str = Field(default="lite", description="日志级别样式")
|
||||
color_text: str = Field(default="full", description="日志文本颜色")
|
||||
log_level: str = Field(default="INFO", description="全局日志级别(向下兼容,优先级低于分别设置)")
|
||||
file_retention_days: int = Field(default=7, description="文件日志保留天数,0=禁用文件日志,-1=永不删除")
|
||||
console_log_level: str = Field(default="INFO", description="控制台日志级别")
|
||||
file_log_level: str = Field(default="DEBUG", description="文件日志级别")
|
||||
suppress_libraries: list[str] = Field(default_factory=list, description="完全屏蔽日志的第三方库列表")
|
||||
library_log_levels: dict[str, str] = Field(default_factory=dict, description="设置特定库的日志级别")
|
||||
|
||||
|
||||
class DebugConfig(ValidatedConfigBase):
|
||||
"""调试配置类"""
|
||||
|
||||
@@ -703,6 +723,7 @@ class WebSearchConfig(ValidatedConfigBase):
|
||||
enable_url_tool: bool = Field(default=True, description="启用URL工具")
|
||||
tavily_api_keys: list[str] = Field(default_factory=lambda: [], description="Tavily API密钥列表,支持轮询机制")
|
||||
exa_api_keys: list[str] = Field(default_factory=lambda: [], description="exa API密钥列表,支持轮询机制")
|
||||
metaso_api_keys: list[str] = Field(default_factory=lambda: [], description="Metaso API密钥列表,支持轮询机制")
|
||||
searxng_instances: list[str] = Field(default_factory=list, description="SearXNG 实例 URL 列表")
|
||||
searxng_api_keys: list[str] = Field(default_factory=list, description="SearXNG 实例 API 密钥列表")
|
||||
serper_api_keys: list[str] = Field(default_factory=list, description="serper API 密钥列表")
|
||||
@@ -988,6 +1009,12 @@ class KokoroFlowChatterConfig(ValidatedConfigBase):
|
||||
description="开启后KFC将接管所有私聊消息;关闭后私聊消息将由AFC处理"
|
||||
)
|
||||
|
||||
# --- 工作模式 ---
|
||||
mode: Literal["unified", "split"] = Field(
|
||||
default="split",
|
||||
description='工作模式: "unified"(单次调用) 或 "split"(planner+replyer两次调用)',
|
||||
)
|
||||
|
||||
# --- 核心行为配置 ---
|
||||
max_wait_seconds_default: int = Field(
|
||||
default=300, ge=30, le=3600,
|
||||
@@ -998,6 +1025,12 @@ class KokoroFlowChatterConfig(ValidatedConfigBase):
|
||||
description="是否在等待期间启用心理活动更新"
|
||||
)
|
||||
|
||||
# --- 自定义决策提示词 ---
|
||||
custom_decision_prompt: str = Field(
|
||||
default="",
|
||||
description="自定义KFC决策行为指导提示词(unified影响整体,split仅影响planner)",
|
||||
)
|
||||
|
||||
waiting: KokoroFlowChatterWaitingConfig = Field(
|
||||
default_factory=KokoroFlowChatterWaitingConfig,
|
||||
description="等待策略配置(默认等待时间、倍率等)",
|
||||
|
||||
Reference in New Issue
Block a user