tools系统
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
from typing import Any, List, Tuple
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -17,8 +17,8 @@ class BaseTool(ABC):
|
||||
"""工具的名称"""
|
||||
description: str = ""
|
||||
"""工具的描述"""
|
||||
parameters: Dict[str, Any] = {}
|
||||
"""工具的参数定义"""
|
||||
parameters: List[Tuple[str, str, str, bool]] = []
|
||||
"""工具的参数定义,为[("param_name", "param_type", "description", required)]"""
|
||||
available_for_llm: bool = False
|
||||
"""是否可供LLM使用"""
|
||||
|
||||
@@ -32,10 +32,7 @@ class BaseTool(ABC):
|
||||
if not cls.name or not cls.description or not cls.parameters:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
|
||||
}
|
||||
return {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
|
||||
|
||||
@classmethod
|
||||
def get_tool_info(cls) -> ToolInfo:
|
||||
@@ -79,7 +76,9 @@ class BaseTool(ABC):
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
if self.parameters and (missing := [p for p in self.parameters.get("required", []) if p not in function_args]):
|
||||
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {', '.join(missing)}")
|
||||
parameter_required = [param[0] for param in self.parameters if param[3]] # 获取所有必填参数名
|
||||
for param_name in parameter_required:
|
||||
if param_name not in function_args:
|
||||
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {param_name}")
|
||||
|
||||
return await self.execute(function_args)
|
||||
|
||||
Reference in New Issue
Block a user