修复lpmm template的一些问题
This commit is contained in:
@@ -21,11 +21,7 @@ from src.plugins.knowledge.src.utils.hash import get_sha256
|
|||||||
|
|
||||||
# 添加项目根目录到 sys.path
|
# 添加项目根目录到 sys.path
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
OPENIE_DIR = (
|
OPENIE_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data/openie")
|
||||||
global_config["persistence"]["openie_data_path"]
|
|
||||||
if global_config["persistence"]["openie_data_path"]
|
|
||||||
else os.path.join(ROOT_PATH, "data/openie")
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = get_module_logger("OpenIE导入")
|
logger = get_module_logger("OpenIE导入")
|
||||||
|
|
||||||
@@ -49,14 +45,14 @@ def hash_deduplicate(
|
|||||||
new_triple_list_data: 去重后的三元组
|
new_triple_list_data: 去重后的三元组
|
||||||
"""
|
"""
|
||||||
# 保存去重后的段落
|
# 保存去重后的段落
|
||||||
new_raw_paragraphs = dict()
|
new_raw_paragraphs = {}
|
||||||
# 保存去重后的三元组
|
# 保存去重后的三元组
|
||||||
new_triple_list_data = dict()
|
new_triple_list_data = {}
|
||||||
|
|
||||||
for _, (raw_paragraph, triple_list) in enumerate(zip(raw_paragraphs.values(), triple_list_data.values())):
|
for _, (raw_paragraph, triple_list) in enumerate(zip(raw_paragraphs.values(), triple_list_data.values())):
|
||||||
# 段落hash
|
# 段落hash
|
||||||
paragraph_hash = get_sha256(raw_paragraph)
|
paragraph_hash = get_sha256(raw_paragraph)
|
||||||
if ((PG_NAMESPACE + "-" + paragraph_hash) in stored_pg_hashes) and (paragraph_hash in stored_paragraph_hashes):
|
if f"{PG_NAMESPACE}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
|
||||||
continue
|
continue
|
||||||
new_raw_paragraphs[paragraph_hash] = raw_paragraph
|
new_raw_paragraphs[paragraph_hash] = raw_paragraph
|
||||||
new_triple_list_data[paragraph_hash] = triple_list
|
new_triple_list_data[paragraph_hash] = triple_list
|
||||||
@@ -65,6 +61,7 @@ def hash_deduplicate(
|
|||||||
|
|
||||||
|
|
||||||
def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool:
|
def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool:
|
||||||
|
# sourcery skip: extract-method
|
||||||
# 从OpenIE数据中提取段落原文与三元组列表
|
# 从OpenIE数据中提取段落原文与三元组列表
|
||||||
# 索引的段落原文
|
# 索引的段落原文
|
||||||
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
||||||
@@ -117,7 +114,7 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
|||||||
return False
|
return False
|
||||||
# 新增:提示用户是否删除非法文段继续导入
|
# 新增:提示用户是否删除非法文段继续导入
|
||||||
# 将print移到所有logger.error之后,确保不会被冲掉
|
# 将print移到所有logger.error之后,确保不会被冲掉
|
||||||
logger.info("\n检测到非法文段,共{}条。".format(len(missing_idxs)))
|
logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。")
|
||||||
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
|
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
|
||||||
user_choice = input().strip().lower()
|
user_choice = input().strip().lower()
|
||||||
if user_choice != "y":
|
if user_choice != "y":
|
||||||
@@ -166,7 +163,7 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main(): # sourcery skip: dict-comprehension
|
||||||
# 新增确认提示
|
# 新增确认提示
|
||||||
print("=== 重要操作确认 ===")
|
print("=== 重要操作确认 ===")
|
||||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||||
@@ -185,7 +182,7 @@ def main():
|
|||||||
logger.info("----开始导入openie数据----\n")
|
logger.info("----开始导入openie数据----\n")
|
||||||
|
|
||||||
logger.info("创建LLM客户端")
|
logger.info("创建LLM客户端")
|
||||||
llm_client_list = dict()
|
llm_client_list = {}
|
||||||
for key in global_config["llm_providers"]:
|
for key in global_config["llm_providers"]:
|
||||||
llm_client_list[key] = LLMClient(
|
llm_client_list[key] = LLMClient(
|
||||||
global_config["llm_providers"][key]["base_url"],
|
global_config["llm_providers"][key]["base_url"],
|
||||||
@@ -198,7 +195,7 @@ def main():
|
|||||||
try:
|
try:
|
||||||
embed_manager.load_from_file()
|
embed_manager.load_from_file()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("从文件加载Embedding库时发生错误:{}".format(e))
|
logger.error(f"从文件加载Embedding库时发生错误:{e}")
|
||||||
if "嵌入模型与本地存储不一致" in str(e):
|
if "嵌入模型与本地存储不一致" in str(e):
|
||||||
logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
|
logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
|
||||||
logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型")
|
logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型")
|
||||||
@@ -213,7 +210,7 @@ def main():
|
|||||||
try:
|
try:
|
||||||
kg_manager.load_from_file()
|
kg_manager.load_from_file()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("从文件加载KG时发生错误:{}".format(e))
|
logger.error(f"从文件加载KG时发生错误:{e}")
|
||||||
logger.error("如果你是第一次导入知识,请忽略此错误")
|
logger.error("如果你是第一次导入知识,请忽略此错误")
|
||||||
logger.info("KG加载完成")
|
logger.info("KG加载完成")
|
||||||
|
|
||||||
@@ -222,7 +219,7 @@ def main():
|
|||||||
|
|
||||||
# 数据比对:Embedding库与KG的段落hash集合
|
# 数据比对:Embedding库与KG的段落hash集合
|
||||||
for pg_hash in kg_manager.stored_paragraph_hashes:
|
for pg_hash in kg_manager.stored_paragraph_hashes:
|
||||||
key = PG_NAMESPACE + "-" + pg_hash
|
key = f"{PG_NAMESPACE}-{pg_hash}"
|
||||||
if key not in embed_manager.stored_pg_hashes:
|
if key not in embed_manager.stored_pg_hashes:
|
||||||
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
||||||
|
|
||||||
@@ -230,7 +227,7 @@ def main():
|
|||||||
try:
|
try:
|
||||||
openie_data = OpenIE.load()
|
openie_data = OpenIE.load()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("导入OpenIE数据文件时发生错误:{}".format(e))
|
logger.error(f"导入OpenIE数据文件时发生错误:{e}")
|
||||||
return False
|
return False
|
||||||
if handle_import_openie(openie_data, embed_manager, kg_manager) is False:
|
if handle_import_openie(openie_data, embed_manager, kg_manager) is False:
|
||||||
logger.error("处理OpenIE数据时发生错误")
|
logger.error("处理OpenIE数据时发生错误")
|
||||||
|
|||||||
@@ -33,16 +33,10 @@ logger = get_module_logger("LPMM知识库-信息提取")
|
|||||||
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
||||||
IMPORTED_DATA_PATH = (
|
IMPORTED_DATA_PATH = global_config["persistence"]["imported_data_path"] or os.path.join(
|
||||||
global_config["persistence"]["raw_data_path"]
|
ROOT_PATH, "data/imported_lpmm_data"
|
||||||
if global_config["persistence"]["raw_data_path"]
|
|
||||||
else os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
|
||||||
)
|
|
||||||
OPENIE_OUTPUT_DIR = (
|
|
||||||
global_config["persistence"]["openie_data_path"]
|
|
||||||
if global_config["persistence"]["openie_data_path"]
|
|
||||||
else os.path.join(ROOT_PATH, "data/openie")
|
|
||||||
)
|
)
|
||||||
|
OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data/openie")
|
||||||
|
|
||||||
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
||||||
file_lock = Lock()
|
file_lock = Lock()
|
||||||
@@ -76,7 +70,6 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
|
|||||||
)
|
)
|
||||||
if entity_list is None or rdf_triple_list is None:
|
if entity_list is None or rdf_triple_list is None:
|
||||||
return None, pg_hash
|
return None, pg_hash
|
||||||
else:
|
|
||||||
doc_item = {
|
doc_item = {
|
||||||
"idx": pg_hash,
|
"idx": pg_hash,
|
||||||
"passage": raw_data,
|
"passage": raw_data,
|
||||||
@@ -104,7 +97,7 @@ def signal_handler(_signum, _frame):
|
|||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main(): # sourcery skip: comprehension-to-generator, extract-method
|
||||||
# 设置信号处理器
|
# 设置信号处理器
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
@@ -125,13 +118,13 @@ def main():
|
|||||||
logger.info("--------进行信息提取--------\n")
|
logger.info("--------进行信息提取--------\n")
|
||||||
|
|
||||||
logger.info("创建LLM客户端")
|
logger.info("创建LLM客户端")
|
||||||
llm_client_list = dict()
|
llm_client_list = {
|
||||||
for key in global_config["llm_providers"]:
|
key: LLMClient(
|
||||||
llm_client_list[key] = LLMClient(
|
|
||||||
global_config["llm_providers"][key]["base_url"],
|
global_config["llm_providers"][key]["base_url"],
|
||||||
global_config["llm_providers"][key]["api_key"],
|
global_config["llm_providers"][key]["api_key"],
|
||||||
)
|
)
|
||||||
|
for key in global_config["llm_providers"]
|
||||||
|
}
|
||||||
# 检查 openie 输出目录
|
# 检查 openie 输出目录
|
||||||
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
||||||
os.makedirs(OPENIE_OUTPUT_DIR)
|
os.makedirs(OPENIE_OUTPUT_DIR)
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ res_top_k = 3 # 最终提供的文段TopK
|
|||||||
[persistence]
|
[persistence]
|
||||||
# 持久化配置(存储中间数据,防止重复计算)
|
# 持久化配置(存储中间数据,防止重复计算)
|
||||||
data_root_path = "data" # 数据根目录
|
data_root_path = "data" # 数据根目录
|
||||||
raw_data_path = "data/imported_lpmm_data" # 原始数据路径
|
imported_data_path = "data/imported_lpmm_data" # 转换为json的raw文件数据路径
|
||||||
openie_data_path = "data/openie" # OpenIE数据路径
|
openie_data_path = "data/openie" # OpenIE数据路径
|
||||||
embedding_data_dir = "data/embedding" # 嵌入数据目录
|
embedding_data_dir = "data/embedding" # 嵌入数据目录
|
||||||
rag_data_dir = "data/rag" # RAG数据目录
|
rag_data_dir = "data/rag" # RAG数据目录
|
||||||
|
|||||||
Reference in New Issue
Block a user