fix: 修复了自定义API提供商无法识别的问题。增加新的env文件配置项,增加功能:可以自己在WebUI中添加提供商。增加检测文件是否存在功能

This commit is contained in:
DrSmoothl
2025-03-18 23:17:16 +08:00
parent 79e6aa358a
commit 4d1e5395d6

205
webui.py
View File

@@ -1,6 +1,5 @@
import gradio as gr import gradio as gr
import os import os
import sys
import toml import toml
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
import shutil import shutil
@@ -12,12 +11,24 @@ logger = get_module_logger("webui")
is_share = False is_share = False
debug = True debug = True
# 检查配置文件是否存在
if not os.path.exists("config/bot_config.toml"):
logger.error("配置文件 bot_config.toml 不存在,请检查配置文件路径")
raise FileNotFoundError("配置文件 bot_config.toml 不存在,请检查配置文件路径")
if not os.path.exists(".env.prod"):
logger.error("环境配置文件 .env.prod 不存在,请检查配置文件路径")
raise FileNotFoundError("环境配置文件 .env.prod 不存在,请检查配置文件路径")
config_data = toml.load("config/bot_config.toml") config_data = toml.load("config/bot_config.toml")
CONFIG_VERSION = config_data["inner"]["version"] CONFIG_VERSION = config_data["inner"]["version"]
PARSED_CONFIG_VERSION = version.parse(CONFIG_VERSION) PARSED_CONFIG_VERSION = version.parse(CONFIG_VERSION)
HAVE_ONLINE_STATUS_VERSION = version.parse("0.0.9") HAVE_ONLINE_STATUS_VERSION = version.parse("0.0.9")
#添加WebUI配置文件版本
WEBUI_VERSION = version.parse("0.0.7")
# ============================================== # ==============================================
# env环境配置文件读取部分 # env环境配置文件读取部分
def parse_env_config(config_file): def parse_env_config(config_file):
@@ -92,12 +103,50 @@ else:
logger.info("VOLCENGINE_KEY 不存在,已创建并使用默认值") logger.info("VOLCENGINE_KEY 不存在,已创建并使用默认值")
env_config_data["env_VOLCENGINE_KEY"] = "volc_key" env_config_data["env_VOLCENGINE_KEY"] = "volc_key"
save_to_env_file(env_config_data, env_config_file) save_to_env_file(env_config_data, env_config_file)
MODEL_PROVIDER_LIST = [
"VOLCENGINE", def parse_model_providers(env_vars):
"CHAT_ANY_WHERE", """
"SILICONFLOW", 从环境变量中解析模型提供商列表
"DEEP_SEEK" 参数:
] env_vars: 包含环境变量的字典
返回:
list: 模型提供商列表
"""
providers = []
for key in env_vars.keys():
if key.startswith("env_") and key.endswith("_BASE_URL"):
# 提取中间部分作为提供商名称
provider = key[4:-9] # 移除"env_"前缀和"_BASE_URL"后缀
providers.append(provider)
return providers
def add_new_provider(provider_name, current_providers):
"""
添加新的提供商到列表中
参数:
provider_name: 新的提供商名称
current_providers: 当前的提供商列表
返回:
tuple: (更新后的提供商列表, 更新后的下拉列表选项)
"""
if not provider_name or provider_name in current_providers:
return current_providers, gr.update(choices=current_providers)
# 添加新的提供商到环境变量中
env_config_data[f"env_{provider_name}_BASE_URL"] = ""
env_config_data[f"env_{provider_name}_KEY"] = ""
# 更新提供商列表
updated_providers = current_providers + [provider_name]
# 保存到环境文件
save_to_env_file(env_config_data)
return updated_providers, gr.update(choices=updated_providers)
# 从环境变量中解析并更新提供商列表
MODEL_PROVIDER_LIST = parse_model_providers(env_config_data)
# env读取保存结束 # env读取保存结束
# ============================================== # ==============================================
@@ -224,7 +273,7 @@ def format_list_to_str(lst):
# env保存函数 # env保存函数
def save_trigger(server_address, server_port, final_result_list,t_mongodb_host,t_mongodb_port,t_mongodb_database_name,t_chatanywhere_base_url,t_chatanywhere_key,t_siliconflow_base_url,t_siliconflow_key,t_deepseek_base_url,t_deepseek_key,t_volcengine_base_url,t_volcengine_key): def save_trigger(server_address, server_port, final_result_list, t_mongodb_host, t_mongodb_port, t_mongodb_database_name, t_console_log_level, t_file_log_level, t_default_console_log_level, t_default_file_log_level, t_api_provider, t_api_base_url, t_api_key):
final_result_lists = format_list_to_str(final_result_list) final_result_lists = format_list_to_str(final_result_list)
env_config_data["env_HOST"] = server_address env_config_data["env_HOST"] = server_address
env_config_data["env_PORT"] = server_port env_config_data["env_PORT"] = server_port
@@ -232,18 +281,32 @@ def save_trigger(server_address, server_port, final_result_list,t_mongodb_host,t
env_config_data["env_MONGODB_HOST"] = t_mongodb_host env_config_data["env_MONGODB_HOST"] = t_mongodb_host
env_config_data["env_MONGODB_PORT"] = t_mongodb_port env_config_data["env_MONGODB_PORT"] = t_mongodb_port
env_config_data["env_DATABASE_NAME"] = t_mongodb_database_name env_config_data["env_DATABASE_NAME"] = t_mongodb_database_name
env_config_data["env_CHAT_ANY_WHERE_BASE_URL"] = t_chatanywhere_base_url
env_config_data["env_CHAT_ANY_WHERE_KEY"] = t_chatanywhere_key # 保存日志配置
env_config_data["env_SILICONFLOW_BASE_URL"] = t_siliconflow_base_url env_config_data["env_CONSOLE_LOG_LEVEL"] = t_console_log_level
env_config_data["env_SILICONFLOW_KEY"] = t_siliconflow_key env_config_data["env_FILE_LOG_LEVEL"] = t_file_log_level
env_config_data["env_DEEP_SEEK_BASE_URL"] = t_deepseek_base_url env_config_data["env_DEFAULT_CONSOLE_LOG_LEVEL"] = t_default_console_log_level
env_config_data["env_DEEP_SEEK_KEY"] = t_deepseek_key env_config_data["env_DEFAULT_FILE_LOG_LEVEL"] = t_default_file_log_level
env_config_data["env_VOLCENGINE_BASE_URL"] = t_volcengine_base_url
env_config_data["env_VOLCENGINE_KEY"] = t_volcengine_key # 保存选中的API提供商的配置
env_config_data[f"env_{t_api_provider}_BASE_URL"] = t_api_base_url
env_config_data[f"env_{t_api_provider}_KEY"] = t_api_key
save_to_env_file(env_config_data) save_to_env_file(env_config_data)
logger.success("配置已保存到 .env.prod 文件中") logger.success("配置已保存到 .env.prod 文件中")
return "配置已保存" return "配置已保存"
def update_api_inputs(provider):
"""
根据选择的提供商更新Base URL和API Key输入框的值
"""
base_url = env_config_data.get(f"env_{provider}_BASE_URL", "")
api_key = env_config_data.get(f"env_{provider}_KEY", "")
return base_url, api_key
# 绑定下拉列表的change事件
# ============================================== # ==============================================
@@ -455,7 +518,9 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
gr.Markdown( gr.Markdown(
value="## 全球在线MaiMBot数量: " + str(online_maimbot_data['online_clients']) value="## 全球在线MaiMBot数量: " + str(online_maimbot_data['online_clients'])
) )
gr.Markdown(
value="## 当前WebUI版本: " + str(WEBUI_VERSION)
)
gr.Markdown( gr.Markdown(
value="### 配置文件版本:" + config_data["inner"]["version"] value="### 配置文件版本:" + config_data["inner"]["version"]
) )
@@ -546,81 +611,99 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
) )
with gr.Row(): with gr.Row():
gr.Markdown( gr.Markdown(
'''ChatAnyWhere的baseURL和APIkey\n '''日志设置\n
配置日志输出级别\n
改完了记得保存!!! 改完了记得保存!!!
''' '''
) )
with gr.Row(): with gr.Row():
chatanywhere_base_url = gr.Textbox( console_log_level = gr.Dropdown(
label="ChatAnyWhere的BaseURL", choices=["INFO", "DEBUG", "WARNING", "ERROR", "SUCCESS"],
value=env_config_data["env_CHAT_ANY_WHERE_BASE_URL"], label="控制台日志级别",
value=env_config_data.get("env_CONSOLE_LOG_LEVEL", "INFO"),
interactive=True interactive=True
) )
with gr.Row(): with gr.Row():
chatanywhere_key = gr.Textbox( file_log_level = gr.Dropdown(
label="ChatAnyWhere的key", choices=["INFO", "DEBUG", "WARNING", "ERROR", "SUCCESS"],
value=env_config_data["env_CHAT_ANY_WHERE_KEY"], label="文件日志级别",
value=env_config_data.get("env_FILE_LOG_LEVEL", "DEBUG"),
interactive=True
)
with gr.Row():
default_console_log_level = gr.Dropdown(
choices=["INFO", "DEBUG", "WARNING", "ERROR", "SUCCESS", "NONE"],
label="默认控制台日志级别",
value=env_config_data.get("env_DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"),
interactive=True
)
with gr.Row():
default_file_log_level = gr.Dropdown(
choices=["INFO", "DEBUG", "WARNING", "ERROR", "SUCCESS", "NONE"],
label="默认文件日志级别",
value=env_config_data.get("env_DEFAULT_FILE_LOG_LEVEL", "DEBUG"),
interactive=True interactive=True
) )
with gr.Row(): with gr.Row():
gr.Markdown( gr.Markdown(
'''SiliconFlow的baseURL和APIkey\n '''API设置\n
选择API提供商并配置相应的BaseURL和Key\n
改完了记得保存!!! 改完了记得保存!!!
''' '''
) )
with gr.Row(): with gr.Row():
siliconflow_base_url = gr.Textbox( with gr.Column(scale=3):
label="SiliconFlow的BaseURL", new_provider_input = gr.Textbox(
value=env_config_data["env_SILICONFLOW_BASE_URL"], label="添加新提供商",
placeholder="输入新提供商名称"
)
add_provider_btn = gr.Button("添加提供商", scale=1)
with gr.Row():
api_provider = gr.Dropdown(
choices=MODEL_PROVIDER_LIST,
label="选择API提供商",
value=MODEL_PROVIDER_LIST[0] if MODEL_PROVIDER_LIST else None
)
with gr.Row():
api_base_url = gr.Textbox(
label="Base URL",
value=env_config_data.get(f"env_{MODEL_PROVIDER_LIST[0]}_BASE_URL", "") if MODEL_PROVIDER_LIST else "",
interactive=True interactive=True
) )
with gr.Row(): with gr.Row():
siliconflow_key = gr.Textbox( api_key = gr.Textbox(
label="SiliconFlow的key", label="API Key",
value=env_config_data["env_SILICONFLOW_KEY"], value=env_config_data.get(f"env_{MODEL_PROVIDER_LIST[0]}_KEY", "") if MODEL_PROVIDER_LIST else "",
interactive=True interactive=True
) )
with gr.Row(): api_provider.change(
gr.Markdown( update_api_inputs,
'''DeepSeek的baseURL和APIkey\n inputs=[api_provider],
改完了记得保存!!! outputs=[api_base_url, api_key]
'''
)
with gr.Row():
deepseek_base_url = gr.Textbox(
label="DeepSeek的BaseURL",
value=env_config_data["env_DEEP_SEEK_BASE_URL"],
interactive=True
)
with gr.Row():
deepseek_key = gr.Textbox(
label="DeepSeek的key",
value=env_config_data["env_DEEP_SEEK_KEY"],
interactive=True
)
with gr.Row():
volcengine_base_url = gr.Textbox(
label="VolcEngine的BaseURL",
value=env_config_data["env_VOLCENGINE_BASE_URL"],
interactive=True
)
with gr.Row():
volcengine_key = gr.Textbox(
label="VolcEngine的key",
value=env_config_data["env_VOLCENGINE_KEY"],
interactive=True
) )
with gr.Row(): with gr.Row():
save_env_btn = gr.Button("保存环境配置",variant="primary") save_env_btn = gr.Button("保存环境配置",variant="primary")
with gr.Row(): with gr.Row():
save_env_btn.click( save_env_btn.click(
save_trigger, save_trigger,
inputs=[server_address,server_port,final_result,mongodb_host,mongodb_port,mongodb_database_name,chatanywhere_base_url,chatanywhere_key,siliconflow_base_url,siliconflow_key,deepseek_base_url,deepseek_key,volcengine_base_url,volcengine_key], inputs=[server_address, server_port, final_result, mongodb_host, mongodb_port, mongodb_database_name, console_log_level, file_log_level, default_console_log_level, default_file_log_level, api_provider, api_base_url, api_key],
outputs=[gr.Textbox( outputs=[gr.Textbox(
label="保存结果", label="保存结果",
interactive=False interactive=False
)] )]
) )
# 绑定添加提供商按钮的点击事件
add_provider_btn.click(
add_new_provider,
inputs=[new_provider_input, gr.State(value=MODEL_PROVIDER_LIST)],
outputs=[gr.State(value=MODEL_PROVIDER_LIST), api_provider]
).then(
lambda x: (env_config_data.get(f"env_{x}_BASE_URL", ""), env_config_data.get(f"env_{x}_KEY", "")),
inputs=[api_provider],
outputs=[api_base_url, api_key]
)
with gr.TabItem("1-Bot基础设置"): with gr.TabItem("1-Bot基础设置"):
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):