Merge pull request #949 from UnCLAS-Prommer/dev

fix: 修复lpmm template和导入的一些问题
This commit is contained in:
UnCLAS-Prommer
2025-05-13 22:21:43 +08:00
committed by GitHub
3 changed files with 44 additions and 54 deletions

View File

@@ -21,11 +21,7 @@ from src.plugins.knowledge.src.utils.hash import get_sha256
# 添加项目根目录到 sys.path # 添加项目根目录到 sys.path
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
OPENIE_DIR = ( OPENIE_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data/openie")
global_config["persistence"]["openie_data_path"]
if global_config["persistence"]["openie_data_path"]
else os.path.join(ROOT_PATH, "data/openie")
)
logger = get_module_logger("OpenIE导入") logger = get_module_logger("OpenIE导入")
@@ -49,14 +45,14 @@ def hash_deduplicate(
new_triple_list_data: 去重后的三元组 new_triple_list_data: 去重后的三元组
""" """
# 保存去重后的段落 # 保存去重后的段落
new_raw_paragraphs = dict() new_raw_paragraphs = {}
# 保存去重后的三元组 # 保存去重后的三元组
new_triple_list_data = dict() new_triple_list_data = {}
for _, (raw_paragraph, triple_list) in enumerate(zip(raw_paragraphs.values(), triple_list_data.values())): for _, (raw_paragraph, triple_list) in enumerate(zip(raw_paragraphs.values(), triple_list_data.values())):
# 段落hash # 段落hash
paragraph_hash = get_sha256(raw_paragraph) paragraph_hash = get_sha256(raw_paragraph)
if ((PG_NAMESPACE + "-" + paragraph_hash) in stored_pg_hashes) and (paragraph_hash in stored_paragraph_hashes): if f"{PG_NAMESPACE}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
continue continue
new_raw_paragraphs[paragraph_hash] = raw_paragraph new_raw_paragraphs[paragraph_hash] = raw_paragraph
new_triple_list_data[paragraph_hash] = triple_list new_triple_list_data[paragraph_hash] = triple_list
@@ -65,6 +61,7 @@ def hash_deduplicate(
def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool: def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool:
# sourcery skip: extract-method
# 从OpenIE数据中提取段落原文与三元组列表 # 从OpenIE数据中提取段落原文与三元组列表
# 索引的段落原文 # 索引的段落原文
raw_paragraphs = openie_data.extract_raw_paragraph_dict() raw_paragraphs = openie_data.extract_raw_paragraph_dict()
@@ -117,7 +114,7 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
return False return False
# 新增:提示用户是否删除非法文段继续导入 # 新增:提示用户是否删除非法文段继续导入
# 将print移到所有logger.error之后确保不会被冲掉 # 将print移到所有logger.error之后确保不会被冲掉
logger.info("\n检测到非法文段,共{}条。".format(len(missing_idxs))) logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。")
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="") logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
user_choice = input().strip().lower() user_choice = input().strip().lower()
if user_choice != "y": if user_choice != "y":
@@ -133,10 +130,10 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
raw_paragraphs = openie_data.extract_raw_paragraph_dict() raw_paragraphs = openie_data.extract_raw_paragraph_dict()
entity_list_data = openie_data.extract_entity_dict() entity_list_data = openie_data.extract_entity_dict()
triple_list_data = openie_data.extract_triple_dict() triple_list_data = openie_data.extract_triple_dict()
# 再次校验 # 再次校验
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data): if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
logger.error("删除非法文段后,数据仍不一致,程序终止。") logger.error("删除非法文段后,数据仍不一致,程序终止。")
sys.exit(1) sys.exit(1)
# 将索引换为对应段落的hash值 # 将索引换为对应段落的hash值
logger.info("正在进行段落去重与重索引") logger.info("正在进行段落去重与重索引")
raw_paragraphs, triple_list_data = hash_deduplicate( raw_paragraphs, triple_list_data = hash_deduplicate(
@@ -166,7 +163,7 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
return True return True
def main(): def main(): # sourcery skip: dict-comprehension
# 新增确认提示 # 新增确认提示
print("=== 重要操作确认 ===") print("=== 重要操作确认 ===")
print("OpenIE导入时会大量发送请求可能会撞到请求速度上限请注意选用的模型") print("OpenIE导入时会大量发送请求可能会撞到请求速度上限请注意选用的模型")
@@ -185,7 +182,7 @@ def main():
logger.info("----开始导入openie数据----\n") logger.info("----开始导入openie数据----\n")
logger.info("创建LLM客户端") logger.info("创建LLM客户端")
llm_client_list = dict() llm_client_list = {}
for key in global_config["llm_providers"]: for key in global_config["llm_providers"]:
llm_client_list[key] = LLMClient( llm_client_list[key] = LLMClient(
global_config["llm_providers"][key]["base_url"], global_config["llm_providers"][key]["base_url"],
@@ -198,7 +195,7 @@ def main():
try: try:
embed_manager.load_from_file() embed_manager.load_from_file()
except Exception as e: except Exception as e:
logger.error("从文件加载Embedding库时发生错误{}".format(e)) logger.error(f"从文件加载Embedding库时发生错误{e}")
if "嵌入模型与本地存储不一致" in str(e): if "嵌入模型与本地存储不一致" in str(e):
logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。") logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型") logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型")
@@ -213,7 +210,7 @@ def main():
try: try:
kg_manager.load_from_file() kg_manager.load_from_file()
except Exception as e: except Exception as e:
logger.error("从文件加载KG时发生错误{}".format(e)) logger.error(f"从文件加载KG时发生错误{e}")
logger.error("如果你是第一次导入知识,请忽略此错误") logger.error("如果你是第一次导入知识,请忽略此错误")
logger.info("KG加载完成") logger.info("KG加载完成")
@@ -222,7 +219,7 @@ def main():
# 数据比对Embedding库与KG的段落hash集合 # 数据比对Embedding库与KG的段落hash集合
for pg_hash in kg_manager.stored_paragraph_hashes: for pg_hash in kg_manager.stored_paragraph_hashes:
key = PG_NAMESPACE + "-" + pg_hash key = f"{PG_NAMESPACE}-{pg_hash}"
if key not in embed_manager.stored_pg_hashes: if key not in embed_manager.stored_pg_hashes:
logger.warning(f"KG中存在Embedding库中不存在的段落{key}") logger.warning(f"KG中存在Embedding库中不存在的段落{key}")
@@ -230,7 +227,7 @@ def main():
try: try:
openie_data = OpenIE.load() openie_data = OpenIE.load()
except Exception as e: except Exception as e:
logger.error("导入OpenIE数据文件时发生错误{}".format(e)) logger.error(f"导入OpenIE数据文件时发生错误{e}")
return False return False
if handle_import_openie(openie_data, embed_manager, kg_manager) is False: if handle_import_openie(openie_data, embed_manager, kg_manager) is False:
logger.error("处理OpenIE数据时发生错误") logger.error("处理OpenIE数据时发生错误")

View File

@@ -33,16 +33,10 @@ logger = get_module_logger("LPMM知识库-信息提取")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
TEMP_DIR = os.path.join(ROOT_PATH, "temp") TEMP_DIR = os.path.join(ROOT_PATH, "temp")
IMPORTED_DATA_PATH = ( IMPORTED_DATA_PATH = global_config["persistence"]["imported_data_path"] or os.path.join(
global_config["persistence"]["raw_data_path"] ROOT_PATH, "data/imported_lpmm_data"
if global_config["persistence"]["raw_data_path"]
else os.path.join(ROOT_PATH, "data/imported_lpmm_data")
)
OPENIE_OUTPUT_DIR = (
global_config["persistence"]["openie_data_path"]
if global_config["persistence"]["openie_data_path"]
else os.path.join(ROOT_PATH, "data/openie")
) )
OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data/openie")
# 创建一个线程安全的锁,用于保护文件操作和共享数据 # 创建一个线程安全的锁,用于保护文件操作和共享数据
file_lock = Lock() file_lock = Lock()
@@ -76,26 +70,25 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
) )
if entity_list is None or rdf_triple_list is None: if entity_list is None or rdf_triple_list is None:
return None, pg_hash return None, pg_hash
else: doc_item = {
doc_item = { "idx": pg_hash,
"idx": pg_hash, "passage": raw_data,
"passage": raw_data, "extracted_entities": entity_list,
"extracted_entities": entity_list, "extracted_triples": rdf_triple_list,
"extracted_triples": rdf_triple_list, }
} # 保存临时提取结果
# 保存临时提取结果 with file_lock:
with file_lock: try:
try: with open(temp_file_path, "w", encoding="utf-8") as f:
with open(temp_file_path, "w", encoding="utf-8") as f: json.dump(doc_item, f, ensure_ascii=False, indent=4)
json.dump(doc_item, f, ensure_ascii=False, indent=4) except Exception as e:
except Exception as e: logger.error(f"保存缓存文件失败:{pg_hash}, 错误:{e}")
logger.error(f"保存缓存文件失败:{pg_hash}, 错误:{e}") # 如果保存失败,确保不会留下损坏的文件
# 如果保存失败,确保不会留下损坏的文件 if os.path.exists(temp_file_path):
if os.path.exists(temp_file_path): os.remove(temp_file_path)
os.remove(temp_file_path) sys.exit(0)
sys.exit(0) return None, pg_hash
return None, pg_hash return doc_item, None
return doc_item, None
def signal_handler(_signum, _frame): def signal_handler(_signum, _frame):
@@ -104,7 +97,7 @@ def signal_handler(_signum, _frame):
sys.exit(0) sys.exit(0)
def main(): def main(): # sourcery skip: comprehension-to-generator, extract-method
# 设置信号处理器 # 设置信号处理器
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
@@ -125,13 +118,13 @@ def main():
logger.info("--------进行信息提取--------\n") logger.info("--------进行信息提取--------\n")
logger.info("创建LLM客户端") logger.info("创建LLM客户端")
llm_client_list = dict() llm_client_list = {
for key in global_config["llm_providers"]: key: LLMClient(
llm_client_list[key] = LLMClient(
global_config["llm_providers"][key]["base_url"], global_config["llm_providers"][key]["base_url"],
global_config["llm_providers"][key]["api_key"], global_config["llm_providers"][key]["api_key"],
) )
for key in global_config["llm_providers"]
}
# 检查 openie 输出目录 # 检查 openie 输出目录
if not os.path.exists(OPENIE_OUTPUT_DIR): if not os.path.exists(OPENIE_OUTPUT_DIR):
os.makedirs(OPENIE_OUTPUT_DIR) os.makedirs(OPENIE_OUTPUT_DIR)

View File

@@ -54,7 +54,7 @@ res_top_k = 3 # 最终提供的文段TopK
[persistence] [persistence]
# 持久化配置(存储中间数据,防止重复计算) # 持久化配置(存储中间数据,防止重复计算)
data_root_path = "data" # 数据根目录 data_root_path = "data" # 数据根目录
raw_data_path = "data/imported_lpmm_data" # 原始数据路径 imported_data_path = "data/imported_lpmm_data" # 转换为json的raw文件数据路径
openie_data_path = "data/openie" # OpenIE数据路径 openie_data_path = "data/openie" # OpenIE数据路径
embedding_data_dir = "data/embedding" # 嵌入数据目录 embedding_data_dir = "data/embedding" # 嵌入数据目录
rag_data_dir = "data/rag" # RAG数据目录 rag_data_dir = "data/rag" # RAG数据目录