修复Gemini api专属的那个gemini_client.py里面的一个潜在的导入问题并增加回退机制
This commit is contained in:
committed by
Windpicker-owo
parent
f7b99cc546
commit
6d231c4036
@@ -1,32 +1,54 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
import base64
|
import base64
|
||||||
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List
|
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List, Dict, Union
|
||||||
|
|
||||||
from google import genai
|
import google.generativeai as genai
|
||||||
from google.genai.types import (
|
from google.generativeai.types import (
|
||||||
Content,
|
|
||||||
Part,
|
|
||||||
FunctionDeclaration,
|
|
||||||
GenerateContentResponse,
|
GenerateContentResponse,
|
||||||
ContentListUnion,
|
|
||||||
ContentUnion,
|
|
||||||
ThinkingConfig,
|
|
||||||
Tool,
|
|
||||||
GenerateContentConfig,
|
|
||||||
EmbedContentResponse,
|
|
||||||
EmbedContentConfig,
|
|
||||||
SafetySetting,
|
|
||||||
HarmCategory,
|
HarmCategory,
|
||||||
HarmBlockThreshold,
|
HarmBlockThreshold,
|
||||||
)
|
)
|
||||||
from google.genai.errors import (
|
|
||||||
ClientError,
|
try:
|
||||||
ServerError,
|
# 尝试从较新的API导入
|
||||||
UnknownFunctionCallArgumentError,
|
from google.generativeai import configure
|
||||||
UnsupportedFunctionError,
|
from google.generativeai.types import SafetySetting, GenerationConfig
|
||||||
FunctionInvocationError,
|
except ImportError:
|
||||||
)
|
# 回退到基本类型
|
||||||
|
SafetySetting = Dict
|
||||||
|
GenerationConfig = Dict
|
||||||
|
|
||||||
|
# 定义兼容性类型
|
||||||
|
ContentDict = Dict
|
||||||
|
PartDict = Dict
|
||||||
|
ToolDict = Dict
|
||||||
|
FunctionDeclaration = Dict
|
||||||
|
Tool = Dict
|
||||||
|
ContentListUnion = List[Dict]
|
||||||
|
ContentUnion = Dict
|
||||||
|
Content = Dict
|
||||||
|
Part = Dict
|
||||||
|
ThinkingConfig = Dict
|
||||||
|
GenerateContentConfig = Dict
|
||||||
|
EmbedContentConfig = Dict
|
||||||
|
EmbedContentResponse = Dict
|
||||||
|
|
||||||
|
# 定义异常类型
|
||||||
|
class ClientError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ServerError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class UnknownFunctionCallArgumentError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class UnsupportedFunctionError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class FunctionInvocationError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -44,18 +66,17 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
|||||||
|
|
||||||
logger = get_logger("Gemini客户端")
|
logger = get_logger("Gemini客户端")
|
||||||
|
|
||||||
gemini_safe_settings = [
|
SAFETY_SETTINGS = [
|
||||||
SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE),
|
{"category": HarmCategory.HARM_CATEGORY_HATE_SPEECH, "threshold": HarmBlockThreshold.BLOCK_NONE},
|
||||||
SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
{"category": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, "threshold": HarmBlockThreshold.BLOCK_NONE},
|
||||||
SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
{"category": HarmCategory.HARM_CATEGORY_HARASSMENT, "threshold": HarmBlockThreshold.BLOCK_NONE},
|
||||||
SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
{"category": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, "threshold": HarmBlockThreshold.BLOCK_NONE},
|
||||||
SafetySetting(category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, threshold=HarmBlockThreshold.BLOCK_NONE),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _convert_messages(
|
def _convert_messages(
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
) -> tuple[ContentListUnion, list[str] | None]:
|
) -> tuple[List[Dict], list[str] | None]:
|
||||||
"""
|
"""
|
||||||
转换消息格式 - 将消息转换为Gemini API所需的格式
|
转换消息格式 - 将消息转换为Gemini API所需的格式
|
||||||
:param messages: 消息列表
|
:param messages: 消息列表
|
||||||
@@ -81,7 +102,7 @@ def _convert_messages(
|
|||||||
normalized_format = format_mapping.get(image_format.lower(), image_format.lower())
|
normalized_format = format_mapping.get(image_format.lower(), image_format.lower())
|
||||||
return f"image/{normalized_format}"
|
return f"image/{normalized_format}"
|
||||||
|
|
||||||
def _convert_message_item(message: Message) -> Content:
|
def _convert_message_item(message: Message) -> Dict:
|
||||||
"""
|
"""
|
||||||
转换单个消息格式,除了system和tool类型的消息
|
转换单个消息格式,除了system和tool类型的消息
|
||||||
:param message: 消息对象
|
:param message: 消息对象
|
||||||
@@ -96,23 +117,25 @@ def _convert_messages(
|
|||||||
|
|
||||||
# 添加Content
|
# 添加Content
|
||||||
if isinstance(message.content, str):
|
if isinstance(message.content, str):
|
||||||
content = [Part.from_text(text=message.content)]
|
content = [{"text": message.content}]
|
||||||
elif isinstance(message.content, list):
|
elif isinstance(message.content, list):
|
||||||
content: List[Part] = []
|
content = []
|
||||||
for item in message.content:
|
for item in message.content:
|
||||||
if isinstance(item, tuple):
|
if isinstance(item, tuple):
|
||||||
image_format = "jpeg" if item[0].lower() == "jpg" else item[0].lower()
|
content.append({
|
||||||
content.append(
|
"inline_data": {
|
||||||
Part.from_bytes(data=base64.b64decode(item[1]), mime_type=_get_correct_mime_type(item[0]))
|
"mime_type": _get_correct_mime_type(item[0]),
|
||||||
)
|
"data": item[1]
|
||||||
|
}
|
||||||
|
})
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
content.append(Part.from_text(text=item))
|
content.append({"text": item})
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
||||||
|
|
||||||
return Content(role=role, parts=content)
|
return {"role": role, "parts": content}
|
||||||
|
|
||||||
temp_list: list[ContentUnion] = []
|
temp_list: List[Dict] = []
|
||||||
system_instructions: list[str] = []
|
system_instructions: list[str] = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message.role == RoleType.System:
|
if message.role == RoleType.System:
|
||||||
@@ -339,13 +362,10 @@ def _default_normal_response_parser(
|
|||||||
|
|
||||||
@client_registry.register_client_class("gemini")
|
@client_registry.register_client_class("gemini")
|
||||||
class GeminiClient(BaseClient):
|
class GeminiClient(BaseClient):
|
||||||
client: genai.Client
|
|
||||||
|
|
||||||
def __init__(self, api_provider: APIProvider):
|
def __init__(self, api_provider: APIProvider):
|
||||||
super().__init__(api_provider)
|
super().__init__(api_provider)
|
||||||
self.client = genai.Client(
|
# 配置 Google Generative AI
|
||||||
api_key=api_provider.api_key,
|
genai.configure(api_key=api_provider.api_key)
|
||||||
) # 这里和openai不一样,gemini会自己决定自己是否需要retry
|
|
||||||
|
|
||||||
async def get_response(
|
async def get_response(
|
||||||
self,
|
self,
|
||||||
@@ -397,18 +417,18 @@ class GeminiClient(BaseClient):
|
|||||||
"max_output_tokens": max_tokens,
|
"max_output_tokens": max_tokens,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"response_modalities": ["TEXT"],
|
"response_modalities": ["TEXT"],
|
||||||
"thinking_config": ThinkingConfig(
|
"thinking_config": {
|
||||||
include_thoughts=True,
|
"include_thoughts": True,
|
||||||
thinking_budget=(
|
"thinking_budget": (
|
||||||
extra_params["thinking_budget"]
|
extra_params["thinking_budget"]
|
||||||
if extra_params and "thinking_budget" in extra_params
|
if extra_params and "thinking_budget" in extra_params
|
||||||
else int(max_tokens / 2) # 默认思考预算为最大token数的一半,防止空回复
|
else int(max_tokens / 2) # 默认思考预算为最大token数的一半,防止空回复
|
||||||
),
|
),
|
||||||
),
|
},
|
||||||
"safety_settings": gemini_safe_settings, # 防止空回复问题
|
"safety_settings": SAFETY_SETTINGS, # 防止空回复问题
|
||||||
}
|
}
|
||||||
if tools:
|
if tools:
|
||||||
generation_config_dict["tools"] = Tool(function_declarations=tools)
|
generation_config_dict["tools"] = {"function_declarations": tools}
|
||||||
if messages[1]:
|
if messages[1]:
|
||||||
# 如果有system消息,则将其添加到配置中
|
# 如果有system消息,则将其添加到配置中
|
||||||
generation_config_dict["system_instructions"] = messages[1]
|
generation_config_dict["system_instructions"] = messages[1]
|
||||||
@@ -418,15 +438,18 @@ class GeminiClient(BaseClient):
|
|||||||
generation_config_dict["response_mime_type"] = "application/json"
|
generation_config_dict["response_mime_type"] = "application/json"
|
||||||
generation_config_dict["response_schema"] = response_format.to_dict()
|
generation_config_dict["response_schema"] = response_format.to_dict()
|
||||||
|
|
||||||
generation_config = GenerateContentConfig(**generation_config_dict)
|
generation_config = generation_config_dict
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 创建模型实例
|
||||||
|
model = genai.GenerativeModel(model_info.model_identifier)
|
||||||
|
|
||||||
if model_info.force_stream_mode:
|
if model_info.force_stream_mode:
|
||||||
req_task = asyncio.create_task(
|
req_task = asyncio.create_task(
|
||||||
self.client.aio.models.generate_content_stream(
|
model.generate_content_async(
|
||||||
model=model_info.model_identifier,
|
|
||||||
contents=messages[0],
|
contents=messages[0],
|
||||||
config=generation_config,
|
generation_config=generation_config,
|
||||||
|
stream=True
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
while not req_task.done():
|
while not req_task.done():
|
||||||
@@ -438,10 +461,9 @@ class GeminiClient(BaseClient):
|
|||||||
resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag)
|
resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag)
|
||||||
else:
|
else:
|
||||||
req_task = asyncio.create_task(
|
req_task = asyncio.create_task(
|
||||||
self.client.aio.models.generate_content(
|
model.generate_content_async(
|
||||||
model=model_info.model_identifier,
|
|
||||||
contents=messages[0],
|
contents=messages[0],
|
||||||
config=generation_config,
|
generation_config=generation_config
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
while not req_task.done():
|
while not req_task.done():
|
||||||
@@ -452,16 +474,17 @@ class GeminiClient(BaseClient):
|
|||||||
await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态
|
await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态
|
||||||
|
|
||||||
resp, usage_record = async_response_parser(req_task.result())
|
resp, usage_record = async_response_parser(req_task.result())
|
||||||
except (ClientError, ServerError) as e:
|
|
||||||
# 重封装ClientError和ServerError为RespNotOkException
|
|
||||||
raise RespNotOkException(e.code, e.message) from None
|
|
||||||
except (
|
|
||||||
UnknownFunctionCallArgumentError,
|
|
||||||
UnsupportedFunctionError,
|
|
||||||
FunctionInvocationError,
|
|
||||||
) as e:
|
|
||||||
raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# 处理Google Generative AI异常
|
||||||
|
if "rate limit" in str(e).lower():
|
||||||
|
raise RespNotOkException(429, "请求频率过高,请稍后再试") from None
|
||||||
|
elif "quota" in str(e).lower():
|
||||||
|
raise RespNotOkException(429, "配额已用完") from None
|
||||||
|
elif "invalid" in str(e).lower() or "bad request" in str(e).lower():
|
||||||
|
raise RespNotOkException(400, f"请求无效:{str(e)}") from None
|
||||||
|
elif "permission" in str(e).lower() or "forbidden" in str(e).lower():
|
||||||
|
raise RespNotOkException(403, "权限不足") from None
|
||||||
|
else:
|
||||||
raise NetworkConnectionError() from e
|
raise NetworkConnectionError() from e
|
||||||
|
|
||||||
if usage_record:
|
if usage_record:
|
||||||
@@ -536,7 +559,7 @@ class GeminiClient(BaseClient):
|
|||||||
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
|
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
"safety_settings": gemini_safe_settings,
|
"safety_settings": SAFETY_SETTINGS,
|
||||||
}
|
}
|
||||||
generate_content_config = GenerateContentConfig(**generation_config_dict)
|
generate_content_config = GenerateContentConfig(**generation_config_dict)
|
||||||
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
|
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
|
||||||
|
|||||||
Reference in New Issue
Block a user