diff --git a/docs/plugins/tool-components.md b/docs/plugins/tool-components.md index cd48a0541..059656aa4 100644 --- a/docs/plugins/tool-components.md +++ b/docs/plugins/tool-components.md @@ -24,7 +24,7 @@ 每个工具必须继承 `BaseTool` 基类并实现以下属性和方法: ```python -from src.plugin_system import BaseTool +from src.plugin_system import BaseTool, ToolParamType class MyTool(BaseTool): # 工具名称,必须唯一 @@ -45,13 +45,14 @@ class MyTool(BaseTool): # "limit": { # "type": "integer", # "description": "结果数量限制" + # "enum": [10, 20, 50] # 可选值 # } # }, # "required": ["query"] # } parameters = [ - ("query", "string", "查询参数", True), # 必填参数 - ("limit", "integer", "结果数量限制", False) # 可选参数 + ("query", ToolParamType.STRING, "查询参数", True, None), # 必填参数 + ("limit", ToolParamType.INTEGER, "结果数量限制", False, ["10", "20", "50"]) # 可选参数 ] available_for_llm = True # 是否对LLM可用 @@ -104,8 +105,8 @@ class WeatherTool(BaseTool): description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况等" available_for_llm = True # 允许LLM调用此工具 parameters = [ - ("city", "string", "要查询天气的城市名称,如:北京、上海、纽约", True), - ("country", "string", "国家代码,如:CN、US,可选参数", False) + ("city", ToolParamType.STRING, "要查询天气的城市名称,如:北京、上海、纽约", True, None), + ("country", ToolParamType.STRING, "国家代码,如:CN、US,可选参数", False, None) ] async def execute(self, function_args: dict): @@ -214,8 +215,8 @@ description = "获取信息" # 不够具体 #### ✅ 合理的参数设计 ```python parameters = [ - ("city", "string", "城市名称,如:北京、上海", True), - ("unit", "string", "温度单位:celsius 或 fahrenheit", False) + ("city", ToolParamType.STRING, "城市名称,如:北京、上海", True, None), + ("unit", ToolParamType.STRING, "温度单位:celsius 或 fahrenheit", False, ["celsius", "fahrenheit"]) ] ``` #### ❌ 避免的参数设计 diff --git a/src/plugins/built_in/knowledge/get_knowledge.py b/src/plugins/built_in/knowledge/get_knowledge.py index 54f93cddf..ce90cb680 100644 --- a/src/plugins/built_in/knowledge/get_knowledge.py +++ b/src/plugins/built_in/knowledge/get_knowledge.py @@ -1,10 +1,12 @@ -from src.plugin_system.base.base_tool import BaseTool +import json # Added for parsing embedding +import math # Added for cosine similarity +from typing import Any, Union, List # Added List + from src.chat.utils.utils import get_embedding from src.common.database.database_model import Knowledges # Updated import from src.common.logger import get_logger -from typing import Any, Union, List # Added List -import json # Added for parsing embedding -import math # Added for cosine similarity +from src.plugin_system import BaseTool, ToolParamType + logger = get_logger("get_knowledge_tool") @@ -15,8 +17,8 @@ class SearchKnowledgeTool(BaseTool): name = "search_knowledge" description = "使用工具从知识库中搜索相关信息" parameters = [ - ("query", "string", "搜索查询关键词", True), - ("threshold", "float", "相似度阈值,0.0到1.0之间", False), + ("query", ToolParamType.STRING, "搜索查询关键词", True, None), + ("threshold", ToolParamType.FLOAT, "相似度阈值,0.0到1.0之间", False, None), ] async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: diff --git a/src/plugins/built_in/knowledge/lpmm_get_knowledge.py b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py index ef74add92..da20c348b 100644 --- a/src/plugins/built_in/knowledge/lpmm_get_knowledge.py +++ b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py @@ -1,10 +1,8 @@ -from src.plugin_system.base.base_tool import BaseTool - -# from src.common.database import db -from src.common.logger import get_logger from typing import Dict, Any -from src.chat.knowledge.knowledge_lib import qa_manager +from src.common.logger import get_logger +from src.chat.knowledge.knowledge_lib import qa_manager +from src.plugin_system import BaseTool, ToolParamType logger = get_logger("lpmm_get_knowledge_tool") @@ -15,8 +13,8 @@ class SearchKnowledgeFromLPMMTool(BaseTool): name = "lpmm_search_knowledge" description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具" parameters = [ - ("query", "string", "搜索查询关键词", True), - ("threshold", "float", "相似度阈值,0.0到1.0之间", False), + ("query", ToolParamType.STRING, "搜索查询关键词", True, None), + ("threshold", ToolParamType.FLOAT, "相似度阈值,0.0到1.0之间", False, None), ] async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: