From f5894e01539520744017c1499b8c01714d23feb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Mon, 5 May 2025 22:04:50 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BC=98=E5=8C=96=E5=B5=8C=E5=85=A5?= =?UTF-8?q?=E5=BA=93=E5=8A=A0=E8=BD=BD=E8=BF=87=E7=A8=8B=EF=BC=8C=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E8=BF=9B=E5=BA=A6=E6=9D=A1=E6=98=BE=E7=A4=BA=EF=BC=9B?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E9=A6=96=E6=AC=A1=E5=AF=BC=E5=85=A5=E7=9F=A5?= =?UTF-8?q?=E8=AF=86=E6=97=B6=E7=9A=84=E9=94=99=E8=AF=AF=E6=8F=90=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/import_openie.py | 3 ++- src/plugins/knowledge/src/embedding_store.py | 22 ++++++++++++++++---- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/scripts/import_openie.py b/scripts/import_openie.py index f81cbb5c7..472667c14 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -204,7 +204,8 @@ def main(): logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型") # print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。") sys.exit(1) - logger.error("如果你是第一次导入知识,请忽略此错误") + if "不存在" in str(e): + logger.error("如果你是第一次导入知识,请忽略此错误") logger.info("Embedding库加载完成") # 初始化KG kg_manager = KGManager() diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py index 5ee92a869..7d012b19b 100644 --- a/src/plugins/knowledge/src/embedding_store.py +++ b/src/plugins/knowledge/src/embedding_store.py @@ -6,7 +6,7 @@ from typing import Dict, List, Tuple import numpy as np import pandas as pd -import tqdm +# import tqdm import faiss from .llm_client import LLMClient @@ -194,11 +194,25 @@ class EmbeddingStore: """从文件中加载""" if not os.path.exists(self.embedding_file_path): raise Exception(f"文件{self.embedding_file_path}不存在") - logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库") data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow") - for _, row in tqdm.tqdm(data_frame.iterrows(), total=len(data_frame)): - self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"]) + total = len(data_frame) + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + "•", + TimeElapsedColumn(), + "<", + TimeRemainingColumn(), + transient=False, + ) as progress: + task = progress.add_task("加载嵌入库", total=total) + for _, row in data_frame.iterrows(): + self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"]) + progress.update(task, advance=1) logger.info(f"{self.namespace}嵌入库加载成功") try: