From e863ca71fd077c805cc60194b8093e1a3cbefc3d Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Sat, 11 Oct 2025 14:18:54 +0800 Subject: [PATCH] =?UTF-8?q?=E6=80=BB=E4=B9=8B=E5=B0=B1=E6=98=AF=E7=9F=A5?= =?UTF-8?q?=E8=AF=86=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/lpmm_learning_tool.py | 91 +++++++------ src/chat/knowledge/embedding_store.py | 176 ++++++++++---------------- 2 files changed, 114 insertions(+), 153 deletions(-) diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index 58aa91c64..1f4e93eca 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -192,11 +192,7 @@ async def extract_info_async(pg_hash, paragraph, llm_api): return None, pg_hash -def extract_info_sync(pg_hash, paragraph, llm_api): - return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api)) - - -def extract_information(paragraphs_dict, model_set): +async def extract_information(paragraphs_dict, model_set): logger.info("--- 步骤 2: 开始信息提取 ---") os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True) os.makedirs(TEMP_DIR, exist_ok=True) @@ -204,32 +200,35 @@ def extract_information(paragraphs_dict, model_set): llm_api = LLMRequest(model_set=model_set) failed_hashes, open_ie_docs = [], [] - with ThreadPoolExecutor(max_workers=5) as executor: - f_to_hash = { - executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items() - } - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - MofNCompleteColumn(), - "•", - TimeElapsedColumn(), - "<", - TimeRemainingColumn(), - ) as progress: - task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict)) - for future in as_completed(f_to_hash): - doc_item, failed_hash = future.result() - if failed_hash: - failed_hashes.append(failed_hash) - elif doc_item: - open_ie_docs.append(doc_item) - progress.update(task, advance=1) + tasks = [ + extract_info_async(p_hash, p, llm_api) + for p_hash, p in paragraphs_dict.items() + ] + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + "•", + TimeElapsedColumn(), + "<", + TimeRemainingColumn(), + ) as progress: + prog_task = progress.add_task("[cyan]正在提取信息...", total=len(tasks)) + for future in asyncio.as_completed(tasks): + doc_item, failed_hash = await future + if failed_hash: + failed_hashes.append(failed_hash) + elif doc_item: + open_ie_docs.append(doc_item) + progress.update(prog_task, advance=1) if open_ie_docs: - all_entities = [e for doc in open_ie_docs for e in doc["extracted_entities"]] + all_entities = [ + e for doc in open_ie_docs for e in doc["extracted_entities"] + ] num_entities = len(all_entities) avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0 avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0 @@ -314,7 +313,7 @@ async def import_data(openie_obj: OpenIE | None = None): logger.info("--- 数据导入完成 ---") -def import_from_specific_file(): +async def import_from_specific_file(): """从用户指定的 openie.json 文件导入数据""" file_path = input("请输入 openie.json 文件的完整路径: ").strip() @@ -329,7 +328,7 @@ def import_from_specific_file(): try: logger.info(f"正在从 {file_path} 加载 OpenIE 数据...") openie_obj = OpenIE.load() - asyncio.run(import_data(openie_obj=openie_obj)) + await import_data(openie_obj=openie_obj) except Exception as e: logger.error(f"从指定文件导入数据时发生错误: {e}") @@ -337,14 +336,20 @@ def import_from_specific_file(): # --- 主函数 --- -def main(): +async def async_main(): # 使用 os.path.relpath 创建相对于项目根目录的友好路径 - raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, "..")) - openie_output_relpath = os.path.relpath(OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, "..")) + raw_data_relpath = os.path.relpath( + RAW_DATA_PATH, os.path.join(ROOT_PATH, "..") + ) + openie_output_relpath = os.path.relpath( + OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, "..") + ) print("=== LPMM 知识库学习工具 ===") print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)") - print(f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)") + print( + f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)" + ) print("3. [数据导入] -> 从 openie 文件夹自动导入最新知识") print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3") print("5. [指定导入] -> 从特定的 openie.json 文件导入知识") @@ -358,16 +363,20 @@ def main(): elif choice == "2": paragraphs = preprocess_raw_data() if paragraphs: - extract_information(paragraphs, model_config.model_task_config.lpmm_qa) + await extract_information( + paragraphs, model_config.model_task_config.lpmm_qa + ) elif choice == "3": - asyncio.run(import_data()) + await import_data() elif choice == "4": paragraphs = preprocess_raw_data() if paragraphs: - extract_information(paragraphs, model_config.model_task_config.lpmm_qa) - asyncio.run(import_data()) + await extract_information( + paragraphs, model_config.model_task_config.lpmm_qa + ) + await import_data() elif choice == "5": - import_from_specific_file() + await import_from_specific_file() elif choice == "6": clear_cache() elif choice == "0": @@ -377,4 +386,4 @@ def main(): if __name__ == "__main__": - main() + asyncio.run(async_main()) diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 2c1056bb1..8eab15834 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -124,124 +124,60 @@ class EmbeddingStore: self.faiss_index = None self.idx2hash = None - @staticmethod - def _get_embedding(s: str) -> list[float]: - """获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题""" - # 创建新的事件循环并在完成后立即关闭 - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - # 创建新的LLMRequest实例 - from src.config.config import model_config - from src.llm_models.utils_model import LLMRequest - - llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") - - # 使用新的事件循环运行异步方法 - embedding, _ = loop.run_until_complete(llm.get_embedding(s)) - - if embedding and len(embedding) > 0: - return embedding - else: - logger.error(f"获取嵌入失败: {s}") - return [] - - except Exception as e: - logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}") - return [] - finally: - # 确保事件循环被正确关闭 - try: - loop.close() - except Exception: - ... - @staticmethod def _get_embeddings_batch_threaded( - strs: list[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None + strs: list[str], + main_loop: asyncio.AbstractEventLoop, + chunk_size: int = 10, + max_workers: int = 10, + progress_callback=None, ) -> list[tuple[str, list[float]]]: - """使用多线程批量获取嵌入向量 - - Args: - strs: 要获取嵌入的字符串列表 - chunk_size: 每个线程处理的数据块大小 - max_workers: 最大线程数 - progress_callback: 进度回调函数,接收一个参数表示完成的数量 - - Returns: - 包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致 - """ + """使用多线程批量获取嵌入向量, 并通过 run_coroutine_threadsafe 在主事件循环中运行异步任务""" if not strs: return [] - # 分块 - chunks = [] - for i in range(0, len(strs), chunk_size): - chunk = strs[i : i + chunk_size] - chunks.append((i, chunk)) # 保存起始索引以维持顺序 + # 导入必要的模块 + from src.config.config import model_config + from src.llm_models.utils_model import LLMRequest - # 结果存储,使用字典按索引存储以保证顺序 + # 在主线程(即主事件循环所在的线程)中创建LLMRequest实例 + # 这样可以确保它绑定到正确的事件循环 + llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") + + # 分块 + chunks = [(i, strs[i : i + chunk_size]) for i in range(0, len(strs), chunk_size)] results = {} def process_chunk(chunk_data): - """处理单个数据块的函数""" + """在工作线程中运行的函数""" start_idx, chunk_strs = chunk_data chunk_results = [] - # 为每个线程创建独立的LLMRequest实例 - from src.config.config import model_config - from src.llm_models.utils_model import LLMRequest + for i, s in enumerate(chunk_strs): + embedding = [] + try: + # 将异步的 get_embedding 调用提交到主事件循环 + future = asyncio.run_coroutine_threadsafe(llm.get_embedding(s), main_loop) + # 同步等待结果,延长超时时间 + embedding_result, _ = future.result(timeout=60) - try: - # 创建线程专用的LLM实例 - llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") + if embedding_result and len(embedding_result) > 0: + embedding = embedding_result + else: + logger.error(f"获取嵌入失败(返回为空): {s}") - for i, s in enumerate(chunk_strs): - try: - # 在线程中创建独立的事件循环 - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - embedding = loop.run_until_complete(llm.get_embedding(s)) - finally: - loop.close() - - if embedding and len(embedding) > 0: - chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量 - else: - logger.error(f"获取嵌入失败: {s}") - chunk_results.append((start_idx + i, s, [])) - - # 每完成一个嵌入立即更新进度 - if progress_callback: - progress_callback(1) - - except Exception as e: - logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}") - chunk_results.append((start_idx + i, s, [])) - - # 即使失败也要更新进度 - if progress_callback: - progress_callback(1) - - except Exception as e: - logger.error(f"创建LLM实例失败: {e}") - # 如果创建LLM实例失败,返回空结果 - for i, s in enumerate(chunk_strs): - chunk_results.append((start_idx + i, s, [])) - # 即使失败也要更新进度 + except Exception as e: + logger.error(f"在线程中获取嵌入时发生异常: {s}, 错误: {type(e).__name__}: {e}") + finally: + chunk_results.append((start_idx + i, s, embedding)) if progress_callback: progress_callback(1) return chunk_results - # 使用线程池处理 with ThreadPoolExecutor(max_workers=max_workers) as executor: - # 提交所有任务 future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks} - # 收集结果(进度已在process_chunk中实时更新) for future in as_completed(future_to_chunk): try: chunk_results = future.result() @@ -249,22 +185,14 @@ class EmbeddingStore: results[idx] = (s, embedding) except Exception as e: chunk = future_to_chunk[future] - logger.error(f"处理数据块时发生异常: {chunk}, 错误: {e}") - # 为失败的块添加空结果 + logger.error(f"处理数据块时发生严重异常: {chunk}, 错误: {e}") start_idx, chunk_strs = chunk - for i, s in enumerate(chunk_strs): - results[start_idx + i] = (s, []) + for i, s_item in enumerate(chunk_strs): + if (start_idx + i) not in results: + results[start_idx + i] = (s_item, []) # 按原始顺序返回结果 - ordered_results = [] - for i in range(len(strs)): - if i in results: - ordered_results.append(results[i]) - else: - # 防止遗漏 - ordered_results.append((strs[i], [])) - - return ordered_results + return [results.get(i, (strs[i], [])) for i in range(len(strs))] @staticmethod def get_test_file_path(): @@ -274,9 +202,17 @@ class EmbeddingStore: """保存测试字符串的嵌入到本地(使用多线程优化)""" logger.info("开始保存测试字符串的嵌入向量...") + # 获取当前正在运行的事件循环 + try: + main_loop = asyncio.get_running_loop() + except RuntimeError: + logger.error("无法获取正在运行的事件循环。请确保在异步上下文中调用此方法。") + return + # 使用多线程批量获取测试字符串的嵌入 embedding_results = self._get_embeddings_batch_threaded( EMBEDDING_TEST_STRINGS, + main_loop, chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)), max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)), ) @@ -288,8 +224,6 @@ class EmbeddingStore: test_vectors[str(idx)] = embedding else: logger.error(f"获取测试字符串嵌入失败: {s}") - # 使用原始单线程方法作为后备 - test_vectors[str(idx)] = self._get_embedding(s) with open(self.get_test_file_path(), "w", encoding="utf-8") as f: f.write(orjson.dumps(test_vectors, option=orjson.OPT_INDENT_2).decode("utf-8")) @@ -321,9 +255,17 @@ class EmbeddingStore: logger.info("开始检验嵌入模型一致性...") + # 获取当前正在运行的事件循环 + try: + main_loop = asyncio.get_running_loop() + except RuntimeError: + logger.error("无法获取正在运行的事件循环。请确保在异步上下文中调用此方法。") + return False + # 使用多线程批量获取当前模型的嵌入 embedding_results = self._get_embeddings_batch_threaded( EMBEDDING_TEST_STRINGS, + main_loop, chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)), max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)), ) @@ -383,11 +325,20 @@ class EmbeddingStore: progress.update(task, advance=already_processed) if new_strs: + try: + main_loop = asyncio.get_running_loop() + except RuntimeError: + logger.error("无法获取正在运行的事件循环。请确保在异步上下文中调用此方法。") + # 更新进度条以反映未处理的项目 + progress.update(task, advance=len(new_strs)) + return + # 使用实例配置的参数,智能调整分块和线程数 optimal_chunk_size = max( MIN_CHUNK_SIZE, min( - self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size + self.chunk_size, + len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size, ), ) optimal_max_workers = min( @@ -404,12 +355,13 @@ class EmbeddingStore: # 批量获取嵌入,并实时更新进度 embedding_results = self._get_embeddings_batch_threaded( new_strs, + main_loop, chunk_size=optimal_chunk_size, max_workers=optimal_max_workers, progress_callback=update_progress, ) - # 存入结果(不再需要在这里更新进度,因为已经在回调中更新了) + # 存入结果 for s, embedding in embedding_results: item_hash = self.namespace + "-" + get_sha256(s) if embedding: # 只有成功获取到嵌入才存入