diff --git a/scripts/import_openie.py b/scripts/import_openie.py index f81cbb5c7..472667c14 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -204,7 +204,8 @@ def main(): logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型") # print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。") sys.exit(1) - logger.error("如果你是第一次导入知识,请忽略此错误") + if "不存在" in str(e): + logger.error("如果你是第一次导入知识,请忽略此错误") logger.info("Embedding库加载完成") # 初始化KG kg_manager = KGManager() diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py index 5ee92a869..7d012b19b 100644 --- a/src/plugins/knowledge/src/embedding_store.py +++ b/src/plugins/knowledge/src/embedding_store.py @@ -6,7 +6,7 @@ from typing import Dict, List, Tuple import numpy as np import pandas as pd -import tqdm +# import tqdm import faiss from .llm_client import LLMClient @@ -194,11 +194,25 @@ class EmbeddingStore: """从文件中加载""" if not os.path.exists(self.embedding_file_path): raise Exception(f"文件{self.embedding_file_path}不存在") - logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库") data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow") - for _, row in tqdm.tqdm(data_frame.iterrows(), total=len(data_frame)): - self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"]) + total = len(data_frame) + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + "•", + TimeElapsedColumn(), + "<", + TimeRemainingColumn(), + transient=False, + ) as progress: + task = progress.add_task("加载嵌入库", total=total) + for _, row in data_frame.iterrows(): + self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"]) + progress.update(task, advance=1) logger.info(f"{self.namespace}嵌入库加载成功") try: