fix: 优化嵌入库加载过程,添加进度条显示;修复首次导入知识时的错误提示
This commit is contained in:
@@ -204,6 +204,7 @@ def main():
|
||||
logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型")
|
||||
# print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
|
||||
sys.exit(1)
|
||||
if "不存在" in str(e):
|
||||
logger.error("如果你是第一次导入知识,请忽略此错误")
|
||||
logger.info("Embedding库加载完成")
|
||||
# 初始化KG
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tqdm
|
||||
# import tqdm
|
||||
import faiss
|
||||
|
||||
from .llm_client import LLMClient
|
||||
@@ -194,11 +194,25 @@ class EmbeddingStore:
|
||||
"""从文件中加载"""
|
||||
if not os.path.exists(self.embedding_file_path):
|
||||
raise Exception(f"文件{self.embedding_file_path}不存在")
|
||||
|
||||
logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库")
|
||||
data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow")
|
||||
for _, row in tqdm.tqdm(data_frame.iterrows(), total=len(data_frame)):
|
||||
total = len(data_frame)
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
MofNCompleteColumn(),
|
||||
"•",
|
||||
TimeElapsedColumn(),
|
||||
"<",
|
||||
TimeRemainingColumn(),
|
||||
transient=False,
|
||||
) as progress:
|
||||
task = progress.add_task("加载嵌入库", total=total)
|
||||
for _, row in data_frame.iterrows():
|
||||
self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"])
|
||||
progress.update(task, advance=1)
|
||||
logger.info(f"{self.namespace}嵌入库加载成功")
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user