From d15bd95bb0b29b731277cca908498ee1fe49e681 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sun, 3 Aug 2025 11:19:41 +0800 Subject: [PATCH] fix typing --- src/llm_models/model_client/gemini_client.py | 50 ++++++++++++-------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index 6a89cc0af..9a74d490b 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -1,12 +1,22 @@ import asyncio import io import base64 -from collections.abc import Iterable -from typing import Callable, TypeVar, AsyncIterator, Optional, Coroutine, Any, List +from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List from google import genai -from google.genai import types -from google.genai.types import FunctionDeclaration, GenerateContentResponse +from google.genai.types import ( + Content, + Part, + FunctionDeclaration, + GenerateContentResponse, + ContentListUnion, + ContentUnion, + ThinkingConfig, + Tool, + GenerateContentConfig, + EmbedContentResponse, + EmbedContentConfig, +) from google.genai.errors import ( ClientError, ServerError, @@ -28,19 +38,17 @@ from ..payload_content.message import Message, RoleType from ..payload_content.resp_format import RespFormat, RespFormatType from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall -T = TypeVar("T") - def _convert_messages( messages: list[Message], -) -> tuple[list[types.Content], list[str] | None]: +) -> tuple[ContentListUnion, list[str] | None]: """ 转换消息格式 - 将消息转换为Gemini API所需的格式 :param messages: 消息列表 :return: 转换后的消息列表(和可能存在的system消息) """ - def _convert_message_item(message: Message) -> types.Content: + def _convert_message_item(message: Message) -> Content: """ 转换单个消息格式,除了system和tool类型的消息 :param message: 消息对象 @@ -55,22 +63,22 @@ def _convert_messages( # 添加Content if isinstance(message.content, str): - content = [types.Part.from_text(text=message.content)] + content = [Part.from_text(text=message.content)] elif isinstance(message.content, list): - content: List[types.Part] = [] + content: List[Part] = [] for item in message.content: if isinstance(item, tuple): content.append( - types.Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{item[0].lower()}") + Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{item[0].lower()}") ) elif isinstance(item, str): - content.append(types.Part.from_text(text=item)) + content.append(Part.from_text(text=item)) else: raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象") - return types.Content(role=role, parts=content) + return Content(role=role, parts=content) - temp_list: list[types.Content] = [] + temp_list: list[ContentUnion] = [] system_instructions: list[str] = [] for message in messages: if message.role == RoleType.System: @@ -127,7 +135,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar "properties": {param.name: _convert_tool_param(param) for param in tool_option.params}, "required": [param.name for param in tool_option.params if param.required], } - ret1 = types.FunctionDeclaration(**ret) + ret1 = FunctionDeclaration(**ret) return ret1 return [_convert_tool_option_item(tool_option) for tool_option in tool_options] @@ -310,6 +318,7 @@ class GeminiClient(BaseClient): Callable[[GenerateContentResponse], tuple[APIResponse, Optional[tuple[int, int, int]]]] ] = None, interrupt_flag: asyncio.Event | None = None, + extra_params: dict[str, Any] | None = None, ) -> APIResponse: """ 获取对话响应 @@ -343,11 +352,11 @@ class GeminiClient(BaseClient): } if "2.5" in model_info.model_identifier.lower(): # 我偷个懒,在这里识别一下2.5然后开摆,反正现在只有2.5支持思维链,然后我测试之后发现它不返回思考内容,反正我也怕他有朝一日返回了,我决定干掉任何有关的思维内容 - generation_config_dict["thinking_config"] = types.ThinkingConfig( + generation_config_dict["thinking_config"] = ThinkingConfig( thinking_budget=thinking_budget, include_thoughts=False ) if tools: - generation_config_dict["tools"] = types.Tool(tools) + generation_config_dict["tools"] = Tool(function_declarations=tools) if messages[1]: # 如果有system消息,则将其添加到配置中 generation_config_dict["system_instructions"] = messages[1] @@ -357,7 +366,7 @@ class GeminiClient(BaseClient): generation_config_dict["response_mime_type"] = "application/json" generation_config_dict["response_schema"] = response_format.to_dict() - generation_config = types.GenerateContentConfig(**generation_config_dict) + generation_config = GenerateContentConfig(**generation_config_dict) try: if model_info.force_stream_mode: @@ -418,6 +427,7 @@ class GeminiClient(BaseClient): self, model_info: ModelInfo, embedding_input: str, + extra_params: dict[str, Any] | None = None, ) -> APIResponse: """ 获取文本嵌入 @@ -426,10 +436,10 @@ class GeminiClient(BaseClient): :return: 嵌入响应 """ try: - raw_response: types.EmbedContentResponse = await self.client.aio.models.embed_content( + raw_response: EmbedContentResponse = await self.client.aio.models.embed_content( model=model_info.model_identifier, contents=embedding_input, - config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"), + config=EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"), ) except (ClientError, ServerError) as e: # 重封装ClientError和ServerError为RespNotOkException