From 1f53ecff1007c7b37119b972e31757f4e3ba8a82 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sun, 3 Aug 2025 10:27:47 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8A=A0=E4=B8=8Atools=E7=9A=84enum=E5=B1=9E?= =?UTF-8?q?=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm_models/model_client/gemini_client.py | 11 ---------- src/llm_models/model_client/openai_client.py | 7 +++++-- src/llm_models/payload_content/tool_option.py | 18 ++++++++++++----- src/llm_models/utils_model.py | 20 ++++++++++++++----- src/plugin_system/__init__.py | 4 ++++ src/plugin_system/base/__init__.py | 2 ++ src/plugin_system/base/base_tool.py | 13 +++++++++--- src/plugin_system/base/component_types.py | 9 ++++++--- 8 files changed, 55 insertions(+), 29 deletions(-) diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index 286f4648b..6a89cc0af 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -192,17 +192,6 @@ def _build_stream_api_resp( return resp -async def _to_async_iterable(iterable: Iterable[T]) -> AsyncIterator[T]: - """ - 将迭代器转换为异步迭代器 - :param iterable: 迭代器对象 - :return: 异步迭代器对象 - """ - for item in iterable: - await asyncio.sleep(0) - yield item - - async def _default_stream_response_handler( resp_stream: AsyncIterator[GenerateContentResponse], interrupt_flag: asyncio.Event | None, diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 7f097e2c0..ad9cbf177 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -94,16 +94,19 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any] :return: 转换后的工具选项列表 """ - def _convert_tool_param(tool_option_param: ToolParam) -> dict[str, str]: + def _convert_tool_param(tool_option_param: ToolParam) -> dict[str, Any]: """ 转换单个工具参数格式 :param tool_option_param: 工具参数对象 :return: 转换后的工具参数字典 """ - return { + return_dict: dict[str, Any] = { "type": tool_option_param.param_type.value, "description": tool_option_param.description, } + if tool_option_param.enum_values: + return_dict["enum"] = tool_option_param.enum_values + return return_dict def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]: """ diff --git a/src/llm_models/payload_content/tool_option.py b/src/llm_models/payload_content/tool_option.py index 8a9bbdb31..9fedbc86d 100644 --- a/src/llm_models/payload_content/tool_option.py +++ b/src/llm_models/payload_content/tool_option.py @@ -6,10 +6,10 @@ class ToolParamType(Enum): 工具调用参数类型 """ - String = "string" # 字符串 - Int = "integer" # 整型 - Float = "float" # 浮点型 - Boolean = "bool" # 布尔型 + STRING = "string" # 字符串 + INTEGER = "integer" # 整型 + FLOAT = "float" # 浮点型 + BOOLEAN = "bool" # 布尔型 class ToolParam: @@ -18,7 +18,12 @@ class ToolParam: """ def __init__( - self, name: str, param_type: ToolParamType, description: str, required: bool + self, + name: str, + param_type: ToolParamType, + description: str, + required: bool, + enum_values: list[str] | None = None, ): """ 初始化工具调用参数 @@ -32,6 +37,7 @@ class ToolParam: self.param_type: ToolParamType = param_type self.description: str = description self.required: bool = required + self.enum_values: list[str] | None = enum_values class ToolOption: @@ -95,6 +101,7 @@ class ToolOptionBuilder: param_type: ToolParamType, description: str, required: bool = False, + enum_values: list[str] | None = None, ) -> "ToolOptionBuilder": """ 添加工具参数 @@ -113,6 +120,7 @@ class ToolOptionBuilder: param_type=param_type, description=description, required=required, + enum_values=enum_values, ) ) diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index ad66252f6..d2a960f1d 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -77,7 +77,9 @@ class LLMRequest: # 请求体构建 message_builder = MessageBuilder() message_builder.add_text_content(prompt) - message_builder.add_image_content(image_base64=image_base64, image_format=image_format, support_formats=client.get_support_image_formats()) + message_builder.add_image_content( + image_base64=image_base64, image_format=image_format, support_formats=client.get_support_image_formats() + ) messages = [message_builder.build()] # 请求并处理返回值 @@ -458,6 +460,7 @@ class LLMRequest: return -1, None def _build_tool_options(self, tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]: + # sourcery skip: extract-method """构建工具选项列表""" if not tools: return None @@ -467,18 +470,25 @@ class LLMRequest: tool_options_builder = ToolOptionBuilder() tool_options_builder.set_name(tool.get("name", "")) tool_options_builder.set_description(tool.get("description", "")) - parameters: List[Tuple[str, str, str, bool]] = tool.get("parameters", []) + parameters: List[Tuple[str, str, str, bool, List[str] | None]] = tool.get("parameters", []) for param in parameters: try: + assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组" + assert isinstance(param[0], str), "参数名称必须是字符串" + assert isinstance(param[1], ToolParamType), "参数类型必须是ToolParamType枚举" + assert isinstance(param[2], str), "参数描述必须是字符串" + assert isinstance(param[3], bool), "参数是否必填必须是布尔值" + assert isinstance(param[4], list) or param[4] is None, "参数枚举值必须是列表或None" tool_options_builder.add_param( name=param[0], - param_type=ToolParamType(param[1]), + param_type=param[1], description=param[2], required=param[3], + enum_values=param[4], ) - except ValueError as ve: + except AssertionError as ae: tool_legal = False - logger.error(f"{param[1]} 参数类型错误: {str(ve)}") + logger.error(f"{param[0]} 参数定义错误: {str(ae)}") except Exception as e: tool_legal = False logger.error(f"构建工具参数失败: {str(e)}") diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index f8c71af42..a102ecd06 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -18,11 +18,13 @@ from .base import ( ActionInfo, CommandInfo, PluginInfo, + ToolInfo, PythonDependency, BaseEventHandler, EventHandlerInfo, EventType, MaiMessages, + ToolParamType, ) # 导入工具模块 @@ -83,9 +85,11 @@ __all__ = [ "ActionInfo", "CommandInfo", "PluginInfo", + "ToolInfo", "PythonDependency", "EventHandlerInfo", "EventType", + "ToolParamType", # 消息 "MaiMessages", # 装饰器 diff --git a/src/plugin_system/base/__init__.py b/src/plugin_system/base/__init__.py index b9a2893e4..bc63d35d1 100644 --- a/src/plugin_system/base/__init__.py +++ b/src/plugin_system/base/__init__.py @@ -22,6 +22,7 @@ from .component_types import ( EventHandlerInfo, EventType, MaiMessages, + ToolParamType, ) from .config_types import ConfigField @@ -44,4 +45,5 @@ __all__ = [ "EventType", "BaseEventHandler", "MaiMessages", + "ToolParamType", ] diff --git a/src/plugin_system/base/base_tool.py b/src/plugin_system/base/base_tool.py index 5b996d375..1d589eca9 100644 --- a/src/plugin_system/base/base_tool.py +++ b/src/plugin_system/base/base_tool.py @@ -3,7 +3,7 @@ from typing import Any, List, Tuple from rich.traceback import install from src.common.logger import get_logger -from src.plugin_system.base.component_types import ComponentType, ToolInfo +from src.plugin_system.base.component_types import ComponentType, ToolInfo, ToolParamType install(extra_lines=3) @@ -17,8 +17,15 @@ class BaseTool(ABC): """工具的名称""" description: str = "" """工具的描述""" - parameters: List[Tuple[str, str, str, bool]] = [] - """工具的参数定义,为[("param_name", "param_type", "description", required)]""" + parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = [] + """工具的参数定义,为[("param_name", param_type, "description", required, enum_values)]格式 + param_name: 参数名称 + param_type: 参数类型 + description: 参数描述 + required: 是否必填 + enum_values: 枚举值列表 + 例如: [("arg1", ToolParamType.STRING, "参数1描述", True, None), ("arg2", ToolParamType.INTEGER, "参数2描述", False, ["1", "2", "3"])] + """ available_for_llm: bool = False """是否可供LLM使用""" diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index 5ed75a7bb..7775f5fb8 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -3,6 +3,7 @@ from typing import Dict, Any, List, Optional, Tuple from dataclasses import dataclass, field from maim_message import Seg +from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType # 组件类型枚举 class ComponentType(Enum): @@ -145,17 +146,19 @@ class CommandInfo(ComponentInfo): def __post_init__(self): super().__post_init__() self.component_type = ComponentType.COMMAND - + + @dataclass class ToolInfo(ComponentInfo): """工具组件信息""" - tool_parameters: List[Tuple[str, str, str, bool]] = field(default_factory=list) # 工具参数定义 + tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(default_factory=list) # 工具参数定义 tool_description: str = "" # 工具描述 def __post_init__(self): super().__post_init__() - self.component_type = ComponentType.TOOL + self.component_type = ComponentType.TOOL + @dataclass class EventHandlerInfo(ComponentInfo):