fix typing

This commit is contained in:
UnCLAS-Prommer
2025-08-03 13:08:28 +08:00
parent a5631fd23a
commit 44f53213af
2 changed files with 49 additions and 49 deletions

View File

@@ -11,6 +11,7 @@ from src.plugin_system import (
BaseEventHandler, BaseEventHandler,
EventType, EventType,
MaiMessages, MaiMessages,
ToolParamType
) )
@@ -20,8 +21,8 @@ class CompareNumbersTool(BaseTool):
name = "compare_numbers" name = "compare_numbers"
description = "使用工具 比较两个数的大小,返回较大的数" description = "使用工具 比较两个数的大小,返回较大的数"
parameters = [ parameters = [
("num1", "number", "第一个数字", True), ("num1", ToolParamType.FLOAT, "第一个数字", True, None),
("num2", "number", "第二个数字", True), ("num2", ToolParamType.FLOAT, "第二个数字", True, None),
] ]
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:

View File

@@ -28,7 +28,7 @@ class QAManager:
self.kg_manager = kg_manager self.kg_manager = kg_manager
self.qa_model = LLMRequest(model_set=model_config.model_task_config.lpmm_qa, request_type="lpmm.qa") self.qa_model = LLMRequest(model_set=model_config.model_task_config.lpmm_qa, request_type="lpmm.qa")
async def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]: async def process_query(self, question: str) -> Optional[Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]]:
"""处理查询""" """处理查询"""
# 生成问题的Embedding # 生成问题的Embedding
@@ -46,7 +46,8 @@ class QAManager:
question_embedding, question_embedding,
global_config.lpmm_knowledge.qa_relation_search_top_k, global_config.lpmm_knowledge.qa_relation_search_top_k,
) )
if relation_search_res is not None: if relation_search_res is None:
return None
# 过滤阈值 # 过滤阈值
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果 # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0) relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
@@ -97,10 +98,8 @@ class QAManager:
print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n") print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
return result, ppr_node_weights return result, ppr_node_weights
else:
return None
async def get_knowledge(self, question: str) -> str: async def get_knowledge(self, question: str) -> Optional[str]:
"""获取知识""" """获取知识"""
# 处理查询 # 处理查询
processed_result = await self.process_query(question) processed_result = await self.process_query(question)