修复Gemini api专属的那个gemini_client.py里面的一个潜在的导入问题并增加回退机制

This commit is contained in:
Furina-1013-create
2025-08-18 16:18:21 +08:00
parent 9205edf8ca
commit 23aec68cc0

View File

@@ -1,32 +1,54 @@
import asyncio
import io
import base64
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List, Dict, Union
from google import genai
from google.genai.types import (
Content,
Part,
FunctionDeclaration,
import google.generativeai as genai
from google.generativeai.types import (
GenerateContentResponse,
ContentListUnion,
ContentUnion,
ThinkingConfig,
Tool,
GenerateContentConfig,
EmbedContentResponse,
EmbedContentConfig,
SafetySetting,
HarmCategory,
HarmBlockThreshold,
)
from google.genai.errors import (
ClientError,
ServerError,
UnknownFunctionCallArgumentError,
UnsupportedFunctionError,
FunctionInvocationError,
)
try:
# 尝试从较新的API导入
from google.generativeai import configure
from google.generativeai.types import SafetySetting, GenerationConfig
except ImportError:
# 回退到基本类型
SafetySetting = Dict
GenerationConfig = Dict
# 定义兼容性类型
ContentDict = Dict
PartDict = Dict
ToolDict = Dict
FunctionDeclaration = Dict
Tool = Dict
ContentListUnion = List[Dict]
ContentUnion = Dict
Content = Dict
Part = Dict
ThinkingConfig = Dict
GenerateContentConfig = Dict
EmbedContentConfig = Dict
EmbedContentResponse = Dict
# 定义异常类型
class ClientError(Exception):
pass
class ServerError(Exception):
pass
class UnknownFunctionCallArgumentError(Exception):
pass
class UnsupportedFunctionError(Exception):
pass
class FunctionInvocationError(Exception):
pass
from src.config.api_ada_configs import ModelInfo, APIProvider
from src.common.logger import get_logger
@@ -44,18 +66,17 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
logger = get_logger("Gemini客户端")
gemini_safe_settings = [
SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE),
SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE),
SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE),
SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE),
SafetySetting(category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, threshold=HarmBlockThreshold.BLOCK_NONE),
SAFETY_SETTINGS = [
{"category": HarmCategory.HARM_CATEGORY_HATE_SPEECH, "threshold": HarmBlockThreshold.BLOCK_NONE},
{"category": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, "threshold": HarmBlockThreshold.BLOCK_NONE},
{"category": HarmCategory.HARM_CATEGORY_HARASSMENT, "threshold": HarmBlockThreshold.BLOCK_NONE},
{"category": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, "threshold": HarmBlockThreshold.BLOCK_NONE},
]
def _convert_messages(
messages: list[Message],
) -> tuple[ContentListUnion, list[str] | None]:
) -> tuple[List[Dict], list[str] | None]:
"""
转换消息格式 - 将消息转换为Gemini API所需的格式
:param messages: 消息列表
@@ -81,7 +102,7 @@ def _convert_messages(
normalized_format = format_mapping.get(image_format.lower(), image_format.lower())
return f"image/{normalized_format}"
def _convert_message_item(message: Message) -> Content:
def _convert_message_item(message: Message) -> Dict:
"""
转换单个消息格式除了system和tool类型的消息
:param message: 消息对象
@@ -96,22 +117,25 @@ def _convert_messages(
# 添加Content
if isinstance(message.content, str):
content = [Part.from_text(text=message.content)]
content = [{"text": message.content}]
elif isinstance(message.content, list):
content: List[Part] = []
content = []
for item in message.content:
if isinstance(item, tuple):
content.append(
Part.from_bytes(data=base64.b64decode(item[1]), mime_type=_get_correct_mime_type(item[0]))
)
content.append({
"inline_data": {
"mime_type": _get_correct_mime_type(item[0]),
"data": item[1]
}
})
elif isinstance(item, str):
content.append(Part.from_text(text=item))
content.append({"text": item})
else:
raise RuntimeError("无法触及的代码请使用MessageBuilder类构建消息对象")
return Content(role=role, parts=content)
return {"role": role, "parts": content}
temp_list: list[ContentUnion] = []
temp_list: List[Dict] = []
system_instructions: list[str] = []
for message in messages:
if message.role == RoleType.System:
@@ -338,13 +362,10 @@ def _default_normal_response_parser(
@client_registry.register_client_class("gemini")
class GeminiClient(BaseClient):
client: genai.Client
def __init__(self, api_provider: APIProvider):
super().__init__(api_provider)
self.client = genai.Client(
api_key=api_provider.api_key,
) # 这里和openai不一样gemini会自己决定自己是否需要retry
# 配置 Google Generative AI
genai.configure(api_key=api_provider.api_key)
async def get_response(
self,
@@ -396,18 +417,18 @@ class GeminiClient(BaseClient):
"max_output_tokens": max_tokens,
"temperature": temperature,
"response_modalities": ["TEXT"],
"thinking_config": ThinkingConfig(
include_thoughts=True,
thinking_budget=(
"thinking_config": {
"include_thoughts": True,
"thinking_budget": (
extra_params["thinking_budget"]
if extra_params and "thinking_budget" in extra_params
else int(max_tokens / 2) # 默认思考预算为最大token数的一半防止空回复
),
),
"safety_settings": gemini_safe_settings, # 防止空回复问题
},
"safety_settings": SAFETY_SETTINGS, # 防止空回复问题
}
if tools:
generation_config_dict["tools"] = Tool(function_declarations=tools)
generation_config_dict["tools"] = {"function_declarations": tools}
if messages[1]:
# 如果有system消息则将其添加到配置中
generation_config_dict["system_instructions"] = messages[1]
@@ -417,15 +438,18 @@ class GeminiClient(BaseClient):
generation_config_dict["response_mime_type"] = "application/json"
generation_config_dict["response_schema"] = response_format.to_dict()
generation_config = GenerateContentConfig(**generation_config_dict)
generation_config = generation_config_dict
try:
# 创建模型实例
model = genai.GenerativeModel(model_info.model_identifier)
if model_info.force_stream_mode:
req_task = asyncio.create_task(
self.client.aio.models.generate_content_stream(
model=model_info.model_identifier,
model.generate_content_async(
contents=messages[0],
config=generation_config,
generation_config=generation_config,
stream=True
)
)
while not req_task.done():
@@ -437,10 +461,9 @@ class GeminiClient(BaseClient):
resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag)
else:
req_task = asyncio.create_task(
self.client.aio.models.generate_content(
model=model_info.model_identifier,
model.generate_content_async(
contents=messages[0],
config=generation_config,
generation_config=generation_config
)
)
while not req_task.done():
@@ -451,17 +474,18 @@ class GeminiClient(BaseClient):
await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态
resp, usage_record = async_response_parser(req_task.result())
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.code, e.message) from None
except (
UnknownFunctionCallArgumentError,
UnsupportedFunctionError,
FunctionInvocationError,
) as e:
raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None
except Exception as e:
raise NetworkConnectionError() from e
# 处理Google Generative AI异常
if "rate limit" in str(e).lower():
raise RespNotOkException(429, "请求频率过高,请稍后再试") from None
elif "quota" in str(e).lower():
raise RespNotOkException(429, "配额已用完") from None
elif "invalid" in str(e).lower() or "bad request" in str(e).lower():
raise RespNotOkException(400, f"请求无效:{str(e)}") from None
elif "permission" in str(e).lower() or "forbidden" in str(e).lower():
raise RespNotOkException(403, "权限不足") from None
else:
raise NetworkConnectionError() from e
if usage_record:
resp.usage = UsageRecord(
@@ -535,7 +559,7 @@ class GeminiClient(BaseClient):
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
),
),
"safety_settings": gemini_safe_settings,
"safety_settings": SAFETY_SETTINGS,
}
generate_content_config = GenerateContentConfig(**generation_config_dict)
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."