Files
Mofox-Core/scripts/info_extraction.py
2025-08-26 14:20:26 +08:00

218 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import orjson
import os
import signal
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock, Event
import sys
import datetime
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
# 添加项目根目录到 sys.path
from rich.progress import Progress # 替换为 rich 进度条
from src.common.logger import get_logger
# from src.chat.knowledge.lpmmconfig import global_config
from src.chat.knowledge.ie_process import info_extract_from_str
from src.chat.knowledge.open_ie import OpenIE
from rich.progress import (
BarColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("LPMM知识库-信息提取")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
def ensure_dirs():
"""确保临时目录和输出目录存在"""
if not os.path.exists(TEMP_DIR):
os.makedirs(TEMP_DIR)
logger.info(f"已创建临时目录: {TEMP_DIR}")
if not os.path.exists(OPENIE_OUTPUT_DIR):
os.makedirs(OPENIE_OUTPUT_DIR)
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
if not os.path.exists(RAW_DATA_PATH):
os.makedirs(RAW_DATA_PATH)
logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}")
# 创建一个线程安全的锁,用于保护文件操作和共享数据
file_lock = Lock()
open_ie_doc_lock = Lock()
# 创建一个事件标志,用于控制程序终止
shutdown_event = Event()
lpmm_entity_extract_llm = LLMRequest(
model_set=model_config.model_task_config.lpmm_entity_extract,
request_type="lpmm.entity_extract"
)
lpmm_rdf_build_llm = LLMRequest(
model_set=model_config.model_task_config.lpmm_rdf_build,
request_type="lpmm.rdf_build"
)
def process_single_text(pg_hash, raw_data):
"""处理单个文本的函数,用于线程池"""
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"
# 使用文件锁检查和读取缓存文件
with file_lock:
if os.path.exists(temp_file_path):
try:
# 存在对应的提取结果
logger.info(f"找到缓存的提取结果:{pg_hash}")
with open(temp_file_path, "r", encoding="utf-8") as f:
return orjson.loads(f.read()), None
except orjson.JSONDecodeError:
# 如果JSON文件损坏删除它并重新处理
logger.warning(f"缓存文件损坏,重新处理:{pg_hash}")
os.remove(temp_file_path)
entity_list, rdf_triple_list = info_extract_from_str(
lpmm_entity_extract_llm,
lpmm_rdf_build_llm,
raw_data,
)
if entity_list is None or rdf_triple_list is None:
return None, pg_hash
doc_item = {
"idx": pg_hash,
"passage": raw_data,
"extracted_entities": entity_list,
"extracted_triples": rdf_triple_list,
}
# 保存临时提取结果
with file_lock:
try:
with open(temp_file_path, "w", encoding="utf-8") as f:
f.write(orjson.dumps(doc_item, option=orjson.OPT_INDENT_2).decode('utf-8'))
except Exception as e:
logger.error(f"保存缓存文件失败:{pg_hash}, 错误:{e}")
# 如果保存失败,确保不会留下损坏的文件
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
sys.exit(0)
return None, pg_hash
return doc_item, None
def signal_handler(_signum, _frame):
"""处理Ctrl+C信号"""
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
sys.exit(0)
def main(): # sourcery skip: comprehension-to-generator, extract-method
# 设置信号处理器
signal.signal(signal.SIGINT, signal_handler)
ensure_dirs() # 确保目录存在
# 新增用户确认提示
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
print("实体提取操作将会花费较多api余额和时间建议在空闲时段执行。")
print("举例600万字全剧情提取选用deepseek v3 0324消耗约40元约3小时。")
print("建议使用硅基流动的非Pro模型")
print("或者使用可以用赠金抵扣的Pro模型")
print("请确保账户余额充足,并且在执行前确认无误。")
confirm = input("确认继续执行?(y/n): ").strip().lower()
if confirm != "y":
logger.info("用户取消操作")
print("操作已取消")
sys.exit(1)
print("\n" + "=" * 40 + "\n")
ensure_dirs() # 确保目录存在
logger.info("--------进行信息提取--------\n")
# 加载原始数据
logger.info("正在加载原始数据")
all_sha256_list, all_raw_datas = load_raw_data()
failed_sha256 = []
open_ie_doc = []
workers = global_config.lpmm_knowledge.info_extraction_workers
with ThreadPoolExecutor(max_workers=workers) as executor:
future_to_hash = {
executor.submit(process_single_text, pg_hash, raw_data): pg_hash
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False)
}
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task("正在进行提取:", total=len(future_to_hash))
try:
for future in as_completed(future_to_hash):
if shutdown_event.is_set():
for f in future_to_hash:
if not f.done():
f.cancel()
break
doc_item, failed_hash = future.result()
if failed_hash:
failed_sha256.append(failed_hash)
logger.error(f"提取失败:{failed_hash}")
elif doc_item:
with open_ie_doc_lock:
open_ie_doc.append(doc_item)
progress.update(task, advance=1)
except KeyboardInterrupt:
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
shutdown_event.set()
for f in future_to_hash:
if not f.done():
f.cancel()
# 合并所有文件的提取结果并保存
if open_ie_doc:
sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc])
openie_obj = OpenIE(
open_ie_doc,
round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0,
round(sum_phrase_words / num_phrases, 4) if num_phrases else 0,
)
# 输出文件名格式MM-DD-HH-ss-openie.json
now = datetime.datetime.now()
filename = now.strftime("%m-%d-%H-%S-openie.json")
output_path = os.path.join(OPENIE_OUTPUT_DIR, filename)
with open(output_path, "w", encoding="utf-8") as f:
f.write(
orjson.dumps(
openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__,
option=orjson.OPT_INDENT_2
).decode('utf-8')
)
logger.info(f"信息提取结果已保存到: {output_path}")
else:
logger.warning("没有可保存的信息提取结果")
logger.info("--------信息提取完成--------")
logger.info(f"提取失败的文段SHA256{failed_sha256}")
if __name__ == "__main__":
main()