fix(tool): 增强信息提取失败时的错误日志

在信息提取过程中,当大语言模型(LLM)返回的 JSON 格式不正确时,先前的日志只会记录一个通用的解析错误,而不会显示导致失败的原始响应内容,这使得调试变得困难。

此次更新通过在捕获到 JSON 解析异常时,额外记录 LLM 的原始输出内容来解决此问题。这有助于快速诊断并定位是模型输出不稳定还是提示词需要调整,从而提高了脚本的健壮性和可维护性。

此外,还对代码进行了一些格式化调整以提高可读性。
This commit is contained in:
minecraft1024a
2025-09-27 14:06:22 +08:00
parent f968d134c7
commit 93b0a6a862

View File

@@ -38,11 +38,13 @@ file_lock = Lock()
# --- 模块一:数据预处理 ---
def process_text_file(file_path):
with open(file_path, "r", encoding="utf-8") as f:
raw = f.read()
return [p.strip() for p in raw.split("\n\n") if p.strip()]
def preprocess_raw_data():
logger.info("--- 步骤 1: 开始数据预处理 ---")
os.makedirs(RAW_DATA_PATH, exist_ok=True)
@@ -61,8 +63,10 @@ def preprocess_raw_data():
logger.info("--- 数据预处理完成 ---")
return unique_paragraphs
# --- 模块二:信息提取 ---
def get_extraction_prompt(paragraph: str) -> str:
return f"""
请从以下段落中提取关键信息。你需要提取两种类型的信息:
@@ -81,6 +85,7 @@ def get_extraction_prompt(paragraph: str) -> str:
---
"""
async def extract_info_async(pg_hash, paragraph, llm_api):
temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json")
with file_lock:
@@ -92,11 +97,13 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
os.remove(temp_file_path)
prompt = get_extraction_prompt(paragraph)
content = None
try:
content, (_, _, _) = await llm_api.generate_response_async(prompt)
extracted_data = orjson.loads(content)
doc_item = {
"idx": pg_hash, "passage": paragraph,
"idx": pg_hash,
"passage": paragraph,
"extracted_entities": extracted_data.get("entities", []),
"extracted_triples": extracted_data.get("triples", []),
}
@@ -106,11 +113,15 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
return doc_item, None
except Exception as e:
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")
if content:
logger.error(f"导致解析失败的原始输出: {content}")
return None, pg_hash
def extract_info_sync(pg_hash, paragraph, llm_api):
return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api))
def extract_information(paragraphs_dict, model_set):
logger.info("--- 步骤 2: 开始信息提取 ---")
os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True)
@@ -120,13 +131,27 @@ def extract_information(paragraphs_dict, model_set):
failed_hashes, open_ie_docs = [], []
with ThreadPoolExecutor(max_workers=5) as executor:
f_to_hash = {executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items()}
with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), MofNCompleteColumn(), "", TimeElapsedColumn(), "<", TimeRemainingColumn()) as progress:
f_to_hash = {
executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items()
}
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
) as progress:
task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict))
for future in as_completed(f_to_hash):
doc_item, failed_hash = future.result()
if failed_hash: failed_hashes.append(failed_hash)
elif doc_item: open_ie_docs.append(doc_item)
if failed_hash:
failed_hashes.append(failed_hash)
elif doc_item:
open_ie_docs.append(doc_item)
progress.update(task, advance=1)
if open_ie_docs:
@@ -143,11 +168,14 @@ def extract_information(paragraphs_dict, model_set):
f.write(orjson.dumps(openie_obj._to_dict()))
logger.info(f"信息提取结果已保存到: {output_path}")
if failed_hashes: logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}")
if failed_hashes:
logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}")
logger.info("--- 信息提取完成 ---")
# --- 模块三:数据导入 ---
async def import_data(openie_obj: Optional[OpenIE] = None):
"""
将OpenIE数据导入知识库Embedding Store 和 KG
@@ -161,12 +189,16 @@ async def import_data(openie_obj: Optional[OpenIE] = None):
embed_manager, kg_manager = EmbeddingManager(), KGManager()
logger.info("正在加载现有的 Embedding 库...")
try: embed_manager.load_from_file()
except Exception as e: logger.warning(f"加载 Embedding 库失败: {e}")
try:
embed_manager.load_from_file()
except Exception as e:
logger.warning(f"加载 Embedding 库失败: {e}")
logger.info("正在加载现有的 KG...")
try: kg_manager.load_from_file()
except Exception as e: logger.warning(f"加载 KG 失败: {e}")
try:
kg_manager.load_from_file()
except Exception as e:
logger.warning(f"加载 KG 失败: {e}")
try:
if openie_obj:
@@ -207,6 +239,7 @@ async def import_data(openie_obj: Optional[OpenIE] = None):
logger.info("--- 数据导入完成 ---")
def import_from_specific_file():
"""从用户指定的 openie.json 文件导入数据"""
file_path = input("请输入 openie.json 文件的完整路径: ").strip()
@@ -221,13 +254,15 @@ def import_from_specific_file():
try:
logger.info(f"正在从 {file_path} 加载 OpenIE 数据...")
openie_obj = OpenIE.load(filepath=file_path)
openie_obj = OpenIE.load()
asyncio.run(import_data(openie_obj=openie_obj))
except Exception as e:
logger.error(f"从指定文件导入数据时发生错误: {e}")
# --- 主函数 ---
def main():
# 使用 os.path.relpath 创建相对于项目根目录的友好路径
raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, ".."))
@@ -243,24 +278,26 @@ def main():
print("-" * 30)
choice = input("请输入你的选择 (0-5): ").strip()
if choice == '1':
if choice == "1":
preprocess_raw_data()
elif choice == '2':
elif choice == "2":
paragraphs = preprocess_raw_data()
if paragraphs: extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
elif choice == '3':
if paragraphs:
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
elif choice == "3":
asyncio.run(import_data())
elif choice == '4':
elif choice == "4":
paragraphs = preprocess_raw_data()
if paragraphs:
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
asyncio.run(import_data())
elif choice == '5':
elif choice == "5":
import_from_specific_file()
elif choice == '0':
elif choice == "0":
sys.exit(0)
else:
print("无效输入,请重新运行脚本。")
if __name__ == "__main__":
main()