fix: 修复了自定义API提供商无法识别的问题。增加新的env文件配置项,增加功能:可以自己在WebUI中添加提供商。增加检测文件是否存在功能
This commit is contained in:
205
webui.py
205
webui.py
@@ -1,6 +1,5 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import toml
|
import toml
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
import shutil
|
import shutil
|
||||||
@@ -12,12 +11,24 @@ logger = get_module_logger("webui")
|
|||||||
|
|
||||||
is_share = False
|
is_share = False
|
||||||
debug = True
|
debug = True
|
||||||
|
# 检查配置文件是否存在
|
||||||
|
if not os.path.exists("config/bot_config.toml"):
|
||||||
|
logger.error("配置文件 bot_config.toml 不存在,请检查配置文件路径")
|
||||||
|
raise FileNotFoundError("配置文件 bot_config.toml 不存在,请检查配置文件路径")
|
||||||
|
|
||||||
|
if not os.path.exists(".env.prod"):
|
||||||
|
logger.error("环境配置文件 .env.prod 不存在,请检查配置文件路径")
|
||||||
|
raise FileNotFoundError("环境配置文件 .env.prod 不存在,请检查配置文件路径")
|
||||||
|
|
||||||
config_data = toml.load("config/bot_config.toml")
|
config_data = toml.load("config/bot_config.toml")
|
||||||
|
|
||||||
CONFIG_VERSION = config_data["inner"]["version"]
|
CONFIG_VERSION = config_data["inner"]["version"]
|
||||||
PARSED_CONFIG_VERSION = version.parse(CONFIG_VERSION)
|
PARSED_CONFIG_VERSION = version.parse(CONFIG_VERSION)
|
||||||
HAVE_ONLINE_STATUS_VERSION = version.parse("0.0.9")
|
HAVE_ONLINE_STATUS_VERSION = version.parse("0.0.9")
|
||||||
|
|
||||||
|
#添加WebUI配置文件版本
|
||||||
|
WEBUI_VERSION = version.parse("0.0.7")
|
||||||
|
|
||||||
# ==============================================
|
# ==============================================
|
||||||
# env环境配置文件读取部分
|
# env环境配置文件读取部分
|
||||||
def parse_env_config(config_file):
|
def parse_env_config(config_file):
|
||||||
@@ -92,12 +103,50 @@ else:
|
|||||||
logger.info("VOLCENGINE_KEY 不存在,已创建并使用默认值")
|
logger.info("VOLCENGINE_KEY 不存在,已创建并使用默认值")
|
||||||
env_config_data["env_VOLCENGINE_KEY"] = "volc_key"
|
env_config_data["env_VOLCENGINE_KEY"] = "volc_key"
|
||||||
save_to_env_file(env_config_data, env_config_file)
|
save_to_env_file(env_config_data, env_config_file)
|
||||||
MODEL_PROVIDER_LIST = [
|
|
||||||
"VOLCENGINE",
|
def parse_model_providers(env_vars):
|
||||||
"CHAT_ANY_WHERE",
|
"""
|
||||||
"SILICONFLOW",
|
从环境变量中解析模型提供商列表
|
||||||
"DEEP_SEEK"
|
参数:
|
||||||
]
|
env_vars: 包含环境变量的字典
|
||||||
|
返回:
|
||||||
|
list: 模型提供商列表
|
||||||
|
"""
|
||||||
|
providers = []
|
||||||
|
for key in env_vars.keys():
|
||||||
|
if key.startswith("env_") and key.endswith("_BASE_URL"):
|
||||||
|
# 提取中间部分作为提供商名称
|
||||||
|
provider = key[4:-9] # 移除"env_"前缀和"_BASE_URL"后缀
|
||||||
|
providers.append(provider)
|
||||||
|
return providers
|
||||||
|
|
||||||
|
def add_new_provider(provider_name, current_providers):
|
||||||
|
"""
|
||||||
|
添加新的提供商到列表中
|
||||||
|
参数:
|
||||||
|
provider_name: 新的提供商名称
|
||||||
|
current_providers: 当前的提供商列表
|
||||||
|
返回:
|
||||||
|
tuple: (更新后的提供商列表, 更新后的下拉列表选项)
|
||||||
|
"""
|
||||||
|
if not provider_name or provider_name in current_providers:
|
||||||
|
return current_providers, gr.update(choices=current_providers)
|
||||||
|
|
||||||
|
# 添加新的提供商到环境变量中
|
||||||
|
env_config_data[f"env_{provider_name}_BASE_URL"] = ""
|
||||||
|
env_config_data[f"env_{provider_name}_KEY"] = ""
|
||||||
|
|
||||||
|
# 更新提供商列表
|
||||||
|
updated_providers = current_providers + [provider_name]
|
||||||
|
|
||||||
|
# 保存到环境文件
|
||||||
|
save_to_env_file(env_config_data)
|
||||||
|
|
||||||
|
return updated_providers, gr.update(choices=updated_providers)
|
||||||
|
|
||||||
|
# 从环境变量中解析并更新提供商列表
|
||||||
|
MODEL_PROVIDER_LIST = parse_model_providers(env_config_data)
|
||||||
|
|
||||||
# env读取保存结束
|
# env读取保存结束
|
||||||
# ==============================================
|
# ==============================================
|
||||||
|
|
||||||
@@ -224,7 +273,7 @@ def format_list_to_str(lst):
|
|||||||
|
|
||||||
|
|
||||||
# env保存函数
|
# env保存函数
|
||||||
def save_trigger(server_address, server_port, final_result_list,t_mongodb_host,t_mongodb_port,t_mongodb_database_name,t_chatanywhere_base_url,t_chatanywhere_key,t_siliconflow_base_url,t_siliconflow_key,t_deepseek_base_url,t_deepseek_key,t_volcengine_base_url,t_volcengine_key):
|
def save_trigger(server_address, server_port, final_result_list, t_mongodb_host, t_mongodb_port, t_mongodb_database_name, t_console_log_level, t_file_log_level, t_default_console_log_level, t_default_file_log_level, t_api_provider, t_api_base_url, t_api_key):
|
||||||
final_result_lists = format_list_to_str(final_result_list)
|
final_result_lists = format_list_to_str(final_result_list)
|
||||||
env_config_data["env_HOST"] = server_address
|
env_config_data["env_HOST"] = server_address
|
||||||
env_config_data["env_PORT"] = server_port
|
env_config_data["env_PORT"] = server_port
|
||||||
@@ -232,18 +281,32 @@ def save_trigger(server_address, server_port, final_result_list,t_mongodb_host,t
|
|||||||
env_config_data["env_MONGODB_HOST"] = t_mongodb_host
|
env_config_data["env_MONGODB_HOST"] = t_mongodb_host
|
||||||
env_config_data["env_MONGODB_PORT"] = t_mongodb_port
|
env_config_data["env_MONGODB_PORT"] = t_mongodb_port
|
||||||
env_config_data["env_DATABASE_NAME"] = t_mongodb_database_name
|
env_config_data["env_DATABASE_NAME"] = t_mongodb_database_name
|
||||||
env_config_data["env_CHAT_ANY_WHERE_BASE_URL"] = t_chatanywhere_base_url
|
|
||||||
env_config_data["env_CHAT_ANY_WHERE_KEY"] = t_chatanywhere_key
|
# 保存日志配置
|
||||||
env_config_data["env_SILICONFLOW_BASE_URL"] = t_siliconflow_base_url
|
env_config_data["env_CONSOLE_LOG_LEVEL"] = t_console_log_level
|
||||||
env_config_data["env_SILICONFLOW_KEY"] = t_siliconflow_key
|
env_config_data["env_FILE_LOG_LEVEL"] = t_file_log_level
|
||||||
env_config_data["env_DEEP_SEEK_BASE_URL"] = t_deepseek_base_url
|
env_config_data["env_DEFAULT_CONSOLE_LOG_LEVEL"] = t_default_console_log_level
|
||||||
env_config_data["env_DEEP_SEEK_KEY"] = t_deepseek_key
|
env_config_data["env_DEFAULT_FILE_LOG_LEVEL"] = t_default_file_log_level
|
||||||
env_config_data["env_VOLCENGINE_BASE_URL"] = t_volcengine_base_url
|
|
||||||
env_config_data["env_VOLCENGINE_KEY"] = t_volcengine_key
|
# 保存选中的API提供商的配置
|
||||||
|
env_config_data[f"env_{t_api_provider}_BASE_URL"] = t_api_base_url
|
||||||
|
env_config_data[f"env_{t_api_provider}_KEY"] = t_api_key
|
||||||
|
|
||||||
save_to_env_file(env_config_data)
|
save_to_env_file(env_config_data)
|
||||||
logger.success("配置已保存到 .env.prod 文件中")
|
logger.success("配置已保存到 .env.prod 文件中")
|
||||||
return "配置已保存"
|
return "配置已保存"
|
||||||
|
|
||||||
|
def update_api_inputs(provider):
|
||||||
|
"""
|
||||||
|
根据选择的提供商更新Base URL和API Key输入框的值
|
||||||
|
"""
|
||||||
|
base_url = env_config_data.get(f"env_{provider}_BASE_URL", "")
|
||||||
|
api_key = env_config_data.get(f"env_{provider}_KEY", "")
|
||||||
|
return base_url, api_key
|
||||||
|
|
||||||
|
# 绑定下拉列表的change事件
|
||||||
|
|
||||||
|
|
||||||
# ==============================================
|
# ==============================================
|
||||||
|
|
||||||
|
|
||||||
@@ -455,7 +518,9 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
|
|||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
value="## 全球在线MaiMBot数量: " + str(online_maimbot_data['online_clients'])
|
value="## 全球在线MaiMBot数量: " + str(online_maimbot_data['online_clients'])
|
||||||
)
|
)
|
||||||
|
gr.Markdown(
|
||||||
|
value="## 当前WebUI版本: " + str(WEBUI_VERSION)
|
||||||
|
)
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
value="### 配置文件版本:" + config_data["inner"]["version"]
|
value="### 配置文件版本:" + config_data["inner"]["version"]
|
||||||
)
|
)
|
||||||
@@ -546,81 +611,99 @@ with gr.Blocks(title="MaimBot配置文件编辑") as app:
|
|||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
'''ChatAnyWhere的baseURL和APIkey\n
|
'''日志设置\n
|
||||||
|
配置日志输出级别\n
|
||||||
改完了记得保存!!!
|
改完了记得保存!!!
|
||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
chatanywhere_base_url = gr.Textbox(
|
console_log_level = gr.Dropdown(
|
||||||
label="ChatAnyWhere的BaseURL",
|
choices=["INFO", "DEBUG", "WARNING", "ERROR", "SUCCESS"],
|
||||||
value=env_config_data["env_CHAT_ANY_WHERE_BASE_URL"],
|
label="控制台日志级别",
|
||||||
|
value=env_config_data.get("env_CONSOLE_LOG_LEVEL", "INFO"),
|
||||||
interactive=True
|
interactive=True
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
chatanywhere_key = gr.Textbox(
|
file_log_level = gr.Dropdown(
|
||||||
label="ChatAnyWhere的key",
|
choices=["INFO", "DEBUG", "WARNING", "ERROR", "SUCCESS"],
|
||||||
value=env_config_data["env_CHAT_ANY_WHERE_KEY"],
|
label="文件日志级别",
|
||||||
|
value=env_config_data.get("env_FILE_LOG_LEVEL", "DEBUG"),
|
||||||
|
interactive=True
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
default_console_log_level = gr.Dropdown(
|
||||||
|
choices=["INFO", "DEBUG", "WARNING", "ERROR", "SUCCESS", "NONE"],
|
||||||
|
label="默认控制台日志级别",
|
||||||
|
value=env_config_data.get("env_DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"),
|
||||||
|
interactive=True
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
default_file_log_level = gr.Dropdown(
|
||||||
|
choices=["INFO", "DEBUG", "WARNING", "ERROR", "SUCCESS", "NONE"],
|
||||||
|
label="默认文件日志级别",
|
||||||
|
value=env_config_data.get("env_DEFAULT_FILE_LOG_LEVEL", "DEBUG"),
|
||||||
interactive=True
|
interactive=True
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
'''SiliconFlow的baseURL和APIkey\n
|
'''API设置\n
|
||||||
|
选择API提供商并配置相应的BaseURL和Key\n
|
||||||
改完了记得保存!!!
|
改完了记得保存!!!
|
||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
siliconflow_base_url = gr.Textbox(
|
with gr.Column(scale=3):
|
||||||
label="SiliconFlow的BaseURL",
|
new_provider_input = gr.Textbox(
|
||||||
value=env_config_data["env_SILICONFLOW_BASE_URL"],
|
label="添加新提供商",
|
||||||
|
placeholder="输入新提供商名称"
|
||||||
|
)
|
||||||
|
add_provider_btn = gr.Button("添加提供商", scale=1)
|
||||||
|
with gr.Row():
|
||||||
|
api_provider = gr.Dropdown(
|
||||||
|
choices=MODEL_PROVIDER_LIST,
|
||||||
|
label="选择API提供商",
|
||||||
|
value=MODEL_PROVIDER_LIST[0] if MODEL_PROVIDER_LIST else None
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
api_base_url = gr.Textbox(
|
||||||
|
label="Base URL",
|
||||||
|
value=env_config_data.get(f"env_{MODEL_PROVIDER_LIST[0]}_BASE_URL", "") if MODEL_PROVIDER_LIST else "",
|
||||||
interactive=True
|
interactive=True
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
siliconflow_key = gr.Textbox(
|
api_key = gr.Textbox(
|
||||||
label="SiliconFlow的key",
|
label="API Key",
|
||||||
value=env_config_data["env_SILICONFLOW_KEY"],
|
value=env_config_data.get(f"env_{MODEL_PROVIDER_LIST[0]}_KEY", "") if MODEL_PROVIDER_LIST else "",
|
||||||
interactive=True
|
interactive=True
|
||||||
)
|
)
|
||||||
with gr.Row():
|
api_provider.change(
|
||||||
gr.Markdown(
|
update_api_inputs,
|
||||||
'''DeepSeek的baseURL和APIkey\n
|
inputs=[api_provider],
|
||||||
改完了记得保存!!!
|
outputs=[api_base_url, api_key]
|
||||||
'''
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
deepseek_base_url = gr.Textbox(
|
|
||||||
label="DeepSeek的BaseURL",
|
|
||||||
value=env_config_data["env_DEEP_SEEK_BASE_URL"],
|
|
||||||
interactive=True
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
deepseek_key = gr.Textbox(
|
|
||||||
label="DeepSeek的key",
|
|
||||||
value=env_config_data["env_DEEP_SEEK_KEY"],
|
|
||||||
interactive=True
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
volcengine_base_url = gr.Textbox(
|
|
||||||
label="VolcEngine的BaseURL",
|
|
||||||
value=env_config_data["env_VOLCENGINE_BASE_URL"],
|
|
||||||
interactive=True
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
volcengine_key = gr.Textbox(
|
|
||||||
label="VolcEngine的key",
|
|
||||||
value=env_config_data["env_VOLCENGINE_KEY"],
|
|
||||||
interactive=True
|
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
save_env_btn = gr.Button("保存环境配置",variant="primary")
|
save_env_btn = gr.Button("保存环境配置",variant="primary")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
save_env_btn.click(
|
save_env_btn.click(
|
||||||
save_trigger,
|
save_trigger,
|
||||||
inputs=[server_address,server_port,final_result,mongodb_host,mongodb_port,mongodb_database_name,chatanywhere_base_url,chatanywhere_key,siliconflow_base_url,siliconflow_key,deepseek_base_url,deepseek_key,volcengine_base_url,volcengine_key],
|
inputs=[server_address, server_port, final_result, mongodb_host, mongodb_port, mongodb_database_name, console_log_level, file_log_level, default_console_log_level, default_file_log_level, api_provider, api_base_url, api_key],
|
||||||
outputs=[gr.Textbox(
|
outputs=[gr.Textbox(
|
||||||
label="保存结果",
|
label="保存结果",
|
||||||
interactive=False
|
interactive=False
|
||||||
)]
|
)]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 绑定添加提供商按钮的点击事件
|
||||||
|
add_provider_btn.click(
|
||||||
|
add_new_provider,
|
||||||
|
inputs=[new_provider_input, gr.State(value=MODEL_PROVIDER_LIST)],
|
||||||
|
outputs=[gr.State(value=MODEL_PROVIDER_LIST), api_provider]
|
||||||
|
).then(
|
||||||
|
lambda x: (env_config_data.get(f"env_{x}_BASE_URL", ""), env_config_data.get(f"env_{x}_KEY", "")),
|
||||||
|
inputs=[api_provider],
|
||||||
|
outputs=[api_base_url, api_key]
|
||||||
|
)
|
||||||
with gr.TabItem("1-Bot基础设置"):
|
with gr.TabItem("1-Bot基础设置"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
|
|||||||
Reference in New Issue
Block a user