Gemini音频转录功能,以及尝试防止空回复

This commit is contained in:
UnCLAS-Prommer
2025-08-04 20:12:24 +08:00
parent 998eed4a43
commit cbe244d8f6

View File

@@ -16,6 +16,9 @@ from google.genai.types import (
GenerateContentConfig,
EmbedContentResponse,
EmbedContentConfig,
SafetySetting,
HarmCategory,
HarmBlockThreshold,
)
from google.genai.errors import (
ClientError,
@@ -41,6 +44,14 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
logger = get_logger("Gemini客户端")
gemini_safe_settings = [
SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE),
SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE),
SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE),
SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE),
SafetySetting(category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, threshold=HarmBlockThreshold.BLOCK_NONE),
]
def _convert_messages(
messages: list[Message],
@@ -322,7 +333,7 @@ class GeminiClient(BaseClient):
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
temperature: float = 0.4,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[
@@ -369,9 +380,12 @@ class GeminiClient(BaseClient):
"thinking_config": ThinkingConfig(
include_thoughts=True,
thinking_budget=(
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else None
extra_params["thinking_budget"]
if extra_params and "thinking_budget" in extra_params
else int(max_tokens / 2) # 默认思考预算为最大token数的一半防止空回复
),
),
"safety_settings": gemini_safe_settings, # 防止空回复问题
}
if tools:
generation_config_dict["tools"] = Tool(function_declarations=tools)
@@ -486,7 +500,57 @@ class GeminiClient(BaseClient):
def get_audio_transcriptions(
self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None
) -> APIResponse:
raise NotImplementedError("尚未实现音频转录功能")
"""
获取音频转录
:param model_info: 模型信息
:param audio_base64: 音频文件的Base64编码字符串
:param extra_params: 额外参数(可选)
:return: 转录响应
"""
generation_config_dict = {
"max_output_tokens": 2048,
"response_modalities": ["TEXT"],
"thinking_config": ThinkingConfig(
include_thoughts=True,
thinking_budget=(
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
),
),
"safety_settings": gemini_safe_settings,
}
generate_content_config = GenerateContentConfig(**generation_config_dict)
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
try:
raw_response: GenerateContentResponse = self.client.models.generate_content(
model=model_info.model_identifier,
contents=[
Content(
role="user",
parts=[
Part.from_text(text=prompt),
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
],
)
],
config=generate_content_config,
)
resp, usage_record = _default_normal_response_parser(raw_response)
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.code) from None
except Exception as e:
raise NetworkConnectionError() from e
if usage_record:
resp.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=usage_record[0],
completion_tokens=usage_record[1],
total_tokens=usage_record[2],
)
return resp
def get_support_image_formats(self) -> list[str]:
"""