fix typing

This commit is contained in:
UnCLAS-Prommer
2025-08-03 11:19:41 +08:00
parent 1f53ecff10
commit d15bd95bb0

View File

@@ -1,12 +1,22 @@
import asyncio
import io
import base64
from collections.abc import Iterable
from typing import Callable, TypeVar, AsyncIterator, Optional, Coroutine, Any, List
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List
from google import genai
from google.genai import types
from google.genai.types import FunctionDeclaration, GenerateContentResponse
from google.genai.types import (
Content,
Part,
FunctionDeclaration,
GenerateContentResponse,
ContentListUnion,
ContentUnion,
ThinkingConfig,
Tool,
GenerateContentConfig,
EmbedContentResponse,
EmbedContentConfig,
)
from google.genai.errors import (
ClientError,
ServerError,
@@ -28,19 +38,17 @@ from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat, RespFormatType
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
T = TypeVar("T")
def _convert_messages(
messages: list[Message],
) -> tuple[list[types.Content], list[str] | None]:
) -> tuple[ContentListUnion, list[str] | None]:
"""
转换消息格式 - 将消息转换为Gemini API所需的格式
:param messages: 消息列表
:return: 转换后的消息列表(和可能存在的system消息)
"""
def _convert_message_item(message: Message) -> types.Content:
def _convert_message_item(message: Message) -> Content:
"""
转换单个消息格式除了system和tool类型的消息
:param message: 消息对象
@@ -55,22 +63,22 @@ def _convert_messages(
# 添加Content
if isinstance(message.content, str):
content = [types.Part.from_text(text=message.content)]
content = [Part.from_text(text=message.content)]
elif isinstance(message.content, list):
content: List[types.Part] = []
content: List[Part] = []
for item in message.content:
if isinstance(item, tuple):
content.append(
types.Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{item[0].lower()}")
Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{item[0].lower()}")
)
elif isinstance(item, str):
content.append(types.Part.from_text(text=item))
content.append(Part.from_text(text=item))
else:
raise RuntimeError("无法触及的代码请使用MessageBuilder类构建消息对象")
return types.Content(role=role, parts=content)
return Content(role=role, parts=content)
temp_list: list[types.Content] = []
temp_list: list[ContentUnion] = []
system_instructions: list[str] = []
for message in messages:
if message.role == RoleType.System:
@@ -127,7 +135,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar
"properties": {param.name: _convert_tool_param(param) for param in tool_option.params},
"required": [param.name for param in tool_option.params if param.required],
}
ret1 = types.FunctionDeclaration(**ret)
ret1 = FunctionDeclaration(**ret)
return ret1
return [_convert_tool_option_item(tool_option) for tool_option in tool_options]
@@ -310,6 +318,7 @@ class GeminiClient(BaseClient):
Callable[[GenerateContentResponse], tuple[APIResponse, Optional[tuple[int, int, int]]]]
] = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取对话响应
@@ -343,11 +352,11 @@ class GeminiClient(BaseClient):
}
if "2.5" in model_info.model_identifier.lower():
# 我偷个懒在这里识别一下2.5然后开摆反正现在只有2.5支持思维链,然后我测试之后发现它不返回思考内容,反正我也怕他有朝一日返回了,我决定干掉任何有关的思维内容
generation_config_dict["thinking_config"] = types.ThinkingConfig(
generation_config_dict["thinking_config"] = ThinkingConfig(
thinking_budget=thinking_budget, include_thoughts=False
)
if tools:
generation_config_dict["tools"] = types.Tool(tools)
generation_config_dict["tools"] = Tool(function_declarations=tools)
if messages[1]:
# 如果有system消息则将其添加到配置中
generation_config_dict["system_instructions"] = messages[1]
@@ -357,7 +366,7 @@ class GeminiClient(BaseClient):
generation_config_dict["response_mime_type"] = "application/json"
generation_config_dict["response_schema"] = response_format.to_dict()
generation_config = types.GenerateContentConfig(**generation_config_dict)
generation_config = GenerateContentConfig(**generation_config_dict)
try:
if model_info.force_stream_mode:
@@ -418,6 +427,7 @@ class GeminiClient(BaseClient):
self,
model_info: ModelInfo,
embedding_input: str,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取文本嵌入
@@ -426,10 +436,10 @@ class GeminiClient(BaseClient):
:return: 嵌入响应
"""
try:
raw_response: types.EmbedContentResponse = await self.client.aio.models.embed_content(
raw_response: EmbedContentResponse = await self.client.aio.models.embed_content(
model=model_info.model_identifier,
contents=embedding_input,
config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
config=EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
)
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException