Gemini音频转录功能,以及尝试防止空回复
This commit is contained in:
@@ -16,6 +16,9 @@ from google.genai.types import (
|
|||||||
GenerateContentConfig,
|
GenerateContentConfig,
|
||||||
EmbedContentResponse,
|
EmbedContentResponse,
|
||||||
EmbedContentConfig,
|
EmbedContentConfig,
|
||||||
|
SafetySetting,
|
||||||
|
HarmCategory,
|
||||||
|
HarmBlockThreshold,
|
||||||
)
|
)
|
||||||
from google.genai.errors import (
|
from google.genai.errors import (
|
||||||
ClientError,
|
ClientError,
|
||||||
@@ -41,6 +44,14 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
|||||||
|
|
||||||
logger = get_logger("Gemini客户端")
|
logger = get_logger("Gemini客户端")
|
||||||
|
|
||||||
|
gemini_safe_settings = [
|
||||||
|
SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||||
|
SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||||
|
SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||||
|
SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||||
|
SafetySetting(category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def _convert_messages(
|
def _convert_messages(
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
@@ -322,7 +333,7 @@ class GeminiClient(BaseClient):
|
|||||||
message_list: list[Message],
|
message_list: list[Message],
|
||||||
tool_options: list[ToolOption] | None = None,
|
tool_options: list[ToolOption] | None = None,
|
||||||
max_tokens: int = 1024,
|
max_tokens: int = 1024,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.4,
|
||||||
response_format: RespFormat | None = None,
|
response_format: RespFormat | None = None,
|
||||||
stream_response_handler: Optional[
|
stream_response_handler: Optional[
|
||||||
Callable[
|
Callable[
|
||||||
@@ -369,9 +380,12 @@ class GeminiClient(BaseClient):
|
|||||||
"thinking_config": ThinkingConfig(
|
"thinking_config": ThinkingConfig(
|
||||||
include_thoughts=True,
|
include_thoughts=True,
|
||||||
thinking_budget=(
|
thinking_budget=(
|
||||||
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else None
|
extra_params["thinking_budget"]
|
||||||
|
if extra_params and "thinking_budget" in extra_params
|
||||||
|
else int(max_tokens / 2) # 默认思考预算为最大token数的一半,防止空回复
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
"safety_settings": gemini_safe_settings, # 防止空回复问题
|
||||||
}
|
}
|
||||||
if tools:
|
if tools:
|
||||||
generation_config_dict["tools"] = Tool(function_declarations=tools)
|
generation_config_dict["tools"] = Tool(function_declarations=tools)
|
||||||
@@ -486,7 +500,57 @@ class GeminiClient(BaseClient):
|
|||||||
def get_audio_transcriptions(
|
def get_audio_transcriptions(
|
||||||
self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None
|
self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None
|
||||||
) -> APIResponse:
|
) -> APIResponse:
|
||||||
raise NotImplementedError("尚未实现音频转录功能")
|
"""
|
||||||
|
获取音频转录
|
||||||
|
:param model_info: 模型信息
|
||||||
|
:param audio_base64: 音频文件的Base64编码字符串
|
||||||
|
:param extra_params: 额外参数(可选)
|
||||||
|
:return: 转录响应
|
||||||
|
"""
|
||||||
|
generation_config_dict = {
|
||||||
|
"max_output_tokens": 2048,
|
||||||
|
"response_modalities": ["TEXT"],
|
||||||
|
"thinking_config": ThinkingConfig(
|
||||||
|
include_thoughts=True,
|
||||||
|
thinking_budget=(
|
||||||
|
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"safety_settings": gemini_safe_settings,
|
||||||
|
}
|
||||||
|
generate_content_config = GenerateContentConfig(**generation_config_dict)
|
||||||
|
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
|
||||||
|
try:
|
||||||
|
raw_response: GenerateContentResponse = self.client.models.generate_content(
|
||||||
|
model=model_info.model_identifier,
|
||||||
|
contents=[
|
||||||
|
Content(
|
||||||
|
role="user",
|
||||||
|
parts=[
|
||||||
|
Part.from_text(text=prompt),
|
||||||
|
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
config=generate_content_config,
|
||||||
|
)
|
||||||
|
resp, usage_record = _default_normal_response_parser(raw_response)
|
||||||
|
except (ClientError, ServerError) as e:
|
||||||
|
# 重封装ClientError和ServerError为RespNotOkException
|
||||||
|
raise RespNotOkException(e.code) from None
|
||||||
|
except Exception as e:
|
||||||
|
raise NetworkConnectionError() from e
|
||||||
|
|
||||||
|
if usage_record:
|
||||||
|
resp.usage = UsageRecord(
|
||||||
|
model_name=model_info.name,
|
||||||
|
provider_name=model_info.api_provider,
|
||||||
|
prompt_tokens=usage_record[0],
|
||||||
|
completion_tokens=usage_record[1],
|
||||||
|
total_tokens=usage_record[2],
|
||||||
|
)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
def get_support_image_formats(self) -> list[str]:
|
def get_support_image_formats(self) -> list[str]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user