diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index 4ff01879d..f9855481f 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -11,6 +11,7 @@ from src.plugin_system import ( BaseEventHandler, EventType, MaiMessages, + ToolParamType ) @@ -20,8 +21,8 @@ class CompareNumbersTool(BaseTool): name = "compare_numbers" description = "使用工具 比较两个数的大小,返回较大的数" parameters = [ - ("num1", "number", "第一个数字", True), - ("num2", "number", "第二个数字", True), + ("num1", ToolParamType.FLOAT, "第一个数字", True, None), + ("num2", ToolParamType.FLOAT, "第二个数字", True, None), ] async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: diff --git a/src/chat/knowledge/qa_manager.py b/src/chat/knowledge/qa_manager.py index 587775755..1a47767cb 100644 --- a/src/chat/knowledge/qa_manager.py +++ b/src/chat/knowledge/qa_manager.py @@ -28,7 +28,7 @@ class QAManager: self.kg_manager = kg_manager self.qa_model = LLMRequest(model_set=model_config.model_task_config.lpmm_qa, request_type="lpmm.qa") - async def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]: + async def process_query(self, question: str) -> Optional[Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]]: """处理查询""" # 生成问题的Embedding @@ -46,61 +46,60 @@ class QAManager: question_embedding, global_config.lpmm_knowledge.qa_relation_search_top_k, ) - if relation_search_res is not None: - # 过滤阈值 - # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果 - relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0) - if not relation_search_res or relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold: - # 未找到相关关系 - logger.debug("未找到相关关系,跳过关系检索") - relation_search_res = [] + if relation_search_res is None: + return None + # 过滤阈值 + # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果 + relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0) + if not relation_search_res or relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold: + # 未找到相关关系 + logger.debug("未找到相关关系,跳过关系检索") + relation_search_res = [] - part_end_time = time.perf_counter() - logger.debug(f"关系检索用时:{part_end_time - part_start_time:.5f}s") + part_end_time = time.perf_counter() + logger.debug(f"关系检索用时:{part_end_time - part_start_time:.5f}s") - for res in relation_search_res: - rel_str = self.embed_manager.relation_embedding_store.store.get(res[0]).str - print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}") + for res in relation_search_res: + rel_str = self.embed_manager.relation_embedding_store.store.get(res[0]).str + print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}") - # TODO: 使用LLM过滤三元组结果 - # logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s") - # part_start_time = time.time() + # TODO: 使用LLM过滤三元组结果 + # logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s") + # part_start_time = time.time() - # 根据问题Embedding查询Paragraph Embedding库 + # 根据问题Embedding查询Paragraph Embedding库 + part_start_time = time.perf_counter() + paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k( + question_embedding, + global_config.lpmm_knowledge.qa_paragraph_search_top_k, + ) + part_end_time = time.perf_counter() + logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s") + + if len(relation_search_res) != 0: + logger.info("找到相关关系,将使用RAG进行检索") + # 使用KG检索 part_start_time = time.perf_counter() - paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k( - question_embedding, - global_config.lpmm_knowledge.qa_paragraph_search_top_k, + result, ppr_node_weights = self.kg_manager.kg_search( + relation_search_res, paragraph_search_res, self.embed_manager ) part_end_time = time.perf_counter() - logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s") - - if len(relation_search_res) != 0: - logger.info("找到相关关系,将使用RAG进行检索") - # 使用KG检索 - part_start_time = time.perf_counter() - result, ppr_node_weights = self.kg_manager.kg_search( - relation_search_res, paragraph_search_res, self.embed_manager - ) - part_end_time = time.perf_counter() - logger.info(f"RAG检索用时:{part_end_time - part_start_time:.5f}s") - else: - logger.info("未找到相关关系,将使用文段检索结果") - result = paragraph_search_res - ppr_node_weights = None - - # 过滤阈值 - result = dyn_select_top_k(result, 0.5, 1.0) - - for res in result: - raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str - print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n") - - return result, ppr_node_weights + logger.info(f"RAG检索用时:{part_end_time - part_start_time:.5f}s") else: - return None + logger.info("未找到相关关系,将使用文段检索结果") + result = paragraph_search_res + ppr_node_weights = None - async def get_knowledge(self, question: str) -> str: + # 过滤阈值 + result = dyn_select_top_k(result, 0.5, 1.0) + + for res in result: + raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str + print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n") + + return result, ppr_node_weights + + async def get_knowledge(self, question: str) -> Optional[str]: """获取知识""" # 处理查询 processed_result = await self.process_query(question)