加上tools的enum属性
This commit is contained in:
@@ -192,17 +192,6 @@ def _build_stream_api_resp(
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
async def _to_async_iterable(iterable: Iterable[T]) -> AsyncIterator[T]:
|
|
||||||
"""
|
|
||||||
将迭代器转换为异步迭代器
|
|
||||||
:param iterable: 迭代器对象
|
|
||||||
:return: 异步迭代器对象
|
|
||||||
"""
|
|
||||||
for item in iterable:
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
yield item
|
|
||||||
|
|
||||||
|
|
||||||
async def _default_stream_response_handler(
|
async def _default_stream_response_handler(
|
||||||
resp_stream: AsyncIterator[GenerateContentResponse],
|
resp_stream: AsyncIterator[GenerateContentResponse],
|
||||||
interrupt_flag: asyncio.Event | None,
|
interrupt_flag: asyncio.Event | None,
|
||||||
|
|||||||
@@ -94,16 +94,19 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]
|
|||||||
:return: 转换后的工具选项列表
|
:return: 转换后的工具选项列表
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _convert_tool_param(tool_option_param: ToolParam) -> dict[str, str]:
|
def _convert_tool_param(tool_option_param: ToolParam) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
转换单个工具参数格式
|
转换单个工具参数格式
|
||||||
:param tool_option_param: 工具参数对象
|
:param tool_option_param: 工具参数对象
|
||||||
:return: 转换后的工具参数字典
|
:return: 转换后的工具参数字典
|
||||||
"""
|
"""
|
||||||
return {
|
return_dict: dict[str, Any] = {
|
||||||
"type": tool_option_param.param_type.value,
|
"type": tool_option_param.param_type.value,
|
||||||
"description": tool_option_param.description,
|
"description": tool_option_param.description,
|
||||||
}
|
}
|
||||||
|
if tool_option_param.enum_values:
|
||||||
|
return_dict["enum"] = tool_option_param.enum_values
|
||||||
|
return return_dict
|
||||||
|
|
||||||
def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]:
|
def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ class ToolParamType(Enum):
|
|||||||
工具调用参数类型
|
工具调用参数类型
|
||||||
"""
|
"""
|
||||||
|
|
||||||
String = "string" # 字符串
|
STRING = "string" # 字符串
|
||||||
Int = "integer" # 整型
|
INTEGER = "integer" # 整型
|
||||||
Float = "float" # 浮点型
|
FLOAT = "float" # 浮点型
|
||||||
Boolean = "bool" # 布尔型
|
BOOLEAN = "bool" # 布尔型
|
||||||
|
|
||||||
|
|
||||||
class ToolParam:
|
class ToolParam:
|
||||||
@@ -18,7 +18,12 @@ class ToolParam:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, name: str, param_type: ToolParamType, description: str, required: bool
|
self,
|
||||||
|
name: str,
|
||||||
|
param_type: ToolParamType,
|
||||||
|
description: str,
|
||||||
|
required: bool,
|
||||||
|
enum_values: list[str] | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化工具调用参数
|
初始化工具调用参数
|
||||||
@@ -32,6 +37,7 @@ class ToolParam:
|
|||||||
self.param_type: ToolParamType = param_type
|
self.param_type: ToolParamType = param_type
|
||||||
self.description: str = description
|
self.description: str = description
|
||||||
self.required: bool = required
|
self.required: bool = required
|
||||||
|
self.enum_values: list[str] | None = enum_values
|
||||||
|
|
||||||
|
|
||||||
class ToolOption:
|
class ToolOption:
|
||||||
@@ -95,6 +101,7 @@ class ToolOptionBuilder:
|
|||||||
param_type: ToolParamType,
|
param_type: ToolParamType,
|
||||||
description: str,
|
description: str,
|
||||||
required: bool = False,
|
required: bool = False,
|
||||||
|
enum_values: list[str] | None = None,
|
||||||
) -> "ToolOptionBuilder":
|
) -> "ToolOptionBuilder":
|
||||||
"""
|
"""
|
||||||
添加工具参数
|
添加工具参数
|
||||||
@@ -113,6 +120,7 @@ class ToolOptionBuilder:
|
|||||||
param_type=param_type,
|
param_type=param_type,
|
||||||
description=description,
|
description=description,
|
||||||
required=required,
|
required=required,
|
||||||
|
enum_values=enum_values,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -77,7 +77,9 @@ class LLMRequest:
|
|||||||
# 请求体构建
|
# 请求体构建
|
||||||
message_builder = MessageBuilder()
|
message_builder = MessageBuilder()
|
||||||
message_builder.add_text_content(prompt)
|
message_builder.add_text_content(prompt)
|
||||||
message_builder.add_image_content(image_base64=image_base64, image_format=image_format, support_formats=client.get_support_image_formats())
|
message_builder.add_image_content(
|
||||||
|
image_base64=image_base64, image_format=image_format, support_formats=client.get_support_image_formats()
|
||||||
|
)
|
||||||
messages = [message_builder.build()]
|
messages = [message_builder.build()]
|
||||||
|
|
||||||
# 请求并处理返回值
|
# 请求并处理返回值
|
||||||
@@ -458,6 +460,7 @@ class LLMRequest:
|
|||||||
return -1, None
|
return -1, None
|
||||||
|
|
||||||
def _build_tool_options(self, tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
|
def _build_tool_options(self, tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
|
||||||
|
# sourcery skip: extract-method
|
||||||
"""构建工具选项列表"""
|
"""构建工具选项列表"""
|
||||||
if not tools:
|
if not tools:
|
||||||
return None
|
return None
|
||||||
@@ -467,18 +470,25 @@ class LLMRequest:
|
|||||||
tool_options_builder = ToolOptionBuilder()
|
tool_options_builder = ToolOptionBuilder()
|
||||||
tool_options_builder.set_name(tool.get("name", ""))
|
tool_options_builder.set_name(tool.get("name", ""))
|
||||||
tool_options_builder.set_description(tool.get("description", ""))
|
tool_options_builder.set_description(tool.get("description", ""))
|
||||||
parameters: List[Tuple[str, str, str, bool]] = tool.get("parameters", [])
|
parameters: List[Tuple[str, str, str, bool, List[str] | None]] = tool.get("parameters", [])
|
||||||
for param in parameters:
|
for param in parameters:
|
||||||
try:
|
try:
|
||||||
|
assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组"
|
||||||
|
assert isinstance(param[0], str), "参数名称必须是字符串"
|
||||||
|
assert isinstance(param[1], ToolParamType), "参数类型必须是ToolParamType枚举"
|
||||||
|
assert isinstance(param[2], str), "参数描述必须是字符串"
|
||||||
|
assert isinstance(param[3], bool), "参数是否必填必须是布尔值"
|
||||||
|
assert isinstance(param[4], list) or param[4] is None, "参数枚举值必须是列表或None"
|
||||||
tool_options_builder.add_param(
|
tool_options_builder.add_param(
|
||||||
name=param[0],
|
name=param[0],
|
||||||
param_type=ToolParamType(param[1]),
|
param_type=param[1],
|
||||||
description=param[2],
|
description=param[2],
|
||||||
required=param[3],
|
required=param[3],
|
||||||
|
enum_values=param[4],
|
||||||
)
|
)
|
||||||
except ValueError as ve:
|
except AssertionError as ae:
|
||||||
tool_legal = False
|
tool_legal = False
|
||||||
logger.error(f"{param[1]} 参数类型错误: {str(ve)}")
|
logger.error(f"{param[0]} 参数定义错误: {str(ae)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tool_legal = False
|
tool_legal = False
|
||||||
logger.error(f"构建工具参数失败: {str(e)}")
|
logger.error(f"构建工具参数失败: {str(e)}")
|
||||||
|
|||||||
@@ -18,11 +18,13 @@ from .base import (
|
|||||||
ActionInfo,
|
ActionInfo,
|
||||||
CommandInfo,
|
CommandInfo,
|
||||||
PluginInfo,
|
PluginInfo,
|
||||||
|
ToolInfo,
|
||||||
PythonDependency,
|
PythonDependency,
|
||||||
BaseEventHandler,
|
BaseEventHandler,
|
||||||
EventHandlerInfo,
|
EventHandlerInfo,
|
||||||
EventType,
|
EventType,
|
||||||
MaiMessages,
|
MaiMessages,
|
||||||
|
ToolParamType,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 导入工具模块
|
# 导入工具模块
|
||||||
@@ -83,9 +85,11 @@ __all__ = [
|
|||||||
"ActionInfo",
|
"ActionInfo",
|
||||||
"CommandInfo",
|
"CommandInfo",
|
||||||
"PluginInfo",
|
"PluginInfo",
|
||||||
|
"ToolInfo",
|
||||||
"PythonDependency",
|
"PythonDependency",
|
||||||
"EventHandlerInfo",
|
"EventHandlerInfo",
|
||||||
"EventType",
|
"EventType",
|
||||||
|
"ToolParamType",
|
||||||
# 消息
|
# 消息
|
||||||
"MaiMessages",
|
"MaiMessages",
|
||||||
# 装饰器
|
# 装饰器
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from .component_types import (
|
|||||||
EventHandlerInfo,
|
EventHandlerInfo,
|
||||||
EventType,
|
EventType,
|
||||||
MaiMessages,
|
MaiMessages,
|
||||||
|
ToolParamType,
|
||||||
)
|
)
|
||||||
from .config_types import ConfigField
|
from .config_types import ConfigField
|
||||||
|
|
||||||
@@ -44,4 +45,5 @@ __all__ = [
|
|||||||
"EventType",
|
"EventType",
|
||||||
"BaseEventHandler",
|
"BaseEventHandler",
|
||||||
"MaiMessages",
|
"MaiMessages",
|
||||||
|
"ToolParamType",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Any, List, Tuple
|
|||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.component_types import ComponentType, ToolInfo
|
from src.plugin_system.base.component_types import ComponentType, ToolInfo, ToolParamType
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -17,8 +17,15 @@ class BaseTool(ABC):
|
|||||||
"""工具的名称"""
|
"""工具的名称"""
|
||||||
description: str = ""
|
description: str = ""
|
||||||
"""工具的描述"""
|
"""工具的描述"""
|
||||||
parameters: List[Tuple[str, str, str, bool]] = []
|
parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = []
|
||||||
"""工具的参数定义,为[("param_name", "param_type", "description", required)]"""
|
"""工具的参数定义,为[("param_name", param_type, "description", required, enum_values)]格式
|
||||||
|
param_name: 参数名称
|
||||||
|
param_type: 参数类型
|
||||||
|
description: 参数描述
|
||||||
|
required: 是否必填
|
||||||
|
enum_values: 枚举值列表
|
||||||
|
例如: [("arg1", ToolParamType.STRING, "参数1描述", True, None), ("arg2", ToolParamType.INTEGER, "参数2描述", False, ["1", "2", "3"])]
|
||||||
|
"""
|
||||||
available_for_llm: bool = False
|
available_for_llm: bool = False
|
||||||
"""是否可供LLM使用"""
|
"""是否可供LLM使用"""
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from typing import Dict, Any, List, Optional, Tuple
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from maim_message import Seg
|
from maim_message import Seg
|
||||||
|
|
||||||
|
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
|
||||||
|
|
||||||
# 组件类型枚举
|
# 组件类型枚举
|
||||||
class ComponentType(Enum):
|
class ComponentType(Enum):
|
||||||
@@ -145,17 +146,19 @@ class CommandInfo(ComponentInfo):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
self.component_type = ComponentType.COMMAND
|
self.component_type = ComponentType.COMMAND
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolInfo(ComponentInfo):
|
class ToolInfo(ComponentInfo):
|
||||||
"""工具组件信息"""
|
"""工具组件信息"""
|
||||||
|
|
||||||
tool_parameters: List[Tuple[str, str, str, bool]] = field(default_factory=list) # 工具参数定义
|
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(default_factory=list) # 工具参数定义
|
||||||
tool_description: str = "" # 工具描述
|
tool_description: str = "" # 工具描述
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
self.component_type = ComponentType.TOOL
|
self.component_type = ComponentType.TOOL
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EventHandlerInfo(ComponentInfo):
|
class EventHandlerInfo(ComponentInfo):
|
||||||
|
|||||||
Reference in New Issue
Block a user