Merge remote-tracking branch 'origin/dev' into HFC-para
This commit is contained in:
@@ -33,16 +33,10 @@ logger = get_module_logger("LPMM知识库-信息提取")
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
||||
IMPORTED_DATA_PATH = (
|
||||
global_config["persistence"]["raw_data_path"]
|
||||
if global_config["persistence"]["raw_data_path"]
|
||||
else os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
||||
)
|
||||
OPENIE_OUTPUT_DIR = (
|
||||
global_config["persistence"]["openie_data_path"]
|
||||
if global_config["persistence"]["openie_data_path"]
|
||||
else os.path.join(ROOT_PATH, "data/openie")
|
||||
IMPORTED_DATA_PATH = global_config["persistence"]["imported_data_path"] or os.path.join(
|
||||
ROOT_PATH, "data/imported_lpmm_data"
|
||||
)
|
||||
OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data/openie")
|
||||
|
||||
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
||||
file_lock = Lock()
|
||||
@@ -76,26 +70,25 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
|
||||
)
|
||||
if entity_list is None or rdf_triple_list is None:
|
||||
return None, pg_hash
|
||||
else:
|
||||
doc_item = {
|
||||
"idx": pg_hash,
|
||||
"passage": raw_data,
|
||||
"extracted_entities": entity_list,
|
||||
"extracted_triples": rdf_triple_list,
|
||||
}
|
||||
# 保存临时提取结果
|
||||
with file_lock:
|
||||
try:
|
||||
with open(temp_file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(doc_item, f, ensure_ascii=False, indent=4)
|
||||
except Exception as e:
|
||||
logger.error(f"保存缓存文件失败:{pg_hash}, 错误:{e}")
|
||||
# 如果保存失败,确保不会留下损坏的文件
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
sys.exit(0)
|
||||
return None, pg_hash
|
||||
return doc_item, None
|
||||
doc_item = {
|
||||
"idx": pg_hash,
|
||||
"passage": raw_data,
|
||||
"extracted_entities": entity_list,
|
||||
"extracted_triples": rdf_triple_list,
|
||||
}
|
||||
# 保存临时提取结果
|
||||
with file_lock:
|
||||
try:
|
||||
with open(temp_file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(doc_item, f, ensure_ascii=False, indent=4)
|
||||
except Exception as e:
|
||||
logger.error(f"保存缓存文件失败:{pg_hash}, 错误:{e}")
|
||||
# 如果保存失败,确保不会留下损坏的文件
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
sys.exit(0)
|
||||
return None, pg_hash
|
||||
return doc_item, None
|
||||
|
||||
|
||||
def signal_handler(_signum, _frame):
|
||||
@@ -104,7 +97,7 @@ def signal_handler(_signum, _frame):
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def main():
|
||||
def main(): # sourcery skip: comprehension-to-generator, extract-method
|
||||
# 设置信号处理器
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
@@ -125,13 +118,13 @@ def main():
|
||||
logger.info("--------进行信息提取--------\n")
|
||||
|
||||
logger.info("创建LLM客户端")
|
||||
llm_client_list = dict()
|
||||
for key in global_config["llm_providers"]:
|
||||
llm_client_list[key] = LLMClient(
|
||||
llm_client_list = {
|
||||
key: LLMClient(
|
||||
global_config["llm_providers"][key]["base_url"],
|
||||
global_config["llm_providers"][key]["api_key"],
|
||||
)
|
||||
|
||||
for key in global_config["llm_providers"]
|
||||
}
|
||||
# 检查 openie 输出目录
|
||||
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
||||
os.makedirs(OPENIE_OUTPUT_DIR)
|
||||
|
||||
Reference in New Issue
Block a user