文档和tool适配
This commit is contained in:
@@ -24,7 +24,7 @@
|
|||||||
|
|
||||||
每个工具必须继承 `BaseTool` 基类并实现以下属性和方法:
|
每个工具必须继承 `BaseTool` 基类并实现以下属性和方法:
|
||||||
```python
|
```python
|
||||||
from src.plugin_system import BaseTool
|
from src.plugin_system import BaseTool, ToolParamType
|
||||||
|
|
||||||
class MyTool(BaseTool):
|
class MyTool(BaseTool):
|
||||||
# 工具名称,必须唯一
|
# 工具名称,必须唯一
|
||||||
@@ -45,13 +45,14 @@ class MyTool(BaseTool):
|
|||||||
# "limit": {
|
# "limit": {
|
||||||
# "type": "integer",
|
# "type": "integer",
|
||||||
# "description": "结果数量限制"
|
# "description": "结果数量限制"
|
||||||
|
# "enum": [10, 20, 50] # 可选值
|
||||||
# }
|
# }
|
||||||
# },
|
# },
|
||||||
# "required": ["query"]
|
# "required": ["query"]
|
||||||
# }
|
# }
|
||||||
parameters = [
|
parameters = [
|
||||||
("query", "string", "查询参数", True), # 必填参数
|
("query", ToolParamType.STRING, "查询参数", True, None), # 必填参数
|
||||||
("limit", "integer", "结果数量限制", False) # 可选参数
|
("limit", ToolParamType.INTEGER, "结果数量限制", False, ["10", "20", "50"]) # 可选参数
|
||||||
]
|
]
|
||||||
|
|
||||||
available_for_llm = True # 是否对LLM可用
|
available_for_llm = True # 是否对LLM可用
|
||||||
@@ -104,8 +105,8 @@ class WeatherTool(BaseTool):
|
|||||||
description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况等"
|
description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况等"
|
||||||
available_for_llm = True # 允许LLM调用此工具
|
available_for_llm = True # 允许LLM调用此工具
|
||||||
parameters = [
|
parameters = [
|
||||||
("city", "string", "要查询天气的城市名称,如:北京、上海、纽约", True),
|
("city", ToolParamType.STRING, "要查询天气的城市名称,如:北京、上海、纽约", True, None),
|
||||||
("country", "string", "国家代码,如:CN、US,可选参数", False)
|
("country", ToolParamType.STRING, "国家代码,如:CN、US,可选参数", False, None)
|
||||||
]
|
]
|
||||||
|
|
||||||
async def execute(self, function_args: dict):
|
async def execute(self, function_args: dict):
|
||||||
@@ -214,8 +215,8 @@ description = "获取信息" # 不够具体
|
|||||||
#### ✅ 合理的参数设计
|
#### ✅ 合理的参数设计
|
||||||
```python
|
```python
|
||||||
parameters = [
|
parameters = [
|
||||||
("city", "string", "城市名称,如:北京、上海", True),
|
("city", ToolParamType.STRING, "城市名称,如:北京、上海", True, None),
|
||||||
("unit", "string", "温度单位:celsius 或 fahrenheit", False)
|
("unit", ToolParamType.STRING, "温度单位:celsius 或 fahrenheit", False, ["celsius", "fahrenheit"])
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
#### ❌ 避免的参数设计
|
#### ❌ 避免的参数设计
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
from src.plugin_system.base.base_tool import BaseTool
|
import json # Added for parsing embedding
|
||||||
|
import math # Added for cosine similarity
|
||||||
|
from typing import Any, Union, List # Added List
|
||||||
|
|
||||||
from src.chat.utils.utils import get_embedding
|
from src.chat.utils.utils import get_embedding
|
||||||
from src.common.database.database_model import Knowledges # Updated import
|
from src.common.database.database_model import Knowledges # Updated import
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from typing import Any, Union, List # Added List
|
from src.plugin_system import BaseTool, ToolParamType
|
||||||
import json # Added for parsing embedding
|
|
||||||
import math # Added for cosine similarity
|
|
||||||
|
|
||||||
logger = get_logger("get_knowledge_tool")
|
logger = get_logger("get_knowledge_tool")
|
||||||
|
|
||||||
@@ -15,8 +17,8 @@ class SearchKnowledgeTool(BaseTool):
|
|||||||
name = "search_knowledge"
|
name = "search_knowledge"
|
||||||
description = "使用工具从知识库中搜索相关信息"
|
description = "使用工具从知识库中搜索相关信息"
|
||||||
parameters = [
|
parameters = [
|
||||||
("query", "string", "搜索查询关键词", True),
|
("query", ToolParamType.STRING, "搜索查询关键词", True, None),
|
||||||
("threshold", "float", "相似度阈值,0.0到1.0之间", False),
|
("threshold", ToolParamType.FLOAT, "相似度阈值,0.0到1.0之间", False, None),
|
||||||
]
|
]
|
||||||
|
|
||||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
from src.plugin_system.base.base_tool import BaseTool
|
|
||||||
|
|
||||||
# from src.common.database import db
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||||
|
from src.plugin_system import BaseTool, ToolParamType
|
||||||
|
|
||||||
logger = get_logger("lpmm_get_knowledge_tool")
|
logger = get_logger("lpmm_get_knowledge_tool")
|
||||||
|
|
||||||
@@ -15,8 +13,8 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
|
|||||||
name = "lpmm_search_knowledge"
|
name = "lpmm_search_knowledge"
|
||||||
description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具"
|
description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具"
|
||||||
parameters = [
|
parameters = [
|
||||||
("query", "string", "搜索查询关键词", True),
|
("query", ToolParamType.STRING, "搜索查询关键词", True, None),
|
||||||
("threshold", "float", "相似度阈值,0.0到1.0之间", False),
|
("threshold", ToolParamType.FLOAT, "相似度阈值,0.0到1.0之间", False, None),
|
||||||
]
|
]
|
||||||
|
|
||||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
|||||||
Reference in New Issue
Block a user