re-style: 格式化代码
This commit is contained in:
@@ -1,28 +1,26 @@
|
||||
import orjson
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
import pandas as pd
|
||||
from quick_algo import di_graph, pagerank
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
BarColumn,
|
||||
MofNCompleteColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from quick_algo import di_graph, pagerank
|
||||
|
||||
|
||||
from .utils.hash import get_sha256
|
||||
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
|
||||
from src.config.config import global_config
|
||||
|
||||
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
|
||||
from .global_logger import logger
|
||||
from .utils.hash import get_sha256
|
||||
|
||||
|
||||
def _get_kg_dir():
|
||||
@@ -87,7 +85,7 @@ class KGManager:
|
||||
raise FileNotFoundError(f"KG图文件{self.graph_data_path}不存在")
|
||||
|
||||
# 加载段落hash
|
||||
with open(self.pg_hash_file_path, "r", encoding="utf-8") as f:
|
||||
with open(self.pg_hash_file_path, encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
self.stored_paragraph_hashes = set(data["stored_paragraph_hashes"])
|
||||
|
||||
@@ -100,8 +98,8 @@ class KGManager:
|
||||
|
||||
def _build_edges_between_ent(
|
||||
self,
|
||||
node_to_node: Dict[Tuple[str, str], float],
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
node_to_node: dict[tuple[str, str], float],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
):
|
||||
"""构建实体节点之间的关系,同时统计实体出现次数"""
|
||||
for triple_list in triple_list_data.values():
|
||||
@@ -124,8 +122,8 @@ class KGManager:
|
||||
|
||||
@staticmethod
|
||||
def _build_edges_between_ent_pg(
|
||||
node_to_node: Dict[Tuple[str, str], float],
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
node_to_node: dict[tuple[str, str], float],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
):
|
||||
"""构建实体节点与文段节点之间的关系"""
|
||||
for idx in triple_list_data:
|
||||
@@ -136,8 +134,8 @@ class KGManager:
|
||||
|
||||
@staticmethod
|
||||
def _synonym_connect(
|
||||
node_to_node: Dict[Tuple[str, str], float],
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
node_to_node: dict[tuple[str, str], float],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
embedding_manager: EmbeddingManager,
|
||||
) -> int:
|
||||
"""同义词连接"""
|
||||
@@ -208,7 +206,7 @@ class KGManager:
|
||||
|
||||
def _update_graph(
|
||||
self,
|
||||
node_to_node: Dict[Tuple[str, str], float],
|
||||
node_to_node: dict[tuple[str, str], float],
|
||||
embedding_manager: EmbeddingManager,
|
||||
):
|
||||
"""更新KG图结构
|
||||
@@ -280,7 +278,7 @@ class KGManager:
|
||||
|
||||
def build_kg(
|
||||
self,
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
embedding_manager: EmbeddingManager,
|
||||
):
|
||||
"""增量式构建KG
|
||||
@@ -317,8 +315,8 @@ class KGManager:
|
||||
|
||||
def kg_search(
|
||||
self,
|
||||
relation_search_result: List[Tuple[Tuple[str, str, str], float]],
|
||||
paragraph_search_result: List[Tuple[str, float]],
|
||||
relation_search_result: list[tuple[tuple[str, str, str], float]],
|
||||
paragraph_search_result: list[tuple[str, float]],
|
||||
embed_manager: EmbeddingManager,
|
||||
):
|
||||
"""RAG搜索与PageRank
|
||||
|
||||
Reference in New Issue
Block a user