This commit is contained in:
SengokuCola
2025-05-02 19:19:33 +08:00
12 changed files with 324 additions and 122 deletions

View File

@@ -4,11 +4,13 @@ import signal
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock, Event
import sys
import glob
import datetime
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
# 添加项目根目录到 sys.path
import tqdm
from rich.progress import Progress # 替换为 rich 进度条
from src.common.logger import get_module_logger
from src.plugins.knowledge.src.lpmmconfig import global_config
@@ -16,10 +18,31 @@ from src.plugins.knowledge.src.ie_process import info_extract_from_str
from src.plugins.knowledge.src.llm_client import LLMClient
from src.plugins.knowledge.src.open_ie import OpenIE
from src.plugins.knowledge.src.raw_processing import load_raw_data
from rich.progress import (
BarColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
logger = get_module_logger("LPMM知识库-信息提取")
TEMP_DIR = "./temp"
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
IMPORTED_DATA_PATH = (
global_config["persistence"]["raw_data_path"]
if global_config["persistence"]["raw_data_path"]
else os.path.join(ROOT_PATH, "data/imported_lpmm_data")
)
OPENIE_OUTPUT_DIR = (
global_config["persistence"]["openie_data_path"]
if global_config["persistence"]["openie_data_path"]
else os.path.join(ROOT_PATH, "data/openie")
)
# 创建一个线程安全的锁,用于保护文件操作和共享数据
file_lock = Lock()
@@ -70,8 +93,7 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
# 如果保存失败,确保不会留下损坏的文件
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
# 设置shutdown_event以终止程序
shutdown_event.set()
sys.exit(0)
return None, pg_hash
return doc_item, None
@@ -79,7 +101,7 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
def signal_handler(_signum, _frame):
"""处理Ctrl+C信号"""
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
shutdown_event.set()
sys.exit(0)
def main():
@@ -110,33 +132,61 @@ def main():
global_config["llm_providers"][key]["api_key"],
)
logger.info("正在加载原始数据")
sha256_list, raw_datas = load_raw_data()
logger.info("原始数据加载完成\n")
# 检查 openie 输出目录
if not os.path.exists(OPENIE_OUTPUT_DIR):
os.makedirs(OPENIE_OUTPUT_DIR)
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
# 创建临时目录
if not os.path.exists(f"{TEMP_DIR}"):
os.makedirs(f"{TEMP_DIR}")
# 确保 TEMP_DIR 目录存在
if not os.path.exists(TEMP_DIR):
os.makedirs(TEMP_DIR)
logger.info(f"已创建缓存目录: {TEMP_DIR}")
# 遍历IMPORTED_DATA_PATH下所有json文件
imported_files = sorted(glob.glob(os.path.join(IMPORTED_DATA_PATH, "*.json")))
if not imported_files:
logger.error(f"未在 {IMPORTED_DATA_PATH} 下找到任何json文件")
sys.exit(1)
all_sha256_list = []
all_raw_datas = []
for imported_file in imported_files:
logger.info(f"正在处理文件: {imported_file}")
try:
sha256_list, raw_datas = load_raw_data(imported_file)
except Exception as e:
logger.error(f"读取文件失败: {imported_file}, 错误: {e}")
continue
all_sha256_list.extend(sha256_list)
all_raw_datas.extend(raw_datas)
failed_sha256 = []
open_ie_doc = []
# 创建线程池最大线程数为50
workers = global_config["info_extraction"]["workers"]
with ThreadPoolExecutor(max_workers=workers) as executor:
# 提交所有任务到线程池
future_to_hash = {
executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash
for pg_hash, raw_data in zip(sha256_list, raw_datas)
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas)
}
# 使用tqdm显示进度
with tqdm.tqdm(total=len(future_to_hash), postfix="正在进行提取:") as pbar:
# 处理完成的任务
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task("正在进行提取:", total=len(future_to_hash))
try:
for future in as_completed(future_to_hash):
if shutdown_event.is_set():
# 取消所有未完成的任务
for f in future_to_hash:
if not f.done():
f.cancel()
@@ -149,26 +199,38 @@ def main():
elif doc_item:
with open_ie_doc_lock:
open_ie_doc.append(doc_item)
pbar.update(1)
progress.update(task, advance=1)
except KeyboardInterrupt:
# 如果在这里捕获到KeyboardInterrupt说明signal_handler可能没有正常工作
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
shutdown_event.set()
# 取消所有未完成的任务
for f in future_to_hash:
if not f.done():
f.cancel()
# 保存信息提取结果
sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc])
openie_obj = OpenIE(
open_ie_doc,
round(sum_phrase_chars / num_phrases, 4),
round(sum_phrase_words / num_phrases, 4),
)
OpenIE.save(openie_obj)
# 合并所有文件的提取结果并保存
if open_ie_doc:
sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc])
openie_obj = OpenIE(
open_ie_doc,
round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0,
round(sum_phrase_words / num_phrases, 4) if num_phrases else 0,
)
# 输出文件名格式MM-DD-HH-ss-openie.json
now = datetime.datetime.now()
filename = now.strftime("%m-%d-%H-%S-openie.json")
output_path = os.path.join(OPENIE_OUTPUT_DIR, filename)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(
openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__,
f,
ensure_ascii=False,
indent=4,
)
logger.info(f"信息提取结果已保存到: {output_path}")
else:
logger.warning("没有可保存的信息提取结果")
logger.info("--------信息提取完成--------")
logger.info(f"提取失败的文段SHA256{failed_sha256}")