fix:修复LPMM学习问题
This commit is contained in:
@@ -15,6 +15,7 @@ from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
from src.manager.local_store_manager import local_storage
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
# 添加项目根目录到 sys.path
|
||||
@@ -23,6 +24,45 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
|
||||
logger = get_logger("OpenIE导入")
|
||||
|
||||
ENV_FILE = os.path.join(ROOT_PATH, ".env")
|
||||
|
||||
if os.path.exists(".env"):
|
||||
load_dotenv(".env", override=True)
|
||||
print("成功加载环境变量配置")
|
||||
else:
|
||||
print("未找到.env文件,请确保程序所需的环境变量被正确设置")
|
||||
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
|
||||
|
||||
env_mask = {key: os.getenv(key) for key in os.environ}
|
||||
def scan_provider(env_config: dict):
|
||||
provider = {}
|
||||
|
||||
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
|
||||
# 避免 GPG_KEY 这样的变量干扰检查
|
||||
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
|
||||
|
||||
# 遍历 env_config 的所有键
|
||||
for key in env_config:
|
||||
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
|
||||
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
|
||||
# 提取 provider 名称
|
||||
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
|
||||
|
||||
# 初始化 provider 的字典(如果尚未初始化)
|
||||
if provider_name not in provider:
|
||||
provider[provider_name] = {"url": None, "key": None}
|
||||
|
||||
# 根据键的类型填充 url 或 key
|
||||
if key.endswith("_BASE_URL"):
|
||||
provider[provider_name]["url"] = env_config[key]
|
||||
elif key.endswith("_KEY"):
|
||||
provider[provider_name]["key"] = env_config[key]
|
||||
|
||||
# 检查每个 provider 是否同时存在 url 和 key
|
||||
for provider_name, config in provider.items():
|
||||
if config["url"] is None or config["key"] is None:
|
||||
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
|
||||
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
|
||||
|
||||
def ensure_openie_dir():
|
||||
"""确保OpenIE数据目录存在"""
|
||||
@@ -174,6 +214,8 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
||||
|
||||
def main(): # sourcery skip: dict-comprehension
|
||||
# 新增确认提示
|
||||
env_config = {key: os.getenv(key) for key in os.environ}
|
||||
scan_provider(env_config)
|
||||
print("=== 重要操作确认 ===")
|
||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||
print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快")
|
||||
|
||||
Reference in New Issue
Block a user