re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
committed by Windpicker-owo
parent 00ba07e0e1
commit a79253c714
263 changed files with 3781 additions and 3189 deletions

View File

@@ -1,33 +1,31 @@
from dataclasses import dataclass
import orjson
import os
import math
import asyncio
import math
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
from dataclasses import dataclass
# import tqdm
import faiss
from .utils.hash import get_sha256
from .global_logger import logger
from rich.traceback import install
import numpy as np
import orjson
import pandas as pd
from rich.progress import (
Progress,
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
from src.config.config import global_config
from src.common.config_helpers import resolve_embedding_dimension
from rich.traceback import install
from src.common.config_helpers import resolve_embedding_dimension
from src.config.config import global_config
from .global_logger import logger
from .utils.hash import get_sha256
install(extra_lines=3)
@@ -79,7 +77,7 @@ def cosine_similarity(a, b):
class EmbeddingStoreItem:
"""嵌入库中的项"""
def __init__(self, item_hash: str, embedding: List[float], content: str):
def __init__(self, item_hash: str, embedding: list[float], content: str):
self.hash = item_hash
self.embedding = embedding
self.str = content
@@ -127,7 +125,7 @@ class EmbeddingStore:
self.idx2hash = None
@staticmethod
def _get_embedding(s: str) -> List[float]:
def _get_embedding(s: str) -> list[float]:
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
# 创建新的事件循环并在完成后立即关闭
loop = asyncio.new_event_loop()
@@ -135,8 +133,8 @@ class EmbeddingStore:
try:
# 创建新的LLMRequest实例
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
@@ -161,8 +159,8 @@ class EmbeddingStore:
@staticmethod
def _get_embeddings_batch_threaded(
strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
) -> List[Tuple[str, List[float]]]:
strs: list[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
) -> list[tuple[str, list[float]]]:
"""使用多线程批量获取嵌入向量
Args:
@@ -192,8 +190,8 @@ class EmbeddingStore:
chunk_results = []
# 为每个线程创建独立的LLMRequest实例
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
try:
# 创建线程专用的LLM实例
@@ -303,7 +301,7 @@ class EmbeddingStore:
path = self.get_test_file_path()
if not os.path.exists(path):
return None
with open(path, "r", encoding="utf-8") as f:
with open(path, encoding="utf-8") as f:
return orjson.loads(f.read())
def check_embedding_model_consistency(self):
@@ -345,7 +343,7 @@ class EmbeddingStore:
logger.info("嵌入模型一致性校验通过。")
return True
def batch_insert_strs(self, strs: List[str], times: int) -> None:
def batch_insert_strs(self, strs: list[str], times: int) -> None:
"""向库中存入字符串(使用多线程优化)"""
if not strs:
return
@@ -481,7 +479,7 @@ class EmbeddingStore:
if os.path.exists(self.idx2hash_file_path):
logger.info(f"正在加载{self.namespace}嵌入库的idx2hash映射...")
logger.debug(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射")
with open(self.idx2hash_file_path, "r") as f:
with open(self.idx2hash_file_path) as f:
self.idx2hash = orjson.loads(f.read())
logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功")
else:
@@ -511,7 +509,7 @@ class EmbeddingStore:
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
self.faiss_index.add(embeddings)
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
def search_top_k(self, query: list[float], k: int) -> list[tuple[str, float]]:
"""搜索最相似的k个项以余弦相似度为度量
Args:
query: 查询的embedding
@@ -575,11 +573,11 @@ class EmbeddingManager:
"""对所有嵌入库做模型一致性校验"""
return self.paragraphs_embedding_store.check_embedding_model_consistency()
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
def _store_pg_into_embedding(self, raw_paragraphs: dict[str, str]):
"""将段落编码存入Embedding库"""
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
def _store_ent_into_embedding(self, triple_list_data: dict[str, list[list[str]]]):
"""将实体编码存入Embedding库"""
entities = set()
for triple_list in triple_list_data.values():
@@ -588,7 +586,7 @@ class EmbeddingManager:
entities.add(triple[2])
self.entities_embedding_store.batch_insert_strs(list(entities), times=2)
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
def _store_rel_into_embedding(self, triple_list_data: dict[str, list[list[str]]]):
"""将关系编码存入Embedding库"""
graph_triples = [] # a list of unique relation triple (in tuple) from all chunks
for triples in triple_list_data.values():
@@ -606,8 +604,8 @@ class EmbeddingManager:
def store_new_data_set(
self,
raw_paragraphs: Dict[str, str],
triple_list_data: Dict[str, List[List[str]]],
raw_paragraphs: dict[str, str],
triple_list_data: dict[str, list[list[str]]],
):
if not self.check_all_embedding_model_consistency():
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")