feat: 重构信息提取模块,移除LLMClient依赖,改为使用LLMRequest,优化数据加载和处理逻辑
This commit is contained in:
@@ -5,11 +5,13 @@ from .global_logger import logger
|
||||
|
||||
# from . import prompt_template
|
||||
from .embedding_store import EmbeddingManager
|
||||
from .llm_client import LLMClient
|
||||
# from .llm_client import LLMClient
|
||||
from .kg_manager import KGManager
|
||||
from .lpmmconfig import global_config
|
||||
# from .lpmmconfig import global_config
|
||||
from .utils.dyn_topk import dyn_select_top_k
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.config.config import global_config
|
||||
|
||||
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
||||
|
||||
@@ -19,26 +21,25 @@ class QAManager:
|
||||
self,
|
||||
embed_manager: EmbeddingManager,
|
||||
kg_manager: KGManager,
|
||||
llm_client_embedding: LLMClient,
|
||||
llm_client_filter: LLMClient,
|
||||
llm_client_qa: LLMClient,
|
||||
|
||||
):
|
||||
self.embed_manager = embed_manager
|
||||
self.kg_manager = kg_manager
|
||||
self.llm_client_list = {
|
||||
"embedding": llm_client_embedding,
|
||||
"message_filter": llm_client_filter,
|
||||
"qa": llm_client_qa,
|
||||
}
|
||||
# TODO: API-Adapter修改标记
|
||||
self.qa_model = LLMRequest(
|
||||
model=global_config.model.lpmm_qa,
|
||||
request_type="lpmm.qa"
|
||||
)
|
||||
|
||||
def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]:
|
||||
"""处理查询"""
|
||||
|
||||
# 生成问题的Embedding
|
||||
part_start_time = time.perf_counter()
|
||||
question_embedding = self.llm_client_list["embedding"].send_embedding_request(
|
||||
global_config["embedding"]["model"], question
|
||||
)
|
||||
question_embedding = get_embedding(question)
|
||||
if question_embedding is None:
|
||||
logger.error("生成问题Embedding失败")
|
||||
return None
|
||||
part_end_time = time.perf_counter()
|
||||
logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s")
|
||||
|
||||
@@ -46,14 +47,15 @@ class QAManager:
|
||||
part_start_time = time.perf_counter()
|
||||
relation_search_res = self.embed_manager.relation_embedding_store.search_top_k(
|
||||
question_embedding,
|
||||
global_config["qa"]["params"]["relation_search_top_k"],
|
||||
global_config.lpmm_knowledge.qa_relation_search_top_k,
|
||||
)
|
||||
if relation_search_res is not None:
|
||||
# 过滤阈值
|
||||
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
|
||||
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
|
||||
if relation_search_res[0][1] < global_config["qa"]["params"]["relation_threshold"]:
|
||||
if relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold:
|
||||
# 未找到相关关系
|
||||
logger.debug("未找到相关关系,跳过关系检索")
|
||||
relation_search_res = []
|
||||
|
||||
part_end_time = time.perf_counter()
|
||||
@@ -71,7 +73,7 @@ class QAManager:
|
||||
part_start_time = time.perf_counter()
|
||||
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
|
||||
question_embedding,
|
||||
global_config["qa"]["params"]["paragraph_search_top_k"],
|
||||
global_config.lpmm_knowledge.qa_paragraph_search_top_k,
|
||||
)
|
||||
part_end_time = time.perf_counter()
|
||||
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
|
||||
|
||||
Reference in New Issue
Block a user