fix typing
This commit is contained in:
@@ -11,6 +11,7 @@ from src.plugin_system import (
|
|||||||
BaseEventHandler,
|
BaseEventHandler,
|
||||||
EventType,
|
EventType,
|
||||||
MaiMessages,
|
MaiMessages,
|
||||||
|
ToolParamType
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -20,8 +21,8 @@ class CompareNumbersTool(BaseTool):
|
|||||||
name = "compare_numbers"
|
name = "compare_numbers"
|
||||||
description = "使用工具 比较两个数的大小,返回较大的数"
|
description = "使用工具 比较两个数的大小,返回较大的数"
|
||||||
parameters = [
|
parameters = [
|
||||||
("num1", "number", "第一个数字", True),
|
("num1", ToolParamType.FLOAT, "第一个数字", True, None),
|
||||||
("num2", "number", "第二个数字", True),
|
("num2", ToolParamType.FLOAT, "第二个数字", True, None),
|
||||||
]
|
]
|
||||||
|
|
||||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class QAManager:
|
|||||||
self.kg_manager = kg_manager
|
self.kg_manager = kg_manager
|
||||||
self.qa_model = LLMRequest(model_set=model_config.model_task_config.lpmm_qa, request_type="lpmm.qa")
|
self.qa_model = LLMRequest(model_set=model_config.model_task_config.lpmm_qa, request_type="lpmm.qa")
|
||||||
|
|
||||||
async def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]:
|
async def process_query(self, question: str) -> Optional[Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]]:
|
||||||
"""处理查询"""
|
"""处理查询"""
|
||||||
|
|
||||||
# 生成问题的Embedding
|
# 生成问题的Embedding
|
||||||
@@ -46,61 +46,60 @@ class QAManager:
|
|||||||
question_embedding,
|
question_embedding,
|
||||||
global_config.lpmm_knowledge.qa_relation_search_top_k,
|
global_config.lpmm_knowledge.qa_relation_search_top_k,
|
||||||
)
|
)
|
||||||
if relation_search_res is not None:
|
if relation_search_res is None:
|
||||||
# 过滤阈值
|
return None
|
||||||
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
|
# 过滤阈值
|
||||||
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
|
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
|
||||||
if not relation_search_res or relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold:
|
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
|
||||||
# 未找到相关关系
|
if not relation_search_res or relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold:
|
||||||
logger.debug("未找到相关关系,跳过关系检索")
|
# 未找到相关关系
|
||||||
relation_search_res = []
|
logger.debug("未找到相关关系,跳过关系检索")
|
||||||
|
relation_search_res = []
|
||||||
|
|
||||||
part_end_time = time.perf_counter()
|
part_end_time = time.perf_counter()
|
||||||
logger.debug(f"关系检索用时:{part_end_time - part_start_time:.5f}s")
|
logger.debug(f"关系检索用时:{part_end_time - part_start_time:.5f}s")
|
||||||
|
|
||||||
for res in relation_search_res:
|
for res in relation_search_res:
|
||||||
rel_str = self.embed_manager.relation_embedding_store.store.get(res[0]).str
|
rel_str = self.embed_manager.relation_embedding_store.store.get(res[0]).str
|
||||||
print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
|
print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
|
||||||
|
|
||||||
# TODO: 使用LLM过滤三元组结果
|
# TODO: 使用LLM过滤三元组结果
|
||||||
# logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s")
|
# logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s")
|
||||||
# part_start_time = time.time()
|
# part_start_time = time.time()
|
||||||
|
|
||||||
# 根据问题Embedding查询Paragraph Embedding库
|
# 根据问题Embedding查询Paragraph Embedding库
|
||||||
|
part_start_time = time.perf_counter()
|
||||||
|
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
|
||||||
|
question_embedding,
|
||||||
|
global_config.lpmm_knowledge.qa_paragraph_search_top_k,
|
||||||
|
)
|
||||||
|
part_end_time = time.perf_counter()
|
||||||
|
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
|
||||||
|
|
||||||
|
if len(relation_search_res) != 0:
|
||||||
|
logger.info("找到相关关系,将使用RAG进行检索")
|
||||||
|
# 使用KG检索
|
||||||
part_start_time = time.perf_counter()
|
part_start_time = time.perf_counter()
|
||||||
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
|
result, ppr_node_weights = self.kg_manager.kg_search(
|
||||||
question_embedding,
|
relation_search_res, paragraph_search_res, self.embed_manager
|
||||||
global_config.lpmm_knowledge.qa_paragraph_search_top_k,
|
|
||||||
)
|
)
|
||||||
part_end_time = time.perf_counter()
|
part_end_time = time.perf_counter()
|
||||||
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
|
logger.info(f"RAG检索用时:{part_end_time - part_start_time:.5f}s")
|
||||||
|
|
||||||
if len(relation_search_res) != 0:
|
|
||||||
logger.info("找到相关关系,将使用RAG进行检索")
|
|
||||||
# 使用KG检索
|
|
||||||
part_start_time = time.perf_counter()
|
|
||||||
result, ppr_node_weights = self.kg_manager.kg_search(
|
|
||||||
relation_search_res, paragraph_search_res, self.embed_manager
|
|
||||||
)
|
|
||||||
part_end_time = time.perf_counter()
|
|
||||||
logger.info(f"RAG检索用时:{part_end_time - part_start_time:.5f}s")
|
|
||||||
else:
|
|
||||||
logger.info("未找到相关关系,将使用文段检索结果")
|
|
||||||
result = paragraph_search_res
|
|
||||||
ppr_node_weights = None
|
|
||||||
|
|
||||||
# 过滤阈值
|
|
||||||
result = dyn_select_top_k(result, 0.5, 1.0)
|
|
||||||
|
|
||||||
for res in result:
|
|
||||||
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
|
|
||||||
print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
|
|
||||||
|
|
||||||
return result, ppr_node_weights
|
|
||||||
else:
|
else:
|
||||||
return None
|
logger.info("未找到相关关系,将使用文段检索结果")
|
||||||
|
result = paragraph_search_res
|
||||||
|
ppr_node_weights = None
|
||||||
|
|
||||||
async def get_knowledge(self, question: str) -> str:
|
# 过滤阈值
|
||||||
|
result = dyn_select_top_k(result, 0.5, 1.0)
|
||||||
|
|
||||||
|
for res in result:
|
||||||
|
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
|
||||||
|
print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
|
||||||
|
|
||||||
|
return result, ppr_node_weights
|
||||||
|
|
||||||
|
async def get_knowledge(self, question: str) -> Optional[str]:
|
||||||
"""获取知识"""
|
"""获取知识"""
|
||||||
# 处理查询
|
# 处理查询
|
||||||
processed_result = await self.process_query(question)
|
processed_result = await self.process_query(question)
|
||||||
|
|||||||
Reference in New Issue
Block a user