Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -1,12 +1,22 @@
|
||||
import asyncio
|
||||
import io
|
||||
import base64
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, TypeVar, AsyncIterator, Optional, Coroutine, Any, List
|
||||
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.types import FunctionDeclaration, GenerateContentResponse
|
||||
from google.genai.types import (
|
||||
Content,
|
||||
Part,
|
||||
FunctionDeclaration,
|
||||
GenerateContentResponse,
|
||||
ContentListUnion,
|
||||
ContentUnion,
|
||||
ThinkingConfig,
|
||||
Tool,
|
||||
GenerateContentConfig,
|
||||
EmbedContentResponse,
|
||||
EmbedContentConfig,
|
||||
)
|
||||
from google.genai.errors import (
|
||||
ClientError,
|
||||
ServerError,
|
||||
@@ -28,19 +38,17 @@ from ..payload_content.message import Message, RoleType
|
||||
from ..payload_content.resp_format import RespFormat, RespFormatType
|
||||
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _convert_messages(
|
||||
messages: list[Message],
|
||||
) -> tuple[list[types.Content], list[str] | None]:
|
||||
) -> tuple[ContentListUnion, list[str] | None]:
|
||||
"""
|
||||
转换消息格式 - 将消息转换为Gemini API所需的格式
|
||||
:param messages: 消息列表
|
||||
:return: 转换后的消息列表(和可能存在的system消息)
|
||||
"""
|
||||
|
||||
def _convert_message_item(message: Message) -> types.Content:
|
||||
def _convert_message_item(message: Message) -> Content:
|
||||
"""
|
||||
转换单个消息格式,除了system和tool类型的消息
|
||||
:param message: 消息对象
|
||||
@@ -55,22 +63,22 @@ def _convert_messages(
|
||||
|
||||
# 添加Content
|
||||
if isinstance(message.content, str):
|
||||
content = [types.Part.from_text(text=message.content)]
|
||||
content = [Part.from_text(text=message.content)]
|
||||
elif isinstance(message.content, list):
|
||||
content: List[types.Part] = []
|
||||
content: List[Part] = []
|
||||
for item in message.content:
|
||||
if isinstance(item, tuple):
|
||||
content.append(
|
||||
types.Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{item[0].lower()}")
|
||||
Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{item[0].lower()}")
|
||||
)
|
||||
elif isinstance(item, str):
|
||||
content.append(types.Part.from_text(text=item))
|
||||
content.append(Part.from_text(text=item))
|
||||
else:
|
||||
raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
||||
|
||||
return types.Content(role=role, parts=content)
|
||||
return Content(role=role, parts=content)
|
||||
|
||||
temp_list: list[types.Content] = []
|
||||
temp_list: list[ContentUnion] = []
|
||||
system_instructions: list[str] = []
|
||||
for message in messages:
|
||||
if message.role == RoleType.System:
|
||||
@@ -127,7 +135,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar
|
||||
"properties": {param.name: _convert_tool_param(param) for param in tool_option.params},
|
||||
"required": [param.name for param in tool_option.params if param.required],
|
||||
}
|
||||
ret1 = types.FunctionDeclaration(**ret)
|
||||
ret1 = FunctionDeclaration(**ret)
|
||||
return ret1
|
||||
|
||||
return [_convert_tool_option_item(tool_option) for tool_option in tool_options]
|
||||
@@ -310,6 +318,7 @@ class GeminiClient(BaseClient):
|
||||
Callable[[GenerateContentResponse], tuple[APIResponse, Optional[tuple[int, int, int]]]]
|
||||
] = None,
|
||||
interrupt_flag: asyncio.Event | None = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
获取对话响应
|
||||
@@ -343,11 +352,11 @@ class GeminiClient(BaseClient):
|
||||
}
|
||||
if "2.5" in model_info.model_identifier.lower():
|
||||
# 我偷个懒,在这里识别一下2.5然后开摆,反正现在只有2.5支持思维链,然后我测试之后发现它不返回思考内容,反正我也怕他有朝一日返回了,我决定干掉任何有关的思维内容
|
||||
generation_config_dict["thinking_config"] = types.ThinkingConfig(
|
||||
generation_config_dict["thinking_config"] = ThinkingConfig(
|
||||
thinking_budget=thinking_budget, include_thoughts=False
|
||||
)
|
||||
if tools:
|
||||
generation_config_dict["tools"] = types.Tool(tools)
|
||||
generation_config_dict["tools"] = Tool(function_declarations=tools)
|
||||
if messages[1]:
|
||||
# 如果有system消息,则将其添加到配置中
|
||||
generation_config_dict["system_instructions"] = messages[1]
|
||||
@@ -357,7 +366,7 @@ class GeminiClient(BaseClient):
|
||||
generation_config_dict["response_mime_type"] = "application/json"
|
||||
generation_config_dict["response_schema"] = response_format.to_dict()
|
||||
|
||||
generation_config = types.GenerateContentConfig(**generation_config_dict)
|
||||
generation_config = GenerateContentConfig(**generation_config_dict)
|
||||
|
||||
try:
|
||||
if model_info.force_stream_mode:
|
||||
@@ -418,6 +427,7 @@ class GeminiClient(BaseClient):
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
embedding_input: str,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
获取文本嵌入
|
||||
@@ -426,10 +436,10 @@ class GeminiClient(BaseClient):
|
||||
:return: 嵌入响应
|
||||
"""
|
||||
try:
|
||||
raw_response: types.EmbedContentResponse = await self.client.aio.models.embed_content(
|
||||
raw_response: EmbedContentResponse = await self.client.aio.models.embed_content(
|
||||
model=model_info.model_identifier,
|
||||
contents=embedding_input,
|
||||
config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
|
||||
config=EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
|
||||
)
|
||||
except (ClientError, ServerError) as e:
|
||||
# 重封装ClientError和ServerError为RespNotOkException
|
||||
|
||||
Reference in New Issue
Block a user