tools系统
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
raise DeprecationWarning("Genimi Client is not fully available yet.")
|
||||
raise DeprecationWarning("Genimi Client is not fully available yet. Please remove your Gemini API Provider")
|
||||
import asyncio
|
||||
import io
|
||||
from collections.abc import Iterable
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .tool_option import ToolCall
|
||||
|
||||
__all__ = ["ToolCall"]
|
||||
@@ -11,7 +11,7 @@ from src.config.config import model_config
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
|
||||
from .payload_content.message import MessageBuilder, Message
|
||||
from .payload_content.resp_format import RespFormat
|
||||
from .payload_content.tool_option import ToolOption, ToolCall
|
||||
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType
|
||||
from .model_client.base_client import BaseClient, APIResponse, client_registry
|
||||
from .utils import compress_messages, llm_usage_recorder
|
||||
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
|
||||
@@ -60,7 +60,7 @@ class LLMRequest:
|
||||
image_format: str,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]:
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||
"""
|
||||
为图像生成响应
|
||||
Args:
|
||||
@@ -68,7 +68,7 @@ class LLMRequest:
|
||||
image_base64 (str): 图像的Base64编码字符串
|
||||
image_format (str): 图像格式(如 'png', 'jpeg' 等)
|
||||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
# 请求体构建
|
||||
message_builder = MessageBuilder()
|
||||
@@ -104,18 +104,18 @@ class LLMRequest:
|
||||
request_type=self.request_type,
|
||||
endpoint="/chat/completions",
|
||||
)
|
||||
return content, (
|
||||
reasoning_content,
|
||||
model_info.name,
|
||||
self._convert_tool_calls(tool_calls) if tool_calls else None,
|
||||
)
|
||||
return content, (reasoning_content, model_info.name, tool_calls)
|
||||
|
||||
async def generate_response_for_voice(self):
|
||||
pass
|
||||
|
||||
async def generate_response_async(
|
||||
self, prompt: str, temperature: Optional[float] = None, max_tokens: Optional[int] = None
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]:
|
||||
self,
|
||||
prompt: str,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||
"""
|
||||
异步生成响应
|
||||
Args:
|
||||
@@ -123,13 +123,13 @@ class LLMRequest:
|
||||
temperature (float, optional): 温度参数
|
||||
max_tokens (int, optional): 最大token数
|
||||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
# 请求体构建
|
||||
message_builder = MessageBuilder()
|
||||
message_builder.add_text_content(prompt)
|
||||
messages = [message_builder.build()]
|
||||
|
||||
tool_built = self._build_tool_options(tools)
|
||||
# 模型选择
|
||||
model_info, api_provider, client = self._select_model()
|
||||
|
||||
@@ -142,6 +142,7 @@ class LLMRequest:
|
||||
message_list=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
tool_options=tool_built,
|
||||
)
|
||||
content = response.content
|
||||
reasoning_content = response.reasoning_content or ""
|
||||
@@ -161,11 +162,7 @@ class LLMRequest:
|
||||
if not content:
|
||||
raise RuntimeError("获取LLM生成内容失败")
|
||||
|
||||
return content, (
|
||||
reasoning_content,
|
||||
model_info.name,
|
||||
self._convert_tool_calls(tool_calls) if tool_calls else None,
|
||||
)
|
||||
return content, (reasoning_content, model_info.name, tool_calls)
|
||||
|
||||
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
||||
"""获取嵌入向量
|
||||
@@ -214,10 +211,6 @@ class LLMRequest:
|
||||
client = client_registry.get_client_class(api_provider.client_type)(copy.deepcopy(api_provider))
|
||||
return model_info, api_provider, client
|
||||
|
||||
def _convert_tool_calls(self, tool_calls: List[ToolCall]) -> List[Dict[str, Any]]:
|
||||
"""将ToolCall对象转换为Dict列表"""
|
||||
pass
|
||||
|
||||
async def _execute_request(
|
||||
self,
|
||||
api_provider: APIProvider,
|
||||
@@ -435,6 +428,35 @@ class LLMRequest:
|
||||
)
|
||||
return -1, None
|
||||
|
||||
def _build_tool_options(self, tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
|
||||
"""构建工具选项列表"""
|
||||
if not tools:
|
||||
return None
|
||||
tool_options: List[ToolOption] = []
|
||||
for tool in tools:
|
||||
tool_legal = True
|
||||
tool_options_builder = ToolOptionBuilder()
|
||||
tool_options_builder.set_name(tool.get("name", ""))
|
||||
tool_options_builder.set_description(tool.get("description", ""))
|
||||
parameters: List[Tuple[str, str, str, bool]] = tool.get("parameters", [])
|
||||
for param in parameters:
|
||||
try:
|
||||
tool_options_builder.add_param(
|
||||
name=param[0],
|
||||
param_type=ToolParamType(param[1]),
|
||||
description=param[2],
|
||||
required=param[3],
|
||||
)
|
||||
except ValueError as ve:
|
||||
tool_legal = False
|
||||
logger.error(f"{param[1]} 参数类型错误: {str(ve)}")
|
||||
except Exception as e:
|
||||
tool_legal = False
|
||||
logger.error(f"构建工具参数失败: {str(e)}")
|
||||
if tool_legal:
|
||||
tool_options.append(tool_options_builder.build())
|
||||
return tool_options or None
|
||||
|
||||
@staticmethod
|
||||
def _extract_reasoning(content: str) -> Tuple[str, str]:
|
||||
"""CoT思维链提取,向后兼容"""
|
||||
|
||||
Reference in New Issue
Block a user