Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -1,12 +1,22 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
import base64
|
import base64
|
||||||
from collections.abc import Iterable
|
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List
|
||||||
from typing import Callable, TypeVar, AsyncIterator, Optional, Coroutine, Any, List
|
|
||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.genai import types
|
from google.genai.types import (
|
||||||
from google.genai.types import FunctionDeclaration, GenerateContentResponse
|
Content,
|
||||||
|
Part,
|
||||||
|
FunctionDeclaration,
|
||||||
|
GenerateContentResponse,
|
||||||
|
ContentListUnion,
|
||||||
|
ContentUnion,
|
||||||
|
ThinkingConfig,
|
||||||
|
Tool,
|
||||||
|
GenerateContentConfig,
|
||||||
|
EmbedContentResponse,
|
||||||
|
EmbedContentConfig,
|
||||||
|
)
|
||||||
from google.genai.errors import (
|
from google.genai.errors import (
|
||||||
ClientError,
|
ClientError,
|
||||||
ServerError,
|
ServerError,
|
||||||
@@ -28,19 +38,17 @@ from ..payload_content.message import Message, RoleType
|
|||||||
from ..payload_content.resp_format import RespFormat, RespFormatType
|
from ..payload_content.resp_format import RespFormat, RespFormatType
|
||||||
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_messages(
|
def _convert_messages(
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
) -> tuple[list[types.Content], list[str] | None]:
|
) -> tuple[ContentListUnion, list[str] | None]:
|
||||||
"""
|
"""
|
||||||
转换消息格式 - 将消息转换为Gemini API所需的格式
|
转换消息格式 - 将消息转换为Gemini API所需的格式
|
||||||
:param messages: 消息列表
|
:param messages: 消息列表
|
||||||
:return: 转换后的消息列表(和可能存在的system消息)
|
:return: 转换后的消息列表(和可能存在的system消息)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _convert_message_item(message: Message) -> types.Content:
|
def _convert_message_item(message: Message) -> Content:
|
||||||
"""
|
"""
|
||||||
转换单个消息格式,除了system和tool类型的消息
|
转换单个消息格式,除了system和tool类型的消息
|
||||||
:param message: 消息对象
|
:param message: 消息对象
|
||||||
@@ -55,22 +63,22 @@ def _convert_messages(
|
|||||||
|
|
||||||
# 添加Content
|
# 添加Content
|
||||||
if isinstance(message.content, str):
|
if isinstance(message.content, str):
|
||||||
content = [types.Part.from_text(text=message.content)]
|
content = [Part.from_text(text=message.content)]
|
||||||
elif isinstance(message.content, list):
|
elif isinstance(message.content, list):
|
||||||
content: List[types.Part] = []
|
content: List[Part] = []
|
||||||
for item in message.content:
|
for item in message.content:
|
||||||
if isinstance(item, tuple):
|
if isinstance(item, tuple):
|
||||||
content.append(
|
content.append(
|
||||||
types.Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{item[0].lower()}")
|
Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{item[0].lower()}")
|
||||||
)
|
)
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
content.append(types.Part.from_text(text=item))
|
content.append(Part.from_text(text=item))
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
||||||
|
|
||||||
return types.Content(role=role, parts=content)
|
return Content(role=role, parts=content)
|
||||||
|
|
||||||
temp_list: list[types.Content] = []
|
temp_list: list[ContentUnion] = []
|
||||||
system_instructions: list[str] = []
|
system_instructions: list[str] = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message.role == RoleType.System:
|
if message.role == RoleType.System:
|
||||||
@@ -127,7 +135,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar
|
|||||||
"properties": {param.name: _convert_tool_param(param) for param in tool_option.params},
|
"properties": {param.name: _convert_tool_param(param) for param in tool_option.params},
|
||||||
"required": [param.name for param in tool_option.params if param.required],
|
"required": [param.name for param in tool_option.params if param.required],
|
||||||
}
|
}
|
||||||
ret1 = types.FunctionDeclaration(**ret)
|
ret1 = FunctionDeclaration(**ret)
|
||||||
return ret1
|
return ret1
|
||||||
|
|
||||||
return [_convert_tool_option_item(tool_option) for tool_option in tool_options]
|
return [_convert_tool_option_item(tool_option) for tool_option in tool_options]
|
||||||
@@ -310,6 +318,7 @@ class GeminiClient(BaseClient):
|
|||||||
Callable[[GenerateContentResponse], tuple[APIResponse, Optional[tuple[int, int, int]]]]
|
Callable[[GenerateContentResponse], tuple[APIResponse, Optional[tuple[int, int, int]]]]
|
||||||
] = None,
|
] = None,
|
||||||
interrupt_flag: asyncio.Event | None = None,
|
interrupt_flag: asyncio.Event | None = None,
|
||||||
|
extra_params: dict[str, Any] | None = None,
|
||||||
) -> APIResponse:
|
) -> APIResponse:
|
||||||
"""
|
"""
|
||||||
获取对话响应
|
获取对话响应
|
||||||
@@ -343,11 +352,11 @@ class GeminiClient(BaseClient):
|
|||||||
}
|
}
|
||||||
if "2.5" in model_info.model_identifier.lower():
|
if "2.5" in model_info.model_identifier.lower():
|
||||||
# 我偷个懒,在这里识别一下2.5然后开摆,反正现在只有2.5支持思维链,然后我测试之后发现它不返回思考内容,反正我也怕他有朝一日返回了,我决定干掉任何有关的思维内容
|
# 我偷个懒,在这里识别一下2.5然后开摆,反正现在只有2.5支持思维链,然后我测试之后发现它不返回思考内容,反正我也怕他有朝一日返回了,我决定干掉任何有关的思维内容
|
||||||
generation_config_dict["thinking_config"] = types.ThinkingConfig(
|
generation_config_dict["thinking_config"] = ThinkingConfig(
|
||||||
thinking_budget=thinking_budget, include_thoughts=False
|
thinking_budget=thinking_budget, include_thoughts=False
|
||||||
)
|
)
|
||||||
if tools:
|
if tools:
|
||||||
generation_config_dict["tools"] = types.Tool(tools)
|
generation_config_dict["tools"] = Tool(function_declarations=tools)
|
||||||
if messages[1]:
|
if messages[1]:
|
||||||
# 如果有system消息,则将其添加到配置中
|
# 如果有system消息,则将其添加到配置中
|
||||||
generation_config_dict["system_instructions"] = messages[1]
|
generation_config_dict["system_instructions"] = messages[1]
|
||||||
@@ -357,7 +366,7 @@ class GeminiClient(BaseClient):
|
|||||||
generation_config_dict["response_mime_type"] = "application/json"
|
generation_config_dict["response_mime_type"] = "application/json"
|
||||||
generation_config_dict["response_schema"] = response_format.to_dict()
|
generation_config_dict["response_schema"] = response_format.to_dict()
|
||||||
|
|
||||||
generation_config = types.GenerateContentConfig(**generation_config_dict)
|
generation_config = GenerateContentConfig(**generation_config_dict)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if model_info.force_stream_mode:
|
if model_info.force_stream_mode:
|
||||||
@@ -418,6 +427,7 @@ class GeminiClient(BaseClient):
|
|||||||
self,
|
self,
|
||||||
model_info: ModelInfo,
|
model_info: ModelInfo,
|
||||||
embedding_input: str,
|
embedding_input: str,
|
||||||
|
extra_params: dict[str, Any] | None = None,
|
||||||
) -> APIResponse:
|
) -> APIResponse:
|
||||||
"""
|
"""
|
||||||
获取文本嵌入
|
获取文本嵌入
|
||||||
@@ -426,10 +436,10 @@ class GeminiClient(BaseClient):
|
|||||||
:return: 嵌入响应
|
:return: 嵌入响应
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
raw_response: types.EmbedContentResponse = await self.client.aio.models.embed_content(
|
raw_response: EmbedContentResponse = await self.client.aio.models.embed_content(
|
||||||
model=model_info.model_identifier,
|
model=model_info.model_identifier,
|
||||||
contents=embedding_input,
|
contents=embedding_input,
|
||||||
config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
|
config=EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
|
||||||
)
|
)
|
||||||
except (ClientError, ServerError) as e:
|
except (ClientError, ServerError) as e:
|
||||||
# 重封装ClientError和ServerError为RespNotOkException
|
# 重封装ClientError和ServerError为RespNotOkException
|
||||||
|
|||||||
Reference in New Issue
Block a user