fix: 优化嵌入库加载过程,添加进度条显示;修复首次导入知识时的错误提示
This commit is contained in:
@@ -204,7 +204,8 @@ def main():
|
|||||||
logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型")
|
logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型")
|
||||||
# print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
|
# print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
logger.error("如果你是第一次导入知识,请忽略此错误")
|
if "不存在" in str(e):
|
||||||
|
logger.error("如果你是第一次导入知识,请忽略此错误")
|
||||||
logger.info("Embedding库加载完成")
|
logger.info("Embedding库加载完成")
|
||||||
# 初始化KG
|
# 初始化KG
|
||||||
kg_manager = KGManager()
|
kg_manager = KGManager()
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Dict, List, Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import tqdm
|
# import tqdm
|
||||||
import faiss
|
import faiss
|
||||||
|
|
||||||
from .llm_client import LLMClient
|
from .llm_client import LLMClient
|
||||||
@@ -194,11 +194,25 @@ class EmbeddingStore:
|
|||||||
"""从文件中加载"""
|
"""从文件中加载"""
|
||||||
if not os.path.exists(self.embedding_file_path):
|
if not os.path.exists(self.embedding_file_path):
|
||||||
raise Exception(f"文件{self.embedding_file_path}不存在")
|
raise Exception(f"文件{self.embedding_file_path}不存在")
|
||||||
|
|
||||||
logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库")
|
logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库")
|
||||||
data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow")
|
data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow")
|
||||||
for _, row in tqdm.tqdm(data_frame.iterrows(), total=len(data_frame)):
|
total = len(data_frame)
|
||||||
self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"])
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
MofNCompleteColumn(),
|
||||||
|
"•",
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
"<",
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
transient=False,
|
||||||
|
) as progress:
|
||||||
|
task = progress.add_task("加载嵌入库", total=total)
|
||||||
|
for _, row in data_frame.iterrows():
|
||||||
|
self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"])
|
||||||
|
progress.update(task, advance=1)
|
||||||
logger.info(f"{self.namespace}嵌入库加载成功")
|
logger.info(f"{self.namespace}嵌入库加载成功")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user