feat: 更新数据路径配置,增强数据处理功能并优化错误提示

This commit is contained in:
墨梓柒
2025-05-02 13:42:28 +08:00
parent edda834538
commit 03961b71a2
9 changed files with 226 additions and 109 deletions

View File

@@ -19,7 +19,8 @@ from src.plugins.knowledge.src.utils.hash import get_sha256
# 添加项目根目录到 sys.path # 添加项目根目录到 sys.path
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
OPENIE_DIR = global_config["persistence"]["openie_data_path"] if global_config["persistence"]["openie_data_path"] else os.path.join(ROOT_PATH, "data/openie")
logger = get_module_logger("LPMM知识库-OpenIE导入") logger = get_module_logger("LPMM知识库-OpenIE导入")
@@ -131,6 +132,7 @@ def main():
embed_manager.load_from_file() embed_manager.load_from_file()
except Exception as e: except Exception as e:
logger.error("从文件加载Embedding库时发生错误{}".format(e)) logger.error("从文件加载Embedding库时发生错误{}".format(e))
logger.error("如果你是第一次导入知识,请忽略此错误")
logger.info("Embedding库加载完成") logger.info("Embedding库加载完成")
# 初始化KG # 初始化KG
kg_manager = KGManager() kg_manager = KGManager()
@@ -139,6 +141,7 @@ def main():
kg_manager.load_from_file() kg_manager.load_from_file()
except Exception as e: except Exception as e:
logger.error("从文件加载KG时发生错误{}".format(e)) logger.error("从文件加载KG时发生错误{}".format(e))
logger.error("如果你是第一次导入知识,请忽略此错误")
logger.info("KG加载完成") logger.info("KG加载完成")
logger.info(f"KG节点数量{len(kg_manager.graph.get_node_list())}") logger.info(f"KG节点数量{len(kg_manager.graph.get_node_list())}")
@@ -163,4 +166,5 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
# logger.info(f"111111111111111111111111{ROOT_PATH}")
main() main()

View File

@@ -4,11 +4,13 @@ import signal
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock, Event from threading import Lock, Event
import sys import sys
import glob
import datetime
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
# 添加项目根目录到 sys.path # 添加项目根目录到 sys.path
import tqdm from rich.progress import Progress # 替换为 rich 进度条
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.plugins.knowledge.src.lpmmconfig import global_config from src.plugins.knowledge.src.lpmmconfig import global_config
@@ -16,10 +18,15 @@ from src.plugins.knowledge.src.ie_process import info_extract_from_str
from src.plugins.knowledge.src.llm_client import LLMClient from src.plugins.knowledge.src.llm_client import LLMClient
from src.plugins.knowledge.src.open_ie import OpenIE from src.plugins.knowledge.src.open_ie import OpenIE
from src.plugins.knowledge.src.raw_processing import load_raw_data from src.plugins.knowledge.src.raw_processing import load_raw_data
from rich.progress import BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn
logger = get_module_logger("LPMM知识库-信息提取") logger = get_module_logger("LPMM知识库-信息提取")
TEMP_DIR = "./temp"
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
IMPORTED_DATA_PATH = global_config["persistence"]["raw_data_path"] if global_config["persistence"]["raw_data_path"] else os.path.join(ROOT_PATH, "data/imported_lpmm_data")
OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] if global_config["persistence"]["openie_data_path"] else os.path.join(ROOT_PATH, "data/openie")
# 创建一个线程安全的锁,用于保护文件操作和共享数据 # 创建一个线程安全的锁,用于保护文件操作和共享数据
file_lock = Lock() file_lock = Lock()
@@ -70,8 +77,7 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
# 如果保存失败,确保不会留下损坏的文件 # 如果保存失败,确保不会留下损坏的文件
if os.path.exists(temp_file_path): if os.path.exists(temp_file_path):
os.remove(temp_file_path) os.remove(temp_file_path)
# 设置shutdown_event以终止程序 sys.exit(0)
shutdown_event.set()
return None, pg_hash return None, pg_hash
return doc_item, None return doc_item, None
@@ -79,7 +85,7 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
def signal_handler(_signum, _frame): def signal_handler(_signum, _frame):
"""处理Ctrl+C信号""" """处理Ctrl+C信号"""
logger.info("\n接收到中断信号,正在优雅地关闭程序...") logger.info("\n接收到中断信号,正在优雅地关闭程序...")
shutdown_event.set() sys.exit(0)
def main(): def main():
@@ -110,33 +116,61 @@ def main():
global_config["llm_providers"][key]["api_key"], global_config["llm_providers"][key]["api_key"],
) )
logger.info("正在加载原始数据") # 检查 openie 输出目录
sha256_list, raw_datas = load_raw_data() if not os.path.exists(OPENIE_OUTPUT_DIR):
logger.info("原始数据加载完成\n") os.makedirs(OPENIE_OUTPUT_DIR)
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
# 创建临时目录 # 确保 TEMP_DIR 目录存在
if not os.path.exists(f"{TEMP_DIR}"): if not os.path.exists(TEMP_DIR):
os.makedirs(f"{TEMP_DIR}") os.makedirs(TEMP_DIR)
logger.info(f"已创建缓存目录: {TEMP_DIR}")
# 遍历IMPORTED_DATA_PATH下所有json文件
imported_files = sorted(glob.glob(os.path.join(IMPORTED_DATA_PATH, "*.json")))
if not imported_files:
logger.error(f"未在 {IMPORTED_DATA_PATH} 下找到任何json文件")
sys.exit(1)
all_sha256_list = []
all_raw_datas = []
for imported_file in imported_files:
logger.info(f"正在处理文件: {imported_file}")
try:
sha256_list, raw_datas = load_raw_data(imported_file)
except Exception as e:
logger.error(f"读取文件失败: {imported_file}, 错误: {e}")
continue
all_sha256_list.extend(sha256_list)
all_raw_datas.extend(raw_datas)
failed_sha256 = [] failed_sha256 = []
open_ie_doc = [] open_ie_doc = []
# 创建线程池最大线程数为50
workers = global_config["info_extraction"]["workers"] workers = global_config["info_extraction"]["workers"]
with ThreadPoolExecutor(max_workers=workers) as executor: with ThreadPoolExecutor(max_workers=workers) as executor:
# 提交所有任务到线程池
future_to_hash = { future_to_hash = {
executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash
for pg_hash, raw_data in zip(sha256_list, raw_datas) for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas)
} }
# 使用tqdm显示进度 with Progress(
with tqdm.tqdm(total=len(future_to_hash), postfix="正在进行提取:") as pbar: SpinnerColumn(),
# 处理完成的任务 TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task("正在进行提取:", total=len(future_to_hash))
try: try:
for future in as_completed(future_to_hash): for future in as_completed(future_to_hash):
if shutdown_event.is_set(): if shutdown_event.is_set():
# 取消所有未完成的任务
for f in future_to_hash: for f in future_to_hash:
if not f.done(): if not f.done():
f.cancel() f.cancel()
@@ -149,26 +183,33 @@ def main():
elif doc_item: elif doc_item:
with open_ie_doc_lock: with open_ie_doc_lock:
open_ie_doc.append(doc_item) open_ie_doc.append(doc_item)
pbar.update(1) progress.update(task, advance=1)
except KeyboardInterrupt: except KeyboardInterrupt:
# 如果在这里捕获到KeyboardInterrupt说明signal_handler可能没有正常工作
logger.info("\n接收到中断信号,正在优雅地关闭程序...") logger.info("\n接收到中断信号,正在优雅地关闭程序...")
shutdown_event.set() shutdown_event.set()
# 取消所有未完成的任务
for f in future_to_hash: for f in future_to_hash:
if not f.done(): if not f.done():
f.cancel() f.cancel()
# 保存信息提取结果 # 合并所有文件的提取结果并保存
if open_ie_doc:
sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]]) sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]]) sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc]) num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc])
openie_obj = OpenIE( openie_obj = OpenIE(
open_ie_doc, open_ie_doc,
round(sum_phrase_chars / num_phrases, 4), round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0,
round(sum_phrase_words / num_phrases, 4), round(sum_phrase_words / num_phrases, 4) if num_phrases else 0,
) )
OpenIE.save(openie_obj) # 输出文件名格式MM-DD-HH-ss-openie.json
now = datetime.datetime.now()
filename = now.strftime("%m-%d-%H-%S-openie.json")
output_path = os.path.join(OPENIE_OUTPUT_DIR, filename)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__, f, ensure_ascii=False, indent=4)
logger.info(f"信息提取结果已保存到: {output_path}")
else:
logger.warning("没有可保存的信息提取结果")
logger.info("--------信息提取完成--------") logger.info("--------信息提取完成--------")
logger.info(f"提取失败的文段SHA256{failed_sha256}") logger.info(f"提取失败的文段SHA256{failed_sha256}")

View File

@@ -2,18 +2,22 @@ import json
import os import os
from pathlib import Path from pathlib import Path
import sys # 新增系统模块导入 import sys # 新增系统模块导入
import datetime # 新增导入
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
logger = get_module_logger("LPMM数据库-原始数据处理") logger = get_module_logger("LPMM数据库-原始数据处理")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
# 添加项目根目录到 sys.path # 添加项目根目录到 sys.path
def check_and_create_dirs(): def check_and_create_dirs():
"""检查并创建必要的目录""" """检查并创建必要的目录"""
required_dirs = ["data/lpmm_raw_data", "data/imported_lpmm_data"] required_dirs = [RAW_DATA_PATH, IMPORTED_DATA_PATH]
for dir_path in required_dirs: for dir_path in required_dirs:
if not os.path.exists(dir_path): if not os.path.exists(dir_path):
@@ -58,17 +62,17 @@ def main():
# 检查并创建必要的目录 # 检查并创建必要的目录
check_and_create_dirs() check_and_create_dirs()
# 检查输出文件是否存在 # # 检查输出文件是否存在
if os.path.exists("data/import.json"): # if os.path.exists(RAW_DATA_PATH):
logger.error("错误: data/import.json 已存在,请先处理或删除该文件") # logger.error("错误: data/import.json 已存在,请先处理或删除该文件")
sys.exit(1) # sys.exit(1)
if os.path.exists("data/openie.json"): # if os.path.exists(RAW_DATA_PATH):
logger.error("错误: data/openie.json 已存在,请先处理或删除该文件") # logger.error("错误: data/openie.json 已存在,请先处理或删除该文件")
sys.exit(1) # sys.exit(1)
# 获取所有原始文本文件 # 获取所有原始文本文件
raw_files = list(Path("data/lpmm_raw_data").glob("*.txt")) raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
if not raw_files: if not raw_files:
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件") logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
sys.exit(1) sys.exit(1)
@@ -80,8 +84,10 @@ def main():
paragraphs = process_text_file(file) paragraphs = process_text_file(file)
all_paragraphs.extend(paragraphs) all_paragraphs.extend(paragraphs)
# 保存合并后的结果 # 保存合并后的结果到 IMPORTED_DATA_PATH文件名格式为 MM-DD-HH-ss-imported-data.json
output_path = "data/import.json" now = datetime.datetime.now()
filename = now.strftime("%m-%d-%H-%S-imported-data.json")
output_path = os.path.join(IMPORTED_DATA_PATH, filename)
with open(output_path, "w", encoding="utf-8") as f: with open(output_path, "w", encoding="utf-8") as f:
json.dump(all_paragraphs, f, ensure_ascii=False, indent=4) json.dump(all_paragraphs, f, ensure_ascii=False, indent=4)
@@ -89,4 +95,6 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
print(f"Raw Data Path: {RAW_DATA_PATH}")
print(f"Imported Data Path: {IMPORTED_DATA_PATH}")
main() main()

View File

@@ -26,6 +26,7 @@ try:
embed_manager.load_from_file() embed_manager.load_from_file()
except Exception as e: except Exception as e:
logger.error("从文件加载Embedding库时发生错误{}".format(e)) logger.error("从文件加载Embedding库时发生错误{}".format(e))
logger.error("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("Embedding库加载完成") logger.info("Embedding库加载完成")
# 初始化KG # 初始化KG
kg_manager = KGManager() kg_manager = KGManager()
@@ -34,6 +35,7 @@ try:
kg_manager.load_from_file() kg_manager.load_from_file()
except Exception as e: except Exception as e:
logger.error("从文件加载KG时发生错误{}".format(e)) logger.error("从文件加载KG时发生错误{}".format(e))
logger.error("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("KG加载完成") logger.info("KG加载完成")
logger.info(f"KG节点数量{len(kg_manager.graph.get_node_list())}") logger.info(f"KG节点数量{len(kg_manager.graph.get_node_list())}")

View File

@@ -13,9 +13,11 @@ from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_confi
from .utils.hash import get_sha256 from .utils.hash import get_sha256
from .global_logger import logger from .global_logger import logger
from rich.traceback import install from rich.traceback import install
from rich.progress import Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn
install(extra_lines=3) install(extra_lines=3)
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
@dataclass @dataclass
class EmbeddingStoreItem: class EmbeddingStoreItem:
@@ -52,13 +54,27 @@ class EmbeddingStore:
def _get_embedding(self, s: str) -> List[float]: def _get_embedding(self, s: str) -> List[float]:
return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s) return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
def batch_insert_strs(self, strs: List[str]) -> None: def batch_insert_strs(self, strs: List[str], times: int) -> None:
"""向库中存入字符串""" """向库中存入字符串"""
# 逐项处理 total = len(strs)
for s in tqdm.tqdm(strs, desc="存入嵌入库", unit="items"): with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
for s in strs:
# 计算hash去重 # 计算hash去重
item_hash = self.namespace + "-" + get_sha256(s) item_hash = self.namespace + "-" + get_sha256(s)
if item_hash in self.store: if item_hash in self.store:
progress.update(task, advance=1)
continue continue
# 获取embedding # 获取embedding
@@ -66,6 +82,7 @@ class EmbeddingStore:
# 存入 # 存入
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s) self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
progress.update(task, advance=1)
def save_to_file(self) -> None: def save_to_file(self) -> None:
"""保存到文件""" """保存到文件"""
@@ -191,7 +208,7 @@ class EmbeddingManager:
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]): def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
"""将段落编码存入Embedding库""" """将段落编码存入Embedding库"""
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values())) self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()),times=1)
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
"""将实体编码存入Embedding库""" """将实体编码存入Embedding库"""
@@ -200,7 +217,7 @@ class EmbeddingManager:
for triple in triple_list: for triple in triple_list:
entities.add(triple[0]) entities.add(triple[0])
entities.add(triple[2]) entities.add(triple[2])
self.entities_embedding_store.batch_insert_strs(list(entities)) self.entities_embedding_store.batch_insert_strs(list(entities),times=2)
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
"""将关系编码存入Embedding库""" """将关系编码存入Embedding库"""
@@ -208,7 +225,7 @@ class EmbeddingManager:
for triples in triple_list_data.values(): for triples in triple_list_data.values():
graph_triples.extend([tuple(t) for t in triples]) graph_triples.extend([tuple(t) for t in triples])
graph_triples = list(set(graph_triples)) graph_triples = list(set(graph_triples))
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples]) self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples],times=3)
def load_from_file(self): def load_from_file(self):
"""从文件加载""" """从文件加载"""

View File

@@ -5,7 +5,7 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import tqdm from rich.progress import Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn
from quick_algo import di_graph, pagerank from quick_algo import di_graph, pagerank
@@ -132,17 +132,31 @@ class KGManager:
ent_hash_list = list(ent_hash_list) ent_hash_list = list(ent_hash_list)
synonym_hash_set = set() synonym_hash_set = set()
synonym_result = dict() synonym_result = dict()
# 对每个实体节点,查找其相似的实体节点,建立扩展连接 # rich 进度条
for ent_hash in tqdm.tqdm(ent_hash_list): total = len(ent_hash_list)
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task("同义词连接", total=total)
for ent_hash in ent_hash_list:
if ent_hash in synonym_hash_set: if ent_hash in synonym_hash_set:
# 避免同一批次内重复添加 progress.update(task, advance=1)
continue continue
ent = embedding_manager.entities_embedding_store.store.get(ent_hash) ent = embedding_manager.entities_embedding_store.store.get(ent_hash)
assert isinstance(ent, EmbeddingStoreItem) assert isinstance(ent, EmbeddingStoreItem)
if ent is None: if ent is None:
progress.update(task, advance=1)
continue continue
# 查询相似实体 # 查询相似实体
similar_ents = embedding_manager.entities_embedding_store.search_top_k( similar_ents = embedding_manager.entities_embedding_store.search_top_k(
@@ -167,6 +181,7 @@ class KGManager:
) )
) # Debug ) # Debug
synonym_result[ent.str] = res_ent synonym_result[ent.str] = res_ent
progress.update(task, advance=1)
for k, v in synonym_result.items(): for k, v in synonym_result.items():
print(f'"{k}"的相似实体为:{v}') print(f'"{k}"的相似实体为:{v}')

View File

@@ -1,9 +1,13 @@
import json import json
import os
import glob
from typing import Any, Dict, List from typing import Any, Dict, List
from .lpmmconfig import INVALID_ENTITY, global_config from .lpmmconfig import INVALID_ENTITY, global_config
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
def _filter_invalid_entities(entities: List[str]) -> List[str]: def _filter_invalid_entities(entities: List[str]) -> List[str]:
"""过滤无效的实体""" """过滤无效的实体"""
@@ -74,12 +78,22 @@ class OpenIE:
doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"]) doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"])
@staticmethod @staticmethod
def _from_dict(data): def _from_dict(data_list):
"""字典中获取OpenIE对象""" """多个字典合并OpenIE对象"""
# data_list: List[dict]
all_docs = []
for data in data_list:
all_docs.extend(data.get("docs", []))
# 重新计算统计
sum_phrase_chars = sum([len(e) for chunk in all_docs for e in chunk["extracted_entities"]])
sum_phrase_words = sum([len(e.split()) for chunk in all_docs for e in chunk["extracted_entities"]])
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in all_docs])
avg_ent_chars = round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0
avg_ent_words = round(sum_phrase_words / num_phrases, 4) if num_phrases else 0
return OpenIE( return OpenIE(
docs=data["docs"], docs=all_docs,
avg_ent_chars=data["avg_ent_chars"], avg_ent_chars=avg_ent_chars,
avg_ent_words=data["avg_ent_words"], avg_ent_words=avg_ent_words,
) )
def _to_dict(self): def _to_dict(self):
@@ -92,12 +106,20 @@ class OpenIE:
@staticmethod @staticmethod
def load() -> "OpenIE": def load() -> "OpenIE":
"""文件中加载OpenIE数据""" """OPENIE_DIR下所有json文件合并加载OpenIE数据"""
with open(global_config["persistence"]["openie_data_path"], "r", encoding="utf-8") as f: openie_dir = os.path.join(ROOT_PATH, global_config["persistence"]["openie_data_path"])
data = json.loads(f.read()) if not os.path.exists(openie_dir):
raise Exception(f"OpenIE数据目录不存在: {openie_dir}")
openie_data = OpenIE._from_dict(data) json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
data_list = []
for file in json_files:
with open(file, "r", encoding="utf-8") as f:
data = json.load(f)
data_list.append(data)
if not data_list:
# print(f"111111111111111111111Root Path : \n{ROOT_PATH}")
raise Exception(f"未在 {openie_dir} 找到任何OpenIE json文件")
openie_data = OpenIE._from_dict(data_list)
return openie_data return openie_data
@staticmethod @staticmethod
@@ -132,3 +154,7 @@ class OpenIE:
"""提取原始段落""" """提取原始段落"""
raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs}) raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs})
return raw_paragraph_dict return raw_paragraph_dict
if __name__ == "__main__":
# 测试代码
print(ROOT_PATH)

View File

@@ -6,21 +6,25 @@ from .lpmmconfig import global_config
from .utils.hash import get_sha256 from .utils.hash import get_sha256
def load_raw_data() -> tuple[list[str], list[str]]: def load_raw_data(path: str = None) -> tuple[list[str], list[str]]:
"""加载原始数据文件 """加载原始数据文件
读取原始数据文件,将原始数据加载到内存中 读取原始数据文件,将原始数据加载到内存中
Args:
path: 可选指定要读取的json文件绝对路径
Returns: Returns:
- raw_data: 原始数据字典 - raw_data: 原始数据列表
- md5_set: 原始数据的SHA256集合 - sha256_list: 原始数据的SHA256集合
""" """
# 读取import.json文件 # 读取指定路径或默认路径的json文件
if os.path.exists(global_config["persistence"]["raw_data_path"]) is True: json_path = path if path else global_config["persistence"]["raw_data_path"]
with open(global_config["persistence"]["raw_data_path"], "r", encoding="utf-8") as f: if os.path.exists(json_path):
with open(json_path, "r", encoding="utf-8") as f:
import_json = json.loads(f.read()) import_json = json.loads(f.read())
else: else:
raise Exception("原始数据文件读取失败") raise Exception(f"原始数据文件读取失败: {json_path}")
# import_json内容示例 # import_json内容示例
# import_json = [ # import_json = [
# "The capital of China is Beijing. The capital of France is Paris.", # "The capital of China is Beijing. The capital of France is Paris.",

View File

@@ -51,7 +51,7 @@ res_top_k = 3 # 最终提供的文段TopK
[persistence] [persistence]
# 持久化配置(存储中间数据,防止重复计算) # 持久化配置(存储中间数据,防止重复计算)
data_root_path = "data" # 数据根目录 data_root_path = "data" # 数据根目录
raw_data_path = "data/import.json" # 原始数据路径 raw_data_path = "data/imported_lpmm_data" # 原始数据路径
openie_data_path = "data/openie.json" # OpenIE数据路径 openie_data_path = "data/openie" # OpenIE数据路径
embedding_data_dir = "data/embedding" # 嵌入数据目录 embedding_data_dir = "data/embedding" # 嵌入数据目录
rag_data_dir = "data/rag" # RAG数据目录 rag_data_dir = "data/rag" # RAG数据目录