大修LLMReq

This commit is contained in:
UnCLAS-Prommer
2025-07-30 09:45:13 +08:00
parent 94db64c118
commit 3c40ceda4c
15 changed files with 2290 additions and 1995 deletions

View File

@@ -1,16 +1,14 @@
import os
import tomlkit
import shutil
import sys
from datetime import datetime
from tomlkit import TOMLDocument
from tomlkit.items import Table, KeyType
from dataclasses import field, dataclass
from rich.traceback import install
from packaging import version
from packaging.specifiers import SpecifierSet
from packaging.version import Version, InvalidVersion
from typing import Any, Dict, List
from typing import List, Optional
from src.common.logger import get_logger
from src.config.config_base import ConfigBase
@@ -29,7 +27,6 @@ from src.config.official_configs import (
ResponseSplitterConfig,
TelemetryConfig,
ExperimentalConfig,
ModelConfig,
MessageReceiveConfig,
MaimMessageConfig,
LPMMKnowledgeConfig,
@@ -41,16 +38,12 @@ from src.config.official_configs import (
)
from .api_ada_configs import (
ModelUsageArgConfigItem,
ModelUsageArgConfig,
APIProvider,
ModelTaskConfig,
ModelInfo,
NEWEST_VER,
ModuleConfig,
APIProvider,
)
install(extra_lines=3)
@@ -64,275 +57,270 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.10.0-snapshot1"
MMC_VERSION = "0.10.0-snapshot.2"
# def _get_config_version(toml: Dict) -> Version:
# """提取配置文件的 SpecifierSet 版本数据
# Args:
# toml[dict]: 输入的配置文件字典
# Returns:
# Version
# """
# if "inner" in toml and "version" in toml["inner"]:
# config_version: str = toml["inner"]["version"]
# else:
# raise InvalidVersion("配置文件缺少版本信息,请检查配置文件。")
# try:
# return version.parse(config_version)
# except InvalidVersion as e:
# logger.error(
# "配置文件中 inner段 的 version 键是错误的版本描述\n"
# f"请检查配置文件,当前 version 键: {config_version}\n"
# f"错误信息: {e}"
# )
# raise e
def _get_config_version(toml: Dict) -> Version:
"""提取配置文件的 SpecifierSet 版本数据
Args:
toml[dict]: 输入的配置文件字典
Returns:
Version
"""
if "inner" in toml and "version" in toml["inner"]:
config_version: str = toml["inner"]["version"]
else:
config_version = "0.0.0" # 默认版本
try:
ver = version.parse(config_version)
except InvalidVersion as e:
logger.error(
"配置文件中 inner段 的 version 键是错误的版本描述\n"
f"请检查配置文件,当前 version 键: {config_version}\n"
f"错误信息: {e}"
)
raise InvalidVersion(
"配置文件中 inner段 的 version 键是错误的版本描述\n"
) from e
return ver
# def _request_conf(parent: Dict, config: ModuleConfig):
# request_conf_config = parent.get("request_conf")
# config.req_conf.max_retry = request_conf_config.get(
# "max_retry", config.req_conf.max_retry
# )
# config.req_conf.timeout = request_conf_config.get(
# "timeout", config.req_conf.timeout
# )
# config.req_conf.retry_interval = request_conf_config.get(
# "retry_interval", config.req_conf.retry_interval
# )
# config.req_conf.default_temperature = request_conf_config.get(
# "default_temperature", config.req_conf.default_temperature
# )
# config.req_conf.default_max_tokens = request_conf_config.get(
# "default_max_tokens", config.req_conf.default_max_tokens
# )
def _request_conf(parent: Dict, config: ModuleConfig):
request_conf_config = parent.get("request_conf")
config.req_conf.max_retry = request_conf_config.get(
"max_retry", config.req_conf.max_retry
)
config.req_conf.timeout = request_conf_config.get(
"timeout", config.req_conf.timeout
)
config.req_conf.retry_interval = request_conf_config.get(
"retry_interval", config.req_conf.retry_interval
)
config.req_conf.default_temperature = request_conf_config.get(
"default_temperature", config.req_conf.default_temperature
)
config.req_conf.default_max_tokens = request_conf_config.get(
"default_max_tokens", config.req_conf.default_max_tokens
)
# def _api_providers(parent: Dict, config: ModuleConfig):
# api_providers_config = parent.get("api_providers")
# for provider in api_providers_config:
# name = provider.get("name", None)
# base_url = provider.get("base_url", None)
# api_key = provider.get("api_key", None)
# api_keys = provider.get("api_keys", []) # 新增支持多个API Key
# client_type = provider.get("client_type", "openai")
# if name in config.api_providers: # 查重
# logger.error(f"重复的API提供商名称: {name},请检查配置文件。")
# raise KeyError(f"重复的API提供商名称: {name},请检查配置文件。")
# if name and base_url:
# # 处理API Key配置支持单个api_key或多个api_keys
# if api_keys:
# # 使用新格式api_keys列表
# logger.debug(f"API提供商 '{name}' 配置了 {len(api_keys)} 个API Key")
# elif api_key:
# # 向后兼容使用单个api_key
# api_keys = [api_key]
# logger.debug(f"API提供商 '{name}' 使用单个API Key向后兼容模式")
# else:
# logger.warning(f"API提供商 '{name}' 没有配置API Key某些功能可能不可用")
# config.api_providers[name] = APIProvider(
# name=name,
# base_url=base_url,
# api_key=api_key, # 保留向后兼容
# api_keys=api_keys, # 新格式
# client_type=client_type,
# )
# else:
# logger.error(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
# raise ValueError(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
def _api_providers(parent: Dict, config: ModuleConfig):
api_providers_config = parent.get("api_providers")
for provider in api_providers_config:
name = provider.get("name", None)
base_url = provider.get("base_url", None)
api_key = provider.get("api_key", None)
api_keys = provider.get("api_keys", []) # 新增支持多个API Key
client_type = provider.get("client_type", "openai")
# def _models(parent: Dict, config: ModuleConfig):
# models_config = parent.get("models")
# for model in models_config:
# model_identifier = model.get("model_identifier", None)
# name = model.get("name", model_identifier)
# api_provider = model.get("api_provider", None)
# price_in = model.get("price_in", 0.0)
# price_out = model.get("price_out", 0.0)
# force_stream_mode = model.get("force_stream_mode", False)
# task_type = model.get("task_type", "")
# capabilities = model.get("capabilities", [])
if name in config.api_providers: # 查重
logger.error(f"重复的API提供商名称: {name},请检查配置文件。")
raise KeyError(f"重复的API提供商名称: {name},请检查配置文件。")
# if name in config.models: # 查重
# logger.error(f"重复的模型名称: {name},请检查配置文件。")
# raise KeyError(f"重复的模型名称: {name},请检查配置文件。")
if name and base_url:
# 处理API Key配置支持单个api_key或多个api_keys
if api_keys:
# 使用新格式api_keys列表
logger.debug(f"API提供商 '{name}' 配置了 {len(api_keys)} 个API Key")
elif api_key:
# 向后兼容使用单个api_key
api_keys = [api_key]
logger.debug(f"API提供商 '{name}' 使用单个API Key向后兼容模式")
else:
logger.warning(f"API提供商 '{name}' 没有配置API Key某些功能可能不可用")
config.api_providers[name] = APIProvider(
name=name,
base_url=base_url,
api_key=api_key, # 保留向后兼容
api_keys=api_keys, # 新格式
client_type=client_type,
)
else:
logger.error(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
raise ValueError(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
# if model_identifier and api_provider:
# # 检查API提供商是否存在
# if api_provider not in config.api_providers:
# logger.error(f"未声明的API提供商 '{api_provider}' ,请检查配置文件。")
# raise ValueError(
# f"未声明的API提供商 '{api_provider}' ,请检查配置文件。"
# )
# config.models[name] = ModelInfo(
# name=name,
# model_identifier=model_identifier,
# api_provider=api_provider,
# price_in=price_in,
# price_out=price_out,
# force_stream_mode=force_stream_mode,
# task_type=task_type,
# capabilities=capabilities,
# )
# else:
# logger.error(f"模型 '{name}' 的配置不完整,请检查配置文件。")
# raise ValueError(f"模型 '{name}' 的配置不完整,请检查配置文件。")
def _models(parent: Dict, config: ModuleConfig):
models_config = parent.get("models")
for model in models_config:
model_identifier = model.get("model_identifier", None)
name = model.get("name", model_identifier)
api_provider = model.get("api_provider", None)
price_in = model.get("price_in", 0.0)
price_out = model.get("price_out", 0.0)
force_stream_mode = model.get("force_stream_mode", False)
task_type = model.get("task_type", "")
capabilities = model.get("capabilities", [])
# def _task_model_usage(parent: Dict, config: ModuleConfig):
# model_usage_configs = parent.get("task_model_usage")
# config.task_model_arg_map = {}
# for task_name, item in model_usage_configs.items():
# if task_name in config.task_model_arg_map:
# logger.error(f"子任务 {task_name} 已存在,请检查配置文件。")
# raise KeyError(f"子任务 {task_name} 已存在,请检查配置文件。")
if name in config.models: # 查重
logger.error(f"重复的模型名称: {name},请检查配置文件。")
raise KeyError(f"重复的模型名称: {name},请检查配置文件。")
# usage = []
# if isinstance(item, Dict):
# if "model" in item:
# usage.append(
# ModelUsageArgConfigItem(
# name=item["model"],
# temperature=item.get("temperature", None),
# max_tokens=item.get("max_tokens", None),
# max_retry=item.get("max_retry", None),
# )
# )
# else:
# logger.error(f"子任务 {task_name} 的模型配置不合法,请检查配置文件。")
# raise ValueError(
# f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
# )
# elif isinstance(item, List):
# for model in item:
# if isinstance(model, Dict):
# usage.append(
# ModelUsageArgConfigItem(
# name=model["model"],
# temperature=model.get("temperature", None),
# max_tokens=model.get("max_tokens", None),
# max_retry=model.get("max_retry", None),
# )
# )
# elif isinstance(model, str):
# usage.append(
# ModelUsageArgConfigItem(
# name=model,
# temperature=None,
# max_tokens=None,
# max_retry=None,
# )
# )
# else:
# logger.error(
# f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
# )
# raise ValueError(
# f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
# )
# elif isinstance(item, str):
# usage.append(
# ModelUsageArgConfigItem(
# name=item,
# temperature=None,
# max_tokens=None,
# max_retry=None,
# )
# )
if model_identifier and api_provider:
# 检查API提供商是否存在
if api_provider not in config.api_providers:
logger.error(f"未声明的API提供商 '{api_provider}' ,请检查配置文件。")
raise ValueError(
f"未声明的API提供商 '{api_provider}' ,请检查配置文件。"
)
config.models[name] = ModelInfo(
name=name,
model_identifier=model_identifier,
api_provider=api_provider,
price_in=price_in,
price_out=price_out,
force_stream_mode=force_stream_mode,
task_type=task_type,
capabilities=capabilities,
)
else:
logger.error(f"模型 '{name}' 的配置不完整,请检查配置文件。")
raise ValueError(f"模型 '{name}' 的配置不完整,请检查配置文件。")
# config.task_model_arg_map[task_name] = ModelUsageArgConfig(
# name=task_name,
# usage=usage,
# )
def _task_model_usage(parent: Dict, config: ModuleConfig):
model_usage_configs = parent.get("task_model_usage")
config.task_model_arg_map = {}
for task_name, item in model_usage_configs.items():
if task_name in config.task_model_arg_map:
logger.error(f"子任务 {task_name} 已存在,请检查配置文件。")
raise KeyError(f"子任务 {task_name} 已存在,请检查配置文件。")
# def api_ada_load_config(config_path: str) -> ModuleConfig:
# """从TOML配置文件加载配置"""
# config = ModuleConfig()
usage = []
if isinstance(item, Dict):
if "model" in item:
usage.append(
ModelUsageArgConfigItem(
name=item["model"],
temperature=item.get("temperature", None),
max_tokens=item.get("max_tokens", None),
max_retry=item.get("max_retry", None),
)
)
else:
logger.error(f"子任务 {task_name} 的模型配置不合法,请检查配置文件。")
raise ValueError(
f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
)
elif isinstance(item, List):
for model in item:
if isinstance(model, Dict):
usage.append(
ModelUsageArgConfigItem(
name=model["model"],
temperature=model.get("temperature", None),
max_tokens=model.get("max_tokens", None),
max_retry=model.get("max_retry", None),
)
)
elif isinstance(model, str):
usage.append(
ModelUsageArgConfigItem(
name=model,
temperature=None,
max_tokens=None,
max_retry=None,
)
)
else:
logger.error(
f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
)
raise ValueError(
f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
)
elif isinstance(item, str):
usage.append(
ModelUsageArgConfigItem(
name=item,
temperature=None,
max_tokens=None,
max_retry=None,
)
)
# include_configs: Dict[str, Dict[str, Any]] = {
# "request_conf": {
# "func": _request_conf,
# "support": ">=0.0.0",
# "necessary": False,
# },
# "api_providers": {"func": _api_providers, "support": ">=0.0.0"},
# "models": {"func": _models, "support": ">=0.0.0"},
# "task_model_usage": {"func": _task_model_usage, "support": ">=0.0.0"},
# }
config.task_model_arg_map[task_name] = ModelUsageArgConfig(
name=task_name,
usage=usage,
)
# if os.path.exists(config_path):
# with open(config_path, "rb") as f:
# try:
# toml_dict = tomlkit.load(f)
# except tomlkit.TOMLDecodeError as e:
# logger.critical(
# f"配置文件model_list.toml填写有误请检查第{e.lineno}行第{e.colno}处:{e.msg}"
# )
# exit(1)
# # 获取配置文件版本
# config.INNER_VERSION = _get_config_version(toml_dict)
def api_ada_load_config(config_path: str) -> ModuleConfig:
"""从TOML配置文件加载配置"""
config = ModuleConfig()
# # 检查版本
# if config.INNER_VERSION > Version(NEWEST_VER):
# logger.warning(
# f"当前配置文件版本 {config.INNER_VERSION} 高于支持的最新版本 {NEWEST_VER},可能导致异常,建议更新依赖。"
# )
include_configs: Dict[str, Dict[str, Any]] = {
"request_conf": {
"func": _request_conf,
"support": ">=0.0.0",
"necessary": False,
},
"api_providers": {"func": _api_providers, "support": ">=0.0.0"},
"models": {"func": _models, "support": ">=0.0.0"},
"task_model_usage": {"func": _task_model_usage, "support": ">=0.0.0"},
}
# # 解析配置文件
# # 如果在配置中找到了需要的项,调用对应项的闭包函数处理
# for key in include_configs:
# if key in toml_dict:
# group_specifier_set: SpecifierSet = SpecifierSet(
# include_configs[key]["support"]
# )
if os.path.exists(config_path):
with open(config_path, "rb") as f:
try:
toml_dict = tomlkit.load(f)
except tomlkit.TOMLDecodeError as e:
logger.critical(
f"配置文件model_list.toml填写有误请检查第{e.lineno}行第{e.colno}处:{e.msg}"
)
exit(1)
# # 检查配置文件版本是否在支持范围内
# if config.INNER_VERSION in group_specifier_set:
# # 如果版本在支持范围内,检查是否存在通知
# if "notice" in include_configs[key]:
# logger.warning(include_configs[key]["notice"])
# # 调用闭包函数处理配置
# (include_configs[key]["func"])(toml_dict, config)
# else:
# # 如果版本不在支持范围内,崩溃并提示用户
# logger.error(
# f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n"
# f"当前程序仅支持以下版本范围: {group_specifier_set}"
# )
# raise InvalidVersion(
# f"当前程序仅支持以下版本范围: {group_specifier_set}"
# )
# 获取配置文件版本
config.INNER_VERSION = _get_config_version(toml_dict)
# # 如果 necessary 项目存在,而且显式声明是 False进入特殊处理
# elif (
# "necessary" in include_configs[key]
# and include_configs[key].get("necessary") is False
# ):
# # 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
# if key == "keywords_reaction":
# pass
# else:
# # 如果用户根本没有需要的配置项,提示缺少配置
# logger.error(f"配置文件中缺少必需的字段: '{key}'")
# raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
# 检查版本
if config.INNER_VERSION > Version(NEWEST_VER):
logger.warning(
f"当前配置文件版本 {config.INNER_VERSION} 高于支持的最新版本 {NEWEST_VER},可能导致异常,建议更新依赖。"
)
# logger.info(f"成功加载配置文件: {config_path}")
# 解析配置文件
# 如果在配置中找到了需要的项,调用对应项的闭包函数处理
for key in include_configs:
if key in toml_dict:
group_specifier_set: SpecifierSet = SpecifierSet(
include_configs[key]["support"]
)
# return config
# 检查配置文件版本是否在支持范围内
if config.INNER_VERSION in group_specifier_set:
# 如果版本在支持范围内,检查是否存在通知
if "notice" in include_configs[key]:
logger.warning(include_configs[key]["notice"])
# 调用闭包函数处理配置
(include_configs[key]["func"])(toml_dict, config)
else:
# 如果版本不在支持范围内,崩溃并提示用户
logger.error(
f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n"
f"当前程序仅支持以下版本范围: {group_specifier_set}"
)
raise InvalidVersion(
f"当前程序仅支持以下版本范围: {group_specifier_set}"
)
# 如果 necessary 项目存在,而且显式声明是 False进入特殊处理
elif (
"necessary" in include_configs[key]
and include_configs[key].get("necessary") is False
):
# 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
if key == "keywords_reaction":
pass
else:
# 如果用户根本没有需要的配置项,提示缺少配置
logger.error(f"配置文件中缺少必需的字段: '{key}'")
raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
logger.info(f"成功加载配置文件: {config_path}")
return config
def get_key_comment(toml_table, key):
# 获取key的注释如果有
@@ -361,7 +349,7 @@ def compare_dicts(new, old, path=None, logs=None):
continue
if key not in old:
comment = get_key_comment(new, key)
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else ''}")
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or ''}")
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
compare_dicts(new[key], old[key], path + [str(key)], logs)
# 删减项
@@ -370,7 +358,7 @@ def compare_dicts(new, old, path=None, logs=None):
continue
if key not in new:
comment = get_key_comment(old, key)
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else ''}")
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or ''}")
return logs
@@ -405,17 +393,13 @@ def compare_default_values(new, old, path=None, logs=None, changes=None):
if key in old:
if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
compare_default_values(new[key], old[key], path + [str(key)], logs, changes)
else:
# 只要值发生变化就记录
if new[key] != old[key]:
logs.append(
f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}"
)
changes.append((path + [str(key)], old[key], new[key]))
elif new[key] != old[key]:
logs.append(f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}")
changes.append((path + [str(key)], old[key], new[key]))
return logs, changes
def _get_version_from_toml(toml_path):
def _get_version_from_toml(toml_path) -> Optional[str]:
"""从TOML文件中获取版本号"""
if not os.path.exists(toml_path):
return None
@@ -459,14 +443,13 @@ def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dic
target[key] = value
def _update_config_generic(config_name: str, template_name: str, should_quit_on_new: bool = True):
def _update_config_generic(config_name: str, template_name: str):
"""
通用的配置文件更新函数
Args:
config_name: 配置文件名(不含扩展名),如 'bot_config''model_config'
template_name: 模板文件名(不含扩展名),如 'bot_config_template''model_config_template'
should_quit_on_new: 创建新配置文件后是否退出程序
"""
# 获取根目录路径
old_config_dir = os.path.join(CONFIG_DIR, "old")
@@ -484,19 +467,30 @@ def _update_config_generic(config_name: str, template_name: str, should_quit_on_
template_version = _get_version_from_toml(template_path)
compare_version = _get_version_from_toml(compare_path)
# 检查配置文件是否存在
if not os.path.exists(old_config_path):
logger.info(f"{config_name}.toml配置文件不存在从模板创建新配置")
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
shutil.copy2(template_path, old_config_path) # 复制模板文件
logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}")
# 新创建配置文件,退出
sys.exit(0)
compare_config = None
new_config = None
old_config = None
# 先读取 compare 下的模板(如果有),用于默认值变动检测
if os.path.exists(compare_path):
with open(compare_path, "r", encoding="utf-8") as f:
compare_config = tomlkit.load(f)
else:
compare_config = None
# 读取当前模板
with open(template_path, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 检查默认值变化并处理(只有 compare_config 存在时才做)
if compare_config is not None:
if compare_config:
# 读取旧配置
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
@@ -515,32 +509,16 @@ def _update_config_generic(config_name: str, template_name: str, should_quit_on_
)
else:
logger.info(f"未检测到{config_name}模板默认值变动")
# 保存旧配置的变更(后续合并逻辑会用到 old_config
else:
old_config = None
# 检查 compare 下没有模板,或新模板版本更高,则复制
if not os.path.exists(compare_path):
shutil.copy2(template_path, compare_path)
logger.info(f"已将{config_name}模板文件复制到: {compare_path}")
elif _version_tuple(template_version) > _version_tuple(compare_version):
shutil.copy2(template_path, compare_path)
logger.info(f"{config_name}模板版本较新已替换compare下的模板: {compare_path}")
else:
if _version_tuple(template_version) > _version_tuple(compare_version):
shutil.copy2(template_path, compare_path)
logger.info(f"{config_name}模板版本较新已替换compare下的模板: {compare_path}")
else:
logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}")
# 检查配置文件是否存在
if not os.path.exists(old_config_path):
logger.info(f"{config_name}.toml配置文件不存在从模板创建新配置")
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
shutil.copy2(template_path, old_config_path) # 复制模板文件
logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}")
# 如果是新创建的配置文件,根据参数决定是否退出
if should_quit_on_new:
quit()
else:
return
logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}")
# 读取旧配置文件和模板文件(如果前面没读过 old_config这里再读一次
if old_config is None:
@@ -578,8 +556,7 @@ def _update_config_generic(config_name: str, template_name: str, should_quit_on_
# 输出新增和删减项及注释
if old_config:
logger.info(f"{config_name}配置项变动如下:\n----------------------------------------")
logs = compare_dicts(new_config, old_config)
if logs:
if logs := compare_dicts(new_config, old_config):
for log in logs:
logger.info(log)
else:
@@ -597,12 +574,12 @@ def _update_config_generic(config_name: str, template_name: str, should_quit_on_
def update_config():
"""更新bot_config.toml配置文件"""
_update_config_generic("bot_config", "bot_config_template", should_quit_on_new=True)
_update_config_generic("bot_config", "bot_config_template")
def update_model_config():
"""更新model_config.toml配置文件"""
_update_config_generic("model_config", "model_config_template", should_quit_on_new=False)
_update_config_generic("model_config", "model_config_template")
@dataclass
@@ -627,7 +604,6 @@ class Config(ConfigBase):
response_splitter: ResponseSplitterConfig
telemetry: TelemetryConfig
experimental: ExperimentalConfig
model: ModelConfig
maim_message: MaimMessageConfig
lpmm_knowledge: LPMMKnowledgeConfig
tool: ToolConfig
@@ -635,11 +611,48 @@ class Config(ConfigBase):
custom_prompt: CustomPromptConfig
voice: VoiceConfig
@dataclass
class APIAdapterConfig(ConfigBase):
"""API Adapter配置类"""
models: List[ModelInfo]
"""模型列表"""
model_task_config: ModelTaskConfig
"""模型任务配置"""
api_providers: List[APIProvider] = field(default_factory=list)
"""API提供商列表"""
def __post_init__(self):
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
self.models_dict = {model.name: model for model in self.models}
def get_model_info(self, model_name: str) -> ModelInfo:
"""根据模型名称获取模型信息"""
if not model_name:
raise ValueError("模型名称不能为空")
if model_name not in self.models_dict:
raise KeyError(f"模型 '{model_name}' 不存在")
return self.models_dict[model_name]
def get_provider(self, provider_name: str) -> APIProvider:
"""根据提供商名称获取API提供商信息"""
if not provider_name:
raise ValueError("API提供商名称不能为空")
if provider_name not in self.api_providers_dict:
raise KeyError(f"API提供商 '{provider_name}' 不存在")
return self.api_providers_dict[provider_name]
def load_config(config_path: str) -> Config:
"""
加载配置文件
:param config_path: 配置文件路径
:return: Config对象
Args:
config_path: 配置文件路径
Returns:
Config对象
"""
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
@@ -653,12 +666,24 @@ def load_config(config_path: str) -> Config:
raise e
def get_config_dir() -> str:
def api_ada_load_config(config_path: str) -> APIAdapterConfig:
"""
获取配置目录
:return: 配置目录路径
加载API适配器配置文件
Args:
config_path: 配置文件路径
Returns:
APIAdapterConfig对象
"""
return CONFIG_DIR
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建APIAdapterConfig对象
try:
return APIAdapterConfig.from_dict(config_data)
except Exception as e:
logger.critical("API适配器配置文件解析失败")
raise e
# 获取配置文件路径
@@ -669,4 +694,4 @@ update_model_config()
logger.info("正在品鉴配置文件...")
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
logger.info("非常的新鲜,非常的美味!")
logger.info("非常的新鲜,非常的美味!")