diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 595f22ec2..25a1a8779 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -19,7 +19,8 @@ from src.plugins.knowledge.src.utils.hash import get_sha256 # 添加项目根目录到 sys.path - +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +OPENIE_DIR = global_config["persistence"]["openie_data_path"] if global_config["persistence"]["openie_data_path"] else os.path.join(ROOT_PATH, "data/openie") logger = get_module_logger("LPMM知识库-OpenIE导入") @@ -131,6 +132,7 @@ def main(): embed_manager.load_from_file() except Exception as e: logger.error("从文件加载Embedding库时发生错误:{}".format(e)) + logger.error("如果你是第一次导入知识,请忽略此错误") logger.info("Embedding库加载完成") # 初始化KG kg_manager = KGManager() @@ -139,6 +141,7 @@ def main(): kg_manager.load_from_file() except Exception as e: logger.error("从文件加载KG时发生错误:{}".format(e)) + logger.error("如果你是第一次导入知识,请忽略此错误") logger.info("KG加载完成") logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}") @@ -163,4 +166,5 @@ def main(): if __name__ == "__main__": + # logger.info(f"111111111111111111111111{ROOT_PATH}") main() diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py index 65c4082b6..00f7a2a21 100644 --- a/scripts/info_extraction.py +++ b/scripts/info_extraction.py @@ -4,11 +4,13 @@ import signal from concurrent.futures import ThreadPoolExecutor, as_completed from threading import Lock, Event import sys +import glob +import datetime sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) # 添加项目根目录到 sys.path -import tqdm +from rich.progress import Progress # 替换为 rich 进度条 from src.common.logger import get_module_logger from src.plugins.knowledge.src.lpmmconfig import global_config @@ -16,10 +18,15 @@ from src.plugins.knowledge.src.ie_process import info_extract_from_str from src.plugins.knowledge.src.llm_client import LLMClient from src.plugins.knowledge.src.open_ie import OpenIE from src.plugins.knowledge.src.raw_processing import load_raw_data +from rich.progress import BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn logger = get_module_logger("LPMM知识库-信息提取") -TEMP_DIR = "./temp" + +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +TEMP_DIR = os.path.join(ROOT_PATH, "temp") +IMPORTED_DATA_PATH = global_config["persistence"]["raw_data_path"] if global_config["persistence"]["raw_data_path"] else os.path.join(ROOT_PATH, "data/imported_lpmm_data") +OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] if global_config["persistence"]["openie_data_path"] else os.path.join(ROOT_PATH, "data/openie") # 创建一个线程安全的锁,用于保护文件操作和共享数据 file_lock = Lock() @@ -70,8 +77,7 @@ def process_single_text(pg_hash, raw_data, llm_client_list): # 如果保存失败,确保不会留下损坏的文件 if os.path.exists(temp_file_path): os.remove(temp_file_path) - # 设置shutdown_event以终止程序 - shutdown_event.set() + sys.exit(0) return None, pg_hash return doc_item, None @@ -79,7 +85,7 @@ def process_single_text(pg_hash, raw_data, llm_client_list): def signal_handler(_signum, _frame): """处理Ctrl+C信号""" logger.info("\n接收到中断信号,正在优雅地关闭程序...") - shutdown_event.set() + sys.exit(0) def main(): @@ -110,33 +116,61 @@ def main(): global_config["llm_providers"][key]["api_key"], ) - logger.info("正在加载原始数据") - sha256_list, raw_datas = load_raw_data() - logger.info("原始数据加载完成\n") + # 检查 openie 输出目录 + if not os.path.exists(OPENIE_OUTPUT_DIR): + os.makedirs(OPENIE_OUTPUT_DIR) + logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}") - # 创建临时目录 - if not os.path.exists(f"{TEMP_DIR}"): - os.makedirs(f"{TEMP_DIR}") + # 确保 TEMP_DIR 目录存在 + if not os.path.exists(TEMP_DIR): + os.makedirs(TEMP_DIR) + logger.info(f"已创建缓存目录: {TEMP_DIR}") + + # 遍历IMPORTED_DATA_PATH下所有json文件 + imported_files = sorted(glob.glob(os.path.join(IMPORTED_DATA_PATH, "*.json"))) + if not imported_files: + logger.error(f"未在 {IMPORTED_DATA_PATH} 下找到任何json文件") + sys.exit(1) + + all_sha256_list = [] + all_raw_datas = [] + + for imported_file in imported_files: + logger.info(f"正在处理文件: {imported_file}") + try: + sha256_list, raw_datas = load_raw_data(imported_file) + except Exception as e: + logger.error(f"读取文件失败: {imported_file}, 错误: {e}") + continue + all_sha256_list.extend(sha256_list) + all_raw_datas.extend(raw_datas) failed_sha256 = [] open_ie_doc = [] - # 创建线程池,最大线程数为50 workers = global_config["info_extraction"]["workers"] with ThreadPoolExecutor(max_workers=workers) as executor: - # 提交所有任务到线程池 future_to_hash = { executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash - for pg_hash, raw_data in zip(sha256_list, raw_datas) + for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas) } - # 使用tqdm显示进度 - with tqdm.tqdm(total=len(future_to_hash), postfix="正在进行提取:") as pbar: - # 处理完成的任务 + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + "•", + TimeElapsedColumn(), + "<", + TimeRemainingColumn(), + transient=False, + ) as progress: + task = progress.add_task("正在进行提取:", total=len(future_to_hash)) try: for future in as_completed(future_to_hash): if shutdown_event.is_set(): - # 取消所有未完成的任务 for f in future_to_hash: if not f.done(): f.cancel() @@ -149,26 +183,33 @@ def main(): elif doc_item: with open_ie_doc_lock: open_ie_doc.append(doc_item) - pbar.update(1) + progress.update(task, advance=1) except KeyboardInterrupt: - # 如果在这里捕获到KeyboardInterrupt,说明signal_handler可能没有正常工作 logger.info("\n接收到中断信号,正在优雅地关闭程序...") shutdown_event.set() - # 取消所有未完成的任务 for f in future_to_hash: if not f.done(): f.cancel() - # 保存信息提取结果 - sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]]) - sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]]) - num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc]) - openie_obj = OpenIE( - open_ie_doc, - round(sum_phrase_chars / num_phrases, 4), - round(sum_phrase_words / num_phrases, 4), - ) - OpenIE.save(openie_obj) + # 合并所有文件的提取结果并保存 + if open_ie_doc: + sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]]) + sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]]) + num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc]) + openie_obj = OpenIE( + open_ie_doc, + round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0, + round(sum_phrase_words / num_phrases, 4) if num_phrases else 0, + ) + # 输出文件名格式:MM-DD-HH-ss-openie.json + now = datetime.datetime.now() + filename = now.strftime("%m-%d-%H-%S-openie.json") + output_path = os.path.join(OPENIE_OUTPUT_DIR, filename) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__, f, ensure_ascii=False, indent=4) + logger.info(f"信息提取结果已保存到: {output_path}") + else: + logger.warning("没有可保存的信息提取结果") logger.info("--------信息提取完成--------") logger.info(f"提取失败的文段SHA256:{failed_sha256}") diff --git a/scripts/raw_data_preprocessor.py b/scripts/raw_data_preprocessor.py index 2fc30352e..d808fb0ee 100644 --- a/scripts/raw_data_preprocessor.py +++ b/scripts/raw_data_preprocessor.py @@ -2,18 +2,22 @@ import json import os from pathlib import Path import sys # 新增系统模块导入 +import datetime # 新增导入 sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from src.common.logger import get_module_logger logger = get_module_logger("LPMM数据库-原始数据处理") +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data") +IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data") # 添加项目根目录到 sys.path def check_and_create_dirs(): """检查并创建必要的目录""" - required_dirs = ["data/lpmm_raw_data", "data/imported_lpmm_data"] + required_dirs = [RAW_DATA_PATH, IMPORTED_DATA_PATH] for dir_path in required_dirs: if not os.path.exists(dir_path): @@ -58,17 +62,17 @@ def main(): # 检查并创建必要的目录 check_and_create_dirs() - # 检查输出文件是否存在 - if os.path.exists("data/import.json"): - logger.error("错误: data/import.json 已存在,请先处理或删除该文件") - sys.exit(1) + # # 检查输出文件是否存在 + # if os.path.exists(RAW_DATA_PATH): + # logger.error("错误: data/import.json 已存在,请先处理或删除该文件") + # sys.exit(1) - if os.path.exists("data/openie.json"): - logger.error("错误: data/openie.json 已存在,请先处理或删除该文件") - sys.exit(1) + # if os.path.exists(RAW_DATA_PATH): + # logger.error("错误: data/openie.json 已存在,请先处理或删除该文件") + # sys.exit(1) # 获取所有原始文本文件 - raw_files = list(Path("data/lpmm_raw_data").glob("*.txt")) + raw_files = list(Path(RAW_DATA_PATH).glob("*.txt")) if not raw_files: logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件") sys.exit(1) @@ -80,8 +84,10 @@ def main(): paragraphs = process_text_file(file) all_paragraphs.extend(paragraphs) - # 保存合并后的结果 - output_path = "data/import.json" + # 保存合并后的结果到 IMPORTED_DATA_PATH,文件名格式为 MM-DD-HH-ss-imported-data.json + now = datetime.datetime.now() + filename = now.strftime("%m-%d-%H-%S-imported-data.json") + output_path = os.path.join(IMPORTED_DATA_PATH, filename) with open(output_path, "w", encoding="utf-8") as f: json.dump(all_paragraphs, f, ensure_ascii=False, indent=4) @@ -89,4 +95,6 @@ def main(): if __name__ == "__main__": + print(f"Raw Data Path: {RAW_DATA_PATH}") + print(f"Imported Data Path: {IMPORTED_DATA_PATH}") main() diff --git a/src/plugins/knowledge/knowledge_lib.py b/src/plugins/knowledge/knowledge_lib.py index c0d2fe610..df82970a7 100644 --- a/src/plugins/knowledge/knowledge_lib.py +++ b/src/plugins/knowledge/knowledge_lib.py @@ -26,6 +26,7 @@ try: embed_manager.load_from_file() except Exception as e: logger.error("从文件加载Embedding库时发生错误:{}".format(e)) + logger.error("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") logger.info("Embedding库加载完成") # 初始化KG kg_manager = KGManager() @@ -34,6 +35,7 @@ try: kg_manager.load_from_file() except Exception as e: logger.error("从文件加载KG时发生错误:{}".format(e)) + logger.error("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") logger.info("KG加载完成") logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}") diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py index 8e0d116b6..d2791ca48 100644 --- a/src/plugins/knowledge/src/embedding_store.py +++ b/src/plugins/knowledge/src/embedding_store.py @@ -13,9 +13,11 @@ from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_confi from .utils.hash import get_sha256 from .global_logger import logger from rich.traceback import install +from rich.progress import Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn install(extra_lines=3) +TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数 @dataclass class EmbeddingStoreItem: @@ -52,20 +54,35 @@ class EmbeddingStore: def _get_embedding(self, s: str) -> List[float]: return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s) - def batch_insert_strs(self, strs: List[str]) -> None: + def batch_insert_strs(self, strs: List[str], times: int) -> None: """向库中存入字符串""" - # 逐项处理 - for s in tqdm.tqdm(strs, desc="存入嵌入库", unit="items"): - # 计算hash去重 - item_hash = self.namespace + "-" + get_sha256(s) - if item_hash in self.store: - continue + total = len(strs) + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + "•", + TimeElapsedColumn(), + "<", + TimeRemainingColumn(), + transient=False, + ) as progress: + task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total) + for s in strs: + # 计算hash去重 + item_hash = self.namespace + "-" + get_sha256(s) + if item_hash in self.store: + progress.update(task, advance=1) + continue - # 获取embedding - embedding = self._get_embedding(s) + # 获取embedding + embedding = self._get_embedding(s) - # 存入 - self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s) + # 存入 + self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s) + progress.update(task, advance=1) def save_to_file(self) -> None: """保存到文件""" @@ -191,7 +208,7 @@ class EmbeddingManager: def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]): """将段落编码存入Embedding库""" - self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values())) + self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()),times=1) def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): """将实体编码存入Embedding库""" @@ -200,7 +217,7 @@ class EmbeddingManager: for triple in triple_list: entities.add(triple[0]) entities.add(triple[2]) - self.entities_embedding_store.batch_insert_strs(list(entities)) + self.entities_embedding_store.batch_insert_strs(list(entities),times=2) def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): """将关系编码存入Embedding库""" @@ -208,7 +225,7 @@ class EmbeddingManager: for triples in triple_list_data.values(): graph_triples.extend([tuple(t) for t in triples]) graph_triples = list(set(graph_triples)) - self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples]) + self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples],times=3) def load_from_file(self): """从文件加载""" diff --git a/src/plugins/knowledge/src/kg_manager.py b/src/plugins/knowledge/src/kg_manager.py index 71ce65ef2..ccaf7aa83 100644 --- a/src/plugins/knowledge/src/kg_manager.py +++ b/src/plugins/knowledge/src/kg_manager.py @@ -5,7 +5,7 @@ from typing import Dict, List, Tuple import numpy as np import pandas as pd -import tqdm +from rich.progress import Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn from quick_algo import di_graph, pagerank @@ -132,41 +132,56 @@ class KGManager: ent_hash_list = list(ent_hash_list) synonym_hash_set = set() - synonym_result = dict() - # 对每个实体节点,查找其相似的实体节点,建立扩展连接 - for ent_hash in tqdm.tqdm(ent_hash_list): - if ent_hash in synonym_hash_set: - # 避免同一批次内重复添加 - continue - ent = embedding_manager.entities_embedding_store.store.get(ent_hash) - assert isinstance(ent, EmbeddingStoreItem) - if ent is None: - continue - # 查询相似实体 - similar_ents = embedding_manager.entities_embedding_store.search_top_k( - ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"] - ) - res_ent = [] # Debug - for res_ent_hash, similarity in similar_ents: - if res_ent_hash == ent_hash: - # 避免自连接 + # rich 进度条 + total = len(ent_hash_list) + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + "•", + TimeElapsedColumn(), + "<", + TimeRemainingColumn(), + transient=False, + ) as progress: + task = progress.add_task("同义词连接", total=total) + for ent_hash in ent_hash_list: + if ent_hash in synonym_hash_set: + progress.update(task, advance=1) continue - if similarity < global_config["rag"]["params"]["synonym_threshold"]: - # 相似度阈值 + ent = embedding_manager.entities_embedding_store.store.get(ent_hash) + assert isinstance(ent, EmbeddingStoreItem) + if ent is None: + progress.update(task, advance=1) continue - node_to_node[(res_ent_hash, ent_hash)] = similarity - node_to_node[(ent_hash, res_ent_hash)] = similarity - synonym_hash_set.add(res_ent_hash) - new_edge_cnt += 1 - res_ent.append( - ( - embedding_manager.entities_embedding_store.store[res_ent_hash].str, - similarity, - ) - ) # Debug - synonym_result[ent.str] = res_ent + # 查询相似实体 + similar_ents = embedding_manager.entities_embedding_store.search_top_k( + ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"] + ) + res_ent = [] # Debug + for res_ent_hash, similarity in similar_ents: + if res_ent_hash == ent_hash: + # 避免自连接 + continue + if similarity < global_config["rag"]["params"]["synonym_threshold"]: + # 相似度阈值 + continue + node_to_node[(res_ent_hash, ent_hash)] = similarity + node_to_node[(ent_hash, res_ent_hash)] = similarity + synonym_hash_set.add(res_ent_hash) + new_edge_cnt += 1 + res_ent.append( + ( + embedding_manager.entities_embedding_store.store[res_ent_hash].str, + similarity, + ) + ) # Debug + synonym_result[ent.str] = res_ent + progress.update(task, advance=1) for k, v in synonym_result.items(): print(f'"{k}"的相似实体为:{v}') diff --git a/src/plugins/knowledge/src/open_ie.py b/src/plugins/knowledge/src/open_ie.py index 5fe163bb2..ea84af4ac 100644 --- a/src/plugins/knowledge/src/open_ie.py +++ b/src/plugins/knowledge/src/open_ie.py @@ -1,9 +1,13 @@ import json +import os +import glob from typing import Any, Dict, List from .lpmmconfig import INVALID_ENTITY, global_config +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) + def _filter_invalid_entities(entities: List[str]) -> List[str]: """过滤无效的实体""" @@ -74,12 +78,22 @@ class OpenIE: doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"]) @staticmethod - def _from_dict(data): - """从字典中获取OpenIE对象""" + def _from_dict(data_list): + """从多个字典合并OpenIE对象""" + # data_list: List[dict] + all_docs = [] + for data in data_list: + all_docs.extend(data.get("docs", [])) + # 重新计算统计 + sum_phrase_chars = sum([len(e) for chunk in all_docs for e in chunk["extracted_entities"]]) + sum_phrase_words = sum([len(e.split()) for chunk in all_docs for e in chunk["extracted_entities"]]) + num_phrases = sum([len(chunk["extracted_entities"]) for chunk in all_docs]) + avg_ent_chars = round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0 + avg_ent_words = round(sum_phrase_words / num_phrases, 4) if num_phrases else 0 return OpenIE( - docs=data["docs"], - avg_ent_chars=data["avg_ent_chars"], - avg_ent_words=data["avg_ent_words"], + docs=all_docs, + avg_ent_chars=avg_ent_chars, + avg_ent_words=avg_ent_words, ) def _to_dict(self): @@ -92,12 +106,20 @@ class OpenIE: @staticmethod def load() -> "OpenIE": - """从文件中加载OpenIE数据""" - with open(global_config["persistence"]["openie_data_path"], "r", encoding="utf-8") as f: - data = json.loads(f.read()) - - openie_data = OpenIE._from_dict(data) - + """从OPENIE_DIR下所有json文件合并加载OpenIE数据""" + openie_dir = os.path.join(ROOT_PATH, global_config["persistence"]["openie_data_path"]) + if not os.path.exists(openie_dir): + raise Exception(f"OpenIE数据目录不存在: {openie_dir}") + json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json"))) + data_list = [] + for file in json_files: + with open(file, "r", encoding="utf-8") as f: + data = json.load(f) + data_list.append(data) + if not data_list: + # print(f"111111111111111111111Root Path : \n{ROOT_PATH}") + raise Exception(f"未在 {openie_dir} 找到任何OpenIE json文件") + openie_data = OpenIE._from_dict(data_list) return openie_data @staticmethod @@ -132,3 +154,7 @@ class OpenIE: """提取原始段落""" raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs}) return raw_paragraph_dict + +if __name__ == "__main__": + # 测试代码 + print(ROOT_PATH) diff --git a/src/plugins/knowledge/src/raw_processing.py b/src/plugins/knowledge/src/raw_processing.py index 91e681c7c..a333ef996 100644 --- a/src/plugins/knowledge/src/raw_processing.py +++ b/src/plugins/knowledge/src/raw_processing.py @@ -6,21 +6,25 @@ from .lpmmconfig import global_config from .utils.hash import get_sha256 -def load_raw_data() -> tuple[list[str], list[str]]: +def load_raw_data(path: str = None) -> tuple[list[str], list[str]]: """加载原始数据文件 读取原始数据文件,将原始数据加载到内存中 + Args: + path: 可选,指定要读取的json文件绝对路径 + Returns: - - raw_data: 原始数据字典 - - md5_set: 原始数据的SHA256集合 + - raw_data: 原始数据列表 + - sha256_list: 原始数据的SHA256集合 """ - # 读取import.json文件 - if os.path.exists(global_config["persistence"]["raw_data_path"]) is True: - with open(global_config["persistence"]["raw_data_path"], "r", encoding="utf-8") as f: + # 读取指定路径或默认路径的json文件 + json_path = path if path else global_config["persistence"]["raw_data_path"] + if os.path.exists(json_path): + with open(json_path, "r", encoding="utf-8") as f: import_json = json.loads(f.read()) else: - raise Exception("原始数据文件读取失败") + raise Exception(f"原始数据文件读取失败: {json_path}") # import_json内容示例: # import_json = [ # "The capital of China is Beijing. The capital of France is Paris.", diff --git a/template/lpmm_config_template.toml b/template/lpmm_config_template.toml index 43785e794..8563b7caf 100644 --- a/template/lpmm_config_template.toml +++ b/template/lpmm_config_template.toml @@ -51,7 +51,7 @@ res_top_k = 3 # 最终提供的文段TopK [persistence] # 持久化配置(存储中间数据,防止重复计算) data_root_path = "data" # 数据根目录 -raw_data_path = "data/import.json" # 原始数据路径 -openie_data_path = "data/openie.json" # OpenIE数据路径 +raw_data_path = "data/imported_lpmm_data" # 原始数据路径 +openie_data_path = "data/openie" # OpenIE数据路径 embedding_data_dir = "data/embedding" # 嵌入数据目录 rag_data_dir = "data/rag" # RAG数据目录