import asyncio import os import sys import orjson import datetime from pathlib import Path from concurrent.futures import ThreadPoolExecutor, as_completed from threading import Lock from typing import Optional # 将项目根目录添加到 sys.path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from src.common.logger import get_logger from src.chat.knowledge.utils.hash import get_sha256 from src.llm_models.utils_model import LLMRequest from src.config.config import model_config from src.chat.knowledge.open_ie import OpenIE from src.chat.knowledge.embedding_store import EmbeddingManager from src.chat.knowledge.kg_manager import KGManager from rich.progress import ( Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn, ) logger = get_logger("LPMM_LearningTool") ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) RAW_DATA_PATH = os.path.join(ROOT_PATH, "data", "lpmm_raw_data") OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache") file_lock = Lock() # --- 模块一:数据预处理 --- def process_text_file(file_path): with open(file_path, "r", encoding="utf-8") as f: raw = f.read() return [p.strip() for p in raw.split("\n\n") if p.strip()] def preprocess_raw_data(): logger.info("--- 步骤 1: 开始数据预处理 ---") os.makedirs(RAW_DATA_PATH, exist_ok=True) raw_files = list(Path(RAW_DATA_PATH).glob("*.txt")) if not raw_files: logger.warning(f"警告: 在 '{RAW_DATA_PATH}' 中没有找到任何 .txt 文件") return [] all_paragraphs = [] for file in raw_files: logger.info(f"正在处理文件: {file.name}") all_paragraphs.extend(process_text_file(file)) unique_paragraphs = {get_sha256(p): p for p in all_paragraphs} logger.info(f"共找到 {len(all_paragraphs)} 个段落,去重后剩余 {len(unique_paragraphs)} 个。") logger.info("--- 数据预处理完成 ---") return unique_paragraphs # --- 模块二:信息提取 --- def get_extraction_prompt(paragraph: str) -> str: return f""" 请从以下段落中提取关键信息。你需要提取两种类型的信息: 1. **实体 (Entities)**: 识别并列出段落中所有重要的名词或名词短语。 2. **三元组 (Triples)**: 以 [主语, 谓语, 宾语] 的格式,提取段落中描述关系或事实的核心信息。 请严格按照以下 JSON 格式返回结果,不要添加任何额外的解释或注释: {{ "entities": ["实体1", "实体2"], "triples": [["主语1", "谓语1", "宾语1"]] }} 这是你需要处理的段落: --- {paragraph} --- """ async def extract_info_async(pg_hash, paragraph, llm_api): temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json") with file_lock: if os.path.exists(temp_file_path): try: with open(temp_file_path, "rb") as f: return orjson.loads(f.read()), None except orjson.JSONDecodeError: os.remove(temp_file_path) prompt = get_extraction_prompt(paragraph) content = None try: content, (_, _, _) = await llm_api.generate_response_async(prompt) extracted_data = orjson.loads(content) doc_item = { "idx": pg_hash, "passage": paragraph, "extracted_entities": extracted_data.get("entities", []), "extracted_triples": extracted_data.get("triples", []), } with file_lock: with open(temp_file_path, "wb") as f: f.write(orjson.dumps(doc_item)) return doc_item, None except Exception as e: logger.error(f"提取信息失败:{pg_hash}, 错误:{e}") if content: logger.error(f"导致解析失败的原始输出: {content}") return None, pg_hash def extract_info_sync(pg_hash, paragraph, llm_api): return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api)) def extract_information(paragraphs_dict, model_set): logger.info("--- 步骤 2: 开始信息提取 ---") os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True) os.makedirs(TEMP_DIR, exist_ok=True) llm_api = LLMRequest(model_set=model_set) failed_hashes, open_ie_docs = [], [] with ThreadPoolExecutor(max_workers=5) as executor: f_to_hash = { executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items() } with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), MofNCompleteColumn(), "•", TimeElapsedColumn(), "<", TimeRemainingColumn(), ) as progress: task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict)) for future in as_completed(f_to_hash): doc_item, failed_hash = future.result() if failed_hash: failed_hashes.append(failed_hash) elif doc_item: open_ie_docs.append(doc_item) progress.update(task, advance=1) if open_ie_docs: all_entities = [e for doc in open_ie_docs for e in doc["extracted_entities"]] num_entities = len(all_entities) avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0 avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0 openie_obj = OpenIE(docs=open_ie_docs, avg_ent_chars=avg_ent_chars, avg_ent_words=avg_ent_words) now = datetime.datetime.now() filename = now.strftime("%Y-%m-%d-%H-%M-%S-openie.json") output_path = os.path.join(OPENIE_OUTPUT_DIR, filename) with open(output_path, "wb") as f: f.write(orjson.dumps(openie_obj._to_dict())) logger.info(f"信息提取结果已保存到: {output_path}") if failed_hashes: logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}") logger.info("--- 信息提取完成 ---") # --- 模块三:数据导入 --- async def import_data(openie_obj: Optional[OpenIE] = None): """ 将OpenIE数据导入知识库(Embedding Store 和 KG) Args: openie_obj (Optional[OpenIE], optional): 如果提供,则直接使用这个OpenIE对象; 否则,将自动从默认文件夹加载最新的OpenIE文件。 默认为 None. """ logger.info("--- 步骤 3: 开始数据导入 ---") embed_manager, kg_manager = EmbeddingManager(), KGManager() logger.info("正在加载现有的 Embedding 库...") try: embed_manager.load_from_file() except Exception as e: logger.warning(f"加载 Embedding 库失败: {e}。") logger.info("正在加载现有的 KG...") try: kg_manager.load_from_file() except Exception as e: logger.warning(f"加载 KG 失败: {e}。") try: if openie_obj: openie_data = openie_obj logger.info("已使用指定的 OpenIE 对象。") else: openie_data = OpenIE.load() except Exception as e: logger.error(f"加载OpenIE数据文件失败: {e}") return raw_paragraphs = openie_data.extract_raw_paragraph_dict() triple_list_data = openie_data.extract_triple_dict() new_raw_paragraphs, new_triple_list_data = {}, {} stored_embeds = embed_manager.stored_pg_hashes stored_kgs = kg_manager.stored_paragraph_hashes for p_hash, raw_p in raw_paragraphs.items(): if p_hash not in stored_embeds and p_hash not in stored_kgs: new_raw_paragraphs[p_hash] = raw_p new_triple_list_data[p_hash] = triple_list_data.get(p_hash, []) if not new_raw_paragraphs: logger.info("没有新的段落需要处理。") else: logger.info(f"去重完成,发现 {len(new_raw_paragraphs)} 个新段落。") logger.info("开始生成 Embedding...") embed_manager.store_new_data_set(new_raw_paragraphs, new_triple_list_data) embed_manager.rebuild_faiss_index() embed_manager.save_to_file() logger.info("Embedding 处理完成!") logger.info("开始构建 KG...") kg_manager.build_kg(new_triple_list_data, embed_manager) kg_manager.save_to_file() logger.info("KG 构建完成!") logger.info("--- 数据导入完成 ---") def import_from_specific_file(): """从用户指定的 openie.json 文件导入数据""" file_path = input("请输入 openie.json 文件的完整路径: ").strip() if not os.path.exists(file_path): logger.error(f"文件路径不存在: {file_path}") return if not file_path.endswith(".json"): logger.error("请输入一个有效的 .json 文件路径。") return try: logger.info(f"正在从 {file_path} 加载 OpenIE 数据...") openie_obj = OpenIE.load() asyncio.run(import_data(openie_obj=openie_obj)) except Exception as e: logger.error(f"从指定文件导入数据时发生错误: {e}") # --- 主函数 --- def main(): # 使用 os.path.relpath 创建相对于项目根目录的友好路径 raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, "..")) openie_output_relpath = os.path.relpath(OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, "..")) print("=== LPMM 知识库学习工具 ===") print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)") print(f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)") print("3. [数据导入] -> 从 openie 文件夹自动导入最新知识") print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3") print("5. [指定导入] -> 从特定的 openie.json 文件导入知识") print("0. [退出]") print("-" * 30) choice = input("请输入你的选择 (0-5): ").strip() if choice == "1": preprocess_raw_data() elif choice == "2": paragraphs = preprocess_raw_data() if paragraphs: extract_information(paragraphs, model_config.model_task_config.lpmm_qa) elif choice == "3": asyncio.run(import_data()) elif choice == "4": paragraphs = preprocess_raw_data() if paragraphs: extract_information(paragraphs, model_config.model_task_config.lpmm_qa) asyncio.run(import_data()) elif choice == "5": import_from_specific_file() elif choice == "0": sys.exit(0) else: print("无效输入,请重新运行脚本。") if __name__ == "__main__": main()