From 9afa549aeebd4a0a143550eeedbd154d19abf077 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sun, 3 Aug 2025 00:49:19 +0800 Subject: [PATCH] =?UTF-8?q?=E8=AE=A9Gemini=E7=9A=84=E5=9B=BE=E5=83=8F?= =?UTF-8?q?=E5=8F=AF=E7=94=A8=EF=BC=8C=E4=BF=AE=E5=A4=8D=E9=83=A8=E5=88=86?= =?UTF-8?q?typing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm_models/model_client/base_client.py | 38 +++++++----- src/llm_models/model_client/gemini_client.py | 65 ++++++++++++-------- src/llm_models/model_client/openai_client.py | 19 ++++-- src/llm_models/payload_content/message.py | 11 ++-- src/llm_models/utils_model.py | 12 ++-- 5 files changed, 88 insertions(+), 57 deletions(-) diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index 3d56e4197..8e8affba6 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -1,9 +1,7 @@ import asyncio from dataclasses import dataclass -from typing import Callable, Any - -from openai import AsyncStream -from openai.types.chat import ChatCompletionChunk, ChatCompletion +from abc import ABC, abstractmethod +from typing import Callable, Any, Optional from src.config.api_ada_configs import ModelInfo, APIProvider from ..payload_content.message import Message @@ -58,7 +56,7 @@ class APIResponse: """响应原始数据""" -class BaseClient: +class BaseClient(ABC): """ 基础客户端 """ @@ -68,6 +66,7 @@ class BaseClient: def __init__(self, api_provider: APIProvider): self.api_provider = api_provider + @abstractmethod async def get_response( self, model_info: ModelInfo, @@ -76,12 +75,10 @@ class BaseClient: max_tokens: int = 1024, temperature: float = 0.7, response_format: RespFormat | None = None, - stream_response_handler: Callable[ - [AsyncStream[ChatCompletionChunk], asyncio.Event | None], - tuple[APIResponse, tuple[int, int, int]], - ] - | None = None, - async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]] | None = None, + stream_response_handler: Optional[ + Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]] + ] = None, + async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None, interrupt_flag: asyncio.Event | None = None, extra_params: dict[str, Any] | None = None, ) -> APIResponse: @@ -98,8 +95,9 @@ class BaseClient: :param interrupt_flag: 中断信号量(可选,默认为None) :return: (响应文本, 推理文本, 工具调用, 其他数据) """ - raise RuntimeError("This method should be overridden in subclasses") + raise NotImplementedError("'get_response' method should be overridden in subclasses") + @abstractmethod async def get_embedding( self, model_info: ModelInfo, @@ -112,8 +110,9 @@ class BaseClient: :param embedding_input: 嵌入输入文本 :return: 嵌入响应 """ - raise RuntimeError("This method should be overridden in subclasses") + raise NotImplementedError("'get_embedding' method should be overridden in subclasses") + @abstractmethod async def get_audio_transcriptions( self, model_info: ModelInfo, @@ -127,7 +126,15 @@ class BaseClient: :extra_params: 附加的请求参数 :return: 音频转录响应 """ - raise RuntimeError("This method should be overridden in subclasses") + raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses") + + @abstractmethod + def get_support_image_formats(self) -> list[str]: + """ + 获取支持的图片格式 + :return: 支持的图片格式列表 + """ + raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses") class ClientRegistry: @@ -137,7 +144,8 @@ class ClientRegistry: def register_client_class(self, client_type: str): """ 注册API客户端类 - :param client_class: API客户端类 + Args: + client_class: API客户端类 """ def decorator(cls: type[BaseClient]) -> type[BaseClient]: diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index e04a327df..f30f464a9 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -1,8 +1,8 @@ -raise DeprecationWarning("Genimi Client is not fully available yet. Please remove your Gemini API Provider") import asyncio import io +import base64 from collections.abc import Iterable -from typing import Callable, Iterator, TypeVar, AsyncIterator, Optional, Coroutine, Any +from typing import Callable, Iterator, TypeVar, AsyncIterator, Optional, Coroutine, Any, List from google import genai from google.genai import types @@ -17,7 +17,7 @@ from google.genai.errors import ( from src.config.api_ada_configs import ModelInfo, APIProvider -from .base_client import APIResponse, UsageRecord, BaseClient +from .base_client import APIResponse, UsageRecord, BaseClient, client_registry from ..exceptions import ( RespParseException, NetworkConnectionError, @@ -54,20 +54,21 @@ def _convert_messages( role = "user" # 添加Content - content: types.Part | list if isinstance(message.content, str): - content = types.Part.from_text(message.content) + content = [types.Part.from_text(text=message.content)] elif isinstance(message.content, list): - content = [] + content: List[types.Part] = [] for item in message.content: if isinstance(item, tuple): - content.append(types.Part.from_bytes(data=item[1], mime_type=f"image/{item[0].lower()}")) + content.append( + types.Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{item[0].lower()}") + ) elif isinstance(item, str): - content.append(types.Part.from_text(item)) + content.append(types.Part.from_text(text=item)) else: raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象") - return types.Content(role=role, content=content) + return types.Content(role=role, parts=content) temp_list: list[types.Content] = [] system_instructions: list[str] = [] @@ -76,7 +77,7 @@ def _convert_messages( if isinstance(message.content, str): system_instructions.append(message.content) else: - raise RuntimeError("你tm怎么往system里面塞图片base64?") + raise ValueError("你tm怎么往system里面塞图片base64?") elif message.role == RoleType.Tool: if not message.tool_call_id: raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象") @@ -135,9 +136,9 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar def _process_delta( delta: GenerateContentResponse, fc_delta_buffer: io.StringIO, - tool_calls_buffer: list[tuple[str, str, dict]], + tool_calls_buffer: list[tuple[str, str, dict[str, Any]]], ): - if not hasattr(delta, "candidates") or len(delta.candidates) == 0: + if not hasattr(delta, "candidates") or not delta.candidates: raise RespParseException(delta, "响应解析失败,缺失candidates字段") if delta.text: @@ -148,11 +149,13 @@ def _process_delta( try: if not isinstance(call.args, dict): # gemini返回的function call参数就是dict格式的了 raise RespParseException(delta, "响应解析失败,工具调用参数无法解析为字典类型") + if not call.id or not call.name: + raise RespParseException(delta, "响应解析失败,工具调用缺失id或name字段") tool_calls_buffer.append( ( call.id, call.name, - call.args, + call.args or {}, # 如果args是None,则转换为一个空字典 ) ) except Exception as e: @@ -201,7 +204,7 @@ async def _to_async_iterable(iterable: Iterable[T]) -> AsyncIterator[T]: async def _default_stream_response_handler( - resp_stream: Iterator[GenerateContentResponse], + resp_stream: AsyncIterator[GenerateContentResponse], interrupt_flag: asyncio.Event | None, ) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: """ @@ -232,9 +235,9 @@ async def _default_stream_response_handler( if chunk.usage_metadata: # 如果有使用情况,则将其存储在APIResponse对象中 _usage_record = ( - chunk.usage_metadata.prompt_token_count, - chunk.usage_metadata.candidates_token_count + chunk.usage_metadata.thoughts_token_count, - chunk.usage_metadata.total_token_count, + chunk.usage_metadata.prompt_token_count or 0, + (chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0), + chunk.usage_metadata.total_token_count or 0, ) try: return _build_stream_api_resp( @@ -257,7 +260,7 @@ def _default_normal_response_parser( """ api_response = APIResponse() - if not hasattr(resp, "candidates") or len(resp.candidates) == 0: + if not hasattr(resp, "candidates") or not resp.candidates: raise RespParseException(resp, "响应解析失败,缺失candidates字段") if resp.text: @@ -269,15 +272,17 @@ def _default_normal_response_parser( try: if not isinstance(call.args, dict): raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型") - api_response.tool_calls.append(ToolCall(call.id, call.name, call.args)) + if not call.id or not call.name: + raise RespParseException(resp, "响应解析失败,工具调用缺失id或name字段") + api_response.tool_calls.append(ToolCall(call.id, call.name, call.args or {})) except Exception as e: raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e if resp.usage_metadata: _usage_record = ( - resp.usage_metadata.prompt_token_count, - resp.usage_metadata.candidates_token_count + resp.usage_metadata.thoughts_token_count, - resp.usage_metadata.total_token_count, + resp.usage_metadata.prompt_token_count or 0, + (resp.usage_metadata.candidates_token_count or 0) + (resp.usage_metadata.thoughts_token_count or 0), + resp.usage_metadata.total_token_count or 0, ) else: _usage_record = None @@ -287,6 +292,7 @@ def _default_normal_response_parser( return api_response, _usage_record +@client_registry.register_client_class("gemini") class GeminiClient(BaseClient): client: genai.Client @@ -307,7 +313,7 @@ class GeminiClient(BaseClient): response_format: RespFormat | None = None, stream_response_handler: Optional[ Callable[ - [Iterator[GenerateContentResponse], asyncio.Event | None], + [AsyncIterator[GenerateContentResponse], asyncio.Event | None], Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]], ] ] = None, @@ -398,7 +404,7 @@ class GeminiClient(BaseClient): resp, usage_record = async_response_parser(req_task.result()) except (ClientError, ServerError) as e: # 重封装ClientError和ServerError为RespNotOkException - raise RespNotOkException(e.status_code, e.message) from None + raise RespNotOkException(e.code, e.message) from None except ( UnknownFunctionCallArgumentError, UnsupportedFunctionError, @@ -438,14 +444,14 @@ class GeminiClient(BaseClient): ) except (ClientError, ServerError) as e: # 重封装ClientError和ServerError为RespNotOkException - raise RespNotOkException(e.status_code) from None + raise RespNotOkException(e.code) from None except Exception as e: raise NetworkConnectionError() from e response = APIResponse() # 解析嵌入响应和使用情况 - if hasattr(raw_response, "embeddings"): + if hasattr(raw_response, "embeddings") and raw_response.embeddings: response.embedding = raw_response.embeddings[0].values else: raise RespParseException(raw_response, "响应解析失败,缺失embeddings字段") @@ -459,3 +465,10 @@ class GeminiClient(BaseClient): ) return response + + def get_support_image_formats(self) -> list[str]: + """ + 获取支持的图片格式 + :return: 支持的图片格式列表 + """ + return ["png", "jpg", "jpeg", "webp", "heic", "heif"] diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 6fe3582de..7f097e2c0 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -286,9 +286,9 @@ async def _default_stream_response_handler( if event.usage: # 如果有使用情况,则将其存储在APIResponse对象中 _usage_record = ( - event.usage.prompt_tokens, - event.usage.completion_tokens, - event.usage.total_tokens, + event.usage.prompt_tokens or 0, + event.usage.completion_tokens or 0, + event.usage.total_tokens or 0, ) try: @@ -356,9 +356,9 @@ def _default_normal_response_parser( # 提取Usage信息 if resp.usage: _usage_record = ( - resp.usage.prompt_tokens, - resp.usage.completion_tokens, - resp.usage.total_tokens, + resp.usage.prompt_tokens or 0, + resp.usage.completion_tokens or 0, + resp.usage.total_tokens or 0, ) else: _usage_record = None @@ -568,3 +568,10 @@ class OpenaiClient(BaseClient): "响应解析失败,缺失转录文本。", ) return response + + def get_support_image_formats(self) -> list[str]: + """ + 获取支持的图片格式 + :return: 支持的图片格式列表 + """ + return ["jpg", "jpeg", "png", "webp", "gif"] diff --git a/src/llm_models/payload_content/message.py b/src/llm_models/payload_content/message.py index e07f473b8..f70c3ded5 100644 --- a/src/llm_models/payload_content/message.py +++ b/src/llm_models/payload_content/message.py @@ -11,7 +11,7 @@ class RoleType(Enum): Tool = "tool" -SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] +SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] # openai支持的图片格式 class Message: @@ -53,9 +53,12 @@ class MessageBuilder: """ self.__content.append(text) return self - + def add_image_content( - self, image_format: str, image_base64: str + self, + image_format: str, + image_base64: str, + support_formats: list[str] = SUPPORTED_IMAGE_FORMATS, # 默认支持格式 ) -> "MessageBuilder": """ 添加图片内容 @@ -63,7 +66,7 @@ class MessageBuilder: :param image_base64: 图片的base64编码 :return: MessageBuilder对象 """ - if image_format.lower() not in SUPPORTED_IMAGE_FORMATS: + if image_format.lower() not in support_formats: raise ValueError("不受支持的图片格式") if not image_base64: raise ValueError("图片的base64编码不能为空") diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 329e8f0ba..ab1605dca 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -40,6 +40,7 @@ class RequestType(Enum): EMBEDDING = "embedding" AUDIO = "audio" + class LLMRequest: """LLM请求类""" @@ -70,15 +71,15 @@ class LLMRequest: Returns: (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ + # 模型选择 + model_info, api_provider, client = self._select_model() + # 请求体构建 message_builder = MessageBuilder() message_builder.add_text_content(prompt) - message_builder.add_image_content(image_base64=image_base64, image_format=image_format) + message_builder.add_image_content(image_base64=image_base64, image_format=image_format, support_formats=client.get_support_image_formats()) messages = [message_builder.build()] - # 模型选择 - model_info, api_provider, client = self._select_model() - # 请求并处理返回值 response = await self._execute_request( api_provider=api_provider, @@ -127,7 +128,6 @@ class LLMRequest: ) return response.content or None - async def generate_response_async( self, prompt: str, @@ -245,7 +245,7 @@ class LLMRequest: temperature: Optional[float] = None, max_tokens: Optional[int] = None, embedding_input: str = "", - audio_base64: str = "" + audio_base64: str = "", ) -> APIResponse: """ 实际执行请求的方法