diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 851cc8b31..eae0683db 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -85,6 +85,7 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k logger.error("系统将于2秒后开始检查数据完整性") sleep(2) found_missing = False + missing_idxs = [] for doc in getattr(openie_data, "docs", []): idx = doc.get("idx", "<无idx>") passage = doc.get("passage", "<无passage>") @@ -104,14 +105,36 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k # print(f"检查: idx={idx}") if missing: found_missing = True + missing_idxs.append(idx) logger.error("\n") logger.error("数据缺失:") logger.error(f"对应哈希值:{idx}") logger.error(f"对应文段内容内容:{passage}") logger.error(f"非法原因:{', '.join(missing)}") + # 确保提示在所有非法数据输出后再输出 if not found_missing: - print("所有数据均完整,没有发现缺失字段。") - return False + logger.info("所有数据均完整,没有发现缺失字段。") + return False + # 新增:提示用户是否删除非法文段继续导入 + # 将print移到所有logger.error之后,确保不会被冲掉 + logger.info("\n检测到非法文段,共{}条。".format(len(missing_idxs))) + logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="") + user_choice = input().strip().lower() + if user_choice != "y": + logger.info("用户选择不删除非法文段,程序终止。") + sys.exit(1) + # 删除非法文段 + logger.info("正在删除非法文段并继续导入...") + # 过滤掉非法文段 + openie_data.docs = [doc for doc in getattr(openie_data, "docs", []) if doc.get("idx", "<无idx>") not in missing_idxs] + # 重新提取数据 + raw_paragraphs = openie_data.extract_raw_paragraph_dict() + entity_list_data = openie_data.extract_entity_dict() + triple_list_data = openie_data.extract_triple_dict() + # 再次校验 + if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data): + logger.error("删除非法文段后,数据仍不一致,程序终止。") + sys.exit(1) # 将索引换为对应段落的hash值 logger.info("正在进行段落去重与重索引") raw_paragraphs, triple_list_data = hash_deduplicate( diff --git a/scripts/raw_data_preprocessor.py b/scripts/raw_data_preprocessor.py index c87c30ca8..33fdede9e 100644 --- a/scripts/raw_data_preprocessor.py +++ b/scripts/raw_data_preprocessor.py @@ -5,12 +5,21 @@ import sys # 新增系统模块导入 import datetime # 新增导入 sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from src.common.logger import get_module_logger +from src.common.logger_manager import get_logger +from src.plugins.knowledge.src.lpmmconfig import global_config -logger = get_module_logger("LPMM数据库-原始数据处理") +logger = get_logger("lpmm") ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data") -IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data") +# 新增:确保 RAW_DATA_PATH 存在 +if not os.path.exists(RAW_DATA_PATH): + os.makedirs(RAW_DATA_PATH, exist_ok=True) + logger.info(f"已创建目录: {RAW_DATA_PATH}") + +if global_config.get("persistence", {}).get("raw_data_path") is not None: + IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, global_config["persistence"]["raw_data_path"]) +else: + IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data") # 添加项目根目录到 sys.path @@ -54,7 +63,7 @@ def main(): print("请确保原始数据已放置在正确的目录中。") confirm = input("确认继续执行?(y/n): ").strip().lower() if confirm != "y": - logger.error("操作已取消") + logger.info("操作已取消") sys.exit(1) print("\n" + "=" * 40 + "\n") @@ -94,6 +103,6 @@ def main(): if __name__ == "__main__": - print(f"Raw Data Path: {RAW_DATA_PATH}") - print(f"Imported Data Path: {IMPORTED_DATA_PATH}") + logger.info(f"原始数据路径: {RAW_DATA_PATH}") + logger.info(f"处理后的数据路径: {IMPORTED_DATA_PATH}") main()