diff --git a/src/config/config.py b/src/config/config.py index 659c49dac..d14b89583 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -398,37 +398,74 @@ def compare_default_values(new, old, path=None, logs=None, changes=None): return logs, changes -def update_config(): +def _get_version_from_toml(toml_path): + """从TOML文件中获取版本号""" + if not os.path.exists(toml_path): + return None + with open(toml_path, "r", encoding="utf-8") as f: + doc = tomlkit.load(f) + if "inner" in doc and "version" in doc["inner"]: # type: ignore + return doc["inner"]["version"] # type: ignore + return None + + +def _version_tuple(v): + """将版本字符串转换为元组以便比较""" + if v is None: + return (0,) + return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split(".")) + + +def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict): + """ + 将source字典的值更新到target字典中(如果target中存在相同的键) + """ + for key, value in source.items(): + # 跳过version字段的更新 + if key == "version": + continue + if key in target: + target_value = target[key] + if isinstance(value, dict) and isinstance(target_value, (dict, Table)): + _update_dict(target_value, value) + else: + try: + # 对数组类型进行特殊处理 + if isinstance(value, list): + # 如果是空数组,确保它保持为空数组 + target[key] = tomlkit.array(str(value)) if value else tomlkit.array() + else: + # 其他类型使用item方法创建新值 + target[key] = tomlkit.item(value) + except (TypeError, ValueError): + # 如果转换失败,直接赋值 + target[key] = value + + +def _update_config_generic(config_name: str, template_name: str, should_quit_on_new: bool = True): + """ + 通用的配置文件更新函数 + + Args: + config_name: 配置文件名(不含扩展名),如 'bot_config' 或 'model_config' + template_name: 模板文件名(不含扩展名),如 'bot_config_template' 或 'model_config_template' + should_quit_on_new: 创建新配置文件后是否退出程序 + """ # 获取根目录路径 old_config_dir = os.path.join(CONFIG_DIR, "old") compare_dir = os.path.join(TEMPLATE_DIR, "compare") # 定义文件路径 - template_path = os.path.join(TEMPLATE_DIR, "bot_config_template.toml") - old_config_path = os.path.join(CONFIG_DIR, "bot_config.toml") - new_config_path = os.path.join(CONFIG_DIR, "bot_config.toml") - compare_path = os.path.join(compare_dir, "bot_config_template.toml") + template_path = os.path.join(TEMPLATE_DIR, f"{template_name}.toml") + old_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml") + new_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml") + compare_path = os.path.join(compare_dir, f"{template_name}.toml") # 创建compare目录(如果不存在) os.makedirs(compare_dir, exist_ok=True) - # 处理compare下的模板文件 - def get_version_from_toml(toml_path): - if not os.path.exists(toml_path): - return None - with open(toml_path, "r", encoding="utf-8") as f: - doc = tomlkit.load(f) - if "inner" in doc and "version" in doc["inner"]: # type: ignore - return doc["inner"]["version"] # type: ignore - return None - - template_version = get_version_from_toml(template_path) - compare_version = get_version_from_toml(compare_path) - - def version_tuple(v): - if v is None: - return (0,) - return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split(".")) + template_version = _get_version_from_toml(template_path) + compare_version = _get_version_from_toml(compare_path) # 先读取 compare 下的模板(如果有),用于默认值变动检测 if os.path.exists(compare_path): @@ -448,7 +485,7 @@ def update_config(): old_config = tomlkit.load(f) logs, changes = compare_default_values(new_config, compare_config) if logs: - logger.info("检测到模板默认值变动如下:") + logger.info(f"检测到{config_name}模板默认值变动如下:") for log in logs: logger.info(log) # 检查旧配置是否等于旧默认值,如果是则更新为新默认值 @@ -457,10 +494,10 @@ def update_config(): if old_value == old_default: set_value_by_path(old_config, path, new_default) logger.info( - f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}" + f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}" ) else: - logger.info("未检测到模板默认值变动") + logger.info(f"未检测到{config_name}模板默认值变动") # 保存旧配置的变更(后续合并逻辑会用到 old_config) else: old_config = None @@ -468,22 +505,25 @@ def update_config(): # 检查 compare 下没有模板,或新模板版本更高,则复制 if not os.path.exists(compare_path): shutil.copy2(template_path, compare_path) - logger.info(f"已将模板文件复制到: {compare_path}") + logger.info(f"已将{config_name}模板文件复制到: {compare_path}") else: - if version_tuple(template_version) > version_tuple(compare_version): + if _version_tuple(template_version) > _version_tuple(compare_version): shutil.copy2(template_path, compare_path) - logger.info(f"模板版本较新,已替换compare下的模板: {compare_path}") + logger.info(f"{config_name}模板版本较新,已替换compare下的模板: {compare_path}") else: - logger.debug(f"compare下的模板版本不低于当前模板,无需替换: {compare_path}") + logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}") # 检查配置文件是否存在 if not os.path.exists(old_config_path): - logger.info("配置文件不存在,从模板创建新配置") + logger.info(f"{config_name}.toml配置文件不存在,从模板创建新配置") os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹 shutil.copy2(template_path, old_config_path) # 复制模板文件 - logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}") - # 如果是新创建的配置文件,直接返回 - quit() + logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}") + # 如果是新创建的配置文件,根据参数决定是否退出 + if should_quit_on_new: + quit() + else: + return # 读取旧配置文件和模板文件(如果前面没读过 old_config,这里再读一次) if old_config is None: @@ -491,38 +531,36 @@ def update_config(): old_config = tomlkit.load(f) # new_config 已经读取 - # 读取 compare_config 只用于默认值变动检测,后续合并逻辑不再用 - # 检查version是否相同 if old_config and "inner" in old_config and "inner" in new_config: old_version = old_config["inner"].get("version") # type: ignore new_version = new_config["inner"].get("version") # type: ignore if old_version and new_version and old_version == new_version: - logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新") + logger.info(f"检测到{config_name}配置文件版本号相同 (v{old_version}),跳过更新") return else: logger.info( - f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------" + f"\n----------------------------------------\n检测到{config_name}版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------" ) else: - logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新") + logger.info(f"已有{config_name}配置文件未检测到版本号,可能是旧版本。将进行更新") # 创建old目录(如果不存在) os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - old_backup_path = os.path.join(old_config_dir, f"bot_config_{timestamp}.toml") + old_backup_path = os.path.join(old_config_dir, f"{config_name}_{timestamp}.toml") # 移动旧配置文件到old目录 shutil.move(old_config_path, old_backup_path) - logger.info(f"已备份旧配置文件到: {old_backup_path}") + logger.info(f"已备份旧{config_name}配置文件到: {old_backup_path}") # 复制模板文件到配置目录 shutil.copy2(template_path, new_config_path) - logger.info(f"已创建新配置文件: {new_config_path}") + logger.info(f"已创建新{config_name}配置文件: {new_config_path}") # 输出新增和删减项及注释 if old_config: - logger.info("配置项变动如下:\n----------------------------------------") + logger.info(f"{config_name}配置项变动如下:\n----------------------------------------") logs = compare_dicts(new_config, old_config) if logs: for log in logs: @@ -530,208 +568,24 @@ def update_config(): else: logger.info("无新增或删减项") - def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict): - """ - 将source字典的值更新到target字典中(如果target中存在相同的键) - """ - for key, value in source.items(): - # 跳过version字段的更新 - if key == "version": - continue - if key in target: - target_value = target[key] - if isinstance(value, dict) and isinstance(target_value, (dict, Table)): - update_dict(target_value, value) - else: - try: - # 对数组类型进行特殊处理 - if isinstance(value, list): - # 如果是空数组,确保它保持为空数组 - target[key] = tomlkit.array(str(value)) if value else tomlkit.array() - else: - # 其他类型使用item方法创建新值 - target[key] = tomlkit.item(value) - except (TypeError, ValueError): - # 如果转换失败,直接赋值 - target[key] = value - # 将旧配置的值更新到新配置中 - logger.info("开始合并新旧配置...") - update_dict(new_config, old_config) + logger.info(f"开始合并{config_name}新旧配置...") + _update_dict(new_config, old_config) # 保存更新后的配置(保留注释和格式) with open(new_config_path, "w", encoding="utf-8") as f: f.write(tomlkit.dumps(new_config)) - logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息") - quit() + logger.info(f"{config_name}配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息") + + +def update_config(): + """更新bot_config.toml配置文件""" + _update_config_generic("bot_config", "bot_config_template", should_quit_on_new=True) def update_model_config(): """更新model_config.toml配置文件""" - # 获取根目录路径 - old_config_dir = os.path.join(CONFIG_DIR, "old") - compare_dir = os.path.join(TEMPLATE_DIR, "compare") - - # 定义文件路径 - template_path = os.path.join(TEMPLATE_DIR, "model_config_template.toml") - old_config_path = os.path.join(CONFIG_DIR, "model_config.toml") - new_config_path = os.path.join(CONFIG_DIR, "model_config.toml") - compare_path = os.path.join(compare_dir, "model_config_template.toml") - - # 创建compare目录(如果不存在) - os.makedirs(compare_dir, exist_ok=True) - - # 处理compare下的模板文件 - def get_version_from_toml(toml_path): - if not os.path.exists(toml_path): - return None - with open(toml_path, "r", encoding="utf-8") as f: - doc = tomlkit.load(f) - if "inner" in doc and "version" in doc["inner"]: # type: ignore - return doc["inner"]["version"] # type: ignore - return None - - template_version = get_version_from_toml(template_path) - compare_version = get_version_from_toml(compare_path) - - def version_tuple(v): - if v is None: - return (0,) - return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split(".")) - - # 先读取 compare 下的模板(如果有),用于默认值变动检测 - if os.path.exists(compare_path): - with open(compare_path, "r", encoding="utf-8") as f: - compare_config = tomlkit.load(f) - else: - compare_config = None - - # 读取当前模板 - with open(template_path, "r", encoding="utf-8") as f: - new_config = tomlkit.load(f) - - # 检查默认值变化并处理(只有 compare_config 存在时才做) - if compare_config is not None: - # 读取旧配置 - with open(old_config_path, "r", encoding="utf-8") as f: - old_config = tomlkit.load(f) - logs, changes = compare_default_values(new_config, compare_config) - if logs: - logger.info("检测到model_config模板默认值变动如下:") - for log in logs: - logger.info(log) - # 检查旧配置是否等于旧默认值,如果是则更新为新默认值 - for path, old_default, new_default in changes: - old_value = get_value_by_path(old_config, path) - if old_value == old_default: - set_value_by_path(old_config, path, new_default) - logger.info( - f"已自动将model_config配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}" - ) - else: - logger.info("未检测到model_config模板默认值变动") - # 保存旧配置的变更(后续合并逻辑会用到 old_config) - else: - old_config = None - - # 检查 compare 下没有模板,或新模板版本更高,则复制 - if not os.path.exists(compare_path): - shutil.copy2(template_path, compare_path) - logger.info(f"已将model_config模板文件复制到: {compare_path}") - else: - if version_tuple(template_version) > version_tuple(compare_version): - shutil.copy2(template_path, compare_path) - logger.info(f"model_config模板版本较新,已替换compare下的模板: {compare_path}") - else: - logger.debug(f"compare下的model_config模板版本不低于当前模板,无需替换: {compare_path}") - - # 检查配置文件是否存在 - if not os.path.exists(old_config_path): - logger.info("model_config.toml配置文件不存在,从模板创建新配置") - os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹 - shutil.copy2(template_path, old_config_path) # 复制模板文件 - logger.info(f"已创建新model_config配置文件,请填写后重新运行: {old_config_path}") - # 如果是新创建的配置文件,直接返回 - return - - # 读取旧配置文件和模板文件(如果前面没读过 old_config,这里再读一次) - if old_config is None: - with open(old_config_path, "r", encoding="utf-8") as f: - old_config = tomlkit.load(f) - # new_config 已经读取 - - # 读取 compare_config 只用于默认值变动检测,后续合并逻辑不再用 - - # 检查version是否相同 - if old_config and "inner" in old_config and "inner" in new_config: - old_version = old_config["inner"].get("version") # type: ignore - new_version = new_config["inner"].get("version") # type: ignore - if old_version and new_version and old_version == new_version: - logger.info(f"检测到model_config配置文件版本号相同 (v{old_version}),跳过更新") - return - else: - logger.info( - f"\n----------------------------------------\n检测到model_config版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------" - ) - else: - logger.info("已有model_config配置文件未检测到版本号,可能是旧版本。将进行更新") - - # 创建old目录(如果不存在) - os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名 - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - old_backup_path = os.path.join(old_config_dir, f"model_config_{timestamp}.toml") - - # 移动旧配置文件到old目录 - shutil.move(old_config_path, old_backup_path) - logger.info(f"已备份旧model_config配置文件到: {old_backup_path}") - - # 复制模板文件到配置目录 - shutil.copy2(template_path, new_config_path) - logger.info(f"已创建新model_config配置文件: {new_config_path}") - - # 输出新增和删减项及注释 - if old_config: - logger.info("model_config配置项变动如下:\n----------------------------------------") - logs = compare_dicts(new_config, old_config) - if logs: - for log in logs: - logger.info(log) - else: - logger.info("无新增或删减项") - - def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict): - """ - 将source字典的值更新到target字典中(如果target中存在相同的键) - """ - for key, value in source.items(): - # 跳过version字段的更新 - if key == "version": - continue - if key in target: - target_value = target[key] - if isinstance(value, dict) and isinstance(target_value, (dict, Table)): - update_dict(target_value, value) - else: - try: - # 对数组类型进行特殊处理 - if isinstance(value, list): - # 如果是空数组,确保它保持为空数组 - target[key] = tomlkit.array(str(value)) if value else tomlkit.array() - else: - # 其他类型使用item方法创建新值 - target[key] = tomlkit.item(value) - except (TypeError, ValueError): - # 如果转换失败,直接赋值 - target[key] = value - - # 将旧配置的值更新到新配置中 - logger.info("开始合并model_config新旧配置...") - update_dict(new_config, old_config) - - # 保存更新后的配置(保留注释和格式) - with open(new_config_path, "w", encoding="utf-8") as f: - f.write(tomlkit.dumps(new_config)) - logger.info("model_config配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息") + _update_config_generic("model_config", "model_config_template", should_quit_on_new=False) @dataclass