🤖 自动格式化代码 [skip ci]
This commit is contained in:
@@ -13,12 +13,22 @@ from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_confi
|
||||
from .utils.hash import get_sha256
|
||||
from .global_logger import logger
|
||||
from rich.traceback import install
|
||||
from rich.progress import Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
BarColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingStoreItem:
|
||||
"""嵌入库中的项"""
|
||||
@@ -208,7 +218,7 @@ class EmbeddingManager:
|
||||
|
||||
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
|
||||
"""将段落编码存入Embedding库"""
|
||||
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()),times=1)
|
||||
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
|
||||
|
||||
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
||||
"""将实体编码存入Embedding库"""
|
||||
@@ -217,7 +227,7 @@ class EmbeddingManager:
|
||||
for triple in triple_list:
|
||||
entities.add(triple[0])
|
||||
entities.add(triple[2])
|
||||
self.entities_embedding_store.batch_insert_strs(list(entities),times=2)
|
||||
self.entities_embedding_store.batch_insert_strs(list(entities), times=2)
|
||||
|
||||
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
||||
"""将关系编码存入Embedding库"""
|
||||
@@ -225,7 +235,7 @@ class EmbeddingManager:
|
||||
for triples in triple_list_data.values():
|
||||
graph_triples.extend([tuple(t) for t in triples])
|
||||
graph_triples = list(set(graph_triples))
|
||||
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples],times=3)
|
||||
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples], times=3)
|
||||
|
||||
def load_from_file(self):
|
||||
"""从文件加载"""
|
||||
|
||||
@@ -5,7 +5,16 @@ from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from rich.progress import Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
BarColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from quick_algo import di_graph, pagerank
|
||||
|
||||
|
||||
|
||||
@@ -154,7 +154,8 @@ class OpenIE:
|
||||
"""提取原始段落"""
|
||||
raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs})
|
||||
return raw_paragraph_dict
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
print(ROOT_PATH)
|
||||
|
||||
Reference in New Issue
Block a user