feat: 增强数据导入处理,新增非法文段检测与用户确认删除功能;优化原始数据路径创建与日志记录
This commit is contained in:
@@ -85,6 +85,7 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
|||||||
logger.error("系统将于2秒后开始检查数据完整性")
|
logger.error("系统将于2秒后开始检查数据完整性")
|
||||||
sleep(2)
|
sleep(2)
|
||||||
found_missing = False
|
found_missing = False
|
||||||
|
missing_idxs = []
|
||||||
for doc in getattr(openie_data, "docs", []):
|
for doc in getattr(openie_data, "docs", []):
|
||||||
idx = doc.get("idx", "<无idx>")
|
idx = doc.get("idx", "<无idx>")
|
||||||
passage = doc.get("passage", "<无passage>")
|
passage = doc.get("passage", "<无passage>")
|
||||||
@@ -104,14 +105,36 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
|||||||
# print(f"检查: idx={idx}")
|
# print(f"检查: idx={idx}")
|
||||||
if missing:
|
if missing:
|
||||||
found_missing = True
|
found_missing = True
|
||||||
|
missing_idxs.append(idx)
|
||||||
logger.error("\n")
|
logger.error("\n")
|
||||||
logger.error("数据缺失:")
|
logger.error("数据缺失:")
|
||||||
logger.error(f"对应哈希值:{idx}")
|
logger.error(f"对应哈希值:{idx}")
|
||||||
logger.error(f"对应文段内容内容:{passage}")
|
logger.error(f"对应文段内容内容:{passage}")
|
||||||
logger.error(f"非法原因:{', '.join(missing)}")
|
logger.error(f"非法原因:{', '.join(missing)}")
|
||||||
|
# 确保提示在所有非法数据输出后再输出
|
||||||
if not found_missing:
|
if not found_missing:
|
||||||
print("所有数据均完整,没有发现缺失字段。")
|
logger.info("所有数据均完整,没有发现缺失字段。")
|
||||||
return False
|
return False
|
||||||
|
# 新增:提示用户是否删除非法文段继续导入
|
||||||
|
# 将print移到所有logger.error之后,确保不会被冲掉
|
||||||
|
logger.info("\n检测到非法文段,共{}条。".format(len(missing_idxs)))
|
||||||
|
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
|
||||||
|
user_choice = input().strip().lower()
|
||||||
|
if user_choice != "y":
|
||||||
|
logger.info("用户选择不删除非法文段,程序终止。")
|
||||||
|
sys.exit(1)
|
||||||
|
# 删除非法文段
|
||||||
|
logger.info("正在删除非法文段并继续导入...")
|
||||||
|
# 过滤掉非法文段
|
||||||
|
openie_data.docs = [doc for doc in getattr(openie_data, "docs", []) if doc.get("idx", "<无idx>") not in missing_idxs]
|
||||||
|
# 重新提取数据
|
||||||
|
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
||||||
|
entity_list_data = openie_data.extract_entity_dict()
|
||||||
|
triple_list_data = openie_data.extract_triple_dict()
|
||||||
|
# 再次校验
|
||||||
|
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
|
||||||
|
logger.error("删除非法文段后,数据仍不一致,程序终止。")
|
||||||
|
sys.exit(1)
|
||||||
# 将索引换为对应段落的hash值
|
# 将索引换为对应段落的hash值
|
||||||
logger.info("正在进行段落去重与重索引")
|
logger.info("正在进行段落去重与重索引")
|
||||||
raw_paragraphs, triple_list_data = hash_deduplicate(
|
raw_paragraphs, triple_list_data = hash_deduplicate(
|
||||||
|
|||||||
@@ -5,11 +5,20 @@ import sys # 新增系统模块导入
|
|||||||
import datetime # 新增导入
|
import datetime # 新增导入
|
||||||
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger_manager import get_logger
|
||||||
|
from src.plugins.knowledge.src.lpmmconfig import global_config
|
||||||
|
|
||||||
logger = get_module_logger("LPMM数据库-原始数据处理")
|
logger = get_logger("lpmm")
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
|
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
|
||||||
|
# 新增:确保 RAW_DATA_PATH 存在
|
||||||
|
if not os.path.exists(RAW_DATA_PATH):
|
||||||
|
os.makedirs(RAW_DATA_PATH, exist_ok=True)
|
||||||
|
logger.info(f"已创建目录: {RAW_DATA_PATH}")
|
||||||
|
|
||||||
|
if global_config.get("persistence", {}).get("raw_data_path") is not None:
|
||||||
|
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, global_config["persistence"]["raw_data_path"])
|
||||||
|
else:
|
||||||
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
||||||
|
|
||||||
# 添加项目根目录到 sys.path
|
# 添加项目根目录到 sys.path
|
||||||
@@ -54,7 +63,7 @@ def main():
|
|||||||
print("请确保原始数据已放置在正确的目录中。")
|
print("请确保原始数据已放置在正确的目录中。")
|
||||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
||||||
if confirm != "y":
|
if confirm != "y":
|
||||||
logger.error("操作已取消")
|
logger.info("操作已取消")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
print("\n" + "=" * 40 + "\n")
|
print("\n" + "=" * 40 + "\n")
|
||||||
|
|
||||||
@@ -94,6 +103,6 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(f"Raw Data Path: {RAW_DATA_PATH}")
|
logger.info(f"原始数据路径: {RAW_DATA_PATH}")
|
||||||
print(f"Imported Data Path: {IMPORTED_DATA_PATH}")
|
logger.info(f"处理后的数据路径: {IMPORTED_DATA_PATH}")
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user