diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index e41270298..a74b466f1 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -16,6 +16,9 @@ from google.genai.types import ( GenerateContentConfig, EmbedContentResponse, EmbedContentConfig, + SafetySetting, + HarmCategory, + HarmBlockThreshold, ) from google.genai.errors import ( ClientError, @@ -41,6 +44,14 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall logger = get_logger("Gemini客户端") +gemini_safe_settings = [ + SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE), + SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE), + SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE), + SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE), + SafetySetting(category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, threshold=HarmBlockThreshold.BLOCK_NONE), +] + def _convert_messages( messages: list[Message], @@ -322,7 +333,7 @@ class GeminiClient(BaseClient): message_list: list[Message], tool_options: list[ToolOption] | None = None, max_tokens: int = 1024, - temperature: float = 0.7, + temperature: float = 0.4, response_format: RespFormat | None = None, stream_response_handler: Optional[ Callable[ @@ -369,9 +380,12 @@ class GeminiClient(BaseClient): "thinking_config": ThinkingConfig( include_thoughts=True, thinking_budget=( - extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else None + extra_params["thinking_budget"] + if extra_params and "thinking_budget" in extra_params + else int(max_tokens / 2) # 默认思考预算为最大token数的一半,防止空回复 ), ), + "safety_settings": gemini_safe_settings, # 防止空回复问题 } if tools: generation_config_dict["tools"] = Tool(function_declarations=tools) @@ -486,7 +500,57 @@ class GeminiClient(BaseClient): def get_audio_transcriptions( self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None ) -> APIResponse: - raise NotImplementedError("尚未实现音频转录功能") + """ + 获取音频转录 + :param model_info: 模型信息 + :param audio_base64: 音频文件的Base64编码字符串 + :param extra_params: 额外参数(可选) + :return: 转录响应 + """ + generation_config_dict = { + "max_output_tokens": 2048, + "response_modalities": ["TEXT"], + "thinking_config": ThinkingConfig( + include_thoughts=True, + thinking_budget=( + extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024 + ), + ), + "safety_settings": gemini_safe_settings, + } + generate_content_config = GenerateContentConfig(**generation_config_dict) + prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**." + try: + raw_response: GenerateContentResponse = self.client.models.generate_content( + model=model_info.model_identifier, + contents=[ + Content( + role="user", + parts=[ + Part.from_text(text=prompt), + Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"), + ], + ) + ], + config=generate_content_config, + ) + resp, usage_record = _default_normal_response_parser(raw_response) + except (ClientError, ServerError) as e: + # 重封装ClientError和ServerError为RespNotOkException + raise RespNotOkException(e.code) from None + except Exception as e: + raise NetworkConnectionError() from e + + if usage_record: + resp.usage = UsageRecord( + model_name=model_info.name, + provider_name=model_info.api_provider, + prompt_tokens=usage_record[0], + completion_tokens=usage_record[1], + total_tokens=usage_record[2], + ) + + return resp def get_support_image_formats(self) -> list[str]: """