feat: 重构信息提取模块,移除LLMClient依赖,改为使用LLMRequest,优化数据加载和处理逻辑
This commit is contained in:
@@ -13,11 +13,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from rich.progress import Progress # 替换为 rich 进度条
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.knowledge.lpmmconfig import global_config
|
||||
from src.chat.knowledge.ie_process import info_extract_from_str
|
||||
# from src.chat.knowledge.lpmmconfig import global_config
|
||||
from src.chat.knowledge.ie_process import _entity_extract, info_extract_from_str
|
||||
from src.chat.knowledge.llm_client import LLMClient
|
||||
from src.chat.knowledge.open_ie import OpenIE
|
||||
from src.chat.knowledge.raw_processing import load_raw_data
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
TimeElapsedColumn,
|
||||
@@ -27,16 +26,17 @@ from rich.progress import (
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from raw_data_preprocessor import process_multi_files, load_raw_data
|
||||
from src.config.config import global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("LPMM知识库-信息提取")
|
||||
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
||||
IMPORTED_DATA_PATH = global_config["persistence"]["imported_data_path"] or os.path.join(
|
||||
ROOT_PATH, "data", "imported_lpmm_data"
|
||||
)
|
||||
OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data", "openie")
|
||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
|
||||
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
|
||||
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
||||
file_lock = Lock()
|
||||
@@ -45,6 +45,14 @@ open_ie_doc_lock = Lock()
|
||||
# 创建一个事件标志,用于控制程序终止
|
||||
shutdown_event = Event()
|
||||
|
||||
lpmm_entity_extract_llm = LLMRequest(
|
||||
model=global_config.model.lpmm_entity_extract,
|
||||
request_type="lpmm.entity_extract"
|
||||
)
|
||||
lpmm_rdf_build_llm = LLMRequest(
|
||||
model=global_config.model.lpmm_rdf_build,
|
||||
request_type="lpmm.rdf_build"
|
||||
)
|
||||
|
||||
def ensure_dirs():
|
||||
"""确保临时目录和输出目录存在"""
|
||||
@@ -54,12 +62,9 @@ def ensure_dirs():
|
||||
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
||||
os.makedirs(OPENIE_OUTPUT_DIR)
|
||||
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
|
||||
if not os.path.exists(IMPORTED_DATA_PATH):
|
||||
os.makedirs(IMPORTED_DATA_PATH)
|
||||
logger.info(f"已创建导入数据目录: {IMPORTED_DATA_PATH}")
|
||||
|
||||
|
||||
def process_single_text(pg_hash, raw_data, llm_client_list):
|
||||
def process_single_text(pg_hash, raw_data):
|
||||
"""处理单个文本的函数,用于线程池"""
|
||||
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"
|
||||
|
||||
@@ -77,8 +82,8 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
|
||||
os.remove(temp_file_path)
|
||||
|
||||
entity_list, rdf_triple_list = info_extract_from_str(
|
||||
llm_client_list[global_config["entity_extract"]["llm"]["provider"]],
|
||||
llm_client_list[global_config["rdf_build"]["llm"]["provider"]],
|
||||
lpmm_entity_extract_llm,
|
||||
lpmm_rdf_build_llm,
|
||||
raw_data,
|
||||
)
|
||||
if entity_list is None or rdf_triple_list is None:
|
||||
@@ -130,50 +135,17 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method
|
||||
ensure_dirs() # 确保目录存在
|
||||
logger.info("--------进行信息提取--------\n")
|
||||
|
||||
logger.info("创建LLM客户端")
|
||||
llm_client_list = {
|
||||
key: LLMClient(
|
||||
global_config["llm_providers"][key]["base_url"],
|
||||
global_config["llm_providers"][key]["api_key"],
|
||||
)
|
||||
for key in global_config["llm_providers"]
|
||||
}
|
||||
# 检查 openie 输出目录
|
||||
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
||||
os.makedirs(OPENIE_OUTPUT_DIR)
|
||||
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
|
||||
|
||||
# 确保 TEMP_DIR 目录存在
|
||||
if not os.path.exists(TEMP_DIR):
|
||||
os.makedirs(TEMP_DIR)
|
||||
logger.info(f"已创建缓存目录: {TEMP_DIR}")
|
||||
|
||||
# 遍历IMPORTED_DATA_PATH下所有json文件
|
||||
imported_files = sorted(glob.glob(os.path.join(IMPORTED_DATA_PATH, "*.json")))
|
||||
if not imported_files:
|
||||
logger.error(f"未在 {IMPORTED_DATA_PATH} 下找到任何json文件")
|
||||
sys.exit(1)
|
||||
|
||||
all_sha256_list = []
|
||||
all_raw_datas = []
|
||||
|
||||
for imported_file in imported_files:
|
||||
logger.info(f"正在处理文件: {imported_file}")
|
||||
try:
|
||||
sha256_list, raw_datas = load_raw_data(imported_file)
|
||||
except Exception as e:
|
||||
logger.error(f"读取文件失败: {imported_file}, 错误: {e}")
|
||||
continue
|
||||
all_sha256_list.extend(sha256_list)
|
||||
all_raw_datas.extend(raw_datas)
|
||||
# 加载原始数据
|
||||
logger.info("正在加载原始数据")
|
||||
all_sha256_list, all_raw_datas = load_raw_data()
|
||||
|
||||
failed_sha256 = []
|
||||
open_ie_doc = []
|
||||
|
||||
workers = global_config["info_extraction"]["workers"]
|
||||
workers = global_config.lpmm_knowledge.info_extraction_workers
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
future_to_hash = {
|
||||
executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash
|
||||
executor.submit(process_single_text, pg_hash, raw_data): pg_hash
|
||||
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user