128 lines
5.2 KiB
Python
128 lines
5.2 KiB
Python
import time
|
||
from typing import Tuple, List, Dict, Optional
|
||
|
||
from .global_logger import logger
|
||
|
||
# from . import prompt_template
|
||
from .embedding_store import EmbeddingManager
|
||
# from .llm_client import LLMClient
|
||
from .kg_manager import KGManager
|
||
# from .lpmmconfig import global_config
|
||
from .utils.dyn_topk import dyn_select_top_k
|
||
from src.llm_models.utils_model import LLMRequest
|
||
from src.chat.utils.utils import get_embedding
|
||
from src.config.config import global_config
|
||
|
||
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
||
|
||
|
||
class QAManager:
|
||
def __init__(
|
||
self,
|
||
embed_manager: EmbeddingManager,
|
||
kg_manager: KGManager,
|
||
|
||
):
|
||
self.embed_manager = embed_manager
|
||
self.kg_manager = kg_manager
|
||
# TODO: API-Adapter修改标记
|
||
self.qa_model = LLMRequest(
|
||
model=global_config.model.lpmm_qa,
|
||
request_type="lpmm.qa"
|
||
)
|
||
|
||
def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]:
|
||
"""处理查询"""
|
||
|
||
# 生成问题的Embedding
|
||
part_start_time = time.perf_counter()
|
||
question_embedding = get_embedding(question)
|
||
if question_embedding is None:
|
||
logger.error("生成问题Embedding失败")
|
||
return None
|
||
part_end_time = time.perf_counter()
|
||
logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s")
|
||
|
||
# 根据问题Embedding查询Relation Embedding库
|
||
part_start_time = time.perf_counter()
|
||
relation_search_res = self.embed_manager.relation_embedding_store.search_top_k(
|
||
question_embedding,
|
||
global_config.lpmm_knowledge.qa_relation_search_top_k,
|
||
)
|
||
if relation_search_res is not None:
|
||
# 过滤阈值
|
||
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
|
||
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
|
||
if relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold:
|
||
# 未找到相关关系
|
||
logger.debug("未找到相关关系,跳过关系检索")
|
||
relation_search_res = []
|
||
|
||
part_end_time = time.perf_counter()
|
||
logger.debug(f"关系检索用时:{part_end_time - part_start_time:.5f}s")
|
||
|
||
for res in relation_search_res:
|
||
rel_str = self.embed_manager.relation_embedding_store.store.get(res[0]).str
|
||
print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
|
||
|
||
# TODO: 使用LLM过滤三元组结果
|
||
# logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s")
|
||
# part_start_time = time.time()
|
||
|
||
# 根据问题Embedding查询Paragraph Embedding库
|
||
part_start_time = time.perf_counter()
|
||
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
|
||
question_embedding,
|
||
global_config.lpmm_knowledge.qa_paragraph_search_top_k,
|
||
)
|
||
part_end_time = time.perf_counter()
|
||
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
|
||
|
||
if len(relation_search_res) != 0:
|
||
logger.info("找到相关关系,将使用RAG进行检索")
|
||
# 使用KG检索
|
||
part_start_time = time.perf_counter()
|
||
result, ppr_node_weights = self.kg_manager.kg_search(
|
||
relation_search_res, paragraph_search_res, self.embed_manager
|
||
)
|
||
part_end_time = time.perf_counter()
|
||
logger.info(f"RAG检索用时:{part_end_time - part_start_time:.5f}s")
|
||
else:
|
||
logger.info("未找到相关关系,将使用文段检索结果")
|
||
result = paragraph_search_res
|
||
ppr_node_weights = None
|
||
|
||
# 过滤阈值
|
||
result = dyn_select_top_k(result, 0.5, 1.0)
|
||
|
||
for res in result:
|
||
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
|
||
print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
|
||
|
||
return result, ppr_node_weights
|
||
else:
|
||
return None
|
||
|
||
def get_knowledge(self, question: str) -> str:
|
||
"""获取知识"""
|
||
# 处理查询
|
||
processed_result = self.process_query(question)
|
||
if processed_result is not None:
|
||
query_res = processed_result[0]
|
||
knowledge = [
|
||
(
|
||
self.embed_manager.paragraphs_embedding_store.store[res[0]].str,
|
||
res[1],
|
||
)
|
||
for res in query_res
|
||
]
|
||
found_knowledge = "\n".join(
|
||
[f"第{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}" for i, k in enumerate(knowledge)]
|
||
)
|
||
if len(found_knowledge) > MAX_KNOWLEDGE_LENGTH:
|
||
found_knowledge = found_knowledge[:MAX_KNOWLEDGE_LENGTH] + "\n"
|
||
return found_knowledge
|
||
else:
|
||
logger.debug("LPMM知识库并未初始化,可能是从未导入过知识...")
|
||
return None
|