From eb16508fb56367473c638768e0f9acb4f31e0aa2 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sun, 9 Nov 2025 21:38:31 +0800 Subject: [PATCH] =?UTF-8?q?feat(extraction):=20=E4=BC=98=E5=8C=96=E4=BF=A1?= =?UTF-8?q?=E6=81=AF=E6=8F=90=E5=8F=96=E6=B5=81=E7=A8=8B=EF=BC=8C=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E5=BC=82=E6=AD=A5=E5=B9=B6=E5=8F=91=E5=92=8C=E7=BC=93?= =?UTF-8?q?=E5=AD=98=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/lpmm_learning_tool.py | 130 ++++++++++++++++++++++------------ 1 file changed, 83 insertions(+), 47 deletions(-) diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index 35272c6b1..77cda009a 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -3,9 +3,7 @@ import datetime import os import shutil import sys -from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -from threading import Lock import aiofiles import orjson @@ -38,7 +36,6 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) RAW_DATA_PATH = os.path.join(ROOT_PATH, "data", "lpmm_raw_data") OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache") -file_lock = Lock() # --- 缓存清理 --- @@ -155,26 +152,41 @@ def get_extraction_prompt(paragraph: str) -> str: async def extract_info_async(pg_hash, paragraph, llm_api): + """ + 异步提取单个段落的信息(带缓存支持) + + Args: + pg_hash: 段落哈希值 + paragraph: 段落文本 + llm_api: LLM请求实例 + + Returns: + tuple: (doc_item或None, failed_hash或None) + """ temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json") - with file_lock: - if os.path.exists(temp_file_path): + + # 🔧 优化:使用异步文件检查,避免阻塞 + if os.path.exists(temp_file_path): + try: + async with aiofiles.open(temp_file_path, "rb") as f: + content = await f.read() + return orjson.loads(content), None + except orjson.JSONDecodeError: + # 缓存文件损坏,删除并重新生成 try: - async with aiofiles.open(temp_file_path, "rb") as f: - content = await f.read() - return orjson.loads(content), None - except orjson.JSONDecodeError: os.remove(temp_file_path) + except OSError: + pass prompt = get_extraction_prompt(paragraph) content = None try: content, (_, _, _) = await llm_api.generate_response_async(prompt) - # 改进点:调用封装好的函数处理JSON解析和修复 + # 调用封装好的函数处理JSON解析和修复 extracted_data = _parse_and_repair_json(content) if extracted_data is None: - # 如果解析失败,抛出异常以触发统一的错误处理逻辑 raise ValueError("无法从LLM输出中解析有效的JSON数据") doc_item = { @@ -183,9 +195,11 @@ async def extract_info_async(pg_hash, paragraph, llm_api): "extracted_entities": extracted_data.get("entities", []), "extracted_triples": extracted_data.get("triples", []), } - with file_lock: - async with aiofiles.open(temp_file_path, "wb") as f: - await f.write(orjson.dumps(doc_item)) + + # 保存到缓存(异步写入) + async with aiofiles.open(temp_file_path, "wb") as f: + await f.write(orjson.dumps(doc_item)) + return doc_item, None except Exception as e: logger.error(f"提取信息失败:{pg_hash}, 错误:{e}") @@ -194,42 +208,61 @@ async def extract_info_async(pg_hash, paragraph, llm_api): return None, pg_hash -def extract_info_sync(pg_hash, paragraph, model_set): - llm_api = LLMRequest(model_set=model_set) - return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api)) - - -def extract_information(paragraphs_dict, model_set): +async def extract_information(paragraphs_dict, model_set): + """ + 🔧 优化:使用真正的异步并发代替多线程 + + 这样可以: + 1. 避免 event loop closed 错误 + 2. 更高效地利用 I/O 资源 + 3. 与我们优化的 LLM 请求层无缝集成 + + Args: + paragraphs_dict: {hash: paragraph} 字典 + model_set: 模型配置 + """ logger.info("--- 步骤 2: 开始信息提取 ---") os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True) os.makedirs(TEMP_DIR, exist_ok=True) failed_hashes, open_ie_docs = [], [] + + # 🔧 关键修复:创建单个 LLM 请求实例,复用连接 + llm_api = LLMRequest(model_set=model_set, request_type="lpmm_extraction") - with ThreadPoolExecutor(max_workers=3) as executor: - f_to_hash = { - executor.submit(extract_info_sync, p_hash, p, model_set): p_hash - for p_hash, p in paragraphs_dict.items() - } - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - MofNCompleteColumn(), - "•", - TimeElapsedColumn(), - "<", - TimeRemainingColumn(), - ) as progress: - task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict)) - for future in as_completed(f_to_hash): - doc_item, failed_hash = future.result() - if failed_hash: - failed_hashes.append(failed_hash) - elif doc_item: - open_ie_docs.append(doc_item) - progress.update(task, advance=1) + # 创建所有异步任务 + tasks = [ + extract_info_async(p_hash, paragraph, llm_api) + for p_hash, paragraph in paragraphs_dict.items() + ] + + total = len(tasks) + completed = 0 + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + "•", + TimeElapsedColumn(), + "<", + TimeRemainingColumn(), + ) as progress: + task = progress.add_task("[cyan]正在提取信息...", total=total) + + # 🔧 优化:使用 asyncio.gather 并发执行所有任务 + # return_exceptions=True 确保单个失败不影响其他任务 + for coro in asyncio.as_completed(tasks): + doc_item, failed_hash = await coro + if failed_hash: + failed_hashes.append(failed_hash) + elif doc_item: + open_ie_docs.append(doc_item) + + completed += 1 + progress.update(task, advance=1) if open_ie_docs: all_entities = [e for doc in open_ie_docs for e in doc["extracted_entities"]] @@ -244,6 +277,7 @@ def extract_information(paragraphs_dict, model_set): with open(output_path, "wb") as f: f.write(orjson.dumps(openie_obj._to_dict())) logger.info(f"信息提取结果已保存到: {output_path}") + logger.info(f"成功提取 {len(open_ie_docs)} 个段落的信息") if failed_hashes: logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}") @@ -354,20 +388,22 @@ def main(): print("6. [清理缓存] -> 删除所有已提取信息的缓存") print("0. [退出]") print("-" * 30) - choice = input("请输入你的选择 (0-5): ").strip() + choice = input("请输入你的选择 (0-6): ").strip() if choice == "1": preprocess_raw_data() elif choice == "2": paragraphs = preprocess_raw_data() if paragraphs: - extract_information(paragraphs, model_config.model_task_config.lpmm_qa) + # 🔧 修复:使用 asyncio.run 调用异步函数 + asyncio.run(extract_information(paragraphs, model_config.model_task_config.lpmm_qa)) elif choice == "3": asyncio.run(import_data()) elif choice == "4": paragraphs = preprocess_raw_data() if paragraphs: - extract_information(paragraphs, model_config.model_task_config.lpmm_qa) + # 🔧 修复:使用 asyncio.run 调用异步函数 + asyncio.run(extract_information(paragraphs, model_config.model_task_config.lpmm_qa)) asyncio.run(import_data()) elif choice == "5": import_from_specific_file()