feat: 重构信息提取模块,移除LLMClient依赖,改为使用LLMRequest,优化数据加载和处理逻辑

This commit is contained in:
墨梓柒
2025-07-15 16:54:25 +08:00
parent 2b76dc2e21
commit f15e074cca
8 changed files with 119 additions and 177 deletions

View File

@@ -13,11 +13,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from rich.progress import Progress # 替换为 rich 进度条
from src.common.logger import get_logger
from src.chat.knowledge.lpmmconfig import global_config
from src.chat.knowledge.ie_process import info_extract_from_str
# from src.chat.knowledge.lpmmconfig import global_config
from src.chat.knowledge.ie_process import _entity_extract, info_extract_from_str
from src.chat.knowledge.llm_client import LLMClient
from src.chat.knowledge.open_ie import OpenIE
from src.chat.knowledge.raw_processing import load_raw_data
from rich.progress import (
BarColumn,
TimeElapsedColumn,
@@ -27,16 +26,17 @@ from rich.progress import (
SpinnerColumn,
TextColumn,
)
from raw_data_preprocessor import process_multi_files, load_raw_data
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("LPMM知识库-信息提取")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
IMPORTED_DATA_PATH = global_config["persistence"]["imported_data_path"] or os.path.join(
ROOT_PATH, "data", "imported_lpmm_data"
)
OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data", "openie")
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
# 创建一个线程安全的锁,用于保护文件操作和共享数据
file_lock = Lock()
@@ -45,6 +45,14 @@ open_ie_doc_lock = Lock()
# 创建一个事件标志,用于控制程序终止
shutdown_event = Event()
lpmm_entity_extract_llm = LLMRequest(
model=global_config.model.lpmm_entity_extract,
request_type="lpmm.entity_extract"
)
lpmm_rdf_build_llm = LLMRequest(
model=global_config.model.lpmm_rdf_build,
request_type="lpmm.rdf_build"
)
def ensure_dirs():
"""确保临时目录和输出目录存在"""
@@ -54,12 +62,9 @@ def ensure_dirs():
if not os.path.exists(OPENIE_OUTPUT_DIR):
os.makedirs(OPENIE_OUTPUT_DIR)
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
if not os.path.exists(IMPORTED_DATA_PATH):
os.makedirs(IMPORTED_DATA_PATH)
logger.info(f"已创建导入数据目录: {IMPORTED_DATA_PATH}")
def process_single_text(pg_hash, raw_data, llm_client_list):
def process_single_text(pg_hash, raw_data):
"""处理单个文本的函数,用于线程池"""
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"
@@ -77,8 +82,8 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
os.remove(temp_file_path)
entity_list, rdf_triple_list = info_extract_from_str(
llm_client_list[global_config["entity_extract"]["llm"]["provider"]],
llm_client_list[global_config["rdf_build"]["llm"]["provider"]],
lpmm_entity_extract_llm,
lpmm_rdf_build_llm,
raw_data,
)
if entity_list is None or rdf_triple_list is None:
@@ -130,50 +135,17 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method
ensure_dirs() # 确保目录存在
logger.info("--------进行信息提取--------\n")
logger.info("创建LLM客户端")
llm_client_list = {
key: LLMClient(
global_config["llm_providers"][key]["base_url"],
global_config["llm_providers"][key]["api_key"],
)
for key in global_config["llm_providers"]
}
# 检查 openie 输出目录
if not os.path.exists(OPENIE_OUTPUT_DIR):
os.makedirs(OPENIE_OUTPUT_DIR)
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
# 确保 TEMP_DIR 目录存在
if not os.path.exists(TEMP_DIR):
os.makedirs(TEMP_DIR)
logger.info(f"已创建缓存目录: {TEMP_DIR}")
# 遍历IMPORTED_DATA_PATH下所有json文件
imported_files = sorted(glob.glob(os.path.join(IMPORTED_DATA_PATH, "*.json")))
if not imported_files:
logger.error(f"未在 {IMPORTED_DATA_PATH} 下找到任何json文件")
sys.exit(1)
all_sha256_list = []
all_raw_datas = []
for imported_file in imported_files:
logger.info(f"正在处理文件: {imported_file}")
try:
sha256_list, raw_datas = load_raw_data(imported_file)
except Exception as e:
logger.error(f"读取文件失败: {imported_file}, 错误: {e}")
continue
all_sha256_list.extend(sha256_list)
all_raw_datas.extend(raw_datas)
# 加载原始数据
logger.info("正在加载原始数据")
all_sha256_list, all_raw_datas = load_raw_data()
failed_sha256 = []
open_ie_doc = []
workers = global_config["info_extraction"]["workers"]
workers = global_config.lpmm_knowledge.info_extraction_workers
with ThreadPoolExecutor(max_workers=workers) as executor:
future_to_hash = {
executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash
executor.submit(process_single_text, pg_hash, raw_data): pg_hash
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False)
}