From 10b014820409a648a0399e6ba4a5a2ed9f7471b0 Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Sat, 18 Oct 2025 17:09:18 +0800 Subject: [PATCH] =?UTF-8?q?Revert=20"=E6=80=BB=E4=B9=8B=E5=B0=B1=E6=98=AF?= =?UTF-8?q?=E7=9F=A5=E8=AF=86=E5=BA=93"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 0383a999fb4f9187d26b48e04bbb27e0d2592218. --- scripts/lpmm_learning_tool.py | 91 ++++++------- src/chat/knowledge/embedding_store.py | 176 ++++++++++++++++---------- 2 files changed, 153 insertions(+), 114 deletions(-) diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index 77f1a8fd3..884b4e9e3 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -191,7 +191,11 @@ async def extract_info_async(pg_hash, paragraph, llm_api): return None, pg_hash -async def extract_information(paragraphs_dict, model_set): +def extract_info_sync(pg_hash, paragraph, llm_api): + return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api)) + + +def extract_information(paragraphs_dict, model_set): logger.info("--- 步骤 2: 开始信息提取 ---") os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True) os.makedirs(TEMP_DIR, exist_ok=True) @@ -199,35 +203,32 @@ async def extract_information(paragraphs_dict, model_set): llm_api = LLMRequest(model_set=model_set) failed_hashes, open_ie_docs = [], [] - tasks = [ - extract_info_async(p_hash, p, llm_api) - for p_hash, p in paragraphs_dict.items() - ] - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - MofNCompleteColumn(), - "•", - TimeElapsedColumn(), - "<", - TimeRemainingColumn(), - ) as progress: - prog_task = progress.add_task("[cyan]正在提取信息...", total=len(tasks)) - for future in asyncio.as_completed(tasks): - doc_item, failed_hash = await future - if failed_hash: - failed_hashes.append(failed_hash) - elif doc_item: - open_ie_docs.append(doc_item) - progress.update(prog_task, advance=1) + with ThreadPoolExecutor(max_workers=5) as executor: + f_to_hash = { + executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items() + } + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + "•", + TimeElapsedColumn(), + "<", + TimeRemainingColumn(), + ) as progress: + task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict)) + for future in as_completed(f_to_hash): + doc_item, failed_hash = future.result() + if failed_hash: + failed_hashes.append(failed_hash) + elif doc_item: + open_ie_docs.append(doc_item) + progress.update(task, advance=1) if open_ie_docs: - all_entities = [ - e for doc in open_ie_docs for e in doc["extracted_entities"] - ] + all_entities = [e for doc in open_ie_docs for e in doc["extracted_entities"]] num_entities = len(all_entities) avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0 avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0 @@ -312,7 +313,7 @@ async def import_data(openie_obj: OpenIE | None = None): logger.info("--- 数据导入完成 ---") -async def import_from_specific_file(): +def import_from_specific_file(): """从用户指定的 openie.json 文件导入数据""" file_path = input("请输入 openie.json 文件的完整路径: ").strip() @@ -327,7 +328,7 @@ async def import_from_specific_file(): try: logger.info(f"正在从 {file_path} 加载 OpenIE 数据...") openie_obj = OpenIE.load() - await import_data(openie_obj=openie_obj) + asyncio.run(import_data(openie_obj=openie_obj)) except Exception as e: logger.error(f"从指定文件导入数据时发生错误: {e}") @@ -335,20 +336,14 @@ async def import_from_specific_file(): # --- 主函数 --- -async def async_main(): +def main(): # 使用 os.path.relpath 创建相对于项目根目录的友好路径 - raw_data_relpath = os.path.relpath( - RAW_DATA_PATH, os.path.join(ROOT_PATH, "..") - ) - openie_output_relpath = os.path.relpath( - OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, "..") - ) + raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, "..")) + openie_output_relpath = os.path.relpath(OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, "..")) print("=== LPMM 知识库学习工具 ===") print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)") - print( - f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)" - ) + print(f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)") print("3. [数据导入] -> 从 openie 文件夹自动导入最新知识") print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3") print("5. [指定导入] -> 从特定的 openie.json 文件导入知识") @@ -362,20 +357,16 @@ async def async_main(): elif choice == "2": paragraphs = preprocess_raw_data() if paragraphs: - await extract_information( - paragraphs, model_config.model_task_config.lpmm_qa - ) + extract_information(paragraphs, model_config.model_task_config.lpmm_qa) elif choice == "3": - await import_data() + asyncio.run(import_data()) elif choice == "4": paragraphs = preprocess_raw_data() if paragraphs: - await extract_information( - paragraphs, model_config.model_task_config.lpmm_qa - ) - await import_data() + extract_information(paragraphs, model_config.model_task_config.lpmm_qa) + asyncio.run(import_data()) elif choice == "5": - await import_from_specific_file() + import_from_specific_file() elif choice == "6": clear_cache() elif choice == "0": @@ -385,4 +376,4 @@ async def async_main(): if __name__ == "__main__": - asyncio.run(async_main()) + main() diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 8eab15834..2c1056bb1 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -124,60 +124,124 @@ class EmbeddingStore: self.faiss_index = None self.idx2hash = None + @staticmethod + def _get_embedding(s: str) -> list[float]: + """获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题""" + # 创建新的事件循环并在完成后立即关闭 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # 创建新的LLMRequest实例 + from src.config.config import model_config + from src.llm_models.utils_model import LLMRequest + + llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") + + # 使用新的事件循环运行异步方法 + embedding, _ = loop.run_until_complete(llm.get_embedding(s)) + + if embedding and len(embedding) > 0: + return embedding + else: + logger.error(f"获取嵌入失败: {s}") + return [] + + except Exception as e: + logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}") + return [] + finally: + # 确保事件循环被正确关闭 + try: + loop.close() + except Exception: + ... + @staticmethod def _get_embeddings_batch_threaded( - strs: list[str], - main_loop: asyncio.AbstractEventLoop, - chunk_size: int = 10, - max_workers: int = 10, - progress_callback=None, + strs: list[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None ) -> list[tuple[str, list[float]]]: - """使用多线程批量获取嵌入向量, 并通过 run_coroutine_threadsafe 在主事件循环中运行异步任务""" + """使用多线程批量获取嵌入向量 + + Args: + strs: 要获取嵌入的字符串列表 + chunk_size: 每个线程处理的数据块大小 + max_workers: 最大线程数 + progress_callback: 进度回调函数,接收一个参数表示完成的数量 + + Returns: + 包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致 + """ if not strs: return [] - # 导入必要的模块 - from src.config.config import model_config - from src.llm_models.utils_model import LLMRequest - - # 在主线程(即主事件循环所在的线程)中创建LLMRequest实例 - # 这样可以确保它绑定到正确的事件循环 - llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") - # 分块 - chunks = [(i, strs[i : i + chunk_size]) for i in range(0, len(strs), chunk_size)] + chunks = [] + for i in range(0, len(strs), chunk_size): + chunk = strs[i : i + chunk_size] + chunks.append((i, chunk)) # 保存起始索引以维持顺序 + + # 结果存储,使用字典按索引存储以保证顺序 results = {} def process_chunk(chunk_data): - """在工作线程中运行的函数""" + """处理单个数据块的函数""" start_idx, chunk_strs = chunk_data chunk_results = [] - for i, s in enumerate(chunk_strs): - embedding = [] - try: - # 将异步的 get_embedding 调用提交到主事件循环 - future = asyncio.run_coroutine_threadsafe(llm.get_embedding(s), main_loop) - # 同步等待结果,延长超时时间 - embedding_result, _ = future.result(timeout=60) + # 为每个线程创建独立的LLMRequest实例 + from src.config.config import model_config + from src.llm_models.utils_model import LLMRequest - if embedding_result and len(embedding_result) > 0: - embedding = embedding_result - else: - logger.error(f"获取嵌入失败(返回为空): {s}") + try: + # 创建线程专用的LLM实例 + llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") - except Exception as e: - logger.error(f"在线程中获取嵌入时发生异常: {s}, 错误: {type(e).__name__}: {e}") - finally: - chunk_results.append((start_idx + i, s, embedding)) + for i, s in enumerate(chunk_strs): + try: + # 在线程中创建独立的事件循环 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + embedding = loop.run_until_complete(llm.get_embedding(s)) + finally: + loop.close() + + if embedding and len(embedding) > 0: + chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量 + else: + logger.error(f"获取嵌入失败: {s}") + chunk_results.append((start_idx + i, s, [])) + + # 每完成一个嵌入立即更新进度 + if progress_callback: + progress_callback(1) + + except Exception as e: + logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}") + chunk_results.append((start_idx + i, s, [])) + + # 即使失败也要更新进度 + if progress_callback: + progress_callback(1) + + except Exception as e: + logger.error(f"创建LLM实例失败: {e}") + # 如果创建LLM实例失败,返回空结果 + for i, s in enumerate(chunk_strs): + chunk_results.append((start_idx + i, s, [])) + # 即使失败也要更新进度 if progress_callback: progress_callback(1) return chunk_results + # 使用线程池处理 with ThreadPoolExecutor(max_workers=max_workers) as executor: + # 提交所有任务 future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks} + # 收集结果(进度已在process_chunk中实时更新) for future in as_completed(future_to_chunk): try: chunk_results = future.result() @@ -185,14 +249,22 @@ class EmbeddingStore: results[idx] = (s, embedding) except Exception as e: chunk = future_to_chunk[future] - logger.error(f"处理数据块时发生严重异常: {chunk}, 错误: {e}") + logger.error(f"处理数据块时发生异常: {chunk}, 错误: {e}") + # 为失败的块添加空结果 start_idx, chunk_strs = chunk - for i, s_item in enumerate(chunk_strs): - if (start_idx + i) not in results: - results[start_idx + i] = (s_item, []) + for i, s in enumerate(chunk_strs): + results[start_idx + i] = (s, []) # 按原始顺序返回结果 - return [results.get(i, (strs[i], [])) for i in range(len(strs))] + ordered_results = [] + for i in range(len(strs)): + if i in results: + ordered_results.append(results[i]) + else: + # 防止遗漏 + ordered_results.append((strs[i], [])) + + return ordered_results @staticmethod def get_test_file_path(): @@ -202,17 +274,9 @@ class EmbeddingStore: """保存测试字符串的嵌入到本地(使用多线程优化)""" logger.info("开始保存测试字符串的嵌入向量...") - # 获取当前正在运行的事件循环 - try: - main_loop = asyncio.get_running_loop() - except RuntimeError: - logger.error("无法获取正在运行的事件循环。请确保在异步上下文中调用此方法。") - return - # 使用多线程批量获取测试字符串的嵌入 embedding_results = self._get_embeddings_batch_threaded( EMBEDDING_TEST_STRINGS, - main_loop, chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)), max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)), ) @@ -224,6 +288,8 @@ class EmbeddingStore: test_vectors[str(idx)] = embedding else: logger.error(f"获取测试字符串嵌入失败: {s}") + # 使用原始单线程方法作为后备 + test_vectors[str(idx)] = self._get_embedding(s) with open(self.get_test_file_path(), "w", encoding="utf-8") as f: f.write(orjson.dumps(test_vectors, option=orjson.OPT_INDENT_2).decode("utf-8")) @@ -255,17 +321,9 @@ class EmbeddingStore: logger.info("开始检验嵌入模型一致性...") - # 获取当前正在运行的事件循环 - try: - main_loop = asyncio.get_running_loop() - except RuntimeError: - logger.error("无法获取正在运行的事件循环。请确保在异步上下文中调用此方法。") - return False - # 使用多线程批量获取当前模型的嵌入 embedding_results = self._get_embeddings_batch_threaded( EMBEDDING_TEST_STRINGS, - main_loop, chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)), max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)), ) @@ -325,20 +383,11 @@ class EmbeddingStore: progress.update(task, advance=already_processed) if new_strs: - try: - main_loop = asyncio.get_running_loop() - except RuntimeError: - logger.error("无法获取正在运行的事件循环。请确保在异步上下文中调用此方法。") - # 更新进度条以反映未处理的项目 - progress.update(task, advance=len(new_strs)) - return - # 使用实例配置的参数,智能调整分块和线程数 optimal_chunk_size = max( MIN_CHUNK_SIZE, min( - self.chunk_size, - len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size, + self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size ), ) optimal_max_workers = min( @@ -355,13 +404,12 @@ class EmbeddingStore: # 批量获取嵌入,并实时更新进度 embedding_results = self._get_embeddings_batch_threaded( new_strs, - main_loop, chunk_size=optimal_chunk_size, max_workers=optimal_max_workers, progress_callback=update_progress, ) - # 存入结果 + # 存入结果(不再需要在这里更新进度,因为已经在回调中更新了) for s, embedding in embedding_results: item_hash = self.namespace + "-" + get_sha256(s) if embedding: # 只有成功获取到嵌入才存入