让Gemini的图像可用,修复部分typing

This commit is contained in:
UnCLAS-Prommer
2025-08-03 00:49:19 +08:00
parent 38930b0ceb
commit 9afa549aee
5 changed files with 88 additions and 57 deletions

View File

@@ -1,9 +1,7 @@
import asyncio
from dataclasses import dataclass
from typing import Callable, Any
from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk, ChatCompletion
from abc import ABC, abstractmethod
from typing import Callable, Any, Optional
from src.config.api_ada_configs import ModelInfo, APIProvider
from ..payload_content.message import Message
@@ -58,7 +56,7 @@ class APIResponse:
"""响应原始数据"""
class BaseClient:
class BaseClient(ABC):
"""
基础客户端
"""
@@ -68,6 +66,7 @@ class BaseClient:
def __init__(self, api_provider: APIProvider):
self.api_provider = api_provider
@abstractmethod
async def get_response(
self,
model_info: ModelInfo,
@@ -76,12 +75,10 @@ class BaseClient:
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
tuple[APIResponse, tuple[int, int, int]],
]
| None = None,
async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]] | None = None,
stream_response_handler: Optional[
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
] = None,
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
@@ -98,8 +95,9 @@ class BaseClient:
:param interrupt_flag: 中断信号量可选默认为None
:return: (响应文本, 推理文本, 工具调用, 其他数据)
"""
raise RuntimeError("This method should be overridden in subclasses")
raise NotImplementedError("'get_response' method should be overridden in subclasses")
@abstractmethod
async def get_embedding(
self,
model_info: ModelInfo,
@@ -112,8 +110,9 @@ class BaseClient:
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
"""
raise RuntimeError("This method should be overridden in subclasses")
raise NotImplementedError("'get_embedding' method should be overridden in subclasses")
@abstractmethod
async def get_audio_transcriptions(
self,
model_info: ModelInfo,
@@ -127,7 +126,15 @@ class BaseClient:
:extra_params: 附加的请求参数
:return: 音频转录响应
"""
raise RuntimeError("This method should be overridden in subclasses")
raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses")
@abstractmethod
def get_support_image_formats(self) -> list[str]:
"""
获取支持的图片格式
:return: 支持的图片格式列表
"""
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
class ClientRegistry:
@@ -137,7 +144,8 @@ class ClientRegistry:
def register_client_class(self, client_type: str):
"""
注册API客户端类
:param client_class: API客户端类
Args:
client_class: API客户端类
"""
def decorator(cls: type[BaseClient]) -> type[BaseClient]:

View File

@@ -1,8 +1,8 @@
raise DeprecationWarning("Genimi Client is not fully available yet. Please remove your Gemini API Provider")
import asyncio
import io
import base64
from collections.abc import Iterable
from typing import Callable, Iterator, TypeVar, AsyncIterator, Optional, Coroutine, Any
from typing import Callable, Iterator, TypeVar, AsyncIterator, Optional, Coroutine, Any, List
from google import genai
from google.genai import types
@@ -17,7 +17,7 @@ from google.genai.errors import (
from src.config.api_ada_configs import ModelInfo, APIProvider
from .base_client import APIResponse, UsageRecord, BaseClient
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
from ..exceptions import (
RespParseException,
NetworkConnectionError,
@@ -54,20 +54,21 @@ def _convert_messages(
role = "user"
# 添加Content
content: types.Part | list
if isinstance(message.content, str):
content = types.Part.from_text(message.content)
content = [types.Part.from_text(text=message.content)]
elif isinstance(message.content, list):
content = []
content: List[types.Part] = []
for item in message.content:
if isinstance(item, tuple):
content.append(types.Part.from_bytes(data=item[1], mime_type=f"image/{item[0].lower()}"))
content.append(
types.Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{item[0].lower()}")
)
elif isinstance(item, str):
content.append(types.Part.from_text(item))
content.append(types.Part.from_text(text=item))
else:
raise RuntimeError("无法触及的代码请使用MessageBuilder类构建消息对象")
return types.Content(role=role, content=content)
return types.Content(role=role, parts=content)
temp_list: list[types.Content] = []
system_instructions: list[str] = []
@@ -76,7 +77,7 @@ def _convert_messages(
if isinstance(message.content, str):
system_instructions.append(message.content)
else:
raise RuntimeError("你tm怎么往system里面塞图片base64")
raise ValueError("你tm怎么往system里面塞图片base64")
elif message.role == RoleType.Tool:
if not message.tool_call_id:
raise ValueError("无法触及的代码请使用MessageBuilder类构建消息对象")
@@ -135,9 +136,9 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar
def _process_delta(
delta: GenerateContentResponse,
fc_delta_buffer: io.StringIO,
tool_calls_buffer: list[tuple[str, str, dict]],
tool_calls_buffer: list[tuple[str, str, dict[str, Any]]],
):
if not hasattr(delta, "candidates") or len(delta.candidates) == 0:
if not hasattr(delta, "candidates") or not delta.candidates:
raise RespParseException(delta, "响应解析失败缺失candidates字段")
if delta.text:
@@ -148,11 +149,13 @@ def _process_delta(
try:
if not isinstance(call.args, dict): # gemini返回的function call参数就是dict格式的了
raise RespParseException(delta, "响应解析失败,工具调用参数无法解析为字典类型")
if not call.id or not call.name:
raise RespParseException(delta, "响应解析失败工具调用缺失id或name字段")
tool_calls_buffer.append(
(
call.id,
call.name,
call.args,
call.args or {}, # 如果args是None则转换为一个空字典
)
)
except Exception as e:
@@ -201,7 +204,7 @@ async def _to_async_iterable(iterable: Iterable[T]) -> AsyncIterator[T]:
async def _default_stream_response_handler(
resp_stream: Iterator[GenerateContentResponse],
resp_stream: AsyncIterator[GenerateContentResponse],
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
"""
@@ -232,9 +235,9 @@ async def _default_stream_response_handler(
if chunk.usage_metadata:
# 如果有使用情况则将其存储在APIResponse对象中
_usage_record = (
chunk.usage_metadata.prompt_token_count,
chunk.usage_metadata.candidates_token_count + chunk.usage_metadata.thoughts_token_count,
chunk.usage_metadata.total_token_count,
chunk.usage_metadata.prompt_token_count or 0,
(chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0),
chunk.usage_metadata.total_token_count or 0,
)
try:
return _build_stream_api_resp(
@@ -257,7 +260,7 @@ def _default_normal_response_parser(
"""
api_response = APIResponse()
if not hasattr(resp, "candidates") or len(resp.candidates) == 0:
if not hasattr(resp, "candidates") or not resp.candidates:
raise RespParseException(resp, "响应解析失败缺失candidates字段")
if resp.text:
@@ -269,15 +272,17 @@ def _default_normal_response_parser(
try:
if not isinstance(call.args, dict):
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
api_response.tool_calls.append(ToolCall(call.id, call.name, call.args))
if not call.id or not call.name:
raise RespParseException(resp, "响应解析失败工具调用缺失id或name字段")
api_response.tool_calls.append(ToolCall(call.id, call.name, call.args or {}))
except Exception as e:
raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e
if resp.usage_metadata:
_usage_record = (
resp.usage_metadata.prompt_token_count,
resp.usage_metadata.candidates_token_count + resp.usage_metadata.thoughts_token_count,
resp.usage_metadata.total_token_count,
resp.usage_metadata.prompt_token_count or 0,
(resp.usage_metadata.candidates_token_count or 0) + (resp.usage_metadata.thoughts_token_count or 0),
resp.usage_metadata.total_token_count or 0,
)
else:
_usage_record = None
@@ -287,6 +292,7 @@ def _default_normal_response_parser(
return api_response, _usage_record
@client_registry.register_client_class("gemini")
class GeminiClient(BaseClient):
client: genai.Client
@@ -307,7 +313,7 @@ class GeminiClient(BaseClient):
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[
[Iterator[GenerateContentResponse], asyncio.Event | None],
[AsyncIterator[GenerateContentResponse], asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
]
] = None,
@@ -398,7 +404,7 @@ class GeminiClient(BaseClient):
resp, usage_record = async_response_parser(req_task.result())
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.status_code, e.message) from None
raise RespNotOkException(e.code, e.message) from None
except (
UnknownFunctionCallArgumentError,
UnsupportedFunctionError,
@@ -438,14 +444,14 @@ class GeminiClient(BaseClient):
)
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.status_code) from None
raise RespNotOkException(e.code) from None
except Exception as e:
raise NetworkConnectionError() from e
response = APIResponse()
# 解析嵌入响应和使用情况
if hasattr(raw_response, "embeddings"):
if hasattr(raw_response, "embeddings") and raw_response.embeddings:
response.embedding = raw_response.embeddings[0].values
else:
raise RespParseException(raw_response, "响应解析失败缺失embeddings字段")
@@ -459,3 +465,10 @@ class GeminiClient(BaseClient):
)
return response
def get_support_image_formats(self) -> list[str]:
"""
获取支持的图片格式
:return: 支持的图片格式列表
"""
return ["png", "jpg", "jpeg", "webp", "heic", "heif"]

View File

@@ -286,9 +286,9 @@ async def _default_stream_response_handler(
if event.usage:
# 如果有使用情况则将其存储在APIResponse对象中
_usage_record = (
event.usage.prompt_tokens,
event.usage.completion_tokens,
event.usage.total_tokens,
event.usage.prompt_tokens or 0,
event.usage.completion_tokens or 0,
event.usage.total_tokens or 0,
)
try:
@@ -356,9 +356,9 @@ def _default_normal_response_parser(
# 提取Usage信息
if resp.usage:
_usage_record = (
resp.usage.prompt_tokens,
resp.usage.completion_tokens,
resp.usage.total_tokens,
resp.usage.prompt_tokens or 0,
resp.usage.completion_tokens or 0,
resp.usage.total_tokens or 0,
)
else:
_usage_record = None
@@ -568,3 +568,10 @@ class OpenaiClient(BaseClient):
"响应解析失败,缺失转录文本。",
)
return response
def get_support_image_formats(self) -> list[str]:
"""
获取支持的图片格式
:return: 支持的图片格式列表
"""
return ["jpg", "jpeg", "png", "webp", "gif"]

View File

@@ -11,7 +11,7 @@ class RoleType(Enum):
Tool = "tool"
SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"]
SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] # openai支持的图片格式
class Message:
@@ -53,9 +53,12 @@ class MessageBuilder:
"""
self.__content.append(text)
return self
def add_image_content(
self, image_format: str, image_base64: str
self,
image_format: str,
image_base64: str,
support_formats: list[str] = SUPPORTED_IMAGE_FORMATS, # 默认支持格式
) -> "MessageBuilder":
"""
添加图片内容
@@ -63,7 +66,7 @@ class MessageBuilder:
:param image_base64: 图片的base64编码
:return: MessageBuilder对象
"""
if image_format.lower() not in SUPPORTED_IMAGE_FORMATS:
if image_format.lower() not in support_formats:
raise ValueError("不受支持的图片格式")
if not image_base64:
raise ValueError("图片的base64编码不能为空")

View File

@@ -40,6 +40,7 @@ class RequestType(Enum):
EMBEDDING = "embedding"
AUDIO = "audio"
class LLMRequest:
"""LLM请求类"""
@@ -70,15 +71,15 @@ class LLMRequest:
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
# 模型选择
model_info, api_provider, client = self._select_model()
# 请求体构建
message_builder = MessageBuilder()
message_builder.add_text_content(prompt)
message_builder.add_image_content(image_base64=image_base64, image_format=image_format)
message_builder.add_image_content(image_base64=image_base64, image_format=image_format, support_formats=client.get_support_image_formats())
messages = [message_builder.build()]
# 模型选择
model_info, api_provider, client = self._select_model()
# 请求并处理返回值
response = await self._execute_request(
api_provider=api_provider,
@@ -127,7 +128,6 @@ class LLMRequest:
)
return response.content or None
async def generate_response_async(
self,
prompt: str,
@@ -245,7 +245,7 @@ class LLMRequest:
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
embedding_input: str = "",
audio_base64: str = ""
audio_base64: str = "",
) -> APIResponse:
"""
实际执行请求的方法