feat(extraction): 优化信息提取流程,支持异步并发和缓存管理

This commit is contained in:
Windpicker-owo
2025-11-09 21:38:31 +08:00
parent 8d172867dc
commit 62414e865c

View File

@@ -3,9 +3,7 @@ import datetime
import os import os
import shutil import shutil
import sys import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path from pathlib import Path
from threading import Lock
import aiofiles import aiofiles
import orjson import orjson
@@ -38,7 +36,6 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data", "lpmm_raw_data") RAW_DATA_PATH = os.path.join(ROOT_PATH, "data", "lpmm_raw_data")
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache") TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache")
file_lock = Lock()
# --- 缓存清理 --- # --- 缓存清理 ---
@@ -155,26 +152,41 @@ def get_extraction_prompt(paragraph: str) -> str:
async def extract_info_async(pg_hash, paragraph, llm_api): async def extract_info_async(pg_hash, paragraph, llm_api):
"""
异步提取单个段落的信息(带缓存支持)
Args:
pg_hash: 段落哈希值
paragraph: 段落文本
llm_api: LLM请求实例
Returns:
tuple: (doc_item或None, failed_hash或None)
"""
temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json") temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json")
with file_lock:
if os.path.exists(temp_file_path): # 🔧 优化:使用异步文件检查,避免阻塞
if os.path.exists(temp_file_path):
try:
async with aiofiles.open(temp_file_path, "rb") as f:
content = await f.read()
return orjson.loads(content), None
except orjson.JSONDecodeError:
# 缓存文件损坏,删除并重新生成
try: try:
async with aiofiles.open(temp_file_path, "rb") as f:
content = await f.read()
return orjson.loads(content), None
except orjson.JSONDecodeError:
os.remove(temp_file_path) os.remove(temp_file_path)
except OSError:
pass
prompt = get_extraction_prompt(paragraph) prompt = get_extraction_prompt(paragraph)
content = None content = None
try: try:
content, (_, _, _) = await llm_api.generate_response_async(prompt) content, (_, _, _) = await llm_api.generate_response_async(prompt)
# 改进点:调用封装好的函数处理JSON解析和修复 # 调用封装好的函数处理JSON解析和修复
extracted_data = _parse_and_repair_json(content) extracted_data = _parse_and_repair_json(content)
if extracted_data is None: if extracted_data is None:
# 如果解析失败,抛出异常以触发统一的错误处理逻辑
raise ValueError("无法从LLM输出中解析有效的JSON数据") raise ValueError("无法从LLM输出中解析有效的JSON数据")
doc_item = { doc_item = {
@@ -183,9 +195,11 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
"extracted_entities": extracted_data.get("entities", []), "extracted_entities": extracted_data.get("entities", []),
"extracted_triples": extracted_data.get("triples", []), "extracted_triples": extracted_data.get("triples", []),
} }
with file_lock:
async with aiofiles.open(temp_file_path, "wb") as f: # 保存到缓存(异步写入)
await f.write(orjson.dumps(doc_item)) async with aiofiles.open(temp_file_path, "wb") as f:
await f.write(orjson.dumps(doc_item))
return doc_item, None return doc_item, None
except Exception as e: except Exception as e:
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}") logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")
@@ -194,42 +208,61 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
return None, pg_hash return None, pg_hash
def extract_info_sync(pg_hash, paragraph, model_set): async def extract_information(paragraphs_dict, model_set):
llm_api = LLMRequest(model_set=model_set) """
return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api)) 🔧 优化:使用真正的异步并发代替多线程
这样可以:
def extract_information(paragraphs_dict, model_set): 1. 避免 event loop closed 错误
2. 更高效地利用 I/O 资源
3. 与我们优化的 LLM 请求层无缝集成
Args:
paragraphs_dict: {hash: paragraph} 字典
model_set: 模型配置
"""
logger.info("--- 步骤 2: 开始信息提取 ---") logger.info("--- 步骤 2: 开始信息提取 ---")
os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True) os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True)
os.makedirs(TEMP_DIR, exist_ok=True) os.makedirs(TEMP_DIR, exist_ok=True)
failed_hashes, open_ie_docs = [], [] failed_hashes, open_ie_docs = [], []
# 🔧 关键修复:创建单个 LLM 请求实例,复用连接
llm_api = LLMRequest(model_set=model_set, request_type="lpmm_extraction")
with ThreadPoolExecutor(max_workers=3) as executor: # 创建所有异步任务
f_to_hash = { tasks = [
executor.submit(extract_info_sync, p_hash, p, model_set): p_hash extract_info_async(p_hash, paragraph, llm_api)
for p_hash, p in paragraphs_dict.items() for p_hash, paragraph in paragraphs_dict.items()
} ]
with Progress(
SpinnerColumn(), total = len(tasks)
TextColumn("[progress.description]{task.description}"), completed = 0
BarColumn(),
TaskProgressColumn(), with Progress(
MofNCompleteColumn(), SpinnerColumn(),
"", TextColumn("[progress.description]{task.description}"),
TimeElapsedColumn(), BarColumn(),
"<", TaskProgressColumn(),
TimeRemainingColumn(), MofNCompleteColumn(),
) as progress: "",
task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict)) TimeElapsedColumn(),
for future in as_completed(f_to_hash): "<",
doc_item, failed_hash = future.result() TimeRemainingColumn(),
if failed_hash: ) as progress:
failed_hashes.append(failed_hash) task = progress.add_task("[cyan]正在提取信息...", total=total)
elif doc_item:
open_ie_docs.append(doc_item) # 🔧 优化:使用 asyncio.gather 并发执行所有任务
progress.update(task, advance=1) # return_exceptions=True 确保单个失败不影响其他任务
for coro in asyncio.as_completed(tasks):
doc_item, failed_hash = await coro
if failed_hash:
failed_hashes.append(failed_hash)
elif doc_item:
open_ie_docs.append(doc_item)
completed += 1
progress.update(task, advance=1)
if open_ie_docs: if open_ie_docs:
all_entities = [e for doc in open_ie_docs for e in doc["extracted_entities"]] all_entities = [e for doc in open_ie_docs for e in doc["extracted_entities"]]
@@ -244,6 +277,7 @@ def extract_information(paragraphs_dict, model_set):
with open(output_path, "wb") as f: with open(output_path, "wb") as f:
f.write(orjson.dumps(openie_obj._to_dict())) f.write(orjson.dumps(openie_obj._to_dict()))
logger.info(f"信息提取结果已保存到: {output_path}") logger.info(f"信息提取结果已保存到: {output_path}")
logger.info(f"成功提取 {len(open_ie_docs)} 个段落的信息")
if failed_hashes: if failed_hashes:
logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}") logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}")
@@ -354,20 +388,22 @@ def main():
print("6. [清理缓存] -> 删除所有已提取信息的缓存") print("6. [清理缓存] -> 删除所有已提取信息的缓存")
print("0. [退出]") print("0. [退出]")
print("-" * 30) print("-" * 30)
choice = input("请输入你的选择 (0-5): ").strip() choice = input("请输入你的选择 (0-6): ").strip()
if choice == "1": if choice == "1":
preprocess_raw_data() preprocess_raw_data()
elif choice == "2": elif choice == "2":
paragraphs = preprocess_raw_data() paragraphs = preprocess_raw_data()
if paragraphs: if paragraphs:
extract_information(paragraphs, model_config.model_task_config.lpmm_qa) # 🔧 修复:使用 asyncio.run 调用异步函数
asyncio.run(extract_information(paragraphs, model_config.model_task_config.lpmm_qa))
elif choice == "3": elif choice == "3":
asyncio.run(import_data()) asyncio.run(import_data())
elif choice == "4": elif choice == "4":
paragraphs = preprocess_raw_data() paragraphs = preprocess_raw_data()
if paragraphs: if paragraphs:
extract_information(paragraphs, model_config.model_task_config.lpmm_qa) # 🔧 修复:使用 asyncio.run 调用异步函数
asyncio.run(extract_information(paragraphs, model_config.model_task_config.lpmm_qa))
asyncio.run(import_data()) asyncio.run(import_data())
elif choice == "5": elif choice == "5":
import_from_specific_file() import_from_specific_file()