re-style: 格式化代码
This commit is contained in:
committed by
Windpicker-owo
parent
00ba07e0e1
commit
a79253c714
@@ -1,33 +1,31 @@
|
||||
from dataclasses import dataclass
|
||||
import orjson
|
||||
import os
|
||||
import math
|
||||
import asyncio
|
||||
import math
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from dataclasses import dataclass
|
||||
|
||||
# import tqdm
|
||||
import faiss
|
||||
|
||||
from .utils.hash import get_sha256
|
||||
from .global_logger import logger
|
||||
from rich.traceback import install
|
||||
import numpy as np
|
||||
import orjson
|
||||
import pandas as pd
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
BarColumn,
|
||||
MofNCompleteColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from src.config.config import global_config
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.config.config import global_config
|
||||
|
||||
from .global_logger import logger
|
||||
from .utils.hash import get_sha256
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -79,7 +77,7 @@ def cosine_similarity(a, b):
|
||||
class EmbeddingStoreItem:
|
||||
"""嵌入库中的项"""
|
||||
|
||||
def __init__(self, item_hash: str, embedding: List[float], content: str):
|
||||
def __init__(self, item_hash: str, embedding: list[float], content: str):
|
||||
self.hash = item_hash
|
||||
self.embedding = embedding
|
||||
self.str = content
|
||||
@@ -127,7 +125,7 @@ class EmbeddingStore:
|
||||
self.idx2hash = None
|
||||
|
||||
@staticmethod
|
||||
def _get_embedding(s: str) -> List[float]:
|
||||
def _get_embedding(s: str) -> list[float]:
|
||||
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
||||
# 创建新的事件循环并在完成后立即关闭
|
||||
loop = asyncio.new_event_loop()
|
||||
@@ -135,8 +133,8 @@ class EmbeddingStore:
|
||||
|
||||
try:
|
||||
# 创建新的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
@@ -161,8 +159,8 @@ class EmbeddingStore:
|
||||
|
||||
@staticmethod
|
||||
def _get_embeddings_batch_threaded(
|
||||
strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
) -> List[Tuple[str, List[float]]]:
|
||||
strs: list[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
) -> list[tuple[str, list[float]]]:
|
||||
"""使用多线程批量获取嵌入向量
|
||||
|
||||
Args:
|
||||
@@ -192,8 +190,8 @@ class EmbeddingStore:
|
||||
chunk_results = []
|
||||
|
||||
# 为每个线程创建独立的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
try:
|
||||
# 创建线程专用的LLM实例
|
||||
@@ -303,7 +301,7 @@ class EmbeddingStore:
|
||||
path = self.get_test_file_path()
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return orjson.loads(f.read())
|
||||
|
||||
def check_embedding_model_consistency(self):
|
||||
@@ -345,7 +343,7 @@ class EmbeddingStore:
|
||||
logger.info("嵌入模型一致性校验通过。")
|
||||
return True
|
||||
|
||||
def batch_insert_strs(self, strs: List[str], times: int) -> None:
|
||||
def batch_insert_strs(self, strs: list[str], times: int) -> None:
|
||||
"""向库中存入字符串(使用多线程优化)"""
|
||||
if not strs:
|
||||
return
|
||||
@@ -481,7 +479,7 @@ class EmbeddingStore:
|
||||
if os.path.exists(self.idx2hash_file_path):
|
||||
logger.info(f"正在加载{self.namespace}嵌入库的idx2hash映射...")
|
||||
logger.debug(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射")
|
||||
with open(self.idx2hash_file_path, "r") as f:
|
||||
with open(self.idx2hash_file_path) as f:
|
||||
self.idx2hash = orjson.loads(f.read())
|
||||
logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功")
|
||||
else:
|
||||
@@ -511,7 +509,7 @@ class EmbeddingStore:
|
||||
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
|
||||
self.faiss_index.add(embeddings)
|
||||
|
||||
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
|
||||
def search_top_k(self, query: list[float], k: int) -> list[tuple[str, float]]:
|
||||
"""搜索最相似的k个项,以余弦相似度为度量
|
||||
Args:
|
||||
query: 查询的embedding
|
||||
@@ -575,11 +573,11 @@ class EmbeddingManager:
|
||||
"""对所有嵌入库做模型一致性校验"""
|
||||
return self.paragraphs_embedding_store.check_embedding_model_consistency()
|
||||
|
||||
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
|
||||
def _store_pg_into_embedding(self, raw_paragraphs: dict[str, str]):
|
||||
"""将段落编码存入Embedding库"""
|
||||
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
|
||||
|
||||
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
||||
def _store_ent_into_embedding(self, triple_list_data: dict[str, list[list[str]]]):
|
||||
"""将实体编码存入Embedding库"""
|
||||
entities = set()
|
||||
for triple_list in triple_list_data.values():
|
||||
@@ -588,7 +586,7 @@ class EmbeddingManager:
|
||||
entities.add(triple[2])
|
||||
self.entities_embedding_store.batch_insert_strs(list(entities), times=2)
|
||||
|
||||
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
||||
def _store_rel_into_embedding(self, triple_list_data: dict[str, list[list[str]]]):
|
||||
"""将关系编码存入Embedding库"""
|
||||
graph_triples = [] # a list of unique relation triple (in tuple) from all chunks
|
||||
for triples in triple_list_data.values():
|
||||
@@ -606,8 +604,8 @@ class EmbeddingManager:
|
||||
|
||||
def store_new_data_set(
|
||||
self,
|
||||
raw_paragraphs: Dict[str, str],
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
raw_paragraphs: dict[str, str],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
):
|
||||
if not self.check_all_embedding_model_consistency():
|
||||
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")
|
||||
|
||||
Reference in New Issue
Block a user