re-style: 格式化代码
This commit is contained in:
committed by
Windpicker-owo
parent
00ba07e0e1
commit
a79253c714
@@ -1,33 +1,31 @@
|
||||
from dataclasses import dataclass
|
||||
import orjson
|
||||
import os
|
||||
import math
|
||||
import asyncio
|
||||
import math
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from dataclasses import dataclass
|
||||
|
||||
# import tqdm
|
||||
import faiss
|
||||
|
||||
from .utils.hash import get_sha256
|
||||
from .global_logger import logger
|
||||
from rich.traceback import install
|
||||
import numpy as np
|
||||
import orjson
|
||||
import pandas as pd
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
BarColumn,
|
||||
MofNCompleteColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from src.config.config import global_config
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.config.config import global_config
|
||||
|
||||
from .global_logger import logger
|
||||
from .utils.hash import get_sha256
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -79,7 +77,7 @@ def cosine_similarity(a, b):
|
||||
class EmbeddingStoreItem:
|
||||
"""嵌入库中的项"""
|
||||
|
||||
def __init__(self, item_hash: str, embedding: List[float], content: str):
|
||||
def __init__(self, item_hash: str, embedding: list[float], content: str):
|
||||
self.hash = item_hash
|
||||
self.embedding = embedding
|
||||
self.str = content
|
||||
@@ -127,7 +125,7 @@ class EmbeddingStore:
|
||||
self.idx2hash = None
|
||||
|
||||
@staticmethod
|
||||
def _get_embedding(s: str) -> List[float]:
|
||||
def _get_embedding(s: str) -> list[float]:
|
||||
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
||||
# 创建新的事件循环并在完成后立即关闭
|
||||
loop = asyncio.new_event_loop()
|
||||
@@ -135,8 +133,8 @@ class EmbeddingStore:
|
||||
|
||||
try:
|
||||
# 创建新的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
@@ -161,8 +159,8 @@ class EmbeddingStore:
|
||||
|
||||
@staticmethod
|
||||
def _get_embeddings_batch_threaded(
|
||||
strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
) -> List[Tuple[str, List[float]]]:
|
||||
strs: list[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
) -> list[tuple[str, list[float]]]:
|
||||
"""使用多线程批量获取嵌入向量
|
||||
|
||||
Args:
|
||||
@@ -192,8 +190,8 @@ class EmbeddingStore:
|
||||
chunk_results = []
|
||||
|
||||
# 为每个线程创建独立的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
try:
|
||||
# 创建线程专用的LLM实例
|
||||
@@ -303,7 +301,7 @@ class EmbeddingStore:
|
||||
path = self.get_test_file_path()
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return orjson.loads(f.read())
|
||||
|
||||
def check_embedding_model_consistency(self):
|
||||
@@ -345,7 +343,7 @@ class EmbeddingStore:
|
||||
logger.info("嵌入模型一致性校验通过。")
|
||||
return True
|
||||
|
||||
def batch_insert_strs(self, strs: List[str], times: int) -> None:
|
||||
def batch_insert_strs(self, strs: list[str], times: int) -> None:
|
||||
"""向库中存入字符串(使用多线程优化)"""
|
||||
if not strs:
|
||||
return
|
||||
@@ -481,7 +479,7 @@ class EmbeddingStore:
|
||||
if os.path.exists(self.idx2hash_file_path):
|
||||
logger.info(f"正在加载{self.namespace}嵌入库的idx2hash映射...")
|
||||
logger.debug(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射")
|
||||
with open(self.idx2hash_file_path, "r") as f:
|
||||
with open(self.idx2hash_file_path) as f:
|
||||
self.idx2hash = orjson.loads(f.read())
|
||||
logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功")
|
||||
else:
|
||||
@@ -511,7 +509,7 @@ class EmbeddingStore:
|
||||
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
|
||||
self.faiss_index.add(embeddings)
|
||||
|
||||
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
|
||||
def search_top_k(self, query: list[float], k: int) -> list[tuple[str, float]]:
|
||||
"""搜索最相似的k个项,以余弦相似度为度量
|
||||
Args:
|
||||
query: 查询的embedding
|
||||
@@ -575,11 +573,11 @@ class EmbeddingManager:
|
||||
"""对所有嵌入库做模型一致性校验"""
|
||||
return self.paragraphs_embedding_store.check_embedding_model_consistency()
|
||||
|
||||
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
|
||||
def _store_pg_into_embedding(self, raw_paragraphs: dict[str, str]):
|
||||
"""将段落编码存入Embedding库"""
|
||||
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
|
||||
|
||||
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
||||
def _store_ent_into_embedding(self, triple_list_data: dict[str, list[list[str]]]):
|
||||
"""将实体编码存入Embedding库"""
|
||||
entities = set()
|
||||
for triple_list in triple_list_data.values():
|
||||
@@ -588,7 +586,7 @@ class EmbeddingManager:
|
||||
entities.add(triple[2])
|
||||
self.entities_embedding_store.batch_insert_strs(list(entities), times=2)
|
||||
|
||||
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
||||
def _store_rel_into_embedding(self, triple_list_data: dict[str, list[list[str]]]):
|
||||
"""将关系编码存入Embedding库"""
|
||||
graph_triples = [] # a list of unique relation triple (in tuple) from all chunks
|
||||
for triples in triple_list_data.values():
|
||||
@@ -606,8 +604,8 @@ class EmbeddingManager:
|
||||
|
||||
def store_new_data_set(
|
||||
self,
|
||||
raw_paragraphs: Dict[str, str],
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
raw_paragraphs: dict[str, str],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
):
|
||||
if not self.check_all_embedding_model_consistency():
|
||||
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import asyncio
|
||||
import orjson
|
||||
import time
|
||||
from typing import List, Union
|
||||
|
||||
from .global_logger import logger
|
||||
from . import prompt_template
|
||||
from .knowledge_lib import INVALID_ENTITY
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
from . import prompt_template
|
||||
from .global_logger import logger
|
||||
from .knowledge_lib import INVALID_ENTITY
|
||||
|
||||
|
||||
def _extract_json_from_text(text: str):
|
||||
# sourcery skip: assign-if-exp, extract-method
|
||||
@@ -46,7 +47,7 @@ def _extract_json_from_text(text: str):
|
||||
return []
|
||||
|
||||
|
||||
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> list[str]:
|
||||
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
|
||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
||||
@@ -92,7 +93,7 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||
return entity_extract_result
|
||||
|
||||
|
||||
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]:
|
||||
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> list[list[str]]:
|
||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||
paragraph, entities=orjson.dumps(entities).decode("utf-8")
|
||||
@@ -141,7 +142,7 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
|
||||
|
||||
def info_extract_from_str(
|
||||
llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str
|
||||
) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]:
|
||||
) -> tuple[None, None] | tuple[list[str], list[list[str]]]:
|
||||
try_count = 0
|
||||
while True:
|
||||
try:
|
||||
|
||||
@@ -1,28 +1,26 @@
|
||||
import orjson
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
import pandas as pd
|
||||
from quick_algo import di_graph, pagerank
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
BarColumn,
|
||||
MofNCompleteColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from quick_algo import di_graph, pagerank
|
||||
|
||||
|
||||
from .utils.hash import get_sha256
|
||||
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
|
||||
from src.config.config import global_config
|
||||
|
||||
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
|
||||
from .global_logger import logger
|
||||
from .utils.hash import get_sha256
|
||||
|
||||
|
||||
def _get_kg_dir():
|
||||
@@ -87,7 +85,7 @@ class KGManager:
|
||||
raise FileNotFoundError(f"KG图文件{self.graph_data_path}不存在")
|
||||
|
||||
# 加载段落hash
|
||||
with open(self.pg_hash_file_path, "r", encoding="utf-8") as f:
|
||||
with open(self.pg_hash_file_path, encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
self.stored_paragraph_hashes = set(data["stored_paragraph_hashes"])
|
||||
|
||||
@@ -100,8 +98,8 @@ class KGManager:
|
||||
|
||||
def _build_edges_between_ent(
|
||||
self,
|
||||
node_to_node: Dict[Tuple[str, str], float],
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
node_to_node: dict[tuple[str, str], float],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
):
|
||||
"""构建实体节点之间的关系,同时统计实体出现次数"""
|
||||
for triple_list in triple_list_data.values():
|
||||
@@ -124,8 +122,8 @@ class KGManager:
|
||||
|
||||
@staticmethod
|
||||
def _build_edges_between_ent_pg(
|
||||
node_to_node: Dict[Tuple[str, str], float],
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
node_to_node: dict[tuple[str, str], float],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
):
|
||||
"""构建实体节点与文段节点之间的关系"""
|
||||
for idx in triple_list_data:
|
||||
@@ -136,8 +134,8 @@ class KGManager:
|
||||
|
||||
@staticmethod
|
||||
def _synonym_connect(
|
||||
node_to_node: Dict[Tuple[str, str], float],
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
node_to_node: dict[tuple[str, str], float],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
embedding_manager: EmbeddingManager,
|
||||
) -> int:
|
||||
"""同义词连接"""
|
||||
@@ -208,7 +206,7 @@ class KGManager:
|
||||
|
||||
def _update_graph(
|
||||
self,
|
||||
node_to_node: Dict[Tuple[str, str], float],
|
||||
node_to_node: dict[tuple[str, str], float],
|
||||
embedding_manager: EmbeddingManager,
|
||||
):
|
||||
"""更新KG图结构
|
||||
@@ -280,7 +278,7 @@ class KGManager:
|
||||
|
||||
def build_kg(
|
||||
self,
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
embedding_manager: EmbeddingManager,
|
||||
):
|
||||
"""增量式构建KG
|
||||
@@ -317,8 +315,8 @@ class KGManager:
|
||||
|
||||
def kg_search(
|
||||
self,
|
||||
relation_search_result: List[Tuple[Tuple[str, str, str], float]],
|
||||
paragraph_search_result: List[Tuple[str, float]],
|
||||
relation_search_result: list[tuple[tuple[str, str, str], float]],
|
||||
paragraph_search_result: list[tuple[str, float]],
|
||||
embed_manager: EmbeddingManager,
|
||||
):
|
||||
"""RAG搜索与PageRank
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.qa_manager import QAManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.chat.knowledge.global_logger import logger
|
||||
from src.config.config import global_config
|
||||
import os
|
||||
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.global_logger import logger
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.chat.knowledge.qa_manager import QAManager
|
||||
from src.config.config import global_config
|
||||
|
||||
INVALID_ENTITY = [
|
||||
"",
|
||||
"你",
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import orjson
|
||||
import os
|
||||
import glob
|
||||
from typing import Any, Dict, List
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from .knowledge_lib import DATA_PATH, INVALID_ENTITY, ROOT_PATH
|
||||
|
||||
from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH
|
||||
# from src.manager.local_store_manager import local_storage
|
||||
|
||||
|
||||
def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
||||
def _filter_invalid_entities(entities: list[str]) -> list[str]:
|
||||
"""过滤无效的实体"""
|
||||
valid_entities = set()
|
||||
for entity in entities:
|
||||
@@ -20,7 +21,7 @@ def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
||||
return list(valid_entities)
|
||||
|
||||
|
||||
def _filter_invalid_triples(triples: List[List[str]]) -> List[List[str]]:
|
||||
def _filter_invalid_triples(triples: list[list[str]]) -> list[list[str]]:
|
||||
"""过滤无效的三元组"""
|
||||
unique_triples = set()
|
||||
valid_triples = []
|
||||
@@ -62,7 +63,7 @@ class OpenIE:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
docs: List[Dict[str, Any]],
|
||||
docs: list[dict[str, Any]],
|
||||
avg_ent_chars,
|
||||
avg_ent_words,
|
||||
):
|
||||
@@ -112,7 +113,7 @@ class OpenIE:
|
||||
json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
|
||||
data_list = []
|
||||
for file in json_files:
|
||||
with open(file, "r", encoding="utf-8") as f:
|
||||
with open(file, encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
data_list.append(data)
|
||||
if not data_list:
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import time
|
||||
from typing import Tuple, List, Dict, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
from .global_logger import logger
|
||||
from .embedding_store import EmbeddingManager
|
||||
from .global_logger import logger
|
||||
from .kg_manager import KGManager
|
||||
|
||||
# from .lpmmconfig import global_config
|
||||
from .utils.dyn_topk import dyn_select_top_k
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.config.config import global_config, model_config
|
||||
|
||||
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
||||
|
||||
@@ -26,7 +27,7 @@ class QAManager:
|
||||
|
||||
async def process_query(
|
||||
self, question: str
|
||||
) -> Optional[Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]]:
|
||||
) -> tuple[list[tuple[str, float, float]], dict[str, float] | None] | None:
|
||||
"""处理查询"""
|
||||
|
||||
# 生成问题的Embedding
|
||||
@@ -98,7 +99,7 @@ class QAManager:
|
||||
|
||||
return result, ppr_node_weights
|
||||
|
||||
async def get_knowledge(self, question: str) -> Optional[Dict[str, Any]]:
|
||||
async def get_knowledge(self, question: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
获取知识,返回结构化字典
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import List, Any, Tuple
|
||||
from typing import Any
|
||||
|
||||
|
||||
def dyn_select_top_k(
|
||||
score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float
|
||||
) -> List[Tuple[Any, float, float]]:
|
||||
score: list[tuple[Any, float]], jmp_factor: float, var_factor: float
|
||||
) -> list[tuple[Any, float, float]]:
|
||||
"""动态TopK选择"""
|
||||
# 检查输入列表是否为空
|
||||
if not score:
|
||||
|
||||
Reference in New Issue
Block a user