fix: 优化嵌入库加载过程,添加进度条显示;修复首次导入知识时的错误提示

This commit is contained in:
墨梓柒
2025-05-05 22:04:50 +08:00
parent 0147f49ee9
commit f5894e0153
2 changed files with 20 additions and 5 deletions

View File

@@ -204,7 +204,8 @@ def main():
logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型") logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型")
# print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。") # print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
sys.exit(1) sys.exit(1)
logger.error("如果你是第一次导入知识,请忽略此错误") if "不存在" in str(e):
logger.error("如果你是第一次导入知识,请忽略此错误")
logger.info("Embedding库加载完成") logger.info("Embedding库加载完成")
# 初始化KG # 初始化KG
kg_manager = KGManager() kg_manager = KGManager()

View File

@@ -6,7 +6,7 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import tqdm # import tqdm
import faiss import faiss
from .llm_client import LLMClient from .llm_client import LLMClient
@@ -194,11 +194,25 @@ class EmbeddingStore:
"""从文件中加载""" """从文件中加载"""
if not os.path.exists(self.embedding_file_path): if not os.path.exists(self.embedding_file_path):
raise Exception(f"文件{self.embedding_file_path}不存在") raise Exception(f"文件{self.embedding_file_path}不存在")
logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库") logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库")
data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow") data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow")
for _, row in tqdm.tqdm(data_frame.iterrows(), total=len(data_frame)): total = len(data_frame)
self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"]) with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task("加载嵌入库", total=total)
for _, row in data_frame.iterrows():
self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"])
progress.update(task, advance=1)
logger.info(f"{self.namespace}嵌入库加载成功") logger.info(f"{self.namespace}嵌入库加载成功")
try: try: