重构配置文件,新增API提供商和模型信息类,优化配置加载逻辑

This commit is contained in:
墨梓柒
2025-07-23 00:30:05 +08:00
parent c4e76b45dc
commit d27d175f54
3 changed files with 439 additions and 3 deletions

View File

@@ -5,7 +5,6 @@ from packaging.version import Version
NEWEST_VER = "0.1.0" # 当前支持的最新版本
@dataclass
class APIProvider:
name: str = "" # API提供商名称
@@ -62,6 +61,7 @@ class ModelUsageArgConfig:
) # 任务使用的模型列表
@dataclass
class ModuleConfig:
INNER_VERSION: Version | None = None # 配置文件版本
@@ -73,4 +73,4 @@ class ModuleConfig:
models: Dict[str, ModelInfo] = field(default_factory=lambda: {}) # 模型列表
task_model_arg_map: Dict[str, ModelUsageArgConfig] = field(
default_factory=lambda: {}
)
)

View File

@@ -7,8 +7,13 @@ from tomlkit import TOMLDocument
from tomlkit.items import Table, KeyType
from dataclasses import field, dataclass
from rich.traceback import install
from packaging import version
from packaging.specifiers import SpecifierSet
from packaging.version import Version, InvalidVersion
from typing import Any, Dict, List
from src.common.logger import get_logger
from src.common.message import api
from src.config.config_base import ConfigBase
from src.config.official_configs import (
BotConfig,
@@ -35,6 +40,17 @@ from src.config.official_configs import (
CustomPromptConfig,
)
from .api_ada_configs import (
ModelUsageArgConfigItem,
ModelUsageArgConfig,
APIProvider,
ModelInfo,
NEWEST_VER,
ModuleConfig,
)
install(extra_lines=3)
@@ -51,6 +67,256 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
MMC_VERSION = "0.9.0-snapshot.2"
def _get_config_version(toml: Dict) -> Version:
"""提取配置文件的 SpecifierSet 版本数据
Args:
toml[dict]: 输入的配置文件字典
Returns:
Version
"""
if "inner" in toml and "version" in toml["inner"]:
config_version: str = toml["inner"]["version"]
else:
config_version = "0.0.0" # 默认版本
try:
ver = version.parse(config_version)
except InvalidVersion as e:
logger.error(
"配置文件中 inner段 的 version 键是错误的版本描述\n"
f"请检查配置文件,当前 version 键: {config_version}\n"
f"错误信息: {e}"
)
raise InvalidVersion(
"配置文件中 inner段 的 version 键是错误的版本描述\n"
) from e
return ver
def _request_conf(parent: Dict, config: ModuleConfig):
request_conf_config = parent.get("request_conf")
config.req_conf.max_retry = request_conf_config.get(
"max_retry", config.req_conf.max_retry
)
config.req_conf.timeout = request_conf_config.get(
"timeout", config.req_conf.timeout
)
config.req_conf.retry_interval = request_conf_config.get(
"retry_interval", config.req_conf.retry_interval
)
config.req_conf.default_temperature = request_conf_config.get(
"default_temperature", config.req_conf.default_temperature
)
config.req_conf.default_max_tokens = request_conf_config.get(
"default_max_tokens", config.req_conf.default_max_tokens
)
def _api_providers(parent: Dict, config: ModuleConfig):
api_providers_config = parent.get("api_providers")
for provider in api_providers_config:
name = provider.get("name", None)
base_url = provider.get("base_url", None)
api_key = provider.get("api_key", None)
client_type = provider.get("client_type", "openai")
if name in config.api_providers: # 查重
logger.error(f"重复的API提供商名称: {name},请检查配置文件。")
raise KeyError(f"重复的API提供商名称: {name},请检查配置文件。")
if name and base_url:
config.api_providers[name] = APIProvider(
name=name,
base_url=base_url,
api_key=api_key,
client_type=client_type,
)
else:
logger.error(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
raise ValueError(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
def _models(parent: Dict, config: ModuleConfig):
models_config = parent.get("models")
for model in models_config:
model_identifier = model.get("model_identifier", None)
name = model.get("name", model_identifier)
api_provider = model.get("api_provider", None)
price_in = model.get("price_in", 0.0)
price_out = model.get("price_out", 0.0)
force_stream_mode = model.get("force_stream_mode", False)
if name in config.models: # 查重
logger.error(f"重复的模型名称: {name},请检查配置文件。")
raise KeyError(f"重复的模型名称: {name},请检查配置文件。")
if model_identifier and api_provider:
# 检查API提供商是否存在
if api_provider not in config.api_providers:
logger.error(f"未声明的API提供商 '{api_provider}' ,请检查配置文件。")
raise ValueError(
f"未声明的API提供商 '{api_provider}' ,请检查配置文件。"
)
config.models[name] = ModelInfo(
name=name,
model_identifier=model_identifier,
api_provider=api_provider,
price_in=price_in,
price_out=price_out,
force_stream_mode=force_stream_mode,
)
else:
logger.error(f"模型 '{name}' 的配置不完整,请检查配置文件。")
raise ValueError(f"模型 '{name}' 的配置不完整,请检查配置文件。")
def _task_model_usage(parent: Dict, config: ModuleConfig):
model_usage_configs = parent.get("task_model_usage")
config.task_model_arg_map = {}
for task_name, item in model_usage_configs.items():
if task_name in config.task_model_arg_map:
logger.error(f"子任务 {task_name} 已存在,请检查配置文件。")
raise KeyError(f"子任务 {task_name} 已存在,请检查配置文件。")
usage = []
if isinstance(item, Dict):
if "model" in item:
usage.append(
ModelUsageArgConfigItem(
name=item["model"],
temperature=item.get("temperature", None),
max_tokens=item.get("max_tokens", None),
max_retry=item.get("max_retry", None),
)
)
else:
logger.error(f"子任务 {task_name} 的模型配置不合法,请检查配置文件。")
raise ValueError(
f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
)
elif isinstance(item, List):
for model in item:
if isinstance(model, Dict):
usage.append(
ModelUsageArgConfigItem(
name=model["model"],
temperature=model.get("temperature", None),
max_tokens=model.get("max_tokens", None),
max_retry=model.get("max_retry", None),
)
)
elif isinstance(model, str):
usage.append(
ModelUsageArgConfigItem(
name=model,
temperature=None,
max_tokens=None,
max_retry=None,
)
)
else:
logger.error(
f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
)
raise ValueError(
f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
)
elif isinstance(item, str):
usage.append(
ModelUsageArgConfigItem(
name=item,
temperature=None,
max_tokens=None,
max_retry=None,
)
)
config.task_model_arg_map[task_name] = ModelUsageArgConfig(
name=task_name,
usage=usage,
)
def api_ada_load_config(config_path: str) -> ModuleConfig:
"""从TOML配置文件加载配置"""
config = ModuleConfig()
include_configs: Dict[str, Dict[str, Any]] = {
"request_conf": {
"func": _request_conf,
"support": ">=0.0.0",
"necessary": False,
},
"api_providers": {"func": _api_providers, "support": ">=0.0.0"},
"models": {"func": _models, "support": ">=0.0.0"},
"task_model_usage": {"func": _task_model_usage, "support": ">=0.0.0"},
}
if os.path.exists(config_path):
with open(config_path, "rb") as f:
try:
toml_dict = tomlkit.load(f)
except tomlkit.TOMLDecodeError as e:
logger.critical(
f"配置文件model_list.toml填写有误请检查第{e.lineno}行第{e.colno}处:{e.msg}"
)
exit(1)
# 获取配置文件版本
config.INNER_VERSION = _get_config_version(toml_dict)
# 检查版本
if config.INNER_VERSION > Version(NEWEST_VER):
logger.warning(
f"当前配置文件版本 {config.INNER_VERSION} 高于支持的最新版本 {NEWEST_VER},可能导致异常,建议更新依赖。"
)
# 解析配置文件
# 如果在配置中找到了需要的项,调用对应项的闭包函数处理
for key in include_configs:
if key in toml_dict:
group_specifier_set: SpecifierSet = SpecifierSet(
include_configs[key]["support"]
)
# 检查配置文件版本是否在支持范围内
if config.INNER_VERSION in group_specifier_set:
# 如果版本在支持范围内,检查是否存在通知
if "notice" in include_configs[key]:
logger.warning(include_configs[key]["notice"])
# 调用闭包函数处理配置
(include_configs[key]["func"])(toml_dict, config)
else:
# 如果版本不在支持范围内,崩溃并提示用户
logger.error(
f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n"
f"当前程序仅支持以下版本范围: {group_specifier_set}"
)
raise InvalidVersion(
f"当前程序仅支持以下版本范围: {group_specifier_set}"
)
# 如果 necessary 项目存在,而且显式声明是 False进入特殊处理
elif (
"necessary" in include_configs[key]
and include_configs[key].get("necessary") is False
):
# 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
if key == "keywords_reaction":
pass
else:
# 如果用户根本没有需要的配置项,提示缺少配置
logger.error(f"配置文件中缺少必需的字段: '{key}'")
raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
logger.success(f"成功加载配置文件: {config_path}")
return config
def get_key_comment(toml_table, key):
# 获取key的注释如果有
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
@@ -300,6 +566,174 @@ def update_config():
quit()
def update_model_config():
"""更新model_config.toml配置文件"""
# 获取根目录路径
old_config_dir = os.path.join(CONFIG_DIR, "old")
compare_dir = os.path.join(TEMPLATE_DIR, "compare")
# 定义文件路径
template_path = os.path.join(TEMPLATE_DIR, "model_config_template.toml")
old_config_path = os.path.join(CONFIG_DIR, "model_config.toml")
new_config_path = os.path.join(CONFIG_DIR, "model_config.toml")
compare_path = os.path.join(compare_dir, "model_config_template.toml")
# 创建compare目录如果不存在
os.makedirs(compare_dir, exist_ok=True)
# 处理compare下的模板文件
def get_version_from_toml(toml_path):
if not os.path.exists(toml_path):
return None
with open(toml_path, "r", encoding="utf-8") as f:
doc = tomlkit.load(f)
if "inner" in doc and "version" in doc["inner"]: # type: ignore
return doc["inner"]["version"] # type: ignore
return None
template_version = get_version_from_toml(template_path)
compare_version = get_version_from_toml(compare_path)
def version_tuple(v):
if v is None:
return (0,)
return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
# 先读取 compare 下的模板(如果有),用于默认值变动检测
if os.path.exists(compare_path):
with open(compare_path, "r", encoding="utf-8") as f:
compare_config = tomlkit.load(f)
else:
compare_config = None
# 读取当前模板
with open(template_path, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 检查默认值变化并处理(只有 compare_config 存在时才做)
if compare_config is not None:
# 读取旧配置
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
logs, changes = compare_default_values(new_config, compare_config)
if logs:
logger.info("检测到model_config模板默认值变动如下")
for log in logs:
logger.info(log)
# 检查旧配置是否等于旧默认值,如果是则更新为新默认值
for path, old_default, new_default in changes:
old_value = get_value_by_path(old_config, path)
if old_value == old_default:
set_value_by_path(old_config, path, new_default)
logger.info(
f"已自动将model_config配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
)
else:
logger.info("未检测到model_config模板默认值变动")
# 保存旧配置的变更(后续合并逻辑会用到 old_config
else:
old_config = None
# 检查 compare 下没有模板,或新模板版本更高,则复制
if not os.path.exists(compare_path):
shutil.copy2(template_path, compare_path)
logger.info(f"已将model_config模板文件复制到: {compare_path}")
else:
if version_tuple(template_version) > version_tuple(compare_version):
shutil.copy2(template_path, compare_path)
logger.info(f"model_config模板版本较新已替换compare下的模板: {compare_path}")
else:
logger.debug(f"compare下的model_config模板版本不低于当前模板无需替换: {compare_path}")
# 检查配置文件是否存在
if not os.path.exists(old_config_path):
logger.info("model_config.toml配置文件不存在从模板创建新配置")
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
shutil.copy2(template_path, old_config_path) # 复制模板文件
logger.info(f"已创建新model_config配置文件请填写后重新运行: {old_config_path}")
# 如果是新创建的配置文件,直接返回
return
# 读取旧配置文件和模板文件(如果前面没读过 old_config这里再读一次
if old_config is None:
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
# new_config 已经读取
# 读取 compare_config 只用于默认值变动检测,后续合并逻辑不再用
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version") # type: ignore
new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version:
logger.info(f"检测到model_config配置文件版本号相同 (v{old_version}),跳过更新")
return
else:
logger.info(
f"\n----------------------------------------\n检测到model_config版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
)
else:
logger.info("已有model_config配置文件未检测到版本号可能是旧版本。将进行更新")
# 创建old目录如果不存在
os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = os.path.join(old_config_dir, f"model_config_{timestamp}.toml")
# 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path)
logger.info(f"已备份旧model_config配置文件到: {old_backup_path}")
# 复制模板文件到配置目录
shutil.copy2(template_path, new_config_path)
logger.info(f"已创建新model_config配置文件: {new_config_path}")
# 输出新增和删减项及注释
if old_config:
logger.info("model_config配置项变动如下\n----------------------------------------")
logs = compare_dicts(new_config, old_config)
if logs:
for log in logs:
logger.info(log)
else:
logger.info("无新增或删减项")
def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中如果target中存在相同的键
"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
target_value = target[key]
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
update_dict(target_value, value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
# 将旧配置的值更新到新配置中
logger.info("开始合并model_config新旧配置...")
update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式)
with open(new_config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
logger.info("model_config配置文件更新完成建议检查新配置文件中的内容以免丢失重要信息")
@dataclass
class Config(ConfigBase):
"""总配置类"""
@@ -359,7 +793,9 @@ def get_config_dir() -> str:
# 获取配置文件路径
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
update_config()
update_model_config()
logger.info("正在品鉴配置文件...")
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
logger.info("非常的新鲜,非常的美味!")

View File

@@ -4,6 +4,7 @@ from dataclasses import dataclass, field
from typing import Any, Literal, Optional
from src.config.config_base import ConfigBase
from packaging.version import Version
"""
须知:
@@ -605,7 +606,6 @@ class LPMMKnowledgeConfig(ConfigBase):
embedding_dimension: int = 1024
"""嵌入向量维度,应该与模型的输出维度一致"""
@dataclass
class ModelConfig(ConfigBase):
"""模型配置类"""