fix(tool): 增强信息提取失败时的错误日志
在信息提取过程中,当大语言模型(LLM)返回的 JSON 格式不正确时,先前的日志只会记录一个通用的解析错误,而不会显示导致失败的原始响应内容,这使得调试变得困难。 此次更新通过在捕获到 JSON 解析异常时,额外记录 LLM 的原始输出内容来解决此问题。这有助于快速诊断并定位是模型输出不稳定还是提示词需要调整,从而提高了脚本的健壮性和可维护性。 此外,还对代码进行了一些格式化调整以提高可读性。
This commit is contained in:
committed by
Windpicker-owo
parent
80283fe77d
commit
0060c8de19
@@ -38,11 +38,13 @@ file_lock = Lock()
|
|||||||
|
|
||||||
# --- 模块一:数据预处理 ---
|
# --- 模块一:数据预处理 ---
|
||||||
|
|
||||||
|
|
||||||
def process_text_file(file_path):
|
def process_text_file(file_path):
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
raw = f.read()
|
raw = f.read()
|
||||||
return [p.strip() for p in raw.split("\n\n") if p.strip()]
|
return [p.strip() for p in raw.split("\n\n") if p.strip()]
|
||||||
|
|
||||||
|
|
||||||
def preprocess_raw_data():
|
def preprocess_raw_data():
|
||||||
logger.info("--- 步骤 1: 开始数据预处理 ---")
|
logger.info("--- 步骤 1: 开始数据预处理 ---")
|
||||||
os.makedirs(RAW_DATA_PATH, exist_ok=True)
|
os.makedirs(RAW_DATA_PATH, exist_ok=True)
|
||||||
@@ -50,7 +52,7 @@ def preprocess_raw_data():
|
|||||||
if not raw_files:
|
if not raw_files:
|
||||||
logger.warning(f"警告: 在 '{RAW_DATA_PATH}' 中没有找到任何 .txt 文件")
|
logger.warning(f"警告: 在 '{RAW_DATA_PATH}' 中没有找到任何 .txt 文件")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
all_paragraphs = []
|
all_paragraphs = []
|
||||||
for file in raw_files:
|
for file in raw_files:
|
||||||
logger.info(f"正在处理文件: {file.name}")
|
logger.info(f"正在处理文件: {file.name}")
|
||||||
@@ -61,8 +63,10 @@ def preprocess_raw_data():
|
|||||||
logger.info("--- 数据预处理完成 ---")
|
logger.info("--- 数据预处理完成 ---")
|
||||||
return unique_paragraphs
|
return unique_paragraphs
|
||||||
|
|
||||||
|
|
||||||
# --- 模块二:信息提取 ---
|
# --- 模块二:信息提取 ---
|
||||||
|
|
||||||
|
|
||||||
def get_extraction_prompt(paragraph: str) -> str:
|
def get_extraction_prompt(paragraph: str) -> str:
|
||||||
return f"""
|
return f"""
|
||||||
请从以下段落中提取关键信息。你需要提取两种类型的信息:
|
请从以下段落中提取关键信息。你需要提取两种类型的信息:
|
||||||
@@ -81,6 +85,7 @@ def get_extraction_prompt(paragraph: str) -> str:
|
|||||||
---
|
---
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
async def extract_info_async(pg_hash, paragraph, llm_api):
|
async def extract_info_async(pg_hash, paragraph, llm_api):
|
||||||
temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json")
|
temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json")
|
||||||
with file_lock:
|
with file_lock:
|
||||||
@@ -92,11 +97,13 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
|||||||
os.remove(temp_file_path)
|
os.remove(temp_file_path)
|
||||||
|
|
||||||
prompt = get_extraction_prompt(paragraph)
|
prompt = get_extraction_prompt(paragraph)
|
||||||
|
content = None
|
||||||
try:
|
try:
|
||||||
content, (_, _, _) = await llm_api.generate_response_async(prompt)
|
content, (_, _, _) = await llm_api.generate_response_async(prompt)
|
||||||
extracted_data = orjson.loads(content)
|
extracted_data = orjson.loads(content)
|
||||||
doc_item = {
|
doc_item = {
|
||||||
"idx": pg_hash, "passage": paragraph,
|
"idx": pg_hash,
|
||||||
|
"passage": paragraph,
|
||||||
"extracted_entities": extracted_data.get("entities", []),
|
"extracted_entities": extracted_data.get("entities", []),
|
||||||
"extracted_triples": extracted_data.get("triples", []),
|
"extracted_triples": extracted_data.get("triples", []),
|
||||||
}
|
}
|
||||||
@@ -106,27 +113,45 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
|||||||
return doc_item, None
|
return doc_item, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")
|
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")
|
||||||
|
if content:
|
||||||
|
logger.error(f"导致解析失败的原始输出: {content}")
|
||||||
return None, pg_hash
|
return None, pg_hash
|
||||||
|
|
||||||
|
|
||||||
def extract_info_sync(pg_hash, paragraph, llm_api):
|
def extract_info_sync(pg_hash, paragraph, llm_api):
|
||||||
return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api))
|
return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api))
|
||||||
|
|
||||||
|
|
||||||
def extract_information(paragraphs_dict, model_set):
|
def extract_information(paragraphs_dict, model_set):
|
||||||
logger.info("--- 步骤 2: 开始信息提取 ---")
|
logger.info("--- 步骤 2: 开始信息提取 ---")
|
||||||
os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True)
|
os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True)
|
||||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||||
|
|
||||||
llm_api = LLMRequest(model_set=model_set)
|
llm_api = LLMRequest(model_set=model_set)
|
||||||
failed_hashes, open_ie_docs = [], []
|
failed_hashes, open_ie_docs = [], []
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||||
f_to_hash = {executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items()}
|
f_to_hash = {
|
||||||
with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), MofNCompleteColumn(), "•", TimeElapsedColumn(), "<", TimeRemainingColumn()) as progress:
|
executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items()
|
||||||
|
}
|
||||||
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
MofNCompleteColumn(),
|
||||||
|
"•",
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
"<",
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
) as progress:
|
||||||
task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict))
|
task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict))
|
||||||
for future in as_completed(f_to_hash):
|
for future in as_completed(f_to_hash):
|
||||||
doc_item, failed_hash = future.result()
|
doc_item, failed_hash = future.result()
|
||||||
if failed_hash: failed_hashes.append(failed_hash)
|
if failed_hash:
|
||||||
elif doc_item: open_ie_docs.append(doc_item)
|
failed_hashes.append(failed_hash)
|
||||||
|
elif doc_item:
|
||||||
|
open_ie_docs.append(doc_item)
|
||||||
progress.update(task, advance=1)
|
progress.update(task, advance=1)
|
||||||
|
|
||||||
if open_ie_docs:
|
if open_ie_docs:
|
||||||
@@ -135,19 +160,22 @@ def extract_information(paragraphs_dict, model_set):
|
|||||||
avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0
|
avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0
|
||||||
avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0
|
avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0
|
||||||
openie_obj = OpenIE(docs=open_ie_docs, avg_ent_chars=avg_ent_chars, avg_ent_words=avg_ent_words)
|
openie_obj = OpenIE(docs=open_ie_docs, avg_ent_chars=avg_ent_chars, avg_ent_words=avg_ent_words)
|
||||||
|
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
filename = now.strftime("%Y-%m-%d-%H-%M-%S-openie.json")
|
filename = now.strftime("%Y-%m-%d-%H-%M-%S-openie.json")
|
||||||
output_path = os.path.join(OPENIE_OUTPUT_DIR, filename)
|
output_path = os.path.join(OPENIE_OUTPUT_DIR, filename)
|
||||||
with open(output_path, "wb") as f:
|
with open(output_path, "wb") as f:
|
||||||
f.write(orjson.dumps(openie_obj._to_dict()))
|
f.write(orjson.dumps(openie_obj._to_dict()))
|
||||||
logger.info(f"信息提取结果已保存到: {output_path}")
|
logger.info(f"信息提取结果已保存到: {output_path}")
|
||||||
|
|
||||||
if failed_hashes: logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}")
|
if failed_hashes:
|
||||||
|
logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}")
|
||||||
logger.info("--- 信息提取完成 ---")
|
logger.info("--- 信息提取完成 ---")
|
||||||
|
|
||||||
|
|
||||||
# --- 模块三:数据导入 ---
|
# --- 模块三:数据导入 ---
|
||||||
|
|
||||||
|
|
||||||
async def import_data(openie_obj: Optional[OpenIE] = None):
|
async def import_data(openie_obj: Optional[OpenIE] = None):
|
||||||
"""
|
"""
|
||||||
将OpenIE数据导入知识库(Embedding Store 和 KG)
|
将OpenIE数据导入知识库(Embedding Store 和 KG)
|
||||||
@@ -159,15 +187,19 @@ async def import_data(openie_obj: Optional[OpenIE] = None):
|
|||||||
"""
|
"""
|
||||||
logger.info("--- 步骤 3: 开始数据导入 ---")
|
logger.info("--- 步骤 3: 开始数据导入 ---")
|
||||||
embed_manager, kg_manager = EmbeddingManager(), KGManager()
|
embed_manager, kg_manager = EmbeddingManager(), KGManager()
|
||||||
|
|
||||||
logger.info("正在加载现有的 Embedding 库...")
|
logger.info("正在加载现有的 Embedding 库...")
|
||||||
try: embed_manager.load_from_file()
|
try:
|
||||||
except Exception as e: logger.warning(f"加载 Embedding 库失败: {e}。")
|
embed_manager.load_from_file()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"加载 Embedding 库失败: {e}。")
|
||||||
|
|
||||||
logger.info("正在加载现有的 KG...")
|
logger.info("正在加载现有的 KG...")
|
||||||
try: kg_manager.load_from_file()
|
try:
|
||||||
except Exception as e: logger.warning(f"加载 KG 失败: {e}。")
|
kg_manager.load_from_file()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"加载 KG 失败: {e}。")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if openie_obj:
|
if openie_obj:
|
||||||
openie_data = openie_obj
|
openie_data = openie_obj
|
||||||
@@ -180,7 +212,7 @@ async def import_data(openie_obj: Optional[OpenIE] = None):
|
|||||||
|
|
||||||
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
||||||
triple_list_data = openie_data.extract_triple_dict()
|
triple_list_data = openie_data.extract_triple_dict()
|
||||||
|
|
||||||
new_raw_paragraphs, new_triple_list_data = {}, {}
|
new_raw_paragraphs, new_triple_list_data = {}, {}
|
||||||
stored_embeds = embed_manager.stored_pg_hashes
|
stored_embeds = embed_manager.stored_pg_hashes
|
||||||
stored_kgs = kg_manager.stored_paragraph_hashes
|
stored_kgs = kg_manager.stored_paragraph_hashes
|
||||||
@@ -189,7 +221,7 @@ async def import_data(openie_obj: Optional[OpenIE] = None):
|
|||||||
if p_hash not in stored_embeds and p_hash not in stored_kgs:
|
if p_hash not in stored_embeds and p_hash not in stored_kgs:
|
||||||
new_raw_paragraphs[p_hash] = raw_p
|
new_raw_paragraphs[p_hash] = raw_p
|
||||||
new_triple_list_data[p_hash] = triple_list_data.get(p_hash, [])
|
new_triple_list_data[p_hash] = triple_list_data.get(p_hash, [])
|
||||||
|
|
||||||
if not new_raw_paragraphs:
|
if not new_raw_paragraphs:
|
||||||
logger.info("没有新的段落需要处理。")
|
logger.info("没有新的段落需要处理。")
|
||||||
else:
|
else:
|
||||||
@@ -207,32 +239,35 @@ async def import_data(openie_obj: Optional[OpenIE] = None):
|
|||||||
|
|
||||||
logger.info("--- 数据导入完成 ---")
|
logger.info("--- 数据导入完成 ---")
|
||||||
|
|
||||||
|
|
||||||
def import_from_specific_file():
|
def import_from_specific_file():
|
||||||
"""从用户指定的 openie.json 文件导入数据"""
|
"""从用户指定的 openie.json 文件导入数据"""
|
||||||
file_path = input("请输入 openie.json 文件的完整路径: ").strip()
|
file_path = input("请输入 openie.json 文件的完整路径: ").strip()
|
||||||
|
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
logger.error(f"文件路径不存在: {file_path}")
|
logger.error(f"文件路径不存在: {file_path}")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not file_path.endswith(".json"):
|
if not file_path.endswith(".json"):
|
||||||
logger.error("请输入一个有效的 .json 文件路径。")
|
logger.error("请输入一个有效的 .json 文件路径。")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"正在从 {file_path} 加载 OpenIE 数据...")
|
logger.info(f"正在从 {file_path} 加载 OpenIE 数据...")
|
||||||
openie_obj = OpenIE.load(filepath=file_path)
|
openie_obj = OpenIE.load()
|
||||||
asyncio.run(import_data(openie_obj=openie_obj))
|
asyncio.run(import_data(openie_obj=openie_obj))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从指定文件导入数据时发生错误: {e}")
|
logger.error(f"从指定文件导入数据时发生错误: {e}")
|
||||||
|
|
||||||
|
|
||||||
# --- 主函数 ---
|
# --- 主函数 ---
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 使用 os.path.relpath 创建相对于项目根目录的友好路径
|
# 使用 os.path.relpath 创建相对于项目根目录的友好路径
|
||||||
raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, ".."))
|
raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, ".."))
|
||||||
openie_output_relpath = os.path.relpath(OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, ".."))
|
openie_output_relpath = os.path.relpath(OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, ".."))
|
||||||
|
|
||||||
print("=== LPMM 知识库学习工具 ===")
|
print("=== LPMM 知识库学习工具 ===")
|
||||||
print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)")
|
print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)")
|
||||||
print(f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)")
|
print(f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)")
|
||||||
@@ -243,24 +278,26 @@ def main():
|
|||||||
print("-" * 30)
|
print("-" * 30)
|
||||||
choice = input("请输入你的选择 (0-5): ").strip()
|
choice = input("请输入你的选择 (0-5): ").strip()
|
||||||
|
|
||||||
if choice == '1':
|
if choice == "1":
|
||||||
preprocess_raw_data()
|
preprocess_raw_data()
|
||||||
elif choice == '2':
|
elif choice == "2":
|
||||||
paragraphs = preprocess_raw_data()
|
paragraphs = preprocess_raw_data()
|
||||||
if paragraphs: extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
if paragraphs:
|
||||||
elif choice == '3':
|
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
||||||
|
elif choice == "3":
|
||||||
asyncio.run(import_data())
|
asyncio.run(import_data())
|
||||||
elif choice == '4':
|
elif choice == "4":
|
||||||
paragraphs = preprocess_raw_data()
|
paragraphs = preprocess_raw_data()
|
||||||
if paragraphs:
|
if paragraphs:
|
||||||
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
||||||
asyncio.run(import_data())
|
asyncio.run(import_data())
|
||||||
elif choice == '5':
|
elif choice == "5":
|
||||||
import_from_specific_file()
|
import_from_specific_file()
|
||||||
elif choice == '0':
|
elif choice == "0":
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
else:
|
else:
|
||||||
print("无效输入,请重新运行脚本。")
|
print("无效输入,请重新运行脚本。")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user