feat: 知识库小重构
This commit is contained in:
125
src/chat/knowledge/qa_manager.py
Normal file
125
src/chat/knowledge/qa_manager.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import time
|
||||
from typing import Tuple, List, Dict, Optional
|
||||
|
||||
from .global_logger import logger
|
||||
|
||||
# from . import prompt_template
|
||||
from .embedding_store import EmbeddingManager
|
||||
from .llm_client import LLMClient
|
||||
from .kg_manager import KGManager
|
||||
from .lpmmconfig import global_config
|
||||
from src.chat.knowledge.utils import dyn_select_top_k
|
||||
|
||||
|
||||
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
||||
|
||||
|
||||
class QAManager:
|
||||
def __init__(
|
||||
self,
|
||||
embed_manager: EmbeddingManager,
|
||||
kg_manager: KGManager,
|
||||
llm_client_embedding: LLMClient,
|
||||
llm_client_filter: LLMClient,
|
||||
llm_client_qa: LLMClient,
|
||||
):
|
||||
self.embed_manager = embed_manager
|
||||
self.kg_manager = kg_manager
|
||||
self.llm_client_list = {
|
||||
"embedding": llm_client_embedding,
|
||||
"message_filter": llm_client_filter,
|
||||
"qa": llm_client_qa,
|
||||
}
|
||||
|
||||
def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]:
|
||||
"""处理查询"""
|
||||
|
||||
# 生成问题的Embedding
|
||||
part_start_time = time.perf_counter()
|
||||
question_embedding = self.llm_client_list["embedding"].send_embedding_request(
|
||||
global_config["embedding"]["model"], question
|
||||
)
|
||||
part_end_time = time.perf_counter()
|
||||
logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s")
|
||||
|
||||
# 根据问题Embedding查询Relation Embedding库
|
||||
part_start_time = time.perf_counter()
|
||||
relation_search_res = self.embed_manager.relation_embedding_store.search_top_k(
|
||||
question_embedding,
|
||||
global_config["qa"]["params"]["relation_search_top_k"],
|
||||
)
|
||||
if relation_search_res is not None:
|
||||
# 过滤阈值
|
||||
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
|
||||
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
|
||||
if relation_search_res[0][1] < global_config["qa"]["params"]["relation_threshold"]:
|
||||
# 未找到相关关系
|
||||
relation_search_res = []
|
||||
|
||||
part_end_time = time.perf_counter()
|
||||
logger.debug(f"关系检索用时:{part_end_time - part_start_time:.5f}s")
|
||||
|
||||
for res in relation_search_res:
|
||||
rel_str = self.embed_manager.relation_embedding_store.store.get(res[0]).str
|
||||
print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
|
||||
|
||||
# TODO: 使用LLM过滤三元组结果
|
||||
# logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s")
|
||||
# part_start_time = time.time()
|
||||
|
||||
# 根据问题Embedding查询Paragraph Embedding库
|
||||
part_start_time = time.perf_counter()
|
||||
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
|
||||
question_embedding,
|
||||
global_config["qa"]["params"]["paragraph_search_top_k"],
|
||||
)
|
||||
part_end_time = time.perf_counter()
|
||||
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
|
||||
|
||||
if len(relation_search_res) != 0:
|
||||
logger.info("找到相关关系,将使用RAG进行检索")
|
||||
# 使用KG检索
|
||||
part_start_time = time.perf_counter()
|
||||
result, ppr_node_weights = self.kg_manager.kg_search(
|
||||
relation_search_res, paragraph_search_res, self.embed_manager
|
||||
)
|
||||
part_end_time = time.perf_counter()
|
||||
logger.info(f"RAG检索用时:{part_end_time - part_start_time:.5f}s")
|
||||
else:
|
||||
logger.info("未找到相关关系,将使用文段检索结果")
|
||||
result = paragraph_search_res
|
||||
ppr_node_weights = None
|
||||
|
||||
# 过滤阈值
|
||||
result = dyn_select_top_k(result, 0.5, 1.0)
|
||||
|
||||
for res in result:
|
||||
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
|
||||
print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
|
||||
|
||||
return result, ppr_node_weights
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_knowledge(self, question: str) -> str:
|
||||
"""获取知识"""
|
||||
# 处理查询
|
||||
processed_result = self.process_query(question)
|
||||
if processed_result is not None:
|
||||
query_res = processed_result[0]
|
||||
knowledge = [
|
||||
(
|
||||
self.embed_manager.paragraphs_embedding_store.store[res[0]].str,
|
||||
res[1],
|
||||
)
|
||||
for res in query_res
|
||||
]
|
||||
found_knowledge = "\n".join(
|
||||
[f"第{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}" for i, k in enumerate(knowledge)]
|
||||
)
|
||||
if len(found_knowledge) > MAX_KNOWLEDGE_LENGTH:
|
||||
found_knowledge = found_knowledge[:MAX_KNOWLEDGE_LENGTH] + "\n"
|
||||
return found_knowledge
|
||||
else:
|
||||
logger.info("LPMM知识库并未初始化,可能是从未导入过知识...")
|
||||
return None
|
||||
Reference in New Issue
Block a user