diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 25a1a8779..dd4b50ece 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -20,7 +20,11 @@ from src.plugins.knowledge.src.utils.hash import get_sha256 # 添加项目根目录到 sys.path ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -OPENIE_DIR = global_config["persistence"]["openie_data_path"] if global_config["persistence"]["openie_data_path"] else os.path.join(ROOT_PATH, "data/openie") +OPENIE_DIR = ( + global_config["persistence"]["openie_data_path"] + if global_config["persistence"]["openie_data_path"] + else os.path.join(ROOT_PATH, "data/openie") +) logger = get_module_logger("LPMM知识库-OpenIE导入") diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py index 00f7a2a21..ee0d789aa 100644 --- a/scripts/info_extraction.py +++ b/scripts/info_extraction.py @@ -18,15 +18,31 @@ from src.plugins.knowledge.src.ie_process import info_extract_from_str from src.plugins.knowledge.src.llm_client import LLMClient from src.plugins.knowledge.src.open_ie import OpenIE from src.plugins.knowledge.src.raw_processing import load_raw_data -from rich.progress import BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn +from rich.progress import ( + BarColumn, + TimeElapsedColumn, + TimeRemainingColumn, + TaskProgressColumn, + MofNCompleteColumn, + SpinnerColumn, + TextColumn, +) logger = get_module_logger("LPMM知识库-信息提取") ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) TEMP_DIR = os.path.join(ROOT_PATH, "temp") -IMPORTED_DATA_PATH = global_config["persistence"]["raw_data_path"] if global_config["persistence"]["raw_data_path"] else os.path.join(ROOT_PATH, "data/imported_lpmm_data") -OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] if global_config["persistence"]["openie_data_path"] else os.path.join(ROOT_PATH, "data/openie") +IMPORTED_DATA_PATH = ( + global_config["persistence"]["raw_data_path"] + if global_config["persistence"]["raw_data_path"] + else os.path.join(ROOT_PATH, "data/imported_lpmm_data") +) +OPENIE_OUTPUT_DIR = ( + global_config["persistence"]["openie_data_path"] + if global_config["persistence"]["openie_data_path"] + else os.path.join(ROOT_PATH, "data/openie") +) # 创建一个线程安全的锁,用于保护文件操作和共享数据 file_lock = Lock() @@ -206,7 +222,12 @@ def main(): filename = now.strftime("%m-%d-%H-%S-openie.json") output_path = os.path.join(OPENIE_OUTPUT_DIR, filename) with open(output_path, "w", encoding="utf-8") as f: - json.dump(openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__, f, ensure_ascii=False, indent=4) + json.dump( + openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__, + f, + ensure_ascii=False, + indent=4, + ) logger.info(f"信息提取结果已保存到: {output_path}") else: logger.warning("没有可保存的信息提取结果") diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py index d2791ca48..e734f4e9a 100644 --- a/src/plugins/knowledge/src/embedding_store.py +++ b/src/plugins/knowledge/src/embedding_store.py @@ -13,12 +13,22 @@ from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_confi from .utils.hash import get_sha256 from .global_logger import logger from rich.traceback import install -from rich.progress import Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn +from rich.progress import ( + Progress, + BarColumn, + TimeElapsedColumn, + TimeRemainingColumn, + TaskProgressColumn, + MofNCompleteColumn, + SpinnerColumn, + TextColumn, +) install(extra_lines=3) TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数 + @dataclass class EmbeddingStoreItem: """嵌入库中的项""" @@ -208,7 +218,7 @@ class EmbeddingManager: def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]): """将段落编码存入Embedding库""" - self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()),times=1) + self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1) def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): """将实体编码存入Embedding库""" @@ -217,7 +227,7 @@ class EmbeddingManager: for triple in triple_list: entities.add(triple[0]) entities.add(triple[2]) - self.entities_embedding_store.batch_insert_strs(list(entities),times=2) + self.entities_embedding_store.batch_insert_strs(list(entities), times=2) def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): """将关系编码存入Embedding库""" @@ -225,7 +235,7 @@ class EmbeddingManager: for triples in triple_list_data.values(): graph_triples.extend([tuple(t) for t in triples]) graph_triples = list(set(graph_triples)) - self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples],times=3) + self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples], times=3) def load_from_file(self): """从文件加载""" diff --git a/src/plugins/knowledge/src/kg_manager.py b/src/plugins/knowledge/src/kg_manager.py index ccaf7aa83..fd922af48 100644 --- a/src/plugins/knowledge/src/kg_manager.py +++ b/src/plugins/knowledge/src/kg_manager.py @@ -5,7 +5,16 @@ from typing import Dict, List, Tuple import numpy as np import pandas as pd -from rich.progress import Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn +from rich.progress import ( + Progress, + BarColumn, + TimeElapsedColumn, + TimeRemainingColumn, + TaskProgressColumn, + MofNCompleteColumn, + SpinnerColumn, + TextColumn, +) from quick_algo import di_graph, pagerank diff --git a/src/plugins/knowledge/src/open_ie.py b/src/plugins/knowledge/src/open_ie.py index ea84af4ac..75fd18545 100644 --- a/src/plugins/knowledge/src/open_ie.py +++ b/src/plugins/knowledge/src/open_ie.py @@ -154,7 +154,8 @@ class OpenIE: """提取原始段落""" raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs}) return raw_paragraph_dict - + + if __name__ == "__main__": # 测试代码 print(ROOT_PATH)