总之就是知识库
This commit is contained in:
@@ -192,11 +192,7 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
|||||||
return None, pg_hash
|
return None, pg_hash
|
||||||
|
|
||||||
|
|
||||||
def extract_info_sync(pg_hash, paragraph, llm_api):
|
async def extract_information(paragraphs_dict, model_set):
|
||||||
return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api))
|
|
||||||
|
|
||||||
|
|
||||||
def extract_information(paragraphs_dict, model_set):
|
|
||||||
logger.info("--- 步骤 2: 开始信息提取 ---")
|
logger.info("--- 步骤 2: 开始信息提取 ---")
|
||||||
os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True)
|
os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True)
|
||||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||||
@@ -204,10 +200,11 @@ def extract_information(paragraphs_dict, model_set):
|
|||||||
llm_api = LLMRequest(model_set=model_set)
|
llm_api = LLMRequest(model_set=model_set)
|
||||||
failed_hashes, open_ie_docs = [], []
|
failed_hashes, open_ie_docs = [], []
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
tasks = [
|
||||||
f_to_hash = {
|
extract_info_async(p_hash, p, llm_api)
|
||||||
executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items()
|
for p_hash, p in paragraphs_dict.items()
|
||||||
}
|
]
|
||||||
|
|
||||||
with Progress(
|
with Progress(
|
||||||
SpinnerColumn(),
|
SpinnerColumn(),
|
||||||
TextColumn("[progress.description]{task.description}"),
|
TextColumn("[progress.description]{task.description}"),
|
||||||
@@ -219,17 +216,19 @@ def extract_information(paragraphs_dict, model_set):
|
|||||||
"<",
|
"<",
|
||||||
TimeRemainingColumn(),
|
TimeRemainingColumn(),
|
||||||
) as progress:
|
) as progress:
|
||||||
task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict))
|
prog_task = progress.add_task("[cyan]正在提取信息...", total=len(tasks))
|
||||||
for future in as_completed(f_to_hash):
|
for future in asyncio.as_completed(tasks):
|
||||||
doc_item, failed_hash = future.result()
|
doc_item, failed_hash = await future
|
||||||
if failed_hash:
|
if failed_hash:
|
||||||
failed_hashes.append(failed_hash)
|
failed_hashes.append(failed_hash)
|
||||||
elif doc_item:
|
elif doc_item:
|
||||||
open_ie_docs.append(doc_item)
|
open_ie_docs.append(doc_item)
|
||||||
progress.update(task, advance=1)
|
progress.update(prog_task, advance=1)
|
||||||
|
|
||||||
if open_ie_docs:
|
if open_ie_docs:
|
||||||
all_entities = [e for doc in open_ie_docs for e in doc["extracted_entities"]]
|
all_entities = [
|
||||||
|
e for doc in open_ie_docs for e in doc["extracted_entities"]
|
||||||
|
]
|
||||||
num_entities = len(all_entities)
|
num_entities = len(all_entities)
|
||||||
avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0
|
avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0
|
||||||
avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0
|
avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0
|
||||||
@@ -314,7 +313,7 @@ async def import_data(openie_obj: OpenIE | None = None):
|
|||||||
logger.info("--- 数据导入完成 ---")
|
logger.info("--- 数据导入完成 ---")
|
||||||
|
|
||||||
|
|
||||||
def import_from_specific_file():
|
async def import_from_specific_file():
|
||||||
"""从用户指定的 openie.json 文件导入数据"""
|
"""从用户指定的 openie.json 文件导入数据"""
|
||||||
file_path = input("请输入 openie.json 文件的完整路径: ").strip()
|
file_path = input("请输入 openie.json 文件的完整路径: ").strip()
|
||||||
|
|
||||||
@@ -329,7 +328,7 @@ def import_from_specific_file():
|
|||||||
try:
|
try:
|
||||||
logger.info(f"正在从 {file_path} 加载 OpenIE 数据...")
|
logger.info(f"正在从 {file_path} 加载 OpenIE 数据...")
|
||||||
openie_obj = OpenIE.load()
|
openie_obj = OpenIE.load()
|
||||||
asyncio.run(import_data(openie_obj=openie_obj))
|
await import_data(openie_obj=openie_obj)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从指定文件导入数据时发生错误: {e}")
|
logger.error(f"从指定文件导入数据时发生错误: {e}")
|
||||||
|
|
||||||
@@ -337,14 +336,20 @@ def import_from_specific_file():
|
|||||||
# --- 主函数 ---
|
# --- 主函数 ---
|
||||||
|
|
||||||
|
|
||||||
def main():
|
async def async_main():
|
||||||
# 使用 os.path.relpath 创建相对于项目根目录的友好路径
|
# 使用 os.path.relpath 创建相对于项目根目录的友好路径
|
||||||
raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, ".."))
|
raw_data_relpath = os.path.relpath(
|
||||||
openie_output_relpath = os.path.relpath(OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, ".."))
|
RAW_DATA_PATH, os.path.join(ROOT_PATH, "..")
|
||||||
|
)
|
||||||
|
openie_output_relpath = os.path.relpath(
|
||||||
|
OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, "..")
|
||||||
|
)
|
||||||
|
|
||||||
print("=== LPMM 知识库学习工具 ===")
|
print("=== LPMM 知识库学习工具 ===")
|
||||||
print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)")
|
print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)")
|
||||||
print(f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)")
|
print(
|
||||||
|
f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)"
|
||||||
|
)
|
||||||
print("3. [数据导入] -> 从 openie 文件夹自动导入最新知识")
|
print("3. [数据导入] -> 从 openie 文件夹自动导入最新知识")
|
||||||
print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3")
|
print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3")
|
||||||
print("5. [指定导入] -> 从特定的 openie.json 文件导入知识")
|
print("5. [指定导入] -> 从特定的 openie.json 文件导入知识")
|
||||||
@@ -358,16 +363,20 @@ def main():
|
|||||||
elif choice == "2":
|
elif choice == "2":
|
||||||
paragraphs = preprocess_raw_data()
|
paragraphs = preprocess_raw_data()
|
||||||
if paragraphs:
|
if paragraphs:
|
||||||
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
await extract_information(
|
||||||
|
paragraphs, model_config.model_task_config.lpmm_qa
|
||||||
|
)
|
||||||
elif choice == "3":
|
elif choice == "3":
|
||||||
asyncio.run(import_data())
|
await import_data()
|
||||||
elif choice == "4":
|
elif choice == "4":
|
||||||
paragraphs = preprocess_raw_data()
|
paragraphs = preprocess_raw_data()
|
||||||
if paragraphs:
|
if paragraphs:
|
||||||
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
await extract_information(
|
||||||
asyncio.run(import_data())
|
paragraphs, model_config.model_task_config.lpmm_qa
|
||||||
|
)
|
||||||
|
await import_data()
|
||||||
elif choice == "5":
|
elif choice == "5":
|
||||||
import_from_specific_file()
|
await import_from_specific_file()
|
||||||
elif choice == "6":
|
elif choice == "6":
|
||||||
clear_cache()
|
clear_cache()
|
||||||
elif choice == "0":
|
elif choice == "0":
|
||||||
@@ -377,4 +386,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
asyncio.run(async_main())
|
||||||
|
|||||||
@@ -124,124 +124,60 @@ class EmbeddingStore:
|
|||||||
self.faiss_index = None
|
self.faiss_index = None
|
||||||
self.idx2hash = None
|
self.idx2hash = None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_embedding(s: str) -> list[float]:
|
|
||||||
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
|
||||||
# 创建新的事件循环并在完成后立即关闭
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 创建新的LLMRequest实例
|
|
||||||
from src.config.config import model_config
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
|
|
||||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
|
||||||
|
|
||||||
# 使用新的事件循环运行异步方法
|
|
||||||
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
|
|
||||||
|
|
||||||
if embedding and len(embedding) > 0:
|
|
||||||
return embedding
|
|
||||||
else:
|
|
||||||
logger.error(f"获取嵌入失败: {s}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
|
||||||
return []
|
|
||||||
finally:
|
|
||||||
# 确保事件循环被正确关闭
|
|
||||||
try:
|
|
||||||
loop.close()
|
|
||||||
except Exception:
|
|
||||||
...
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_embeddings_batch_threaded(
|
def _get_embeddings_batch_threaded(
|
||||||
strs: list[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
strs: list[str],
|
||||||
|
main_loop: asyncio.AbstractEventLoop,
|
||||||
|
chunk_size: int = 10,
|
||||||
|
max_workers: int = 10,
|
||||||
|
progress_callback=None,
|
||||||
) -> list[tuple[str, list[float]]]:
|
) -> list[tuple[str, list[float]]]:
|
||||||
"""使用多线程批量获取嵌入向量
|
"""使用多线程批量获取嵌入向量, 并通过 run_coroutine_threadsafe 在主事件循环中运行异步任务"""
|
||||||
|
|
||||||
Args:
|
|
||||||
strs: 要获取嵌入的字符串列表
|
|
||||||
chunk_size: 每个线程处理的数据块大小
|
|
||||||
max_workers: 最大线程数
|
|
||||||
progress_callback: 进度回调函数,接收一个参数表示完成的数量
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致
|
|
||||||
"""
|
|
||||||
if not strs:
|
if not strs:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 分块
|
# 导入必要的模块
|
||||||
chunks = []
|
|
||||||
for i in range(0, len(strs), chunk_size):
|
|
||||||
chunk = strs[i : i + chunk_size]
|
|
||||||
chunks.append((i, chunk)) # 保存起始索引以维持顺序
|
|
||||||
|
|
||||||
# 结果存储,使用字典按索引存储以保证顺序
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
def process_chunk(chunk_data):
|
|
||||||
"""处理单个数据块的函数"""
|
|
||||||
start_idx, chunk_strs = chunk_data
|
|
||||||
chunk_results = []
|
|
||||||
|
|
||||||
# 为每个线程创建独立的LLMRequest实例
|
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
try:
|
# 在主线程(即主事件循环所在的线程)中创建LLMRequest实例
|
||||||
# 创建线程专用的LLM实例
|
# 这样可以确保它绑定到正确的事件循环
|
||||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||||
|
|
||||||
for i, s in enumerate(chunk_strs):
|
# 分块
|
||||||
try:
|
chunks = [(i, strs[i : i + chunk_size]) for i in range(0, len(strs), chunk_size)]
|
||||||
# 在线程中创建独立的事件循环
|
results = {}
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
try:
|
|
||||||
embedding = loop.run_until_complete(llm.get_embedding(s))
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
if embedding and len(embedding) > 0:
|
def process_chunk(chunk_data):
|
||||||
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
|
"""在工作线程中运行的函数"""
|
||||||
|
start_idx, chunk_strs = chunk_data
|
||||||
|
chunk_results = []
|
||||||
|
|
||||||
|
for i, s in enumerate(chunk_strs):
|
||||||
|
embedding = []
|
||||||
|
try:
|
||||||
|
# 将异步的 get_embedding 调用提交到主事件循环
|
||||||
|
future = asyncio.run_coroutine_threadsafe(llm.get_embedding(s), main_loop)
|
||||||
|
# 同步等待结果,延长超时时间
|
||||||
|
embedding_result, _ = future.result(timeout=60)
|
||||||
|
|
||||||
|
if embedding_result and len(embedding_result) > 0:
|
||||||
|
embedding = embedding_result
|
||||||
else:
|
else:
|
||||||
logger.error(f"获取嵌入失败: {s}")
|
logger.error(f"获取嵌入失败(返回为空): {s}")
|
||||||
chunk_results.append((start_idx + i, s, []))
|
|
||||||
|
|
||||||
# 每完成一个嵌入立即更新进度
|
|
||||||
if progress_callback:
|
|
||||||
progress_callback(1)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
logger.error(f"在线程中获取嵌入时发生异常: {s}, 错误: {type(e).__name__}: {e}")
|
||||||
chunk_results.append((start_idx + i, s, []))
|
finally:
|
||||||
|
chunk_results.append((start_idx + i, s, embedding))
|
||||||
# 即使失败也要更新进度
|
|
||||||
if progress_callback:
|
|
||||||
progress_callback(1)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"创建LLM实例失败: {e}")
|
|
||||||
# 如果创建LLM实例失败,返回空结果
|
|
||||||
for i, s in enumerate(chunk_strs):
|
|
||||||
chunk_results.append((start_idx + i, s, []))
|
|
||||||
# 即使失败也要更新进度
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(1)
|
progress_callback(1)
|
||||||
|
|
||||||
return chunk_results
|
return chunk_results
|
||||||
|
|
||||||
# 使用线程池处理
|
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
# 提交所有任务
|
|
||||||
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
|
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
|
||||||
|
|
||||||
# 收集结果(进度已在process_chunk中实时更新)
|
|
||||||
for future in as_completed(future_to_chunk):
|
for future in as_completed(future_to_chunk):
|
||||||
try:
|
try:
|
||||||
chunk_results = future.result()
|
chunk_results = future.result()
|
||||||
@@ -249,22 +185,14 @@ class EmbeddingStore:
|
|||||||
results[idx] = (s, embedding)
|
results[idx] = (s, embedding)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
chunk = future_to_chunk[future]
|
chunk = future_to_chunk[future]
|
||||||
logger.error(f"处理数据块时发生异常: {chunk}, 错误: {e}")
|
logger.error(f"处理数据块时发生严重异常: {chunk}, 错误: {e}")
|
||||||
# 为失败的块添加空结果
|
|
||||||
start_idx, chunk_strs = chunk
|
start_idx, chunk_strs = chunk
|
||||||
for i, s in enumerate(chunk_strs):
|
for i, s_item in enumerate(chunk_strs):
|
||||||
results[start_idx + i] = (s, [])
|
if (start_idx + i) not in results:
|
||||||
|
results[start_idx + i] = (s_item, [])
|
||||||
|
|
||||||
# 按原始顺序返回结果
|
# 按原始顺序返回结果
|
||||||
ordered_results = []
|
return [results.get(i, (strs[i], [])) for i in range(len(strs))]
|
||||||
for i in range(len(strs)):
|
|
||||||
if i in results:
|
|
||||||
ordered_results.append(results[i])
|
|
||||||
else:
|
|
||||||
# 防止遗漏
|
|
||||||
ordered_results.append((strs[i], []))
|
|
||||||
|
|
||||||
return ordered_results
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_test_file_path():
|
def get_test_file_path():
|
||||||
@@ -274,9 +202,17 @@ class EmbeddingStore:
|
|||||||
"""保存测试字符串的嵌入到本地(使用多线程优化)"""
|
"""保存测试字符串的嵌入到本地(使用多线程优化)"""
|
||||||
logger.info("开始保存测试字符串的嵌入向量...")
|
logger.info("开始保存测试字符串的嵌入向量...")
|
||||||
|
|
||||||
|
# 获取当前正在运行的事件循环
|
||||||
|
try:
|
||||||
|
main_loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
logger.error("无法获取正在运行的事件循环。请确保在异步上下文中调用此方法。")
|
||||||
|
return
|
||||||
|
|
||||||
# 使用多线程批量获取测试字符串的嵌入
|
# 使用多线程批量获取测试字符串的嵌入
|
||||||
embedding_results = self._get_embeddings_batch_threaded(
|
embedding_results = self._get_embeddings_batch_threaded(
|
||||||
EMBEDDING_TEST_STRINGS,
|
EMBEDDING_TEST_STRINGS,
|
||||||
|
main_loop,
|
||||||
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
||||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
|
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
|
||||||
)
|
)
|
||||||
@@ -288,8 +224,6 @@ class EmbeddingStore:
|
|||||||
test_vectors[str(idx)] = embedding
|
test_vectors[str(idx)] = embedding
|
||||||
else:
|
else:
|
||||||
logger.error(f"获取测试字符串嵌入失败: {s}")
|
logger.error(f"获取测试字符串嵌入失败: {s}")
|
||||||
# 使用原始单线程方法作为后备
|
|
||||||
test_vectors[str(idx)] = self._get_embedding(s)
|
|
||||||
|
|
||||||
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
|
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
|
||||||
f.write(orjson.dumps(test_vectors, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
f.write(orjson.dumps(test_vectors, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||||
@@ -321,9 +255,17 @@ class EmbeddingStore:
|
|||||||
|
|
||||||
logger.info("开始检验嵌入模型一致性...")
|
logger.info("开始检验嵌入模型一致性...")
|
||||||
|
|
||||||
|
# 获取当前正在运行的事件循环
|
||||||
|
try:
|
||||||
|
main_loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
logger.error("无法获取正在运行的事件循环。请确保在异步上下文中调用此方法。")
|
||||||
|
return False
|
||||||
|
|
||||||
# 使用多线程批量获取当前模型的嵌入
|
# 使用多线程批量获取当前模型的嵌入
|
||||||
embedding_results = self._get_embeddings_batch_threaded(
|
embedding_results = self._get_embeddings_batch_threaded(
|
||||||
EMBEDDING_TEST_STRINGS,
|
EMBEDDING_TEST_STRINGS,
|
||||||
|
main_loop,
|
||||||
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
||||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
|
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
|
||||||
)
|
)
|
||||||
@@ -383,11 +325,20 @@ class EmbeddingStore:
|
|||||||
progress.update(task, advance=already_processed)
|
progress.update(task, advance=already_processed)
|
||||||
|
|
||||||
if new_strs:
|
if new_strs:
|
||||||
|
try:
|
||||||
|
main_loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
logger.error("无法获取正在运行的事件循环。请确保在异步上下文中调用此方法。")
|
||||||
|
# 更新进度条以反映未处理的项目
|
||||||
|
progress.update(task, advance=len(new_strs))
|
||||||
|
return
|
||||||
|
|
||||||
# 使用实例配置的参数,智能调整分块和线程数
|
# 使用实例配置的参数,智能调整分块和线程数
|
||||||
optimal_chunk_size = max(
|
optimal_chunk_size = max(
|
||||||
MIN_CHUNK_SIZE,
|
MIN_CHUNK_SIZE,
|
||||||
min(
|
min(
|
||||||
self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size
|
self.chunk_size,
|
||||||
|
len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
optimal_max_workers = min(
|
optimal_max_workers = min(
|
||||||
@@ -404,12 +355,13 @@ class EmbeddingStore:
|
|||||||
# 批量获取嵌入,并实时更新进度
|
# 批量获取嵌入,并实时更新进度
|
||||||
embedding_results = self._get_embeddings_batch_threaded(
|
embedding_results = self._get_embeddings_batch_threaded(
|
||||||
new_strs,
|
new_strs,
|
||||||
|
main_loop,
|
||||||
chunk_size=optimal_chunk_size,
|
chunk_size=optimal_chunk_size,
|
||||||
max_workers=optimal_max_workers,
|
max_workers=optimal_max_workers,
|
||||||
progress_callback=update_progress,
|
progress_callback=update_progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 存入结果(不再需要在这里更新进度,因为已经在回调中更新了)
|
# 存入结果
|
||||||
for s, embedding in embedding_results:
|
for s, embedding in embedding_results:
|
||||||
item_hash = self.namespace + "-" + get_sha256(s)
|
item_hash = self.namespace + "-" + get_sha256(s)
|
||||||
if embedding: # 只有成功获取到嵌入才存入
|
if embedding: # 只有成功获取到嵌入才存入
|
||||||
|
|||||||
Reference in New Issue
Block a user