From 82b5230df12e84fd56505e68747b5a37972ee60a Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Thu, 31 Jul 2025 00:49:59 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A7=A3=E5=86=B3openai=5Fclient=E7=9A=84lint?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm_models/model_client/__init__.py | 8 + src/llm_models/model_client/gemini_client.py | 195 ++++--------------- src/llm_models/model_client/openai_client.py | 89 +++++---- 3 files changed, 97 insertions(+), 195 deletions(-) diff --git a/src/llm_models/model_client/__init__.py b/src/llm_models/model_client/__init__.py index e69de29bb..80f7e115e 100644 --- a/src/llm_models/model_client/__init__.py +++ b/src/llm_models/model_client/__init__.py @@ -0,0 +1,8 @@ +from src.config.config import model_config + +used_client_types = {provider.client_type for provider in model_config.api_providers} + +if "openai" in used_client_types: + from . import openai_client # noqa: F401 +if "gemini" in used_client_types: + from . import gemini_client # noqa: F401 diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index a2c715a21..af144dde2 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -1,7 +1,7 @@ import asyncio import io from collections.abc import Iterable -from typing import Callable, Iterator, TypeVar, AsyncIterator +from typing import Callable, Iterator, TypeVar, AsyncIterator, Optional, Coroutine, Any from google import genai from google.genai import types @@ -14,11 +14,9 @@ from google.genai.errors import ( FunctionInvocationError, ) -from .base_client import APIResponse, UsageRecord from src.config.api_ada_configs import ModelInfo, APIProvider -from . import BaseClient -from src.common.logger import get_logger +from .base_client import APIResponse, UsageRecord, BaseClient from ..exceptions import ( RespParseException, NetworkConnectionError, @@ -29,7 +27,6 @@ from ..payload_content.message import Message, RoleType from ..payload_content.resp_format import RespFormat, RespFormatType from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall -logger = get_logger("Gemini客户端") T = TypeVar("T") @@ -63,11 +60,7 @@ def _convert_messages( content = [] for item in message.content: if isinstance(item, tuple): - content.append( - types.Part.from_bytes( - data=item[1], mime_type=f"image/{item[0].lower()}" - ) - ) + content.append(types.Part.from_bytes(data=item[1], mime_type=f"image/{item[0].lower()}")) elif isinstance(item, str): content.append(types.Part.from_text(item)) else: @@ -122,20 +115,15 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar :param tool_option: 工具选项对象 :return: 转换后的Gemini工具选项对象 """ - ret = { + ret: dict[str, Any] = { "name": tool_option.name, "description": tool_option.description, } if tool_option.params: ret["parameters"] = { "type": "object", - "properties": { - param.name: _convert_tool_param(param) - for param in tool_option.params - }, - "required": [ - param.name for param in tool_option.params if param.required - ], + "properties": {param.name: _convert_tool_param(param) for param in tool_option.params}, + "required": [param.name for param in tool_option.params if param.required], } ret1 = types.FunctionDeclaration(**ret) return ret1 @@ -157,12 +145,8 @@ def _process_delta( if delta.function_calls: # 为什么不用hasattr呢,是因为这个属性一定有,即使是个空的 for call in delta.function_calls: try: - if not isinstance( - call.args, dict - ): # gemini返回的function call参数就是dict格式的了 - raise RespParseException( - delta, "响应解析失败,工具调用参数无法解析为字典类型" - ) + if not isinstance(call.args, dict): # gemini返回的function call参数就是dict格式的了 + raise RespParseException(delta, "响应解析失败,工具调用参数无法解析为字典类型") tool_calls_buffer.append( ( call.id, @@ -178,6 +162,7 @@ def _build_stream_api_resp( _fc_delta_buffer: io.StringIO, _tool_calls_buffer: list[tuple[str, str, dict]], ) -> APIResponse: + # sourcery skip: simplify-len-comparison, use-assigned-variable resp = APIResponse() if _fc_delta_buffer.tell() > 0: @@ -193,8 +178,7 @@ def _build_stream_api_resp( if not isinstance(arguments, dict): raise RespParseException( None, - "响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n" - f"{arguments_buffer}", + f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{arguments_buffer}", ) else: arguments = None @@ -218,16 +202,14 @@ async def _to_async_iterable(iterable: Iterable[T]) -> AsyncIterator[T]: async def _default_stream_response_handler( resp_stream: Iterator[GenerateContentResponse], interrupt_flag: asyncio.Event | None, -) -> tuple[APIResponse, tuple[int, int, int]]: +) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: """ 流式响应处理函数 - 处理Gemini API的流式响应 :param resp_stream: 流式响应对象,是一个神秘的iterator,我完全不知道这个玩意能不能跑,不过遍历一遍之后它就空了,如果跑不了一点的话可以考虑改成别的东西 :return: APIResponse对象 """ _fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容 - _tool_calls_buffer: list[ - tuple[str, str, dict] - ] = [] # 工具调用缓冲区,用于存储接收到的工具调用 + _tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用 _usage_record = None # 使用情况记录 def _insure_buffer_closed(): @@ -250,8 +232,7 @@ async def _default_stream_response_handler( # 如果有使用情况,则将其存储在APIResponse对象中 _usage_record = ( chunk.usage_metadata.prompt_token_count, - chunk.usage_metadata.candidates_token_count - + chunk.usage_metadata.thoughts_token_count, + chunk.usage_metadata.candidates_token_count + chunk.usage_metadata.thoughts_token_count, chunk.usage_metadata.total_token_count, ) try: @@ -267,7 +248,7 @@ async def _default_stream_response_handler( def _default_normal_response_parser( resp: GenerateContentResponse, -) -> tuple[APIResponse, tuple[int, int, int]]: +) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: """ 解析对话补全响应 - 将Gemini API响应解析为APIResponse对象 :param resp: 响应对象 @@ -286,20 +267,15 @@ def _default_normal_response_parser( for call in resp.function_calls: try: if not isinstance(call.args, dict): - raise RespParseException( - resp, "响应解析失败,工具调用参数无法解析为字典类型" - ) + raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型") api_response.tool_calls.append(ToolCall(call.id, call.name, call.args)) except Exception as e: - raise RespParseException( - resp, "响应解析失败,无法解析工具调用参数" - ) from e + raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e if resp.usage_metadata: _usage_record = ( resp.usage_metadata.prompt_token_count, - resp.usage_metadata.candidates_token_count - + resp.usage_metadata.thoughts_token_count, + resp.usage_metadata.candidates_token_count + resp.usage_metadata.thoughts_token_count, resp.usage_metadata.total_token_count, ) else: @@ -311,55 +287,13 @@ def _default_normal_response_parser( class GeminiClient(BaseClient): + client: genai.Client + def __init__(self, api_provider: APIProvider): super().__init__(api_provider) - # 不再在初始化时创建固定的client,而是在请求时动态创建 - self._clients_cache = {} # API Key -> genai.Client 的缓存 - - def _get_client(self, api_key: str = None) -> genai.Client: - """获取或创建对应API Key的客户端""" - if api_key is None: - api_key = self.api_provider.get_current_api_key() - - if not api_key: - raise ValueError(f"API Provider '{self.api_provider.name}' 没有可用的API Key") - - # 使用缓存避免重复创建客户端 - if api_key not in self._clients_cache: - self._clients_cache[api_key] = genai.Client(api_key=api_key) - - return self._clients_cache[api_key] - - async def _execute_with_fallback(self, func, *args, **kwargs): - """执行请求并在失败时切换API Key""" - current_api_key = self.api_provider.get_current_api_key() - max_attempts = len(self.api_provider.api_keys) if self.api_provider.api_keys else 1 - - for attempt in range(max_attempts): - try: - client = self._get_client(current_api_key) - result = await func(client, *args, **kwargs) - # 成功时重置失败计数 - self.api_provider.reset_key_failures(current_api_key) - return result - - except (ClientError, ServerError) as e: - # 记录失败并尝试下一个API Key - logger.warning(f"API Key失败 (尝试 {attempt + 1}/{max_attempts}): {str(e)}") - - if attempt < max_attempts - 1: # 还有重试机会 - next_api_key = self.api_provider.mark_key_failed(current_api_key) - if next_api_key and next_api_key != current_api_key: - current_api_key = next_api_key - logger.info(f"切换到下一个API Key: {current_api_key[:8]}***{current_api_key[-4:]}") - continue - - # 所有API Key都失败了,重新抛出异常 - raise RespNotOkException(e.status_code, e.message) from e - - except Exception as e: - # 其他异常直接抛出 - raise e + self.client = genai.Client( + api_key=api_provider.api_key, + ) # 这里和openai不一样,gemini会自己决定自己是否需要retry async def get_response( self, @@ -370,12 +304,15 @@ class GeminiClient(BaseClient): temperature: float = 0.7, thinking_budget: int = 0, response_format: RespFormat | None = None, - stream_response_handler: Callable[ - [Iterator[GenerateContentResponse], asyncio.Event | None], APIResponse - ] - | None = None, - async_response_parser: Callable[[GenerateContentResponse], APIResponse] - | None = None, + stream_response_handler: Optional[ + Callable[ + [Iterator[GenerateContentResponse], asyncio.Event | None], + Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]], + ] + ] = None, + async_response_parser: Optional[ + Callable[[GenerateContentResponse], tuple[APIResponse, Optional[tuple[int, int, int]]]] + ] = None, interrupt_flag: asyncio.Event | None = None, ) -> APIResponse: """ @@ -392,39 +329,6 @@ class GeminiClient(BaseClient): :param interrupt_flag: 中断信号量(可选,默认为None) :return: (响应文本, 推理文本, 工具调用, 其他数据) """ - return await self._execute_with_fallback( - self._get_response_internal, - model_info, - message_list, - tool_options, - max_tokens, - temperature, - thinking_budget, - response_format, - stream_response_handler, - async_response_parser, - interrupt_flag, - ) - - async def _get_response_internal( - self, - client: genai.Client, - model_info: ModelInfo, - message_list: list[Message], - tool_options: list[ToolOption] | None = None, - max_tokens: int = 1024, - temperature: float = 0.7, - thinking_budget: int = 0, - response_format: RespFormat | None = None, - stream_response_handler: Callable[ - [Iterator[GenerateContentResponse], asyncio.Event | None], APIResponse - ] - | None = None, - async_response_parser: Callable[[GenerateContentResponse], APIResponse] - | None = None, - interrupt_flag: asyncio.Event | None = None, - ) -> APIResponse: - """内部方法:执行实际的API调用""" if stream_response_handler is None: stream_response_handler = _default_stream_response_handler @@ -462,7 +366,7 @@ class GeminiClient(BaseClient): try: if model_info.force_stream_mode: req_task = asyncio.create_task( - client.aio.models.generate_content_stream( + self.client.aio.models.generate_content_stream( model=model_info.model_identifier, contents=messages[0], config=generation_config, @@ -474,12 +378,10 @@ class GeminiClient(BaseClient): req_task.cancel() raise ReqAbortException("请求被外部信号中断") await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态 - resp, usage_record = await stream_response_handler( - req_task.result(), interrupt_flag - ) + resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag) else: req_task = asyncio.create_task( - client.aio.models.generate_content( + self.client.aio.models.generate_content( model=model_info.model_identifier, contents=messages[0], config=generation_config, @@ -495,13 +397,13 @@ class GeminiClient(BaseClient): resp, usage_record = async_response_parser(req_task.result()) except (ClientError, ServerError) as e: # 重封装ClientError和ServerError为RespNotOkException - raise RespNotOkException(e.status_code, e.message) from e + raise RespNotOkException(e.status_code, e.message) from None except ( UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError, ) as e: - raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from e + raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None except Exception as e: raise NetworkConnectionError() from e @@ -527,30 +429,15 @@ class GeminiClient(BaseClient): :param embedding_input: 嵌入输入文本 :return: 嵌入响应 """ - return await self._execute_with_fallback( - self._get_embedding_internal, - model_info, - embedding_input, - ) - - async def _get_embedding_internal( - self, - client: genai.Client, - model_info: ModelInfo, - embedding_input: str, - ) -> APIResponse: - """内部方法:执行实际的嵌入API调用""" try: - raw_response: types.EmbedContentResponse = ( - await client.aio.models.embed_content( - model=model_info.model_identifier, - contents=embedding_input, - config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"), - ) + raw_response: types.EmbedContentResponse = await self.client.aio.models.embed_content( + model=model_info.model_identifier, + contents=embedding_input, + config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"), ) except (ClientError, ServerError) as e: # 重封装ClientError和ServerError为RespNotOkException - raise RespNotOkException(e.status_code) from e + raise RespNotOkException(e.status_code) from None except Exception as e: raise NetworkConnectionError() from e diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 109fe7593..8fc234297 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -3,7 +3,8 @@ import io import json import re from collections.abc import Iterable -from typing import Callable, Any +from typing import Callable, Any, Coroutine, Optional +from json_repair import repair_json from openai import ( AsyncOpenAI, @@ -20,11 +21,9 @@ from openai.types.chat import ( ) from openai.types.chat.chat_completion_chunk import ChoiceDelta -from .base_client import APIResponse, UsageRecord from src.config.api_ada_configs import ModelInfo, APIProvider -from .base_client import BaseClient, client_registry from src.common.logger import get_logger - +from .base_client import APIResponse, UsageRecord, BaseClient, client_registry from ..exceptions import ( RespParseException, NetworkConnectionError, @@ -82,7 +81,7 @@ def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessagePara raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象") ret["tool_call_id"] = message.tool_call_id - return ret + return ret # type: ignore return [_convert_message_item(message) for message in messages] @@ -143,10 +142,10 @@ def _process_delta( # 接收content if has_rc_attr_flag: # 有独立的推理内容块,则无需考虑content内容的判读 - if hasattr(delta, "reasoning_content") and delta.reasoning_content: + if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore # 如果有推理内容,则将其写入推理内容缓冲区 - assert isinstance(delta.reasoning_content, str) - rc_delta_buffer.write(delta.reasoning_content) + assert isinstance(delta.reasoning_content, str) # type: ignore + rc_delta_buffer.write(delta.reasoning_content) # type: ignore elif delta.content: # 如果有正式内容,则将其写入正式内容缓冲区 fc_delta_buffer.write(delta.content) @@ -173,15 +172,18 @@ def _process_delta( if tool_call_delta.index >= len(tool_calls_buffer): # 调用索引号大于等于缓冲区长度,说明是新的工具调用 - tool_calls_buffer.append( - ( - tool_call_delta.id, - tool_call_delta.function.name, - io.StringIO(), + if tool_call_delta.id and tool_call_delta.function and tool_call_delta.function.name: + tool_calls_buffer.append( + ( + tool_call_delta.id, + tool_call_delta.function.name, + io.StringIO(), + ) ) - ) + else: + logger.warning("工具调用索引号大于等于缓冲区长度,但缺少ID或函数信息。") - if tool_call_delta.function.arguments: + if tool_call_delta.function and tool_call_delta.function.arguments: # 如果有工具调用参数,则添加到对应的工具调用的参数串缓冲区中 tool_calls_buffer[tool_call_delta.index][2].write(tool_call_delta.function.arguments) @@ -212,7 +214,7 @@ def _build_stream_api_resp( raw_arg_data = arguments_buffer.getvalue() arguments_buffer.close() try: - arguments = json.loads(raw_arg_data) + arguments = json.loads(repair_json(raw_arg_data)) if not isinstance(arguments, dict): raise RespParseException( None, @@ -235,7 +237,7 @@ def _build_stream_api_resp( async def _default_stream_response_handler( resp_stream: AsyncStream[ChatCompletionChunk], interrupt_flag: asyncio.Event | None, -) -> tuple[APIResponse, tuple[int, int, int]]: +) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: """ 流式响应处理函数 - 处理OpenAI API的流式响应 :param resp_stream: 流式响应对象 @@ -309,7 +311,7 @@ pattern = re.compile( def _default_normal_response_parser( resp: ChatCompletion, -) -> tuple[APIResponse, tuple[int, int, int]]: +) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: """ 解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象 :param resp: 响应对象 @@ -343,7 +345,7 @@ def _default_normal_response_parser( api_response.tool_calls = [] for call in message_part.tool_calls: try: - arguments = json.loads(call.function.arguments) + arguments = json.loads(repair_json(call.function.arguments)) if not isinstance(arguments, dict): raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型") api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments)) @@ -384,26 +386,31 @@ class OpenaiClient(BaseClient): max_tokens: int = 1024, temperature: float = 0.7, response_format: RespFormat | None = None, - stream_response_handler: Callable[ - [AsyncStream[ChatCompletionChunk], asyncio.Event | None], - tuple[APIResponse, tuple[int, int, int]], - ] - | None = None, - async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]] | None = None, + stream_response_handler: Optional[ + Callable[ + [AsyncStream[ChatCompletionChunk], asyncio.Event | None], + Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]], + ] + ] = None, + async_response_parser: Optional[ + Callable[[ChatCompletion], tuple[APIResponse, Optional[tuple[int, int, int]]]] + ] = None, interrupt_flag: asyncio.Event | None = None, ) -> APIResponse: """ 获取对话响应 - :param model_info: 模型信息 - :param message_list: 对话体 - :param tool_options: 工具选项(可选,默认为None) - :param max_tokens: 最大token数(可选,默认为1024) - :param temperature: 温度(可选,默认为0.7) - :param response_format: 响应格式(可选,默认为 NotGiven ) - :param stream_response_handler: 流式响应处理函数(可选,默认为default_stream_response_handler) - :param async_response_parser: 响应解析函数(可选,默认为default_response_parser) - :param interrupt_flag: 中断信号量(可选,默认为None) - :return: (响应文本, 推理文本, 工具调用, 其他数据) + Args: + model_info: 模型信息 + message_list: 对话体 + tool_options: 工具选项(可选,默认为None) + max_tokens: 最大token数(可选,默认为1024) + temperature: 温度(可选,默认为0.7) + response_format: 响应格式(可选,默认为 NotGiven ) + stream_response_handler: 流式响应处理函数(可选,默认为default_stream_response_handler) + async_response_parser: 响应解析函数(可选,默认为default_response_parser) + interrupt_flag: 中断信号量(可选,默认为None) + Returns: + (响应文本, 推理文本, 工具调用, 其他数据) """ if stream_response_handler is None: stream_response_handler = _default_stream_response_handler @@ -414,7 +421,7 @@ class OpenaiClient(BaseClient): # 将messages构造为OpenAI API所需的格式 messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list) # 将tool_options转换为OpenAI API所需的格式 - tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN + tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore try: if model_info.force_stream_mode: @@ -426,7 +433,7 @@ class OpenaiClient(BaseClient): temperature=temperature, max_tokens=max_tokens, stream=True, - response_format=response_format.to_dict() if response_format else NOT_GIVEN, + response_format=NOT_GIVEN, ) ) while not req_task.done(): @@ -447,7 +454,7 @@ class OpenaiClient(BaseClient): temperature=temperature, max_tokens=max_tokens, stream=False, - response_format=response_format.to_dict() if response_format else NOT_GIVEN, + response_format=NOT_GIVEN, ) ) while not req_task.done(): @@ -514,9 +521,9 @@ class OpenaiClient(BaseClient): response.usage = UsageRecord( model_name=model_info.name, provider_name=model_info.api_provider, - prompt_tokens=raw_response.usage.prompt_tokens, - completion_tokens=raw_response.usage.completion_tokens, - total_tokens=raw_response.usage.total_tokens, + prompt_tokens=raw_response.usage.prompt_tokens or 0, + completion_tokens=raw_response.usage.completion_tokens or 0, # type: ignore + total_tokens=raw_response.usage.total_tokens or 0, ) return response