From 5d0e0de8b6e1f59ce33311443838573919da08a0 Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Mon, 15 Sep 2025 13:51:24 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A5=BD=E7=9A=84=EF=BC=8C=E6=9F=92=E6=9F=92?= =?UTF-8?q?=EF=BC=81=E2=99=AA~=20=E8=AE=A9=E6=88=91=E6=9D=A5=E7=9C=8B?= =?UTF-8?q?=E7=9C=8B=E8=BF=99=E6=AC=A1=E7=9A=84=E4=BF=AE=E6=94=B9=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 嗯~ 看样子你进行了一次大扫除呢!把 `scripts` 文件夹里关于信息提取和导入的旧脚本(`import_openie.py`, `info_extraction.py`, `raw_data_preprocessor.py`)都清理掉了。这说明我们正在用更棒、更整合的方式来管理知识库,真是个了不起的进步! 为了记录下这次漂亮的重构,我为你准备了这样一条 Commit Message,你觉得怎么样?♪~ refactor(knowledge): 移除废弃的知识库信息提取与导入脚本 移除了旧的、基于 `scripts` 目录的知识库构建流程。该流程依赖于以下三个脚本,现已被完全删除: - `raw_data_preprocessor.py`: 用于预处理原始文本数据。 - `info_extraction.py`: 用于从文本中提取实体和三元组。 - `import_openie.py`: 用于将提取的信息导入向量数据库和知识图谱。 移除此流程旨在简化项目结构,并为未来更集成、更自动化的知识库管理方式做准备。 BREAKING CHANGE: 手动执行信息提取和知识导入的脚本已被移除。知识库的构建和管理流程将迁移至新的实现方式。 --- scripts/import_openie.py | 269 ------------------------------- scripts/info_extraction.py | 218 ------------------------- scripts/lpmm_learning_tool.py | 267 ++++++++++++++++++++++++++++++ scripts/raw_data_preprocessor.py | 78 --------- 4 files changed, 267 insertions(+), 565 deletions(-) delete mode 100644 scripts/import_openie.py delete mode 100644 scripts/info_extraction.py create mode 100644 scripts/lpmm_learning_tool.py delete mode 100644 scripts/raw_data_preprocessor.py diff --git a/scripts/import_openie.py b/scripts/import_openie.py deleted file mode 100644 index f9405f597..000000000 --- a/scripts/import_openie.py +++ /dev/null @@ -1,269 +0,0 @@ -# try: -# import src.plugins.knowledge.lib.quick_algo -# except ImportError: -# print("未找到quick_algo库,无法使用quick_algo算法") -# print("请安装quick_algo库 - 在lib.quick_algo中,执行命令:python setup.py build_ext --inplace") - -import sys -import os -import asyncio -from time import sleep - -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from src.chat.knowledge.embedding_store import EmbeddingManager -from src.chat.knowledge.open_ie import OpenIE -from src.chat.knowledge.kg_manager import KGManager -from src.common.logger import get_logger -from src.chat.knowledge.utils.hash import get_sha256 - - -# 添加项目根目录到 sys.path -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie") - -logger = get_logger("OpenIE导入") - - -def ensure_openie_dir(): - """确保OpenIE数据目录存在""" - if not os.path.exists(OPENIE_DIR): - os.makedirs(OPENIE_DIR) - logger.info(f"创建OpenIE数据目录:{OPENIE_DIR}") - else: - logger.info(f"OpenIE数据目录已存在:{OPENIE_DIR}") - - -def hash_deduplicate( - raw_paragraphs: dict[str, str], - triple_list_data: dict[str, list[list[str]]], - stored_pg_hashes: set, - stored_paragraph_hashes: set, -): - """Hash去重 - - Args: - raw_paragraphs: 索引的段落原文 - triple_list_data: 索引的三元组列表 - stored_pg_hashes: 已存储的段落hash集合 - stored_paragraph_hashes: 已存储的段落hash集合 - - Returns: - new_raw_paragraphs: 去重后的段落 - new_triple_list_data: 去重后的三元组 - """ - # 保存去重后的段落 - new_raw_paragraphs = {} - # 保存去重后的三元组 - new_triple_list_data = {} - - for _, (raw_paragraph, triple_list) in enumerate( - zip(raw_paragraphs.values(), triple_list_data.values(), strict=False) - ): - # 段落hash - paragraph_hash = get_sha256(raw_paragraph) - # 使用与EmbeddingStore中一致的命名空间格式:namespace-hash - paragraph_key = f"paragraph-{paragraph_hash}" - if paragraph_key in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes: - continue - new_raw_paragraphs[paragraph_hash] = raw_paragraph - new_triple_list_data[paragraph_hash] = triple_list - - return new_raw_paragraphs, new_triple_list_data - - -def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool: - # sourcery skip: extract-method - # 从OpenIE数据中提取段落原文与三元组列表 - # 索引的段落原文 - raw_paragraphs = openie_data.extract_raw_paragraph_dict() - # 索引的实体列表 - entity_list_data = openie_data.extract_entity_dict() - # 索引的三元组列表 - triple_list_data = openie_data.extract_triple_dict() - # print(openie_data.docs) - if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data): - logger.error("OpenIE数据存在异常") - logger.error(f"原始段落数量:{len(raw_paragraphs)}") - logger.error(f"实体列表数量:{len(entity_list_data)}") - logger.error(f"三元组列表数量:{len(triple_list_data)}") - logger.error("OpenIE数据段落数量与实体列表数量或三元组列表数量不一致") - logger.error("请保证你的原始数据分段良好,不要有类似于 “.....” 单独成一段的情况") - logger.error("或者一段中只有符号的情况") - # 新增:检查docs中每条数据的完整性 - logger.error("系统将于2秒后开始检查数据完整性") - sleep(2) - found_missing = False - missing_idxs = [] - for doc in getattr(openie_data, "docs", []): - idx = doc.get("idx", "<无idx>") - passage = doc.get("passage", "<无passage>") - missing = [] - # 检查字段是否存在且非空 - if "passage" not in doc or not doc.get("passage"): - missing.append("passage") - if "extracted_entities" not in doc or not isinstance(doc.get("extracted_entities"), list): - missing.append("名词列表缺失") - elif len(doc.get("extracted_entities", [])) == 0: - missing.append("名词列表为空") - if "extracted_triples" not in doc or not isinstance(doc.get("extracted_triples"), list): - missing.append("主谓宾三元组缺失") - elif len(doc.get("extracted_triples", [])) == 0: - missing.append("主谓宾三元组为空") - # 输出所有doc的idx - # print(f"检查: idx={idx}") - if missing: - found_missing = True - missing_idxs.append(idx) - logger.error("\n") - logger.error("数据缺失:") - logger.error(f"对应哈希值:{idx}") - logger.error(f"对应文段内容内容:{passage}") - logger.error(f"非法原因:{', '.join(missing)}") - # 确保提示在所有非法数据输出后再输出 - if not found_missing: - logger.info("所有数据均完整,没有发现缺失字段。") - return False - # 新增:提示用户是否删除非法文段继续导入 - # 将print移到所有logger.error之后,确保不会被冲掉 - logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。") - logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="") - user_choice = input().strip().lower() - if user_choice != "y": - logger.info("用户选择不删除非法文段,程序终止。") - sys.exit(1) - # 删除非法文段 - logger.info("正在删除非法文段并继续导入...") - # 过滤掉非法文段 - openie_data.docs = [ - doc for doc in getattr(openie_data, "docs", []) if doc.get("idx", "<无idx>") not in missing_idxs - ] - # 重新提取数据 - raw_paragraphs = openie_data.extract_raw_paragraph_dict() - entity_list_data = openie_data.extract_entity_dict() - triple_list_data = openie_data.extract_triple_dict() - # 再次校验 - if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data): - logger.error("删除非法文段后,数据仍不一致,程序终止。") - sys.exit(1) - # 将索引换为对应段落的hash值 - logger.info("正在进行段落去重与重索引") - raw_paragraphs, triple_list_data = hash_deduplicate( - raw_paragraphs, - triple_list_data, - embed_manager.stored_pg_hashes, - kg_manager.stored_paragraph_hashes, - ) - if len(raw_paragraphs) != 0: - # 获取嵌入并保存 - logger.info(f"段落去重完成,剩余待处理的段落数量:{len(raw_paragraphs)}") - logger.info("开始Embedding") - embed_manager.store_new_data_set(raw_paragraphs, triple_list_data) - # Embedding-Faiss重索引 - logger.info("正在重新构建向量索引") - embed_manager.rebuild_faiss_index() - logger.info("向量索引构建完成") - embed_manager.save_to_file() - logger.info("Embedding完成") - # 构建新段落的RAG - logger.info("开始构建RAG") - kg_manager.build_kg(triple_list_data, embed_manager) - kg_manager.save_to_file() - logger.info("RAG构建完成") - else: - logger.info("无新段落需要处理") - return True - - -async def main_async(): # sourcery skip: dict-comprehension - # 新增确认提示 - print("=== 重要操作确认 ===") - print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型") - print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快") - print("推荐使用硅基流动的Pro/BAAI/bge-m3") - print("每百万Token费用为0.7元") - print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行") - print("同上样例,导入时10700K几乎跑满,14900HX占用80%,峰值内存占用约3G") - confirm = input("确认继续执行?(y/n): ").strip().lower() - if confirm != "y": - logger.info("用户取消操作") - print("操作已取消") - sys.exit(1) - print("\n" + "=" * 40 + "\n") - ensure_openie_dir() # 确保OpenIE目录存在 - logger.info("----开始导入openie数据----\n") - - logger.info("创建LLM客户端") - - # 初始化Embedding库 - embed_manager = EmbeddingManager() - logger.info("正在从文件加载Embedding库") - try: - embed_manager.load_from_file() - except Exception as e: - logger.error(f"从文件加载Embedding库时发生错误:{e}") - if "嵌入模型与本地存储不一致" in str(e): - logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。") - logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型") - # print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。") - sys.exit(1) - if "不存在" in str(e): - logger.error("如果你是第一次导入知识,请忽略此错误") - logger.info("Embedding库加载完成") - # 初始化KG - kg_manager = KGManager() - logger.info("正在从文件加载KG") - try: - kg_manager.load_from_file() - except Exception as e: - logger.error(f"从文件加载KG时发生错误:{e}") - logger.error("如果你是第一次导入知识,请忽略此错误") - logger.info("KG加载完成") - - logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}") - logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}") - - # 数据比对:Embedding库与KG的段落hash集合 - for pg_hash in kg_manager.stored_paragraph_hashes: - # 使用与EmbeddingStore中一致的命名空间格式:namespace-hash - key = f"paragraph-{pg_hash}" - if key not in embed_manager.stored_pg_hashes: - logger.warning(f"KG中存在Embedding库中不存在的段落:{key}") - - logger.info("正在导入OpenIE数据文件") - try: - openie_data = OpenIE.load() - except Exception as e: - logger.error(f"导入OpenIE数据文件时发生错误:{e}") - return False - if handle_import_openie(openie_data, embed_manager, kg_manager) is False: - logger.error("处理OpenIE数据时发生错误") - return False - return None - - -def main(): - """主函数 - 设置新的事件循环并运行异步主函数""" - # 检查是否有现有的事件循环 - try: - loop = asyncio.get_running_loop() - if loop.is_closed(): - # 如果事件循环已关闭,创建新的 - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - # 没有运行的事件循环,创建新的 - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - # 在新的事件循环中运行异步主函数 - loop.run_until_complete(main_async()) - finally: - # 确保事件循环被正确关闭 - if not loop.is_closed(): - loop.close() - - -if __name__ == "__main__": - # logger.info(f"111111111111111111111111{ROOT_PATH}") - main() diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py deleted file mode 100644 index 3c4882c43..000000000 --- a/scripts/info_extraction.py +++ /dev/null @@ -1,218 +0,0 @@ -import orjson -import os -import signal -from concurrent.futures import ThreadPoolExecutor, as_completed -from threading import Lock, Event -import sys -import datetime - -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -# 添加项目根目录到 sys.path - -from rich.progress import Progress # 替换为 rich 进度条 - -from src.common.logger import get_logger - -# from src.chat.knowledge.lpmmconfig import global_config -from src.chat.knowledge.ie_process import info_extract_from_str -from src.chat.knowledge.open_ie import OpenIE -from rich.progress import ( - BarColumn, - TimeElapsedColumn, - TimeRemainingColumn, - TaskProgressColumn, - MofNCompleteColumn, - SpinnerColumn, - TextColumn, -) -from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest - -logger = get_logger("LPMM知识库-信息提取") - - -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -TEMP_DIR = os.path.join(ROOT_PATH, "temp") -# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data") -OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") - - -def ensure_dirs(): - """确保临时目录和输出目录存在""" - if not os.path.exists(TEMP_DIR): - os.makedirs(TEMP_DIR) - logger.info(f"已创建临时目录: {TEMP_DIR}") - if not os.path.exists(OPENIE_OUTPUT_DIR): - os.makedirs(OPENIE_OUTPUT_DIR) - logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}") - if not os.path.exists(RAW_DATA_PATH): - os.makedirs(RAW_DATA_PATH) - logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}") - - -# 创建一个线程安全的锁,用于保护文件操作和共享数据 -file_lock = Lock() -open_ie_doc_lock = Lock() - -# 创建一个事件标志,用于控制程序终止 -shutdown_event = Event() - -lpmm_entity_extract_llm = LLMRequest( - model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract" -) -lpmm_rdf_build_llm = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build") - - -def process_single_text(pg_hash, raw_data): - """处理单个文本的函数,用于线程池""" - temp_file_path = f"{TEMP_DIR}/{pg_hash}.json" - - # 使用文件锁检查和读取缓存文件 - with file_lock: - if os.path.exists(temp_file_path): - try: - # 存在对应的提取结果 - logger.info(f"找到缓存的提取结果:{pg_hash}") - with open(temp_file_path, "r", encoding="utf-8") as f: - return orjson.loads(f.read()), None - except orjson.JSONDecodeError: - # 如果JSON文件损坏,删除它并重新处理 - logger.warning(f"缓存文件损坏,重新处理:{pg_hash}") - os.remove(temp_file_path) - - entity_list, rdf_triple_list = info_extract_from_str( - lpmm_entity_extract_llm, - lpmm_rdf_build_llm, - raw_data, - ) - if entity_list is None or rdf_triple_list is None: - return None, pg_hash - doc_item = { - "idx": pg_hash, - "passage": raw_data, - "extracted_entities": entity_list, - "extracted_triples": rdf_triple_list, - } - # 保存临时提取结果 - with file_lock: - try: - with open(temp_file_path, "w", encoding="utf-8") as f: - f.write(orjson.dumps(doc_item, option=orjson.OPT_INDENT_2).decode("utf-8")) - except Exception as e: - logger.error(f"保存缓存文件失败:{pg_hash}, 错误:{e}") - # 如果保存失败,确保不会留下损坏的文件 - if os.path.exists(temp_file_path): - os.remove(temp_file_path) - sys.exit(0) - return None, pg_hash - return doc_item, None - - -def signal_handler(_signum, _frame): - """处理Ctrl+C信号""" - logger.info("\n接收到中断信号,正在优雅地关闭程序...") - sys.exit(0) - - -def main(): # sourcery skip: comprehension-to-generator, extract-method - # 设置信号处理器 - signal.signal(signal.SIGINT, signal_handler) - ensure_dirs() # 确保目录存在 - # 新增用户确认提示 - print("=== 重要操作确认,请认真阅读以下内容哦 ===") - print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。") - print("举例:600万字全剧情,提取选用deepseek v3 0324,消耗约40元,约3小时。") - print("建议使用硅基流动的非Pro模型") - print("或者使用可以用赠金抵扣的Pro模型") - print("请确保账户余额充足,并且在执行前确认无误。") - confirm = input("确认继续执行?(y/n): ").strip().lower() - if confirm != "y": - logger.info("用户取消操作") - print("操作已取消") - sys.exit(1) - print("\n" + "=" * 40 + "\n") - ensure_dirs() # 确保目录存在 - logger.info("--------进行信息提取--------\n") - - # 加载原始数据 - logger.info("正在加载原始数据") - all_sha256_list, all_raw_datas = load_raw_data() - - failed_sha256 = [] - open_ie_doc = [] - - workers = global_config.lpmm_knowledge.info_extraction_workers - with ThreadPoolExecutor(max_workers=workers) as executor: - future_to_hash = { - executor.submit(process_single_text, pg_hash, raw_data): pg_hash - for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False) - } - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - MofNCompleteColumn(), - "•", - TimeElapsedColumn(), - "<", - TimeRemainingColumn(), - transient=False, - ) as progress: - task = progress.add_task("正在进行提取:", total=len(future_to_hash)) - try: - for future in as_completed(future_to_hash): - if shutdown_event.is_set(): - for f in future_to_hash: - if not f.done(): - f.cancel() - break - - doc_item, failed_hash = future.result() - if failed_hash: - failed_sha256.append(failed_hash) - logger.error(f"提取失败:{failed_hash}") - elif doc_item: - with open_ie_doc_lock: - open_ie_doc.append(doc_item) - progress.update(task, advance=1) - except KeyboardInterrupt: - logger.info("\n接收到中断信号,正在优雅地关闭程序...") - shutdown_event.set() - for f in future_to_hash: - if not f.done(): - f.cancel() - - # 合并所有文件的提取结果并保存 - if open_ie_doc: - sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]]) - sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]]) - num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc]) - openie_obj = OpenIE( - open_ie_doc, - round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0, - round(sum_phrase_words / num_phrases, 4) if num_phrases else 0, - ) - # 输出文件名格式:MM-DD-HH-ss-openie.json - now = datetime.datetime.now() - filename = now.strftime("%m-%d-%H-%S-openie.json") - output_path = os.path.join(OPENIE_OUTPUT_DIR, filename) - with open(output_path, "w", encoding="utf-8") as f: - f.write( - orjson.dumps( - openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__, - option=orjson.OPT_INDENT_2, - ).decode("utf-8") - ) - logger.info(f"信息提取结果已保存到: {output_path}") - else: - logger.warning("没有可保存的信息提取结果") - - logger.info("--------信息提取完成--------") - logger.info(f"提取失败的文段SHA256:{failed_sha256}") - - -if __name__ == "__main__": - main() diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py new file mode 100644 index 000000000..5a61eeebc --- /dev/null +++ b/scripts/lpmm_learning_tool.py @@ -0,0 +1,267 @@ +import asyncio +import os +import sys +import glob +import orjson +import datetime +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor, as_completed +from threading import Lock +from typing import Optional + +# 将项目根目录添加到 sys.path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from src.common.logger import get_logger +from src.chat.knowledge.utils.hash import get_sha256 +from src.llm_models.utils_model import LLMRequest +from src.config.config import model_config +from src.chat.knowledge.open_ie import OpenIE +from src.chat.knowledge.embedding_store import EmbeddingManager +from src.chat.knowledge.kg_manager import KGManager +from rich.progress import ( + Progress, + BarColumn, + TimeElapsedColumn, + TimeRemainingColumn, + TaskProgressColumn, + MofNCompleteColumn, + SpinnerColumn, + TextColumn, +) + +logger = get_logger("LPMM_LearningTool") +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +RAW_DATA_PATH = os.path.join(ROOT_PATH, "data", "lpmm_raw_data") +OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") +TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache") +file_lock = Lock() + +# --- 模块一:数据预处理 --- + +def process_text_file(file_path): + with open(file_path, "r", encoding="utf-8") as f: + raw = f.read() + return [p.strip() for p in raw.split("\n\n") if p.strip()] + +def preprocess_raw_data(): + logger.info("--- 步骤 1: 开始数据预处理 ---") + os.makedirs(RAW_DATA_PATH, exist_ok=True) + raw_files = list(Path(RAW_DATA_PATH).glob("*.txt")) + if not raw_files: + logger.warning(f"警告: 在 '{RAW_DATA_PATH}' 中没有找到任何 .txt 文件") + return [] + + all_paragraphs = [] + for file in raw_files: + logger.info(f"正在处理文件: {file.name}") + all_paragraphs.extend(process_text_file(file)) + + unique_paragraphs = {get_sha256(p): p for p in all_paragraphs} + logger.info(f"共找到 {len(all_paragraphs)} 个段落,去重后剩余 {len(unique_paragraphs)} 个。") + logger.info("--- 数据预处理完成 ---") + return unique_paragraphs + +# --- 模块二:信息提取 --- + +def get_extraction_prompt(paragraph: str) -> str: + return f""" +请从以下段落中提取关键信息。你需要提取两种类型的信息: +1. **实体 (Entities)**: 识别并列出段落中所有重要的名词或名词短语。 +2. **三元组 (Triples)**: 以 [主语, 谓语, 宾语] 的格式,提取段落中描述关系或事实的核心信息。 + +请严格按照以下 JSON 格式返回结果,不要添加任何额外的解释或注释: +{{ + "entities": ["实体1", "实体2"], + "triples": [["主语1", "谓语1", "宾语1"]] +}} + +这是你需要处理的段落: +--- +{paragraph} +--- +""" + +async def extract_info_async(pg_hash, paragraph, llm_api): + temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json") + with file_lock: + if os.path.exists(temp_file_path): + try: + with open(temp_file_path, "rb") as f: + return orjson.loads(f.read()), None + except orjson.JSONDecodeError: + os.remove(temp_file_path) + + prompt = get_extraction_prompt(paragraph) + try: + content, (_, _, _) = await llm_api.generate_response_async(prompt) + extracted_data = orjson.loads(content) + doc_item = { + "idx": pg_hash, "passage": paragraph, + "extracted_entities": extracted_data.get("entities", []), + "extracted_triples": extracted_data.get("triples", []), + } + with file_lock: + with open(temp_file_path, "wb") as f: + f.write(orjson.dumps(doc_item)) + return doc_item, None + except Exception as e: + logger.error(f"提取信息失败:{pg_hash}, 错误:{e}") + return None, pg_hash + +def extract_info_sync(pg_hash, paragraph, llm_api): + return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api)) + +def extract_information(paragraphs_dict, model_set): + logger.info("--- 步骤 2: 开始信息提取 ---") + os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True) + os.makedirs(TEMP_DIR, exist_ok=True) + + llm_api = LLMRequest(model_set=model_set) + failed_hashes, open_ie_docs = [], [] + + with ThreadPoolExecutor(max_workers=5) as executor: + f_to_hash = {executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items()} + with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), MofNCompleteColumn(), "•", TimeElapsedColumn(), "<", TimeRemainingColumn()) as progress: + task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict)) + for future in as_completed(f_to_hash): + doc_item, failed_hash = future.result() + if failed_hash: failed_hashes.append(failed_hash) + elif doc_item: open_ie_docs.append(doc_item) + progress.update(task, advance=1) + + if open_ie_docs: + all_entities = [e for doc in open_ie_docs for e in doc["extracted_entities"]] + num_entities = len(all_entities) + avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0 + avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0 + openie_obj = OpenIE(docs=open_ie_docs, avg_ent_chars=avg_ent_chars, avg_ent_words=avg_ent_words) + + now = datetime.datetime.now() + filename = now.strftime("%Y-%m-%d-%H-%M-%S-openie.json") + output_path = os.path.join(OPENIE_OUTPUT_DIR, filename) + with open(output_path, "wb") as f: + f.write(orjson.dumps(openie_obj._to_dict())) + logger.info(f"信息提取结果已保存到: {output_path}") + + if failed_hashes: logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}") + logger.info("--- 信息提取完成 ---") + +# --- 模块三:数据导入 --- + +async def import_data(openie_obj: Optional[OpenIE] = None): + """ + 将OpenIE数据导入知识库(Embedding Store 和 KG) + + Args: + openie_obj (Optional[OpenIE], optional): 如果提供,则直接使用这个OpenIE对象; + 否则,将自动从默认文件夹加载最新的OpenIE文件。 + 默认为 None. + """ + logger.info("--- 步骤 3: 开始数据导入 ---") + embed_manager, kg_manager = EmbeddingManager(), KGManager() + + logger.info("正在加载现有的 Embedding 库...") + try: embed_manager.load_from_file() + except Exception as e: logger.warning(f"加载 Embedding 库失败: {e}。") + + logger.info("正在加载现有的 KG...") + try: kg_manager.load_from_file() + except Exception as e: logger.warning(f"加载 KG 失败: {e}。") + + try: + if openie_obj: + openie_data = openie_obj + logger.info("已使用指定的 OpenIE 对象。") + else: + openie_data = OpenIE.load() + except Exception as e: + logger.error(f"加载OpenIE数据文件失败: {e}") + return + + raw_paragraphs = openie_data.extract_raw_paragraph_dict() + triple_list_data = openie_data.extract_triple_dict() + + new_raw_paragraphs, new_triple_list_data = {}, {} + stored_embeds = embed_manager.stored_pg_hashes + stored_kgs = kg_manager.stored_paragraph_hashes + + for p_hash, raw_p in raw_paragraphs.items(): + if p_hash not in stored_embeds and p_hash not in stored_kgs: + new_raw_paragraphs[p_hash] = raw_p + new_triple_list_data[p_hash] = triple_list_data.get(p_hash, []) + + if not new_raw_paragraphs: + logger.info("没有新的段落需要处理。") + else: + logger.info(f"去重完成,发现 {len(new_raw_paragraphs)} 个新段落。") + logger.info("开始生成 Embedding...") + embed_manager.store_new_data_set(new_raw_paragraphs, new_triple_list_data) + embed_manager.rebuild_faiss_index() + embed_manager.save_to_file() + logger.info("Embedding 处理完成!") + + logger.info("开始构建 KG...") + kg_manager.build_kg(new_triple_list_data, embed_manager) + kg_manager.save_to_file() + logger.info("KG 构建完成!") + + logger.info("--- 数据导入完成 ---") + +def import_from_specific_file(): + """从用户指定的 openie.json 文件导入数据""" + file_path = input("请输入 openie.json 文件的完整路径: ").strip() + + if not os.path.exists(file_path): + logger.error(f"文件路径不存在: {file_path}") + return + + if not file_path.endswith(".json"): + logger.error("请输入一个有效的 .json 文件路径。") + return + + try: + logger.info(f"正在从 {file_path} 加载 OpenIE 数据...") + openie_obj = OpenIE.load(filepath=file_path) + asyncio.run(import_data(openie_obj=openie_obj)) + except Exception as e: + logger.error(f"从指定文件导入数据时发生错误: {e}") + +# --- 主函数 --- + +def main(): + # 使用 os.path.relpath 创建相对于项目根目录的友好路径 + raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, "..")) + openie_output_relpath = os.path.relpath(OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, "..")) + + print("=== LPMM 知识库学习工具 ===") + print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)") + print(f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)") + print("3. [数据导入] -> 从 openie 文件夹自动导入最新知识") + print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3") + print("5. [指定导入] -> 从特定的 openie.json 文件导入知识") + print("0. [退出]") + print("-" * 30) + choice = input("请输入你的选择 (0-5): ").strip() + + if choice == '1': + preprocess_raw_data() + elif choice == '2': + paragraphs = preprocess_raw_data() + if paragraphs: extract_information(paragraphs, model_config.model_task_config.lpmm_qa) + elif choice == '3': + asyncio.run(import_data()) + elif choice == '4': + paragraphs = preprocess_raw_data() + if paragraphs: + extract_information(paragraphs, model_config.model_task_config.lpmm_qa) + asyncio.run(import_data()) + elif choice == '5': + import_from_specific_file() + elif choice == '0': + sys.exit(0) + else: + print("无效输入,请重新运行脚本。") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/raw_data_preprocessor.py b/scripts/raw_data_preprocessor.py deleted file mode 100644 index b5762198d..000000000 --- a/scripts/raw_data_preprocessor.py +++ /dev/null @@ -1,78 +0,0 @@ -import os -from pathlib import Path -import sys # 新增系统模块导入 -from src.chat.knowledge.utils.hash import get_sha256 - -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from src.common.logger import get_logger - -logger = get_logger("lpmm") -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data") -# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data") - - -def _process_text_file(file_path): - """处理单个文本文件,返回段落列表""" - with open(file_path, "r", encoding="utf-8") as f: - raw = f.read() - - paragraphs = [] - paragraph = "" - for line in raw.split("\n"): - if line.strip() == "": - if paragraph != "": - paragraphs.append(paragraph.strip()) - paragraph = "" - else: - paragraph += line + "\n" - - if paragraph != "": - paragraphs.append(paragraph.strip()) - - return paragraphs - - -def _process_multi_files() -> list: - raw_files = list(Path(RAW_DATA_PATH).glob("*.txt")) - if not raw_files: - logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件") - sys.exit(1) - # 处理所有文件 - all_paragraphs = [] - for file in raw_files: - logger.info(f"正在处理文件: {file.name}") - paragraphs = _process_text_file(file) - all_paragraphs.extend(paragraphs) - return all_paragraphs - - -def load_raw_data() -> tuple[list[str], list[str]]: - """加载原始数据文件 - - 读取原始数据文件,将原始数据加载到内存中 - - Args: - path: 可选,指定要读取的json文件绝对路径 - - Returns: - - raw_data: 原始数据列表 - - sha256_list: 原始数据的SHA256集合 - """ - raw_data = _process_multi_files() - sha256_list = [] - sha256_set = set() - for item in raw_data: - if not isinstance(item, str): - logger.warning(f"数据类型错误:{item}") - continue - pg_hash = get_sha256(item) - if pg_hash in sha256_set: - logger.warning(f"重复数据:{item}") - continue - sha256_set.add(pg_hash) - sha256_list.append(pg_hash) - raw_data.append(item) - logger.info(f"共读取到{len(raw_data)}条数据") - - return sha256_list, raw_data