解决openai_client的lint问题

This commit is contained in:
UnCLAS-Prommer
2025-07-31 00:49:59 +08:00
parent 5413c41a01
commit 82b5230df1
3 changed files with 97 additions and 195 deletions

View File

@@ -0,0 +1,8 @@
from src.config.config import model_config
used_client_types = {provider.client_type for provider in model_config.api_providers}
if "openai" in used_client_types:
from . import openai_client # noqa: F401
if "gemini" in used_client_types:
from . import gemini_client # noqa: F401

View File

@@ -1,7 +1,7 @@
import asyncio
import io
from collections.abc import Iterable
from typing import Callable, Iterator, TypeVar, AsyncIterator
from typing import Callable, Iterator, TypeVar, AsyncIterator, Optional, Coroutine, Any
from google import genai
from google.genai import types
@@ -14,11 +14,9 @@ from google.genai.errors import (
FunctionInvocationError,
)
from .base_client import APIResponse, UsageRecord
from src.config.api_ada_configs import ModelInfo, APIProvider
from . import BaseClient
from src.common.logger import get_logger
from .base_client import APIResponse, UsageRecord, BaseClient
from ..exceptions import (
RespParseException,
NetworkConnectionError,
@@ -29,7 +27,6 @@ from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat, RespFormatType
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
logger = get_logger("Gemini客户端")
T = TypeVar("T")
@@ -63,11 +60,7 @@ def _convert_messages(
content = []
for item in message.content:
if isinstance(item, tuple):
content.append(
types.Part.from_bytes(
data=item[1], mime_type=f"image/{item[0].lower()}"
)
)
content.append(types.Part.from_bytes(data=item[1], mime_type=f"image/{item[0].lower()}"))
elif isinstance(item, str):
content.append(types.Part.from_text(item))
else:
@@ -122,20 +115,15 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar
:param tool_option: 工具选项对象
:return: 转换后的Gemini工具选项对象
"""
ret = {
ret: dict[str, Any] = {
"name": tool_option.name,
"description": tool_option.description,
}
if tool_option.params:
ret["parameters"] = {
"type": "object",
"properties": {
param.name: _convert_tool_param(param)
for param in tool_option.params
},
"required": [
param.name for param in tool_option.params if param.required
],
"properties": {param.name: _convert_tool_param(param) for param in tool_option.params},
"required": [param.name for param in tool_option.params if param.required],
}
ret1 = types.FunctionDeclaration(**ret)
return ret1
@@ -157,12 +145,8 @@ def _process_delta(
if delta.function_calls: # 为什么不用hasattr呢是因为这个属性一定有即使是个空的
for call in delta.function_calls:
try:
if not isinstance(
call.args, dict
): # gemini返回的function call参数就是dict格式的了
raise RespParseException(
delta, "响应解析失败,工具调用参数无法解析为字典类型"
)
if not isinstance(call.args, dict): # gemini返回的function call参数就是dict格式的了
raise RespParseException(delta, "响应解析失败,工具调用参数无法解析为字典类型")
tool_calls_buffer.append(
(
call.id,
@@ -178,6 +162,7 @@ def _build_stream_api_resp(
_fc_delta_buffer: io.StringIO,
_tool_calls_buffer: list[tuple[str, str, dict]],
) -> APIResponse:
# sourcery skip: simplify-len-comparison, use-assigned-variable
resp = APIResponse()
if _fc_delta_buffer.tell() > 0:
@@ -193,8 +178,7 @@ def _build_stream_api_resp(
if not isinstance(arguments, dict):
raise RespParseException(
None,
"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n"
f"{arguments_buffer}",
f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{arguments_buffer}",
)
else:
arguments = None
@@ -218,16 +202,14 @@ async def _to_async_iterable(iterable: Iterable[T]) -> AsyncIterator[T]:
async def _default_stream_response_handler(
resp_stream: Iterator[GenerateContentResponse],
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, tuple[int, int, int]]:
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
"""
流式响应处理函数 - 处理Gemini API的流式响应
:param resp_stream: 流式响应对象,是一个神秘的iterator我完全不知道这个玩意能不能跑不过遍历一遍之后它就空了如果跑不了一点的话可以考虑改成别的东西
:return: APIResponse对象
"""
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
_tool_calls_buffer: list[
tuple[str, str, dict]
] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
def _insure_buffer_closed():
@@ -250,8 +232,7 @@ async def _default_stream_response_handler(
# 如果有使用情况则将其存储在APIResponse对象中
_usage_record = (
chunk.usage_metadata.prompt_token_count,
chunk.usage_metadata.candidates_token_count
+ chunk.usage_metadata.thoughts_token_count,
chunk.usage_metadata.candidates_token_count + chunk.usage_metadata.thoughts_token_count,
chunk.usage_metadata.total_token_count,
)
try:
@@ -267,7 +248,7 @@ async def _default_stream_response_handler(
def _default_normal_response_parser(
resp: GenerateContentResponse,
) -> tuple[APIResponse, tuple[int, int, int]]:
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
"""
解析对话补全响应 - 将Gemini API响应解析为APIResponse对象
:param resp: 响应对象
@@ -286,20 +267,15 @@ def _default_normal_response_parser(
for call in resp.function_calls:
try:
if not isinstance(call.args, dict):
raise RespParseException(
resp, "响应解析失败,工具调用参数无法解析为字典类型"
)
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
api_response.tool_calls.append(ToolCall(call.id, call.name, call.args))
except Exception as e:
raise RespParseException(
resp, "响应解析失败,无法解析工具调用参数"
) from e
raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e
if resp.usage_metadata:
_usage_record = (
resp.usage_metadata.prompt_token_count,
resp.usage_metadata.candidates_token_count
+ resp.usage_metadata.thoughts_token_count,
resp.usage_metadata.candidates_token_count + resp.usage_metadata.thoughts_token_count,
resp.usage_metadata.total_token_count,
)
else:
@@ -311,55 +287,13 @@ def _default_normal_response_parser(
class GeminiClient(BaseClient):
client: genai.Client
def __init__(self, api_provider: APIProvider):
super().__init__(api_provider)
# 不再在初始化时创建固定的client而是在请求时动态创建
self._clients_cache = {} # API Key -> genai.Client 的缓存
def _get_client(self, api_key: str = None) -> genai.Client:
"""获取或创建对应API Key的客户端"""
if api_key is None:
api_key = self.api_provider.get_current_api_key()
if not api_key:
raise ValueError(f"API Provider '{self.api_provider.name}' 没有可用的API Key")
# 使用缓存避免重复创建客户端
if api_key not in self._clients_cache:
self._clients_cache[api_key] = genai.Client(api_key=api_key)
return self._clients_cache[api_key]
async def _execute_with_fallback(self, func, *args, **kwargs):
"""执行请求并在失败时切换API Key"""
current_api_key = self.api_provider.get_current_api_key()
max_attempts = len(self.api_provider.api_keys) if self.api_provider.api_keys else 1
for attempt in range(max_attempts):
try:
client = self._get_client(current_api_key)
result = await func(client, *args, **kwargs)
# 成功时重置失败计数
self.api_provider.reset_key_failures(current_api_key)
return result
except (ClientError, ServerError) as e:
# 记录失败并尝试下一个API Key
logger.warning(f"API Key失败 (尝试 {attempt + 1}/{max_attempts}): {str(e)}")
if attempt < max_attempts - 1: # 还有重试机会
next_api_key = self.api_provider.mark_key_failed(current_api_key)
if next_api_key and next_api_key != current_api_key:
current_api_key = next_api_key
logger.info(f"切换到下一个API Key: {current_api_key[:8]}***{current_api_key[-4:]}")
continue
# 所有API Key都失败了重新抛出异常
raise RespNotOkException(e.status_code, e.message) from e
except Exception as e:
# 其他异常直接抛出
raise e
self.client = genai.Client(
api_key=api_provider.api_key,
) # 这里和openai不一样gemini会自己决定自己是否需要retry
async def get_response(
self,
@@ -370,12 +304,15 @@ class GeminiClient(BaseClient):
temperature: float = 0.7,
thinking_budget: int = 0,
response_format: RespFormat | None = None,
stream_response_handler: Callable[
[Iterator[GenerateContentResponse], asyncio.Event | None], APIResponse
stream_response_handler: Optional[
Callable[
[Iterator[GenerateContentResponse], asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
]
| None = None,
async_response_parser: Callable[[GenerateContentResponse], APIResponse]
| None = None,
] = None,
async_response_parser: Optional[
Callable[[GenerateContentResponse], tuple[APIResponse, Optional[tuple[int, int, int]]]]
] = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""
@@ -392,39 +329,6 @@ class GeminiClient(BaseClient):
:param interrupt_flag: 中断信号量可选默认为None
:return: (响应文本, 推理文本, 工具调用, 其他数据)
"""
return await self._execute_with_fallback(
self._get_response_internal,
model_info,
message_list,
tool_options,
max_tokens,
temperature,
thinking_budget,
response_format,
stream_response_handler,
async_response_parser,
interrupt_flag,
)
async def _get_response_internal(
self,
client: genai.Client,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
thinking_budget: int = 0,
response_format: RespFormat | None = None,
stream_response_handler: Callable[
[Iterator[GenerateContentResponse], asyncio.Event | None], APIResponse
]
| None = None,
async_response_parser: Callable[[GenerateContentResponse], APIResponse]
| None = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""内部方法执行实际的API调用"""
if stream_response_handler is None:
stream_response_handler = _default_stream_response_handler
@@ -462,7 +366,7 @@ class GeminiClient(BaseClient):
try:
if model_info.force_stream_mode:
req_task = asyncio.create_task(
client.aio.models.generate_content_stream(
self.client.aio.models.generate_content_stream(
model=model_info.model_identifier,
contents=messages[0],
config=generation_config,
@@ -474,12 +378,10 @@ class GeminiClient(BaseClient):
req_task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态
resp, usage_record = await stream_response_handler(
req_task.result(), interrupt_flag
)
resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag)
else:
req_task = asyncio.create_task(
client.aio.models.generate_content(
self.client.aio.models.generate_content(
model=model_info.model_identifier,
contents=messages[0],
config=generation_config,
@@ -495,13 +397,13 @@ class GeminiClient(BaseClient):
resp, usage_record = async_response_parser(req_task.result())
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.status_code, e.message) from e
raise RespNotOkException(e.status_code, e.message) from None
except (
UnknownFunctionCallArgumentError,
UnsupportedFunctionError,
FunctionInvocationError,
) as e:
raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from e
raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None
except Exception as e:
raise NetworkConnectionError() from e
@@ -527,30 +429,15 @@ class GeminiClient(BaseClient):
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
"""
return await self._execute_with_fallback(
self._get_embedding_internal,
model_info,
embedding_input,
)
async def _get_embedding_internal(
self,
client: genai.Client,
model_info: ModelInfo,
embedding_input: str,
) -> APIResponse:
"""内部方法执行实际的嵌入API调用"""
try:
raw_response: types.EmbedContentResponse = (
await client.aio.models.embed_content(
raw_response: types.EmbedContentResponse = await self.client.aio.models.embed_content(
model=model_info.model_identifier,
contents=embedding_input,
config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
)
)
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.status_code) from e
raise RespNotOkException(e.status_code) from None
except Exception as e:
raise NetworkConnectionError() from e

View File

@@ -3,7 +3,8 @@ import io
import json
import re
from collections.abc import Iterable
from typing import Callable, Any
from typing import Callable, Any, Coroutine, Optional
from json_repair import repair_json
from openai import (
AsyncOpenAI,
@@ -20,11 +21,9 @@ from openai.types.chat import (
)
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from .base_client import APIResponse, UsageRecord
from src.config.api_ada_configs import ModelInfo, APIProvider
from .base_client import BaseClient, client_registry
from src.common.logger import get_logger
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
from ..exceptions import (
RespParseException,
NetworkConnectionError,
@@ -82,7 +81,7 @@ def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessagePara
raise ValueError("无法触及的代码请使用MessageBuilder类构建消息对象")
ret["tool_call_id"] = message.tool_call_id
return ret
return ret # type: ignore
return [_convert_message_item(message) for message in messages]
@@ -143,10 +142,10 @@ def _process_delta(
# 接收content
if has_rc_attr_flag:
# 有独立的推理内容块则无需考虑content内容的判读
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore
# 如果有推理内容,则将其写入推理内容缓冲区
assert isinstance(delta.reasoning_content, str)
rc_delta_buffer.write(delta.reasoning_content)
assert isinstance(delta.reasoning_content, str) # type: ignore
rc_delta_buffer.write(delta.reasoning_content) # type: ignore
elif delta.content:
# 如果有正式内容,则将其写入正式内容缓冲区
fc_delta_buffer.write(delta.content)
@@ -173,6 +172,7 @@ def _process_delta(
if tool_call_delta.index >= len(tool_calls_buffer):
# 调用索引号大于等于缓冲区长度,说明是新的工具调用
if tool_call_delta.id and tool_call_delta.function and tool_call_delta.function.name:
tool_calls_buffer.append(
(
tool_call_delta.id,
@@ -180,8 +180,10 @@ def _process_delta(
io.StringIO(),
)
)
else:
logger.warning("工具调用索引号大于等于缓冲区长度但缺少ID或函数信息。")
if tool_call_delta.function.arguments:
if tool_call_delta.function and tool_call_delta.function.arguments:
# 如果有工具调用参数,则添加到对应的工具调用的参数串缓冲区中
tool_calls_buffer[tool_call_delta.index][2].write(tool_call_delta.function.arguments)
@@ -212,7 +214,7 @@ def _build_stream_api_resp(
raw_arg_data = arguments_buffer.getvalue()
arguments_buffer.close()
try:
arguments = json.loads(raw_arg_data)
arguments = json.loads(repair_json(raw_arg_data))
if not isinstance(arguments, dict):
raise RespParseException(
None,
@@ -235,7 +237,7 @@ def _build_stream_api_resp(
async def _default_stream_response_handler(
resp_stream: AsyncStream[ChatCompletionChunk],
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, tuple[int, int, int]]:
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
"""
流式响应处理函数 - 处理OpenAI API的流式响应
:param resp_stream: 流式响应对象
@@ -309,7 +311,7 @@ pattern = re.compile(
def _default_normal_response_parser(
resp: ChatCompletion,
) -> tuple[APIResponse, tuple[int, int, int]]:
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
"""
解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象
:param resp: 响应对象
@@ -343,7 +345,7 @@ def _default_normal_response_parser(
api_response.tool_calls = []
for call in message_part.tool_calls:
try:
arguments = json.loads(call.function.arguments)
arguments = json.loads(repair_json(call.function.arguments))
if not isinstance(arguments, dict):
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
@@ -384,26 +386,31 @@ class OpenaiClient(BaseClient):
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Callable[
stream_response_handler: Optional[
Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
tuple[APIResponse, tuple[int, int, int]],
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
]
| None = None,
async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]] | None = None,
] = None,
async_response_parser: Optional[
Callable[[ChatCompletion], tuple[APIResponse, Optional[tuple[int, int, int]]]]
] = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""
获取对话响应
:param model_info: 模型信息
:param message_list: 对话体
:param tool_options: 工具选项可选默认为None
:param max_tokens: 最大token数(可选,默认为1024
:param temperature: 温度(可选,默认为0.7
:param response_format: 响应格式(可选,默认为 NotGiven
:param stream_response_handler: 流式响应处理函数可选默认为default_stream_response_handler
:param async_response_parser: 响应解析函数可选默认为default_response_parser
:param interrupt_flag: 中断信号量可选默认为None
:return: (响应文本, 推理文本, 工具调用, 其他数据)
Args:
model_info: 模型信息
message_list: 对话体
tool_options: 工具选项(可选,默认为None
max_tokens: 最大token数(可选,默认为1024
temperature: 温度可选默认为0.7
response_format: 响应格式(可选,默认为 NotGiven
stream_response_handler: 流式响应处理函数可选默认为default_stream_response_handler
async_response_parser: 响应解析函数可选默认为default_response_parser
interrupt_flag: 中断信号量可选默认为None
Returns:
(响应文本, 推理文本, 工具调用, 其他数据)
"""
if stream_response_handler is None:
stream_response_handler = _default_stream_response_handler
@@ -414,7 +421,7 @@ class OpenaiClient(BaseClient):
# 将messages构造为OpenAI API所需的格式
messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list)
# 将tool_options转换为OpenAI API所需的格式
tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN
tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore
try:
if model_info.force_stream_mode:
@@ -426,7 +433,7 @@ class OpenaiClient(BaseClient):
temperature=temperature,
max_tokens=max_tokens,
stream=True,
response_format=response_format.to_dict() if response_format else NOT_GIVEN,
response_format=NOT_GIVEN,
)
)
while not req_task.done():
@@ -447,7 +454,7 @@ class OpenaiClient(BaseClient):
temperature=temperature,
max_tokens=max_tokens,
stream=False,
response_format=response_format.to_dict() if response_format else NOT_GIVEN,
response_format=NOT_GIVEN,
)
)
while not req_task.done():
@@ -514,9 +521,9 @@ class OpenaiClient(BaseClient):
response.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=raw_response.usage.prompt_tokens,
completion_tokens=raw_response.usage.completion_tokens,
total_tokens=raw_response.usage.total_tokens,
prompt_tokens=raw_response.usage.prompt_tokens or 0,
completion_tokens=raw_response.usage.completion_tokens or 0, # type: ignore
total_tokens=raw_response.usage.total_tokens or 0,
)
return response