修复Gemini api专属的那个gemini_client.py里面的一个潜在的导入问题并增加回退机制
This commit is contained in:
@@ -1,32 +1,54 @@
|
||||
import asyncio
|
||||
import io
|
||||
import base64
|
||||
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List
|
||||
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List, Dict, Union
|
||||
|
||||
from google import genai
|
||||
from google.genai.types import (
|
||||
Content,
|
||||
Part,
|
||||
FunctionDeclaration,
|
||||
import google.generativeai as genai
|
||||
from google.generativeai.types import (
|
||||
GenerateContentResponse,
|
||||
ContentListUnion,
|
||||
ContentUnion,
|
||||
ThinkingConfig,
|
||||
Tool,
|
||||
GenerateContentConfig,
|
||||
EmbedContentResponse,
|
||||
EmbedContentConfig,
|
||||
SafetySetting,
|
||||
HarmCategory,
|
||||
HarmBlockThreshold,
|
||||
)
|
||||
from google.genai.errors import (
|
||||
ClientError,
|
||||
ServerError,
|
||||
UnknownFunctionCallArgumentError,
|
||||
UnsupportedFunctionError,
|
||||
FunctionInvocationError,
|
||||
)
|
||||
|
||||
try:
|
||||
# 尝试从较新的API导入
|
||||
from google.generativeai import configure
|
||||
from google.generativeai.types import SafetySetting, GenerationConfig
|
||||
except ImportError:
|
||||
# 回退到基本类型
|
||||
SafetySetting = Dict
|
||||
GenerationConfig = Dict
|
||||
|
||||
# 定义兼容性类型
|
||||
ContentDict = Dict
|
||||
PartDict = Dict
|
||||
ToolDict = Dict
|
||||
FunctionDeclaration = Dict
|
||||
Tool = Dict
|
||||
ContentListUnion = List[Dict]
|
||||
ContentUnion = Dict
|
||||
Content = Dict
|
||||
Part = Dict
|
||||
ThinkingConfig = Dict
|
||||
GenerateContentConfig = Dict
|
||||
EmbedContentConfig = Dict
|
||||
EmbedContentResponse = Dict
|
||||
|
||||
# 定义异常类型
|
||||
class ClientError(Exception):
|
||||
pass
|
||||
|
||||
class ServerError(Exception):
|
||||
pass
|
||||
|
||||
class UnknownFunctionCallArgumentError(Exception):
|
||||
pass
|
||||
|
||||
class UnsupportedFunctionError(Exception):
|
||||
pass
|
||||
|
||||
class FunctionInvocationError(Exception):
|
||||
pass
|
||||
|
||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||
from src.common.logger import get_logger
|
||||
@@ -44,18 +66,17 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
||||
|
||||
logger = get_logger("Gemini客户端")
|
||||
|
||||
gemini_safe_settings = [
|
||||
SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||
SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||
SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||
SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||
SafetySetting(category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||
SAFETY_SETTINGS = [
|
||||
{"category": HarmCategory.HARM_CATEGORY_HATE_SPEECH, "threshold": HarmBlockThreshold.BLOCK_NONE},
|
||||
{"category": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, "threshold": HarmBlockThreshold.BLOCK_NONE},
|
||||
{"category": HarmCategory.HARM_CATEGORY_HARASSMENT, "threshold": HarmBlockThreshold.BLOCK_NONE},
|
||||
{"category": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, "threshold": HarmBlockThreshold.BLOCK_NONE},
|
||||
]
|
||||
|
||||
|
||||
def _convert_messages(
|
||||
messages: list[Message],
|
||||
) -> tuple[ContentListUnion, list[str] | None]:
|
||||
) -> tuple[List[Dict], list[str] | None]:
|
||||
"""
|
||||
转换消息格式 - 将消息转换为Gemini API所需的格式
|
||||
:param messages: 消息列表
|
||||
@@ -81,7 +102,7 @@ def _convert_messages(
|
||||
normalized_format = format_mapping.get(image_format.lower(), image_format.lower())
|
||||
return f"image/{normalized_format}"
|
||||
|
||||
def _convert_message_item(message: Message) -> Content:
|
||||
def _convert_message_item(message: Message) -> Dict:
|
||||
"""
|
||||
转换单个消息格式,除了system和tool类型的消息
|
||||
:param message: 消息对象
|
||||
@@ -96,22 +117,25 @@ def _convert_messages(
|
||||
|
||||
# 添加Content
|
||||
if isinstance(message.content, str):
|
||||
content = [Part.from_text(text=message.content)]
|
||||
content = [{"text": message.content}]
|
||||
elif isinstance(message.content, list):
|
||||
content: List[Part] = []
|
||||
content = []
|
||||
for item in message.content:
|
||||
if isinstance(item, tuple):
|
||||
content.append(
|
||||
Part.from_bytes(data=base64.b64decode(item[1]), mime_type=_get_correct_mime_type(item[0]))
|
||||
)
|
||||
content.append({
|
||||
"inline_data": {
|
||||
"mime_type": _get_correct_mime_type(item[0]),
|
||||
"data": item[1]
|
||||
}
|
||||
})
|
||||
elif isinstance(item, str):
|
||||
content.append(Part.from_text(text=item))
|
||||
content.append({"text": item})
|
||||
else:
|
||||
raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
||||
|
||||
return Content(role=role, parts=content)
|
||||
return {"role": role, "parts": content}
|
||||
|
||||
temp_list: list[ContentUnion] = []
|
||||
temp_list: List[Dict] = []
|
||||
system_instructions: list[str] = []
|
||||
for message in messages:
|
||||
if message.role == RoleType.System:
|
||||
@@ -338,13 +362,10 @@ def _default_normal_response_parser(
|
||||
|
||||
@client_registry.register_client_class("gemini")
|
||||
class GeminiClient(BaseClient):
|
||||
client: genai.Client
|
||||
|
||||
def __init__(self, api_provider: APIProvider):
|
||||
super().__init__(api_provider)
|
||||
self.client = genai.Client(
|
||||
api_key=api_provider.api_key,
|
||||
) # 这里和openai不一样,gemini会自己决定自己是否需要retry
|
||||
# 配置 Google Generative AI
|
||||
genai.configure(api_key=api_provider.api_key)
|
||||
|
||||
async def get_response(
|
||||
self,
|
||||
@@ -396,18 +417,18 @@ class GeminiClient(BaseClient):
|
||||
"max_output_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"response_modalities": ["TEXT"],
|
||||
"thinking_config": ThinkingConfig(
|
||||
include_thoughts=True,
|
||||
thinking_budget=(
|
||||
"thinking_config": {
|
||||
"include_thoughts": True,
|
||||
"thinking_budget": (
|
||||
extra_params["thinking_budget"]
|
||||
if extra_params and "thinking_budget" in extra_params
|
||||
else int(max_tokens / 2) # 默认思考预算为最大token数的一半,防止空回复
|
||||
),
|
||||
),
|
||||
"safety_settings": gemini_safe_settings, # 防止空回复问题
|
||||
},
|
||||
"safety_settings": SAFETY_SETTINGS, # 防止空回复问题
|
||||
}
|
||||
if tools:
|
||||
generation_config_dict["tools"] = Tool(function_declarations=tools)
|
||||
generation_config_dict["tools"] = {"function_declarations": tools}
|
||||
if messages[1]:
|
||||
# 如果有system消息,则将其添加到配置中
|
||||
generation_config_dict["system_instructions"] = messages[1]
|
||||
@@ -417,15 +438,18 @@ class GeminiClient(BaseClient):
|
||||
generation_config_dict["response_mime_type"] = "application/json"
|
||||
generation_config_dict["response_schema"] = response_format.to_dict()
|
||||
|
||||
generation_config = GenerateContentConfig(**generation_config_dict)
|
||||
generation_config = generation_config_dict
|
||||
|
||||
try:
|
||||
# 创建模型实例
|
||||
model = genai.GenerativeModel(model_info.model_identifier)
|
||||
|
||||
if model_info.force_stream_mode:
|
||||
req_task = asyncio.create_task(
|
||||
self.client.aio.models.generate_content_stream(
|
||||
model=model_info.model_identifier,
|
||||
model.generate_content_async(
|
||||
contents=messages[0],
|
||||
config=generation_config,
|
||||
generation_config=generation_config,
|
||||
stream=True
|
||||
)
|
||||
)
|
||||
while not req_task.done():
|
||||
@@ -437,10 +461,9 @@ class GeminiClient(BaseClient):
|
||||
resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag)
|
||||
else:
|
||||
req_task = asyncio.create_task(
|
||||
self.client.aio.models.generate_content(
|
||||
model=model_info.model_identifier,
|
||||
model.generate_content_async(
|
||||
contents=messages[0],
|
||||
config=generation_config,
|
||||
generation_config=generation_config
|
||||
)
|
||||
)
|
||||
while not req_task.done():
|
||||
@@ -451,17 +474,18 @@ class GeminiClient(BaseClient):
|
||||
await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态
|
||||
|
||||
resp, usage_record = async_response_parser(req_task.result())
|
||||
except (ClientError, ServerError) as e:
|
||||
# 重封装ClientError和ServerError为RespNotOkException
|
||||
raise RespNotOkException(e.code, e.message) from None
|
||||
except (
|
||||
UnknownFunctionCallArgumentError,
|
||||
UnsupportedFunctionError,
|
||||
FunctionInvocationError,
|
||||
) as e:
|
||||
raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None
|
||||
except Exception as e:
|
||||
raise NetworkConnectionError() from e
|
||||
# 处理Google Generative AI异常
|
||||
if "rate limit" in str(e).lower():
|
||||
raise RespNotOkException(429, "请求频率过高,请稍后再试") from None
|
||||
elif "quota" in str(e).lower():
|
||||
raise RespNotOkException(429, "配额已用完") from None
|
||||
elif "invalid" in str(e).lower() or "bad request" in str(e).lower():
|
||||
raise RespNotOkException(400, f"请求无效:{str(e)}") from None
|
||||
elif "permission" in str(e).lower() or "forbidden" in str(e).lower():
|
||||
raise RespNotOkException(403, "权限不足") from None
|
||||
else:
|
||||
raise NetworkConnectionError() from e
|
||||
|
||||
if usage_record:
|
||||
resp.usage = UsageRecord(
|
||||
@@ -535,7 +559,7 @@ class GeminiClient(BaseClient):
|
||||
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
|
||||
),
|
||||
),
|
||||
"safety_settings": gemini_safe_settings,
|
||||
"safety_settings": SAFETY_SETTINGS,
|
||||
}
|
||||
generate_content_config = GenerateContentConfig(**generation_config_dict)
|
||||
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
|
||||
|
||||
Reference in New Issue
Block a user