让Gemini的图像可用,修复部分typing
This commit is contained in:
@@ -1,9 +1,7 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Any
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletion
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Any, Optional
|
||||
|
||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||
from ..payload_content.message import Message
|
||||
@@ -58,7 +56,7 @@ class APIResponse:
|
||||
"""响应原始数据"""
|
||||
|
||||
|
||||
class BaseClient:
|
||||
class BaseClient(ABC):
|
||||
"""
|
||||
基础客户端
|
||||
"""
|
||||
@@ -68,6 +66,7 @@ class BaseClient:
|
||||
def __init__(self, api_provider: APIProvider):
|
||||
self.api_provider = api_provider
|
||||
|
||||
@abstractmethod
|
||||
async def get_response(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
@@ -76,12 +75,10 @@ class BaseClient:
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 0.7,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Callable[
|
||||
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
|
||||
tuple[APIResponse, tuple[int, int, int]],
|
||||
]
|
||||
| None = None,
|
||||
async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]] | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
|
||||
] = None,
|
||||
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
|
||||
interrupt_flag: asyncio.Event | None = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
@@ -98,8 +95,9 @@ class BaseClient:
|
||||
:param interrupt_flag: 中断信号量(可选,默认为None)
|
||||
:return: (响应文本, 推理文本, 工具调用, 其他数据)
|
||||
"""
|
||||
raise RuntimeError("This method should be overridden in subclasses")
|
||||
raise NotImplementedError("'get_response' method should be overridden in subclasses")
|
||||
|
||||
@abstractmethod
|
||||
async def get_embedding(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
@@ -112,8 +110,9 @@ class BaseClient:
|
||||
:param embedding_input: 嵌入输入文本
|
||||
:return: 嵌入响应
|
||||
"""
|
||||
raise RuntimeError("This method should be overridden in subclasses")
|
||||
raise NotImplementedError("'get_embedding' method should be overridden in subclasses")
|
||||
|
||||
@abstractmethod
|
||||
async def get_audio_transcriptions(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
@@ -127,7 +126,15 @@ class BaseClient:
|
||||
:extra_params: 附加的请求参数
|
||||
:return: 音频转录响应
|
||||
"""
|
||||
raise RuntimeError("This method should be overridden in subclasses")
|
||||
raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses")
|
||||
|
||||
@abstractmethod
|
||||
def get_support_image_formats(self) -> list[str]:
|
||||
"""
|
||||
获取支持的图片格式
|
||||
:return: 支持的图片格式列表
|
||||
"""
|
||||
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
|
||||
|
||||
|
||||
class ClientRegistry:
|
||||
@@ -137,7 +144,8 @@ class ClientRegistry:
|
||||
def register_client_class(self, client_type: str):
|
||||
"""
|
||||
注册API客户端类
|
||||
:param client_class: API客户端类
|
||||
Args:
|
||||
client_class: API客户端类
|
||||
"""
|
||||
|
||||
def decorator(cls: type[BaseClient]) -> type[BaseClient]:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
raise DeprecationWarning("Genimi Client is not fully available yet. Please remove your Gemini API Provider")
|
||||
import asyncio
|
||||
import io
|
||||
import base64
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, Iterator, TypeVar, AsyncIterator, Optional, Coroutine, Any
|
||||
from typing import Callable, Iterator, TypeVar, AsyncIterator, Optional, Coroutine, Any, List
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
@@ -17,7 +17,7 @@ from google.genai.errors import (
|
||||
|
||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||
|
||||
from .base_client import APIResponse, UsageRecord, BaseClient
|
||||
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
|
||||
from ..exceptions import (
|
||||
RespParseException,
|
||||
NetworkConnectionError,
|
||||
@@ -54,20 +54,21 @@ def _convert_messages(
|
||||
role = "user"
|
||||
|
||||
# 添加Content
|
||||
content: types.Part | list
|
||||
if isinstance(message.content, str):
|
||||
content = types.Part.from_text(message.content)
|
||||
content = [types.Part.from_text(text=message.content)]
|
||||
elif isinstance(message.content, list):
|
||||
content = []
|
||||
content: List[types.Part] = []
|
||||
for item in message.content:
|
||||
if isinstance(item, tuple):
|
||||
content.append(types.Part.from_bytes(data=item[1], mime_type=f"image/{item[0].lower()}"))
|
||||
content.append(
|
||||
types.Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{item[0].lower()}")
|
||||
)
|
||||
elif isinstance(item, str):
|
||||
content.append(types.Part.from_text(item))
|
||||
content.append(types.Part.from_text(text=item))
|
||||
else:
|
||||
raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
||||
|
||||
return types.Content(role=role, content=content)
|
||||
return types.Content(role=role, parts=content)
|
||||
|
||||
temp_list: list[types.Content] = []
|
||||
system_instructions: list[str] = []
|
||||
@@ -76,7 +77,7 @@ def _convert_messages(
|
||||
if isinstance(message.content, str):
|
||||
system_instructions.append(message.content)
|
||||
else:
|
||||
raise RuntimeError("你tm怎么往system里面塞图片base64?")
|
||||
raise ValueError("你tm怎么往system里面塞图片base64?")
|
||||
elif message.role == RoleType.Tool:
|
||||
if not message.tool_call_id:
|
||||
raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
||||
@@ -135,9 +136,9 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar
|
||||
def _process_delta(
|
||||
delta: GenerateContentResponse,
|
||||
fc_delta_buffer: io.StringIO,
|
||||
tool_calls_buffer: list[tuple[str, str, dict]],
|
||||
tool_calls_buffer: list[tuple[str, str, dict[str, Any]]],
|
||||
):
|
||||
if not hasattr(delta, "candidates") or len(delta.candidates) == 0:
|
||||
if not hasattr(delta, "candidates") or not delta.candidates:
|
||||
raise RespParseException(delta, "响应解析失败,缺失candidates字段")
|
||||
|
||||
if delta.text:
|
||||
@@ -148,11 +149,13 @@ def _process_delta(
|
||||
try:
|
||||
if not isinstance(call.args, dict): # gemini返回的function call参数就是dict格式的了
|
||||
raise RespParseException(delta, "响应解析失败,工具调用参数无法解析为字典类型")
|
||||
if not call.id or not call.name:
|
||||
raise RespParseException(delta, "响应解析失败,工具调用缺失id或name字段")
|
||||
tool_calls_buffer.append(
|
||||
(
|
||||
call.id,
|
||||
call.name,
|
||||
call.args,
|
||||
call.args or {}, # 如果args是None,则转换为一个空字典
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -201,7 +204,7 @@ async def _to_async_iterable(iterable: Iterable[T]) -> AsyncIterator[T]:
|
||||
|
||||
|
||||
async def _default_stream_response_handler(
|
||||
resp_stream: Iterator[GenerateContentResponse],
|
||||
resp_stream: AsyncIterator[GenerateContentResponse],
|
||||
interrupt_flag: asyncio.Event | None,
|
||||
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
||||
"""
|
||||
@@ -232,9 +235,9 @@ async def _default_stream_response_handler(
|
||||
if chunk.usage_metadata:
|
||||
# 如果有使用情况,则将其存储在APIResponse对象中
|
||||
_usage_record = (
|
||||
chunk.usage_metadata.prompt_token_count,
|
||||
chunk.usage_metadata.candidates_token_count + chunk.usage_metadata.thoughts_token_count,
|
||||
chunk.usage_metadata.total_token_count,
|
||||
chunk.usage_metadata.prompt_token_count or 0,
|
||||
(chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0),
|
||||
chunk.usage_metadata.total_token_count or 0,
|
||||
)
|
||||
try:
|
||||
return _build_stream_api_resp(
|
||||
@@ -257,7 +260,7 @@ def _default_normal_response_parser(
|
||||
"""
|
||||
api_response = APIResponse()
|
||||
|
||||
if not hasattr(resp, "candidates") or len(resp.candidates) == 0:
|
||||
if not hasattr(resp, "candidates") or not resp.candidates:
|
||||
raise RespParseException(resp, "响应解析失败,缺失candidates字段")
|
||||
|
||||
if resp.text:
|
||||
@@ -269,15 +272,17 @@ def _default_normal_response_parser(
|
||||
try:
|
||||
if not isinstance(call.args, dict):
|
||||
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
|
||||
api_response.tool_calls.append(ToolCall(call.id, call.name, call.args))
|
||||
if not call.id or not call.name:
|
||||
raise RespParseException(resp, "响应解析失败,工具调用缺失id或name字段")
|
||||
api_response.tool_calls.append(ToolCall(call.id, call.name, call.args or {}))
|
||||
except Exception as e:
|
||||
raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e
|
||||
|
||||
if resp.usage_metadata:
|
||||
_usage_record = (
|
||||
resp.usage_metadata.prompt_token_count,
|
||||
resp.usage_metadata.candidates_token_count + resp.usage_metadata.thoughts_token_count,
|
||||
resp.usage_metadata.total_token_count,
|
||||
resp.usage_metadata.prompt_token_count or 0,
|
||||
(resp.usage_metadata.candidates_token_count or 0) + (resp.usage_metadata.thoughts_token_count or 0),
|
||||
resp.usage_metadata.total_token_count or 0,
|
||||
)
|
||||
else:
|
||||
_usage_record = None
|
||||
@@ -287,6 +292,7 @@ def _default_normal_response_parser(
|
||||
return api_response, _usage_record
|
||||
|
||||
|
||||
@client_registry.register_client_class("gemini")
|
||||
class GeminiClient(BaseClient):
|
||||
client: genai.Client
|
||||
|
||||
@@ -307,7 +313,7 @@ class GeminiClient(BaseClient):
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[
|
||||
[Iterator[GenerateContentResponse], asyncio.Event | None],
|
||||
[AsyncIterator[GenerateContentResponse], asyncio.Event | None],
|
||||
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
|
||||
]
|
||||
] = None,
|
||||
@@ -398,7 +404,7 @@ class GeminiClient(BaseClient):
|
||||
resp, usage_record = async_response_parser(req_task.result())
|
||||
except (ClientError, ServerError) as e:
|
||||
# 重封装ClientError和ServerError为RespNotOkException
|
||||
raise RespNotOkException(e.status_code, e.message) from None
|
||||
raise RespNotOkException(e.code, e.message) from None
|
||||
except (
|
||||
UnknownFunctionCallArgumentError,
|
||||
UnsupportedFunctionError,
|
||||
@@ -438,14 +444,14 @@ class GeminiClient(BaseClient):
|
||||
)
|
||||
except (ClientError, ServerError) as e:
|
||||
# 重封装ClientError和ServerError为RespNotOkException
|
||||
raise RespNotOkException(e.status_code) from None
|
||||
raise RespNotOkException(e.code) from None
|
||||
except Exception as e:
|
||||
raise NetworkConnectionError() from e
|
||||
|
||||
response = APIResponse()
|
||||
|
||||
# 解析嵌入响应和使用情况
|
||||
if hasattr(raw_response, "embeddings"):
|
||||
if hasattr(raw_response, "embeddings") and raw_response.embeddings:
|
||||
response.embedding = raw_response.embeddings[0].values
|
||||
else:
|
||||
raise RespParseException(raw_response, "响应解析失败,缺失embeddings字段")
|
||||
@@ -459,3 +465,10 @@ class GeminiClient(BaseClient):
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def get_support_image_formats(self) -> list[str]:
|
||||
"""
|
||||
获取支持的图片格式
|
||||
:return: 支持的图片格式列表
|
||||
"""
|
||||
return ["png", "jpg", "jpeg", "webp", "heic", "heif"]
|
||||
|
||||
@@ -286,9 +286,9 @@ async def _default_stream_response_handler(
|
||||
if event.usage:
|
||||
# 如果有使用情况,则将其存储在APIResponse对象中
|
||||
_usage_record = (
|
||||
event.usage.prompt_tokens,
|
||||
event.usage.completion_tokens,
|
||||
event.usage.total_tokens,
|
||||
event.usage.prompt_tokens or 0,
|
||||
event.usage.completion_tokens or 0,
|
||||
event.usage.total_tokens or 0,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -356,9 +356,9 @@ def _default_normal_response_parser(
|
||||
# 提取Usage信息
|
||||
if resp.usage:
|
||||
_usage_record = (
|
||||
resp.usage.prompt_tokens,
|
||||
resp.usage.completion_tokens,
|
||||
resp.usage.total_tokens,
|
||||
resp.usage.prompt_tokens or 0,
|
||||
resp.usage.completion_tokens or 0,
|
||||
resp.usage.total_tokens or 0,
|
||||
)
|
||||
else:
|
||||
_usage_record = None
|
||||
@@ -568,3 +568,10 @@ class OpenaiClient(BaseClient):
|
||||
"响应解析失败,缺失转录文本。",
|
||||
)
|
||||
return response
|
||||
|
||||
def get_support_image_formats(self) -> list[str]:
|
||||
"""
|
||||
获取支持的图片格式
|
||||
:return: 支持的图片格式列表
|
||||
"""
|
||||
return ["jpg", "jpeg", "png", "webp", "gif"]
|
||||
|
||||
@@ -11,7 +11,7 @@ class RoleType(Enum):
|
||||
Tool = "tool"
|
||||
|
||||
|
||||
SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"]
|
||||
SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] # openai支持的图片格式
|
||||
|
||||
|
||||
class Message:
|
||||
@@ -53,9 +53,12 @@ class MessageBuilder:
|
||||
"""
|
||||
self.__content.append(text)
|
||||
return self
|
||||
|
||||
|
||||
def add_image_content(
|
||||
self, image_format: str, image_base64: str
|
||||
self,
|
||||
image_format: str,
|
||||
image_base64: str,
|
||||
support_formats: list[str] = SUPPORTED_IMAGE_FORMATS, # 默认支持格式
|
||||
) -> "MessageBuilder":
|
||||
"""
|
||||
添加图片内容
|
||||
@@ -63,7 +66,7 @@ class MessageBuilder:
|
||||
:param image_base64: 图片的base64编码
|
||||
:return: MessageBuilder对象
|
||||
"""
|
||||
if image_format.lower() not in SUPPORTED_IMAGE_FORMATS:
|
||||
if image_format.lower() not in support_formats:
|
||||
raise ValueError("不受支持的图片格式")
|
||||
if not image_base64:
|
||||
raise ValueError("图片的base64编码不能为空")
|
||||
|
||||
@@ -40,6 +40,7 @@ class RequestType(Enum):
|
||||
EMBEDDING = "embedding"
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
class LLMRequest:
|
||||
"""LLM请求类"""
|
||||
|
||||
@@ -70,15 +71,15 @@ class LLMRequest:
|
||||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
# 模型选择
|
||||
model_info, api_provider, client = self._select_model()
|
||||
|
||||
# 请求体构建
|
||||
message_builder = MessageBuilder()
|
||||
message_builder.add_text_content(prompt)
|
||||
message_builder.add_image_content(image_base64=image_base64, image_format=image_format)
|
||||
message_builder.add_image_content(image_base64=image_base64, image_format=image_format, support_formats=client.get_support_image_formats())
|
||||
messages = [message_builder.build()]
|
||||
|
||||
# 模型选择
|
||||
model_info, api_provider, client = self._select_model()
|
||||
|
||||
# 请求并处理返回值
|
||||
response = await self._execute_request(
|
||||
api_provider=api_provider,
|
||||
@@ -127,7 +128,6 @@ class LLMRequest:
|
||||
)
|
||||
return response.content or None
|
||||
|
||||
|
||||
async def generate_response_async(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -245,7 +245,7 @@ class LLMRequest:
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
embedding_input: str = "",
|
||||
audio_base64: str = ""
|
||||
audio_base64: str = "",
|
||||
) -> APIResponse:
|
||||
"""
|
||||
实际执行请求的方法
|
||||
|
||||
Reference in New Issue
Block a user