This commit is contained in:
SengokuCola
2025-03-27 00:23:13 +08:00
parent 5886b1c849
commit 8da4729c17
5 changed files with 49 additions and 24 deletions

View File

@@ -1,8 +1,7 @@
import tomli
import sys
import re
from pathlib import Path
from typing import Dict, Any, List, Set, Tuple
from typing import Dict, Any, List, Tuple
def load_toml_file(file_path: str) -> Dict[str, Any]:
"""加载TOML文件"""
@@ -184,10 +183,15 @@ def check_model_configurations(config: Dict[str, Any], env_vars: Dict[str, str])
provider = model_config["provider"].upper()
# 检查拼写错误
for known_provider, correct_provider in reverse_mapping.items():
for known_provider, _correct_provider in reverse_mapping.items():
# 使用模糊匹配检测拼写错误
if provider != known_provider and _similar_strings(provider, known_provider) and provider not in reverse_mapping:
errors.append(f"[model.{model_name}]的provider '{model_config['provider']}' 可能拼写错误,应为 '{known_provider}'")
if (provider != known_provider and
_similar_strings(provider, known_provider) and
provider not in reverse_mapping):
errors.append(
f"[model.{model_name}]的provider '{model_config['provider']}' "
f"可能拼写错误,应为 '{known_provider}'"
)
break
return errors
@@ -223,7 +227,7 @@ def check_api_providers(config: Dict[str, Any], env_vars: Dict[str, str]) -> Lis
# 检查配置文件中使用的所有提供商
used_providers = set()
for model_category, model_config in config["model"].items():
for _model_category, model_config in config["model"].items():
if "provider" in model_config:
provider = model_config["provider"]
used_providers.add(provider)
@@ -247,7 +251,7 @@ def check_api_providers(config: Dict[str, Any], env_vars: Dict[str, str]) -> Lis
# 特别检查常见的拼写错误
for provider in used_providers:
if provider.upper() == "SILICONFOLW":
errors.append(f"提供商 'SILICONFOLW' 存在拼写错误,应为 'SILICONFLOW'")
errors.append("提供商 'SILICONFOLW' 存在拼写错误,应为 'SILICONFLOW'")
return errors
@@ -272,7 +276,7 @@ def check_groups_configuration(config: Dict[str, Any]) -> List[str]:
"main": "groups.talk_allowed中存在默认示例值'123',请修改为真实的群号",
"details": [
f" 当前值: {groups['talk_allowed']}",
f" '123'为示例值,需要替换为真实群号"
" '123'为示例值,需要替换为真实群号"
]
})
@@ -371,7 +375,8 @@ def check_memory_config(config: Dict[str, Any]) -> List[str]:
if "memory_compress_rate" in memory and (memory["memory_compress_rate"] <= 0 or memory["memory_compress_rate"] > 1):
errors.append(f"memory.memory_compress_rate值无效: {memory['memory_compress_rate']}, 应在0-1之间")
if "memory_forget_percentage" in memory and (memory["memory_forget_percentage"] <= 0 or memory["memory_forget_percentage"] > 1):
if ("memory_forget_percentage" in memory
and (memory["memory_forget_percentage"] <= 0 or memory["memory_forget_percentage"] > 1)):
errors.append(f"memory.memory_forget_percentage值无效: {memory['memory_forget_percentage']}, 应在0-1之间")
return errors
@@ -393,7 +398,10 @@ def check_personality_config(config: Dict[str, Any]) -> List[str]:
else:
# 检查数组长度
if len(personality["prompt_personality"]) < 1:
errors.append(f"personality.prompt_personality数组长度不足当前长度: {len(personality['prompt_personality'])}, 需要至少1项")
errors.append(
f"personality.prompt_personality至少需要1项"
f"当前长度: {len(personality['prompt_personality'])}"
)
else:
# 模板默认值
template_values = [
@@ -452,10 +460,13 @@ def check_bot_config(config: Dict[str, Any]) -> List[str]:
def format_results(all_errors):
"""格式化检查结果"""
sections_errors, prob_sum_errors, prob_range_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, personality_errors, bot_results = all_errors
sections_errors, prob_sum_errors, prob_range_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, personality_errors, bot_results = all_errors # noqa: E501, F821
bot_errors, bot_infos = bot_results
if not any([sections_errors, prob_sum_errors, prob_range_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, personality_errors, bot_errors]):
if not any([
sections_errors, prob_sum_errors,
prob_range_errors, model_errors, api_errors, groups_errors,
kr_errors, willing_errors, memory_errors, personality_errors, bot_errors]):
result = "✅ 配置文件检查通过,未发现问题。"
# 添加机器人信息
@@ -574,7 +585,10 @@ def main():
bot_results = check_bot_config(config)
# 格式化并打印结果
all_errors = (sections_errors, prob_sum_errors, prob_range_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, personality_errors, bot_results)
all_errors = (
sections_errors, prob_sum_errors,
prob_range_errors, model_errors, api_errors, groups_errors,
kr_errors, willing_errors, memory_errors, personality_errors, bot_results)
result = format_results(all_errors)
print("📋 机器人配置检查结果:")
print(result)
@@ -586,7 +600,9 @@ def main():
bot_errors, _ = bot_results
# 计算普通错误列表的长度
for errors in [sections_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, bot_errors]:
for errors in [
sections_errors, model_errors, api_errors,
groups_errors, kr_errors, willing_errors, memory_errors, bot_errors]:
total_errors += len(errors)
# 计算元组列表的长度(概率相关错误)