From 0060c8de198d2dae6edc6d069b723f4b08f71f55 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 27 Sep 2025 14:06:22 +0800 Subject: [PATCH] =?UTF-8?q?fix(tool):=20=E5=A2=9E=E5=BC=BA=E4=BF=A1?= =?UTF-8?q?=E6=81=AF=E6=8F=90=E5=8F=96=E5=A4=B1=E8=B4=A5=E6=97=B6=E7=9A=84?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E6=97=A5=E5=BF=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在信息提取过程中,当大语言模型(LLM)返回的 JSON 格式不正确时,先前的日志只会记录一个通用的解析错误,而不会显示导致失败的原始响应内容,这使得调试变得困难。 此次更新通过在捕获到 JSON 解析异常时,额外记录 LLM 的原始输出内容来解决此问题。这有助于快速诊断并定位是模型输出不稳定还是提示词需要调整,从而提高了脚本的健壮性和可维护性。 此外,还对代码进行了一些格式化调整以提高可读性。 --- scripts/lpmm_learning_tool.py | 97 ++++++++++++++++++++++++----------- 1 file changed, 67 insertions(+), 30 deletions(-) diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index 941494bc0..3fe26eb93 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -38,11 +38,13 @@ file_lock = Lock() # --- 模块一:数据预处理 --- + def process_text_file(file_path): with open(file_path, "r", encoding="utf-8") as f: raw = f.read() return [p.strip() for p in raw.split("\n\n") if p.strip()] + def preprocess_raw_data(): logger.info("--- 步骤 1: 开始数据预处理 ---") os.makedirs(RAW_DATA_PATH, exist_ok=True) @@ -50,7 +52,7 @@ def preprocess_raw_data(): if not raw_files: logger.warning(f"警告: 在 '{RAW_DATA_PATH}' 中没有找到任何 .txt 文件") return [] - + all_paragraphs = [] for file in raw_files: logger.info(f"正在处理文件: {file.name}") @@ -61,8 +63,10 @@ def preprocess_raw_data(): logger.info("--- 数据预处理完成 ---") return unique_paragraphs + # --- 模块二:信息提取 --- + def get_extraction_prompt(paragraph: str) -> str: return f""" 请从以下段落中提取关键信息。你需要提取两种类型的信息: @@ -81,6 +85,7 @@ def get_extraction_prompt(paragraph: str) -> str: --- """ + async def extract_info_async(pg_hash, paragraph, llm_api): temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json") with file_lock: @@ -92,11 +97,13 @@ async def extract_info_async(pg_hash, paragraph, llm_api): os.remove(temp_file_path) prompt = get_extraction_prompt(paragraph) + content = None try: content, (_, _, _) = await llm_api.generate_response_async(prompt) extracted_data = orjson.loads(content) doc_item = { - "idx": pg_hash, "passage": paragraph, + "idx": pg_hash, + "passage": paragraph, "extracted_entities": extracted_data.get("entities", []), "extracted_triples": extracted_data.get("triples", []), } @@ -106,27 +113,45 @@ async def extract_info_async(pg_hash, paragraph, llm_api): return doc_item, None except Exception as e: logger.error(f"提取信息失败:{pg_hash}, 错误:{e}") + if content: + logger.error(f"导致解析失败的原始输出: {content}") return None, pg_hash + def extract_info_sync(pg_hash, paragraph, llm_api): return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api)) + def extract_information(paragraphs_dict, model_set): logger.info("--- 步骤 2: 开始信息提取 ---") os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True) os.makedirs(TEMP_DIR, exist_ok=True) - + llm_api = LLMRequest(model_set=model_set) failed_hashes, open_ie_docs = [], [] with ThreadPoolExecutor(max_workers=5) as executor: - f_to_hash = {executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items()} - with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), MofNCompleteColumn(), "•", TimeElapsedColumn(), "<", TimeRemainingColumn()) as progress: + f_to_hash = { + executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items() + } + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + "•", + TimeElapsedColumn(), + "<", + TimeRemainingColumn(), + ) as progress: task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict)) for future in as_completed(f_to_hash): doc_item, failed_hash = future.result() - if failed_hash: failed_hashes.append(failed_hash) - elif doc_item: open_ie_docs.append(doc_item) + if failed_hash: + failed_hashes.append(failed_hash) + elif doc_item: + open_ie_docs.append(doc_item) progress.update(task, advance=1) if open_ie_docs: @@ -135,19 +160,22 @@ def extract_information(paragraphs_dict, model_set): avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0 avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0 openie_obj = OpenIE(docs=open_ie_docs, avg_ent_chars=avg_ent_chars, avg_ent_words=avg_ent_words) - + now = datetime.datetime.now() filename = now.strftime("%Y-%m-%d-%H-%M-%S-openie.json") output_path = os.path.join(OPENIE_OUTPUT_DIR, filename) with open(output_path, "wb") as f: f.write(orjson.dumps(openie_obj._to_dict())) logger.info(f"信息提取结果已保存到: {output_path}") - - if failed_hashes: logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}") + + if failed_hashes: + logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}") logger.info("--- 信息提取完成 ---") + # --- 模块三:数据导入 --- + async def import_data(openie_obj: Optional[OpenIE] = None): """ 将OpenIE数据导入知识库(Embedding Store 和 KG) @@ -159,15 +187,19 @@ async def import_data(openie_obj: Optional[OpenIE] = None): """ logger.info("--- 步骤 3: 开始数据导入 ---") embed_manager, kg_manager = EmbeddingManager(), KGManager() - + logger.info("正在加载现有的 Embedding 库...") - try: embed_manager.load_from_file() - except Exception as e: logger.warning(f"加载 Embedding 库失败: {e}。") + try: + embed_manager.load_from_file() + except Exception as e: + logger.warning(f"加载 Embedding 库失败: {e}。") logger.info("正在加载现有的 KG...") - try: kg_manager.load_from_file() - except Exception as e: logger.warning(f"加载 KG 失败: {e}。") - + try: + kg_manager.load_from_file() + except Exception as e: + logger.warning(f"加载 KG 失败: {e}。") + try: if openie_obj: openie_data = openie_obj @@ -180,7 +212,7 @@ async def import_data(openie_obj: Optional[OpenIE] = None): raw_paragraphs = openie_data.extract_raw_paragraph_dict() triple_list_data = openie_data.extract_triple_dict() - + new_raw_paragraphs, new_triple_list_data = {}, {} stored_embeds = embed_manager.stored_pg_hashes stored_kgs = kg_manager.stored_paragraph_hashes @@ -189,7 +221,7 @@ async def import_data(openie_obj: Optional[OpenIE] = None): if p_hash not in stored_embeds and p_hash not in stored_kgs: new_raw_paragraphs[p_hash] = raw_p new_triple_list_data[p_hash] = triple_list_data.get(p_hash, []) - + if not new_raw_paragraphs: logger.info("没有新的段落需要处理。") else: @@ -207,32 +239,35 @@ async def import_data(openie_obj: Optional[OpenIE] = None): logger.info("--- 数据导入完成 ---") + def import_from_specific_file(): """从用户指定的 openie.json 文件导入数据""" file_path = input("请输入 openie.json 文件的完整路径: ").strip() - + if not os.path.exists(file_path): logger.error(f"文件路径不存在: {file_path}") return - + if not file_path.endswith(".json"): logger.error("请输入一个有效的 .json 文件路径。") return try: logger.info(f"正在从 {file_path} 加载 OpenIE 数据...") - openie_obj = OpenIE.load(filepath=file_path) + openie_obj = OpenIE.load() asyncio.run(import_data(openie_obj=openie_obj)) except Exception as e: logger.error(f"从指定文件导入数据时发生错误: {e}") + # --- 主函数 --- + def main(): # 使用 os.path.relpath 创建相对于项目根目录的友好路径 raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, "..")) openie_output_relpath = os.path.relpath(OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, "..")) - + print("=== LPMM 知识库学习工具 ===") print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)") print(f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)") @@ -243,24 +278,26 @@ def main(): print("-" * 30) choice = input("请输入你的选择 (0-5): ").strip() - if choice == '1': + if choice == "1": preprocess_raw_data() - elif choice == '2': + elif choice == "2": paragraphs = preprocess_raw_data() - if paragraphs: extract_information(paragraphs, model_config.model_task_config.lpmm_qa) - elif choice == '3': + if paragraphs: + extract_information(paragraphs, model_config.model_task_config.lpmm_qa) + elif choice == "3": asyncio.run(import_data()) - elif choice == '4': + elif choice == "4": paragraphs = preprocess_raw_data() if paragraphs: extract_information(paragraphs, model_config.model_task_config.lpmm_qa) asyncio.run(import_data()) - elif choice == '5': + elif choice == "5": import_from_specific_file() - elif choice == '0': + elif choice == "0": sys.exit(0) else: print("无效输入,请重新运行脚本。") + if __name__ == "__main__": - main() \ No newline at end of file + main()