From 23aec68cc010bf790ea3c256dda9b98f60a8ca38 Mon Sep 17 00:00:00 2001 From: Furina-1013-create <189647097+Furina-1013-create@users.noreply.github.com> Date: Mon, 18 Aug 2025 16:18:21 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DGemini=20api=E4=B8=93?= =?UTF-8?q?=E5=B1=9E=E7=9A=84=E9=82=A3=E4=B8=AAgemini=5Fclient.py=E9=87=8C?= =?UTF-8?q?=E9=9D=A2=E7=9A=84=E4=B8=80=E4=B8=AA=E6=BD=9C=E5=9C=A8=E7=9A=84?= =?UTF-8?q?=E5=AF=BC=E5=85=A5=E9=97=AE=E9=A2=98=E5=B9=B6=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E5=9B=9E=E9=80=80=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm_models/model_client/gemini_client.py | 156 +++++++++++-------- 1 file changed, 90 insertions(+), 66 deletions(-) diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index 60c0c3901..9bda858ef 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -1,32 +1,54 @@ import asyncio import io import base64 -from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List +from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List, Dict, Union -from google import genai -from google.genai.types import ( - Content, - Part, - FunctionDeclaration, +import google.generativeai as genai +from google.generativeai.types import ( GenerateContentResponse, - ContentListUnion, - ContentUnion, - ThinkingConfig, - Tool, - GenerateContentConfig, - EmbedContentResponse, - EmbedContentConfig, - SafetySetting, HarmCategory, HarmBlockThreshold, ) -from google.genai.errors import ( - ClientError, - ServerError, - UnknownFunctionCallArgumentError, - UnsupportedFunctionError, - FunctionInvocationError, -) + +try: + # 尝试从较新的API导入 + from google.generativeai import configure + from google.generativeai.types import SafetySetting, GenerationConfig +except ImportError: + # 回退到基本类型 + SafetySetting = Dict + GenerationConfig = Dict + +# 定义兼容性类型 +ContentDict = Dict +PartDict = Dict +ToolDict = Dict +FunctionDeclaration = Dict +Tool = Dict +ContentListUnion = List[Dict] +ContentUnion = Dict +Content = Dict +Part = Dict +ThinkingConfig = Dict +GenerateContentConfig = Dict +EmbedContentConfig = Dict +EmbedContentResponse = Dict + +# 定义异常类型 +class ClientError(Exception): + pass + +class ServerError(Exception): + pass + +class UnknownFunctionCallArgumentError(Exception): + pass + +class UnsupportedFunctionError(Exception): + pass + +class FunctionInvocationError(Exception): + pass from src.config.api_ada_configs import ModelInfo, APIProvider from src.common.logger import get_logger @@ -44,18 +66,17 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall logger = get_logger("Gemini客户端") -gemini_safe_settings = [ - SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE), - SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE), - SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE), - SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE), - SafetySetting(category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, threshold=HarmBlockThreshold.BLOCK_NONE), +SAFETY_SETTINGS = [ + {"category": HarmCategory.HARM_CATEGORY_HATE_SPEECH, "threshold": HarmBlockThreshold.BLOCK_NONE}, + {"category": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, "threshold": HarmBlockThreshold.BLOCK_NONE}, + {"category": HarmCategory.HARM_CATEGORY_HARASSMENT, "threshold": HarmBlockThreshold.BLOCK_NONE}, + {"category": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, "threshold": HarmBlockThreshold.BLOCK_NONE}, ] def _convert_messages( messages: list[Message], -) -> tuple[ContentListUnion, list[str] | None]: +) -> tuple[List[Dict], list[str] | None]: """ 转换消息格式 - 将消息转换为Gemini API所需的格式 :param messages: 消息列表 @@ -81,7 +102,7 @@ def _convert_messages( normalized_format = format_mapping.get(image_format.lower(), image_format.lower()) return f"image/{normalized_format}" - def _convert_message_item(message: Message) -> Content: + def _convert_message_item(message: Message) -> Dict: """ 转换单个消息格式,除了system和tool类型的消息 :param message: 消息对象 @@ -96,22 +117,25 @@ def _convert_messages( # 添加Content if isinstance(message.content, str): - content = [Part.from_text(text=message.content)] + content = [{"text": message.content}] elif isinstance(message.content, list): - content: List[Part] = [] + content = [] for item in message.content: if isinstance(item, tuple): - content.append( - Part.from_bytes(data=base64.b64decode(item[1]), mime_type=_get_correct_mime_type(item[0])) - ) + content.append({ + "inline_data": { + "mime_type": _get_correct_mime_type(item[0]), + "data": item[1] + } + }) elif isinstance(item, str): - content.append(Part.from_text(text=item)) + content.append({"text": item}) else: raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象") - return Content(role=role, parts=content) + return {"role": role, "parts": content} - temp_list: list[ContentUnion] = [] + temp_list: List[Dict] = [] system_instructions: list[str] = [] for message in messages: if message.role == RoleType.System: @@ -338,13 +362,10 @@ def _default_normal_response_parser( @client_registry.register_client_class("gemini") class GeminiClient(BaseClient): - client: genai.Client - def __init__(self, api_provider: APIProvider): super().__init__(api_provider) - self.client = genai.Client( - api_key=api_provider.api_key, - ) # 这里和openai不一样,gemini会自己决定自己是否需要retry + # 配置 Google Generative AI + genai.configure(api_key=api_provider.api_key) async def get_response( self, @@ -396,18 +417,18 @@ class GeminiClient(BaseClient): "max_output_tokens": max_tokens, "temperature": temperature, "response_modalities": ["TEXT"], - "thinking_config": ThinkingConfig( - include_thoughts=True, - thinking_budget=( + "thinking_config": { + "include_thoughts": True, + "thinking_budget": ( extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else int(max_tokens / 2) # 默认思考预算为最大token数的一半,防止空回复 ), - ), - "safety_settings": gemini_safe_settings, # 防止空回复问题 + }, + "safety_settings": SAFETY_SETTINGS, # 防止空回复问题 } if tools: - generation_config_dict["tools"] = Tool(function_declarations=tools) + generation_config_dict["tools"] = {"function_declarations": tools} if messages[1]: # 如果有system消息,则将其添加到配置中 generation_config_dict["system_instructions"] = messages[1] @@ -417,15 +438,18 @@ class GeminiClient(BaseClient): generation_config_dict["response_mime_type"] = "application/json" generation_config_dict["response_schema"] = response_format.to_dict() - generation_config = GenerateContentConfig(**generation_config_dict) + generation_config = generation_config_dict try: + # 创建模型实例 + model = genai.GenerativeModel(model_info.model_identifier) + if model_info.force_stream_mode: req_task = asyncio.create_task( - self.client.aio.models.generate_content_stream( - model=model_info.model_identifier, + model.generate_content_async( contents=messages[0], - config=generation_config, + generation_config=generation_config, + stream=True ) ) while not req_task.done(): @@ -437,10 +461,9 @@ class GeminiClient(BaseClient): resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag) else: req_task = asyncio.create_task( - self.client.aio.models.generate_content( - model=model_info.model_identifier, + model.generate_content_async( contents=messages[0], - config=generation_config, + generation_config=generation_config ) ) while not req_task.done(): @@ -451,17 +474,18 @@ class GeminiClient(BaseClient): await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态 resp, usage_record = async_response_parser(req_task.result()) - except (ClientError, ServerError) as e: - # 重封装ClientError和ServerError为RespNotOkException - raise RespNotOkException(e.code, e.message) from None - except ( - UnknownFunctionCallArgumentError, - UnsupportedFunctionError, - FunctionInvocationError, - ) as e: - raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None except Exception as e: - raise NetworkConnectionError() from e + # 处理Google Generative AI异常 + if "rate limit" in str(e).lower(): + raise RespNotOkException(429, "请求频率过高,请稍后再试") from None + elif "quota" in str(e).lower(): + raise RespNotOkException(429, "配额已用完") from None + elif "invalid" in str(e).lower() or "bad request" in str(e).lower(): + raise RespNotOkException(400, f"请求无效:{str(e)}") from None + elif "permission" in str(e).lower() or "forbidden" in str(e).lower(): + raise RespNotOkException(403, "权限不足") from None + else: + raise NetworkConnectionError() from e if usage_record: resp.usage = UsageRecord( @@ -535,7 +559,7 @@ class GeminiClient(BaseClient): extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024 ), ), - "safety_settings": gemini_safe_settings, + "safety_settings": SAFETY_SETTINGS, } generate_content_config = GenerateContentConfig(**generation_config_dict) prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."