diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 63a4d9852..1177650d4 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -24,46 +24,6 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie") logger = get_logger("OpenIE导入") -ENV_FILE = os.path.join(ROOT_PATH, ".env") - -if os.path.exists(".env"): - load_dotenv(".env", override=True) - print("成功加载环境变量配置") -else: - print("未找到.env文件,请确保程序所需的环境变量被正确设置") - raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量") - -env_mask = {key: os.getenv(key) for key in os.environ} -def scan_provider(env_config: dict): - provider = {} - - # 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重 - # 避免 GPG_KEY 这样的变量干扰检查 - env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items())) - - # 遍历 env_config 的所有键 - for key in env_config: - # 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式 - if key.endswith("_BASE_URL") or key.endswith("_KEY"): - # 提取 provider 名称 - provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分 - - # 初始化 provider 的字典(如果尚未初始化) - if provider_name not in provider: - provider[provider_name] = {"url": None, "key": None} - - # 根据键的类型填充 url 或 key - if key.endswith("_BASE_URL"): - provider[provider_name]["url"] = env_config[key] - elif key.endswith("_KEY"): - provider[provider_name]["key"] = env_config[key] - - # 检查每个 provider 是否同时存在 url 和 key - for provider_name, config in provider.items(): - if config["url"] is None or config["key"] is None: - logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}") - raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量") - def ensure_openie_dir(): """确保OpenIE数据目录存在""" if not os.path.exists(OPENIE_DIR): @@ -214,8 +174,6 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k def main(): # sourcery skip: dict-comprehension # 新增确认提示 - env_config = {key: os.getenv(key) for key in os.environ} - scan_provider(env_config) print("=== 重要操作确认 ===") print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型") print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快") diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py index cb545a44d..47ad55a8b 100644 --- a/scripts/info_extraction.py +++ b/scripts/info_extraction.py @@ -27,7 +27,6 @@ from rich.progress import ( from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from dotenv import load_dotenv logger = get_logger("LPMM知识库-信息提取") @@ -36,45 +35,6 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) TEMP_DIR = os.path.join(ROOT_PATH, "temp") # IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data") OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") -ENV_FILE = os.path.join(ROOT_PATH, ".env") - -if os.path.exists(".env"): - load_dotenv(".env", override=True) - print("成功加载环境变量配置") -else: - print("未找到.env文件,请确保程序所需的环境变量被正确设置") - raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量") - -env_mask = {key: os.getenv(key) for key in os.environ} -def scan_provider(env_config: dict): - provider = {} - - # 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重 - # 避免 GPG_KEY 这样的变量干扰检查 - env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items())) - - # 遍历 env_config 的所有键 - for key in env_config: - # 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式 - if key.endswith("_BASE_URL") or key.endswith("_KEY"): - # 提取 provider 名称 - provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分 - - # 初始化 provider 的字典(如果尚未初始化) - if provider_name not in provider: - provider[provider_name] = {"url": None, "key": None} - - # 根据键的类型填充 url 或 key - if key.endswith("_BASE_URL"): - provider[provider_name]["url"] = env_config[key] - elif key.endswith("_KEY"): - provider[provider_name]["key"] = env_config[key] - - # 检查每个 provider 是否同时存在 url 和 key - for provider_name, config in provider.items(): - if config["url"] is None or config["key"] is None: - logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}") - raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量") def ensure_dirs(): """确保临时目录和输出目录存在""" @@ -158,8 +118,6 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method # 设置信号处理器 signal.signal(signal.SIGINT, signal_handler) ensure_dirs() # 确保目录存在 - env_config = {key: os.getenv(key) for key in os.environ} - scan_provider(env_config) # 新增用户确认提示 print("=== 重要操作确认,请认真阅读以下内容哦 ===") print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")