🤖 自动格式化代码 [skip ci]

This commit is contained in:
github-actions[bot]
2025-05-02 05:42:41 +00:00
parent 03961b71a2
commit b117e87687
5 changed files with 56 additions and 11 deletions

View File

@@ -20,7 +20,11 @@ from src.plugins.knowledge.src.utils.hash import get_sha256
# 添加项目根目录到 sys.path
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
OPENIE_DIR = global_config["persistence"]["openie_data_path"] if global_config["persistence"]["openie_data_path"] else os.path.join(ROOT_PATH, "data/openie")
OPENIE_DIR = (
global_config["persistence"]["openie_data_path"]
if global_config["persistence"]["openie_data_path"]
else os.path.join(ROOT_PATH, "data/openie")
)
logger = get_module_logger("LPMM知识库-OpenIE导入")

View File

@@ -18,15 +18,31 @@ from src.plugins.knowledge.src.ie_process import info_extract_from_str
from src.plugins.knowledge.src.llm_client import LLMClient
from src.plugins.knowledge.src.open_ie import OpenIE
from src.plugins.knowledge.src.raw_processing import load_raw_data
from rich.progress import BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn
from rich.progress import (
BarColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
logger = get_module_logger("LPMM知识库-信息提取")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
IMPORTED_DATA_PATH = global_config["persistence"]["raw_data_path"] if global_config["persistence"]["raw_data_path"] else os.path.join(ROOT_PATH, "data/imported_lpmm_data")
OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] if global_config["persistence"]["openie_data_path"] else os.path.join(ROOT_PATH, "data/openie")
IMPORTED_DATA_PATH = (
global_config["persistence"]["raw_data_path"]
if global_config["persistence"]["raw_data_path"]
else os.path.join(ROOT_PATH, "data/imported_lpmm_data")
)
OPENIE_OUTPUT_DIR = (
global_config["persistence"]["openie_data_path"]
if global_config["persistence"]["openie_data_path"]
else os.path.join(ROOT_PATH, "data/openie")
)
# 创建一个线程安全的锁,用于保护文件操作和共享数据
file_lock = Lock()
@@ -206,7 +222,12 @@ def main():
filename = now.strftime("%m-%d-%H-%S-openie.json")
output_path = os.path.join(OPENIE_OUTPUT_DIR, filename)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__, f, ensure_ascii=False, indent=4)
json.dump(
openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__,
f,
ensure_ascii=False,
indent=4,
)
logger.info(f"信息提取结果已保存到: {output_path}")
else:
logger.warning("没有可保存的信息提取结果")

View File

@@ -13,12 +13,22 @@ from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_confi
from .utils.hash import get_sha256
from .global_logger import logger
from rich.traceback import install
from rich.progress import Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn
from rich.progress import (
Progress,
BarColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
install(extra_lines=3)
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
@dataclass
class EmbeddingStoreItem:
"""嵌入库中的项"""
@@ -208,7 +218,7 @@ class EmbeddingManager:
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
"""将段落编码存入Embedding库"""
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()),times=1)
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
"""将实体编码存入Embedding库"""
@@ -217,7 +227,7 @@ class EmbeddingManager:
for triple in triple_list:
entities.add(triple[0])
entities.add(triple[2])
self.entities_embedding_store.batch_insert_strs(list(entities),times=2)
self.entities_embedding_store.batch_insert_strs(list(entities), times=2)
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
"""将关系编码存入Embedding库"""
@@ -225,7 +235,7 @@ class EmbeddingManager:
for triples in triple_list_data.values():
graph_triples.extend([tuple(t) for t in triples])
graph_triples = list(set(graph_triples))
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples],times=3)
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples], times=3)
def load_from_file(self):
"""从文件加载"""

View File

@@ -5,7 +5,16 @@ from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
from rich.progress import Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn
from rich.progress import (
Progress,
BarColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
from quick_algo import di_graph, pagerank

View File

@@ -155,6 +155,7 @@ class OpenIE:
raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs})
return raw_paragraph_dict
if __name__ == "__main__":
# 测试代码
print(ROOT_PATH)