re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
committed by Windpicker-owo
parent 00ba07e0e1
commit a79253c714
263 changed files with 3781 additions and 3189 deletions

View File

@@ -1,33 +1,31 @@
from dataclasses import dataclass
import orjson
import os
import math
import asyncio
import math
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
from dataclasses import dataclass
# import tqdm
import faiss
from .utils.hash import get_sha256
from .global_logger import logger
from rich.traceback import install
import numpy as np
import orjson
import pandas as pd
from rich.progress import (
Progress,
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
from src.config.config import global_config
from src.common.config_helpers import resolve_embedding_dimension
from rich.traceback import install
from src.common.config_helpers import resolve_embedding_dimension
from src.config.config import global_config
from .global_logger import logger
from .utils.hash import get_sha256
install(extra_lines=3)
@@ -79,7 +77,7 @@ def cosine_similarity(a, b):
class EmbeddingStoreItem:
"""嵌入库中的项"""
def __init__(self, item_hash: str, embedding: List[float], content: str):
def __init__(self, item_hash: str, embedding: list[float], content: str):
self.hash = item_hash
self.embedding = embedding
self.str = content
@@ -127,7 +125,7 @@ class EmbeddingStore:
self.idx2hash = None
@staticmethod
def _get_embedding(s: str) -> List[float]:
def _get_embedding(s: str) -> list[float]:
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
# 创建新的事件循环并在完成后立即关闭
loop = asyncio.new_event_loop()
@@ -135,8 +133,8 @@ class EmbeddingStore:
try:
# 创建新的LLMRequest实例
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
@@ -161,8 +159,8 @@ class EmbeddingStore:
@staticmethod
def _get_embeddings_batch_threaded(
strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
) -> List[Tuple[str, List[float]]]:
strs: list[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
) -> list[tuple[str, list[float]]]:
"""使用多线程批量获取嵌入向量
Args:
@@ -192,8 +190,8 @@ class EmbeddingStore:
chunk_results = []
# 为每个线程创建独立的LLMRequest实例
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
try:
# 创建线程专用的LLM实例
@@ -303,7 +301,7 @@ class EmbeddingStore:
path = self.get_test_file_path()
if not os.path.exists(path):
return None
with open(path, "r", encoding="utf-8") as f:
with open(path, encoding="utf-8") as f:
return orjson.loads(f.read())
def check_embedding_model_consistency(self):
@@ -345,7 +343,7 @@ class EmbeddingStore:
logger.info("嵌入模型一致性校验通过。")
return True
def batch_insert_strs(self, strs: List[str], times: int) -> None:
def batch_insert_strs(self, strs: list[str], times: int) -> None:
"""向库中存入字符串(使用多线程优化)"""
if not strs:
return
@@ -481,7 +479,7 @@ class EmbeddingStore:
if os.path.exists(self.idx2hash_file_path):
logger.info(f"正在加载{self.namespace}嵌入库的idx2hash映射...")
logger.debug(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射")
with open(self.idx2hash_file_path, "r") as f:
with open(self.idx2hash_file_path) as f:
self.idx2hash = orjson.loads(f.read())
logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功")
else:
@@ -511,7 +509,7 @@ class EmbeddingStore:
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
self.faiss_index.add(embeddings)
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
def search_top_k(self, query: list[float], k: int) -> list[tuple[str, float]]:
"""搜索最相似的k个项以余弦相似度为度量
Args:
query: 查询的embedding
@@ -575,11 +573,11 @@ class EmbeddingManager:
"""对所有嵌入库做模型一致性校验"""
return self.paragraphs_embedding_store.check_embedding_model_consistency()
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
def _store_pg_into_embedding(self, raw_paragraphs: dict[str, str]):
"""将段落编码存入Embedding库"""
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
def _store_ent_into_embedding(self, triple_list_data: dict[str, list[list[str]]]):
"""将实体编码存入Embedding库"""
entities = set()
for triple_list in triple_list_data.values():
@@ -588,7 +586,7 @@ class EmbeddingManager:
entities.add(triple[2])
self.entities_embedding_store.batch_insert_strs(list(entities), times=2)
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
def _store_rel_into_embedding(self, triple_list_data: dict[str, list[list[str]]]):
"""将关系编码存入Embedding库"""
graph_triples = [] # a list of unique relation triple (in tuple) from all chunks
for triples in triple_list_data.values():
@@ -606,8 +604,8 @@ class EmbeddingManager:
def store_new_data_set(
self,
raw_paragraphs: Dict[str, str],
triple_list_data: Dict[str, List[List[str]]],
raw_paragraphs: dict[str, str],
triple_list_data: dict[str, list[list[str]]],
):
if not self.check_all_embedding_model_consistency():
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")

View File

@@ -1,14 +1,15 @@
import asyncio
import orjson
import time
from typing import List, Union
from .global_logger import logger
from . import prompt_template
from .knowledge_lib import INVALID_ENTITY
from src.llm_models.utils_model import LLMRequest
import orjson
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from . import prompt_template
from .global_logger import logger
from .knowledge_lib import INVALID_ENTITY
def _extract_json_from_text(text: str):
# sourcery skip: assign-if-exp, extract-method
@@ -46,7 +47,7 @@ def _extract_json_from_text(text: str):
return []
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> list[str]:
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
"""对段落进行实体提取返回提取出的实体列表JSON格式"""
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
@@ -92,7 +93,7 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
return entity_extract_result
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]:
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> list[list[str]]:
"""对段落进行实体提取返回提取出的实体列表JSON格式"""
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
paragraph, entities=orjson.dumps(entities).decode("utf-8")
@@ -141,7 +142,7 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
def info_extract_from_str(
llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str
) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]:
) -> tuple[None, None] | tuple[list[str], list[list[str]]]:
try_count = 0
while True:
try:

View File

@@ -1,28 +1,26 @@
import orjson
import os
import time
from typing import Dict, List, Tuple
import numpy as np
import orjson
import pandas as pd
from quick_algo import di_graph, pagerank
from rich.progress import (
Progress,
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
from quick_algo import di_graph, pagerank
from .utils.hash import get_sha256
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
from src.config.config import global_config
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
from .global_logger import logger
from .utils.hash import get_sha256
def _get_kg_dir():
@@ -87,7 +85,7 @@ class KGManager:
raise FileNotFoundError(f"KG图文件{self.graph_data_path}不存在")
# 加载段落hash
with open(self.pg_hash_file_path, "r", encoding="utf-8") as f:
with open(self.pg_hash_file_path, encoding="utf-8") as f:
data = orjson.loads(f.read())
self.stored_paragraph_hashes = set(data["stored_paragraph_hashes"])
@@ -100,8 +98,8 @@ class KGManager:
def _build_edges_between_ent(
self,
node_to_node: Dict[Tuple[str, str], float],
triple_list_data: Dict[str, List[List[str]]],
node_to_node: dict[tuple[str, str], float],
triple_list_data: dict[str, list[list[str]]],
):
"""构建实体节点之间的关系,同时统计实体出现次数"""
for triple_list in triple_list_data.values():
@@ -124,8 +122,8 @@ class KGManager:
@staticmethod
def _build_edges_between_ent_pg(
node_to_node: Dict[Tuple[str, str], float],
triple_list_data: Dict[str, List[List[str]]],
node_to_node: dict[tuple[str, str], float],
triple_list_data: dict[str, list[list[str]]],
):
"""构建实体节点与文段节点之间的关系"""
for idx in triple_list_data:
@@ -136,8 +134,8 @@ class KGManager:
@staticmethod
def _synonym_connect(
node_to_node: Dict[Tuple[str, str], float],
triple_list_data: Dict[str, List[List[str]]],
node_to_node: dict[tuple[str, str], float],
triple_list_data: dict[str, list[list[str]]],
embedding_manager: EmbeddingManager,
) -> int:
"""同义词连接"""
@@ -208,7 +206,7 @@ class KGManager:
def _update_graph(
self,
node_to_node: Dict[Tuple[str, str], float],
node_to_node: dict[tuple[str, str], float],
embedding_manager: EmbeddingManager,
):
"""更新KG图结构
@@ -280,7 +278,7 @@ class KGManager:
def build_kg(
self,
triple_list_data: Dict[str, List[List[str]]],
triple_list_data: dict[str, list[list[str]]],
embedding_manager: EmbeddingManager,
):
"""增量式构建KG
@@ -317,8 +315,8 @@ class KGManager:
def kg_search(
self,
relation_search_result: List[Tuple[Tuple[str, str, str], float]],
paragraph_search_result: List[Tuple[str, float]],
relation_search_result: list[tuple[tuple[str, str, str], float]],
paragraph_search_result: list[tuple[str, float]],
embed_manager: EmbeddingManager,
):
"""RAG搜索与PageRank

View File

@@ -1,10 +1,11 @@
from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.qa_manager import QAManager
from src.chat.knowledge.kg_manager import KGManager
from src.chat.knowledge.global_logger import logger
from src.config.config import global_config
import os
from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.global_logger import logger
from src.chat.knowledge.kg_manager import KGManager
from src.chat.knowledge.qa_manager import QAManager
from src.config.config import global_config
INVALID_ENTITY = [
"",
"",

View File

@@ -1,14 +1,15 @@
import orjson
import os
import glob
from typing import Any, Dict, List
import os
from typing import Any
import orjson
from .knowledge_lib import DATA_PATH, INVALID_ENTITY, ROOT_PATH
from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH
# from src.manager.local_store_manager import local_storage
def _filter_invalid_entities(entities: List[str]) -> List[str]:
def _filter_invalid_entities(entities: list[str]) -> list[str]:
"""过滤无效的实体"""
valid_entities = set()
for entity in entities:
@@ -20,7 +21,7 @@ def _filter_invalid_entities(entities: List[str]) -> List[str]:
return list(valid_entities)
def _filter_invalid_triples(triples: List[List[str]]) -> List[List[str]]:
def _filter_invalid_triples(triples: list[list[str]]) -> list[list[str]]:
"""过滤无效的三元组"""
unique_triples = set()
valid_triples = []
@@ -62,7 +63,7 @@ class OpenIE:
def __init__(
self,
docs: List[Dict[str, Any]],
docs: list[dict[str, Any]],
avg_ent_chars,
avg_ent_words,
):
@@ -112,7 +113,7 @@ class OpenIE:
json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
data_list = []
for file in json_files:
with open(file, "r", encoding="utf-8") as f:
with open(file, encoding="utf-8") as f:
data = orjson.loads(f.read())
data_list.append(data)
if not data_list:

View File

@@ -1,15 +1,16 @@
import time
from typing import Tuple, List, Dict, Optional, Any
from typing import Any
from src.chat.utils.utils import get_embedding
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from .global_logger import logger
from .embedding_store import EmbeddingManager
from .global_logger import logger
from .kg_manager import KGManager
# from .lpmmconfig import global_config
from .utils.dyn_topk import dyn_select_top_k
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.utils import get_embedding
from src.config.config import global_config, model_config
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
@@ -26,7 +27,7 @@ class QAManager:
async def process_query(
self, question: str
) -> Optional[Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]]:
) -> tuple[list[tuple[str, float, float]], dict[str, float] | None] | None:
"""处理查询"""
# 生成问题的Embedding
@@ -98,7 +99,7 @@ class QAManager:
return result, ppr_node_weights
async def get_knowledge(self, question: str) -> Optional[Dict[str, Any]]:
async def get_knowledge(self, question: str) -> dict[str, Any] | None:
"""
获取知识,返回结构化字典

View File

@@ -1,9 +1,9 @@
from typing import List, Any, Tuple
from typing import Any
def dyn_select_top_k(
score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float
) -> List[Tuple[Any, float, float]]:
score: list[tuple[Any, float]], jmp_factor: float, var_factor: float
) -> list[tuple[Any, float, float]]:
"""动态TopK选择"""
# 检查输入列表是否为空
if not score: