让Gemini的图像可用,修复部分typing
This commit is contained in:
@@ -1,9 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable, Any
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Callable, Any, Optional
|
||||||
from openai import AsyncStream
|
|
||||||
from openai.types.chat import ChatCompletionChunk, ChatCompletion
|
|
||||||
|
|
||||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||||
from ..payload_content.message import Message
|
from ..payload_content.message import Message
|
||||||
@@ -58,7 +56,7 @@ class APIResponse:
|
|||||||
"""响应原始数据"""
|
"""响应原始数据"""
|
||||||
|
|
||||||
|
|
||||||
class BaseClient:
|
class BaseClient(ABC):
|
||||||
"""
|
"""
|
||||||
基础客户端
|
基础客户端
|
||||||
"""
|
"""
|
||||||
@@ -68,6 +66,7 @@ class BaseClient:
|
|||||||
def __init__(self, api_provider: APIProvider):
|
def __init__(self, api_provider: APIProvider):
|
||||||
self.api_provider = api_provider
|
self.api_provider = api_provider
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def get_response(
|
async def get_response(
|
||||||
self,
|
self,
|
||||||
model_info: ModelInfo,
|
model_info: ModelInfo,
|
||||||
@@ -76,12 +75,10 @@ class BaseClient:
|
|||||||
max_tokens: int = 1024,
|
max_tokens: int = 1024,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
response_format: RespFormat | None = None,
|
response_format: RespFormat | None = None,
|
||||||
stream_response_handler: Callable[
|
stream_response_handler: Optional[
|
||||||
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
|
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
|
||||||
tuple[APIResponse, tuple[int, int, int]],
|
] = None,
|
||||||
]
|
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
|
||||||
| None = None,
|
|
||||||
async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]] | None = None,
|
|
||||||
interrupt_flag: asyncio.Event | None = None,
|
interrupt_flag: asyncio.Event | None = None,
|
||||||
extra_params: dict[str, Any] | None = None,
|
extra_params: dict[str, Any] | None = None,
|
||||||
) -> APIResponse:
|
) -> APIResponse:
|
||||||
@@ -98,8 +95,9 @@ class BaseClient:
|
|||||||
:param interrupt_flag: 中断信号量(可选,默认为None)
|
:param interrupt_flag: 中断信号量(可选,默认为None)
|
||||||
:return: (响应文本, 推理文本, 工具调用, 其他数据)
|
:return: (响应文本, 推理文本, 工具调用, 其他数据)
|
||||||
"""
|
"""
|
||||||
raise RuntimeError("This method should be overridden in subclasses")
|
raise NotImplementedError("'get_response' method should be overridden in subclasses")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def get_embedding(
|
async def get_embedding(
|
||||||
self,
|
self,
|
||||||
model_info: ModelInfo,
|
model_info: ModelInfo,
|
||||||
@@ -112,8 +110,9 @@ class BaseClient:
|
|||||||
:param embedding_input: 嵌入输入文本
|
:param embedding_input: 嵌入输入文本
|
||||||
:return: 嵌入响应
|
:return: 嵌入响应
|
||||||
"""
|
"""
|
||||||
raise RuntimeError("This method should be overridden in subclasses")
|
raise NotImplementedError("'get_embedding' method should be overridden in subclasses")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def get_audio_transcriptions(
|
async def get_audio_transcriptions(
|
||||||
self,
|
self,
|
||||||
model_info: ModelInfo,
|
model_info: ModelInfo,
|
||||||
@@ -127,7 +126,15 @@ class BaseClient:
|
|||||||
:extra_params: 附加的请求参数
|
:extra_params: 附加的请求参数
|
||||||
:return: 音频转录响应
|
:return: 音频转录响应
|
||||||
"""
|
"""
|
||||||
raise RuntimeError("This method should be overridden in subclasses")
|
raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_support_image_formats(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
获取支持的图片格式
|
||||||
|
:return: 支持的图片格式列表
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
|
||||||
|
|
||||||
|
|
||||||
class ClientRegistry:
|
class ClientRegistry:
|
||||||
@@ -137,7 +144,8 @@ class ClientRegistry:
|
|||||||
def register_client_class(self, client_type: str):
|
def register_client_class(self, client_type: str):
|
||||||
"""
|
"""
|
||||||
注册API客户端类
|
注册API客户端类
|
||||||
:param client_class: API客户端类
|
Args:
|
||||||
|
client_class: API客户端类
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(cls: type[BaseClient]) -> type[BaseClient]:
|
def decorator(cls: type[BaseClient]) -> type[BaseClient]:
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
raise DeprecationWarning("Genimi Client is not fully available yet. Please remove your Gemini API Provider")
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
|
import base64
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Callable, Iterator, TypeVar, AsyncIterator, Optional, Coroutine, Any
|
from typing import Callable, Iterator, TypeVar, AsyncIterator, Optional, Coroutine, Any, List
|
||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
@@ -17,7 +17,7 @@ from google.genai.errors import (
|
|||||||
|
|
||||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||||
|
|
||||||
from .base_client import APIResponse, UsageRecord, BaseClient
|
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
|
||||||
from ..exceptions import (
|
from ..exceptions import (
|
||||||
RespParseException,
|
RespParseException,
|
||||||
NetworkConnectionError,
|
NetworkConnectionError,
|
||||||
@@ -54,20 +54,21 @@ def _convert_messages(
|
|||||||
role = "user"
|
role = "user"
|
||||||
|
|
||||||
# 添加Content
|
# 添加Content
|
||||||
content: types.Part | list
|
|
||||||
if isinstance(message.content, str):
|
if isinstance(message.content, str):
|
||||||
content = types.Part.from_text(message.content)
|
content = [types.Part.from_text(text=message.content)]
|
||||||
elif isinstance(message.content, list):
|
elif isinstance(message.content, list):
|
||||||
content = []
|
content: List[types.Part] = []
|
||||||
for item in message.content:
|
for item in message.content:
|
||||||
if isinstance(item, tuple):
|
if isinstance(item, tuple):
|
||||||
content.append(types.Part.from_bytes(data=item[1], mime_type=f"image/{item[0].lower()}"))
|
content.append(
|
||||||
|
types.Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{item[0].lower()}")
|
||||||
|
)
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
content.append(types.Part.from_text(item))
|
content.append(types.Part.from_text(text=item))
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
||||||
|
|
||||||
return types.Content(role=role, content=content)
|
return types.Content(role=role, parts=content)
|
||||||
|
|
||||||
temp_list: list[types.Content] = []
|
temp_list: list[types.Content] = []
|
||||||
system_instructions: list[str] = []
|
system_instructions: list[str] = []
|
||||||
@@ -76,7 +77,7 @@ def _convert_messages(
|
|||||||
if isinstance(message.content, str):
|
if isinstance(message.content, str):
|
||||||
system_instructions.append(message.content)
|
system_instructions.append(message.content)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("你tm怎么往system里面塞图片base64?")
|
raise ValueError("你tm怎么往system里面塞图片base64?")
|
||||||
elif message.role == RoleType.Tool:
|
elif message.role == RoleType.Tool:
|
||||||
if not message.tool_call_id:
|
if not message.tool_call_id:
|
||||||
raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
||||||
@@ -135,9 +136,9 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar
|
|||||||
def _process_delta(
|
def _process_delta(
|
||||||
delta: GenerateContentResponse,
|
delta: GenerateContentResponse,
|
||||||
fc_delta_buffer: io.StringIO,
|
fc_delta_buffer: io.StringIO,
|
||||||
tool_calls_buffer: list[tuple[str, str, dict]],
|
tool_calls_buffer: list[tuple[str, str, dict[str, Any]]],
|
||||||
):
|
):
|
||||||
if not hasattr(delta, "candidates") or len(delta.candidates) == 0:
|
if not hasattr(delta, "candidates") or not delta.candidates:
|
||||||
raise RespParseException(delta, "响应解析失败,缺失candidates字段")
|
raise RespParseException(delta, "响应解析失败,缺失candidates字段")
|
||||||
|
|
||||||
if delta.text:
|
if delta.text:
|
||||||
@@ -148,11 +149,13 @@ def _process_delta(
|
|||||||
try:
|
try:
|
||||||
if not isinstance(call.args, dict): # gemini返回的function call参数就是dict格式的了
|
if not isinstance(call.args, dict): # gemini返回的function call参数就是dict格式的了
|
||||||
raise RespParseException(delta, "响应解析失败,工具调用参数无法解析为字典类型")
|
raise RespParseException(delta, "响应解析失败,工具调用参数无法解析为字典类型")
|
||||||
|
if not call.id or not call.name:
|
||||||
|
raise RespParseException(delta, "响应解析失败,工具调用缺失id或name字段")
|
||||||
tool_calls_buffer.append(
|
tool_calls_buffer.append(
|
||||||
(
|
(
|
||||||
call.id,
|
call.id,
|
||||||
call.name,
|
call.name,
|
||||||
call.args,
|
call.args or {}, # 如果args是None,则转换为一个空字典
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -201,7 +204,7 @@ async def _to_async_iterable(iterable: Iterable[T]) -> AsyncIterator[T]:
|
|||||||
|
|
||||||
|
|
||||||
async def _default_stream_response_handler(
|
async def _default_stream_response_handler(
|
||||||
resp_stream: Iterator[GenerateContentResponse],
|
resp_stream: AsyncIterator[GenerateContentResponse],
|
||||||
interrupt_flag: asyncio.Event | None,
|
interrupt_flag: asyncio.Event | None,
|
||||||
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
||||||
"""
|
"""
|
||||||
@@ -232,9 +235,9 @@ async def _default_stream_response_handler(
|
|||||||
if chunk.usage_metadata:
|
if chunk.usage_metadata:
|
||||||
# 如果有使用情况,则将其存储在APIResponse对象中
|
# 如果有使用情况,则将其存储在APIResponse对象中
|
||||||
_usage_record = (
|
_usage_record = (
|
||||||
chunk.usage_metadata.prompt_token_count,
|
chunk.usage_metadata.prompt_token_count or 0,
|
||||||
chunk.usage_metadata.candidates_token_count + chunk.usage_metadata.thoughts_token_count,
|
(chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0),
|
||||||
chunk.usage_metadata.total_token_count,
|
chunk.usage_metadata.total_token_count or 0,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
return _build_stream_api_resp(
|
return _build_stream_api_resp(
|
||||||
@@ -257,7 +260,7 @@ def _default_normal_response_parser(
|
|||||||
"""
|
"""
|
||||||
api_response = APIResponse()
|
api_response = APIResponse()
|
||||||
|
|
||||||
if not hasattr(resp, "candidates") or len(resp.candidates) == 0:
|
if not hasattr(resp, "candidates") or not resp.candidates:
|
||||||
raise RespParseException(resp, "响应解析失败,缺失candidates字段")
|
raise RespParseException(resp, "响应解析失败,缺失candidates字段")
|
||||||
|
|
||||||
if resp.text:
|
if resp.text:
|
||||||
@@ -269,15 +272,17 @@ def _default_normal_response_parser(
|
|||||||
try:
|
try:
|
||||||
if not isinstance(call.args, dict):
|
if not isinstance(call.args, dict):
|
||||||
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
|
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
|
||||||
api_response.tool_calls.append(ToolCall(call.id, call.name, call.args))
|
if not call.id or not call.name:
|
||||||
|
raise RespParseException(resp, "响应解析失败,工具调用缺失id或name字段")
|
||||||
|
api_response.tool_calls.append(ToolCall(call.id, call.name, call.args or {}))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e
|
raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e
|
||||||
|
|
||||||
if resp.usage_metadata:
|
if resp.usage_metadata:
|
||||||
_usage_record = (
|
_usage_record = (
|
||||||
resp.usage_metadata.prompt_token_count,
|
resp.usage_metadata.prompt_token_count or 0,
|
||||||
resp.usage_metadata.candidates_token_count + resp.usage_metadata.thoughts_token_count,
|
(resp.usage_metadata.candidates_token_count or 0) + (resp.usage_metadata.thoughts_token_count or 0),
|
||||||
resp.usage_metadata.total_token_count,
|
resp.usage_metadata.total_token_count or 0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_usage_record = None
|
_usage_record = None
|
||||||
@@ -287,6 +292,7 @@ def _default_normal_response_parser(
|
|||||||
return api_response, _usage_record
|
return api_response, _usage_record
|
||||||
|
|
||||||
|
|
||||||
|
@client_registry.register_client_class("gemini")
|
||||||
class GeminiClient(BaseClient):
|
class GeminiClient(BaseClient):
|
||||||
client: genai.Client
|
client: genai.Client
|
||||||
|
|
||||||
@@ -307,7 +313,7 @@ class GeminiClient(BaseClient):
|
|||||||
response_format: RespFormat | None = None,
|
response_format: RespFormat | None = None,
|
||||||
stream_response_handler: Optional[
|
stream_response_handler: Optional[
|
||||||
Callable[
|
Callable[
|
||||||
[Iterator[GenerateContentResponse], asyncio.Event | None],
|
[AsyncIterator[GenerateContentResponse], asyncio.Event | None],
|
||||||
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
|
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
|
||||||
]
|
]
|
||||||
] = None,
|
] = None,
|
||||||
@@ -398,7 +404,7 @@ class GeminiClient(BaseClient):
|
|||||||
resp, usage_record = async_response_parser(req_task.result())
|
resp, usage_record = async_response_parser(req_task.result())
|
||||||
except (ClientError, ServerError) as e:
|
except (ClientError, ServerError) as e:
|
||||||
# 重封装ClientError和ServerError为RespNotOkException
|
# 重封装ClientError和ServerError为RespNotOkException
|
||||||
raise RespNotOkException(e.status_code, e.message) from None
|
raise RespNotOkException(e.code, e.message) from None
|
||||||
except (
|
except (
|
||||||
UnknownFunctionCallArgumentError,
|
UnknownFunctionCallArgumentError,
|
||||||
UnsupportedFunctionError,
|
UnsupportedFunctionError,
|
||||||
@@ -438,14 +444,14 @@ class GeminiClient(BaseClient):
|
|||||||
)
|
)
|
||||||
except (ClientError, ServerError) as e:
|
except (ClientError, ServerError) as e:
|
||||||
# 重封装ClientError和ServerError为RespNotOkException
|
# 重封装ClientError和ServerError为RespNotOkException
|
||||||
raise RespNotOkException(e.status_code) from None
|
raise RespNotOkException(e.code) from None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise NetworkConnectionError() from e
|
raise NetworkConnectionError() from e
|
||||||
|
|
||||||
response = APIResponse()
|
response = APIResponse()
|
||||||
|
|
||||||
# 解析嵌入响应和使用情况
|
# 解析嵌入响应和使用情况
|
||||||
if hasattr(raw_response, "embeddings"):
|
if hasattr(raw_response, "embeddings") and raw_response.embeddings:
|
||||||
response.embedding = raw_response.embeddings[0].values
|
response.embedding = raw_response.embeddings[0].values
|
||||||
else:
|
else:
|
||||||
raise RespParseException(raw_response, "响应解析失败,缺失embeddings字段")
|
raise RespParseException(raw_response, "响应解析失败,缺失embeddings字段")
|
||||||
@@ -459,3 +465,10 @@ class GeminiClient(BaseClient):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def get_support_image_formats(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
获取支持的图片格式
|
||||||
|
:return: 支持的图片格式列表
|
||||||
|
"""
|
||||||
|
return ["png", "jpg", "jpeg", "webp", "heic", "heif"]
|
||||||
|
|||||||
@@ -286,9 +286,9 @@ async def _default_stream_response_handler(
|
|||||||
if event.usage:
|
if event.usage:
|
||||||
# 如果有使用情况,则将其存储在APIResponse对象中
|
# 如果有使用情况,则将其存储在APIResponse对象中
|
||||||
_usage_record = (
|
_usage_record = (
|
||||||
event.usage.prompt_tokens,
|
event.usage.prompt_tokens or 0,
|
||||||
event.usage.completion_tokens,
|
event.usage.completion_tokens or 0,
|
||||||
event.usage.total_tokens,
|
event.usage.total_tokens or 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -356,9 +356,9 @@ def _default_normal_response_parser(
|
|||||||
# 提取Usage信息
|
# 提取Usage信息
|
||||||
if resp.usage:
|
if resp.usage:
|
||||||
_usage_record = (
|
_usage_record = (
|
||||||
resp.usage.prompt_tokens,
|
resp.usage.prompt_tokens or 0,
|
||||||
resp.usage.completion_tokens,
|
resp.usage.completion_tokens or 0,
|
||||||
resp.usage.total_tokens,
|
resp.usage.total_tokens or 0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_usage_record = None
|
_usage_record = None
|
||||||
@@ -568,3 +568,10 @@ class OpenaiClient(BaseClient):
|
|||||||
"响应解析失败,缺失转录文本。",
|
"响应解析失败,缺失转录文本。",
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def get_support_image_formats(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
获取支持的图片格式
|
||||||
|
:return: 支持的图片格式列表
|
||||||
|
"""
|
||||||
|
return ["jpg", "jpeg", "png", "webp", "gif"]
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class RoleType(Enum):
|
|||||||
Tool = "tool"
|
Tool = "tool"
|
||||||
|
|
||||||
|
|
||||||
SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"]
|
SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] # openai支持的图片格式
|
||||||
|
|
||||||
|
|
||||||
class Message:
|
class Message:
|
||||||
@@ -53,9 +53,12 @@ class MessageBuilder:
|
|||||||
"""
|
"""
|
||||||
self.__content.append(text)
|
self.__content.append(text)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def add_image_content(
|
def add_image_content(
|
||||||
self, image_format: str, image_base64: str
|
self,
|
||||||
|
image_format: str,
|
||||||
|
image_base64: str,
|
||||||
|
support_formats: list[str] = SUPPORTED_IMAGE_FORMATS, # 默认支持格式
|
||||||
) -> "MessageBuilder":
|
) -> "MessageBuilder":
|
||||||
"""
|
"""
|
||||||
添加图片内容
|
添加图片内容
|
||||||
@@ -63,7 +66,7 @@ class MessageBuilder:
|
|||||||
:param image_base64: 图片的base64编码
|
:param image_base64: 图片的base64编码
|
||||||
:return: MessageBuilder对象
|
:return: MessageBuilder对象
|
||||||
"""
|
"""
|
||||||
if image_format.lower() not in SUPPORTED_IMAGE_FORMATS:
|
if image_format.lower() not in support_formats:
|
||||||
raise ValueError("不受支持的图片格式")
|
raise ValueError("不受支持的图片格式")
|
||||||
if not image_base64:
|
if not image_base64:
|
||||||
raise ValueError("图片的base64编码不能为空")
|
raise ValueError("图片的base64编码不能为空")
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ class RequestType(Enum):
|
|||||||
EMBEDDING = "embedding"
|
EMBEDDING = "embedding"
|
||||||
AUDIO = "audio"
|
AUDIO = "audio"
|
||||||
|
|
||||||
|
|
||||||
class LLMRequest:
|
class LLMRequest:
|
||||||
"""LLM请求类"""
|
"""LLM请求类"""
|
||||||
|
|
||||||
@@ -70,15 +71,15 @@ class LLMRequest:
|
|||||||
Returns:
|
Returns:
|
||||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||||
"""
|
"""
|
||||||
|
# 模型选择
|
||||||
|
model_info, api_provider, client = self._select_model()
|
||||||
|
|
||||||
# 请求体构建
|
# 请求体构建
|
||||||
message_builder = MessageBuilder()
|
message_builder = MessageBuilder()
|
||||||
message_builder.add_text_content(prompt)
|
message_builder.add_text_content(prompt)
|
||||||
message_builder.add_image_content(image_base64=image_base64, image_format=image_format)
|
message_builder.add_image_content(image_base64=image_base64, image_format=image_format, support_formats=client.get_support_image_formats())
|
||||||
messages = [message_builder.build()]
|
messages = [message_builder.build()]
|
||||||
|
|
||||||
# 模型选择
|
|
||||||
model_info, api_provider, client = self._select_model()
|
|
||||||
|
|
||||||
# 请求并处理返回值
|
# 请求并处理返回值
|
||||||
response = await self._execute_request(
|
response = await self._execute_request(
|
||||||
api_provider=api_provider,
|
api_provider=api_provider,
|
||||||
@@ -127,7 +128,6 @@ class LLMRequest:
|
|||||||
)
|
)
|
||||||
return response.content or None
|
return response.content or None
|
||||||
|
|
||||||
|
|
||||||
async def generate_response_async(
|
async def generate_response_async(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -245,7 +245,7 @@ class LLMRequest:
|
|||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
embedding_input: str = "",
|
embedding_input: str = "",
|
||||||
audio_base64: str = ""
|
audio_base64: str = "",
|
||||||
) -> APIResponse:
|
) -> APIResponse:
|
||||||
"""
|
"""
|
||||||
实际执行请求的方法
|
实际执行请求的方法
|
||||||
|
|||||||
Reference in New Issue
Block a user