diff --git a/src/llm_models/model_client/aiohttp_gemini_client.py b/src/llm_models/model_client/aiohttp_gemini_client.py index 7b997b680..4085903ac 100644 --- a/src/llm_models/model_client/aiohttp_gemini_client.py +++ b/src/llm_models/model_client/aiohttp_gemini_client.py @@ -120,9 +120,9 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]: 转换工具选项格式 - 将工具选项转换为Gemini REST API所需的格式 """ - def _convert_tool_param(param: ToolParam) -> dict: + def _convert_tool_param(param: ToolParam) -> dict[str, Any]: """转换工具参数""" - result = { + result: dict[str, Any] = { "type": param.param_type.value, "description": param.description, } @@ -130,9 +130,9 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]: result["enum"] = param.enum_values return result - def _convert_tool_option_item(tool_option: ToolOption) -> dict: + def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]: """转换单个工具选项""" - function_declaration = { + function_declaration: dict[str, Any] = { "name": tool_option.name, "description": tool_option.description, } @@ -341,7 +341,6 @@ class AiohttpGeminiClient(BaseClient): super().__init__(api_provider) self.base_url = "https://generativelanguage.googleapis.com/v1beta" self.session: aiohttp.ClientSession | None = None - self.api_key = api_provider.api_key # 如果提供了自定义base_url,使用它 if api_provider.base_url: @@ -388,11 +387,11 @@ class AiohttpGeminiClient(BaseClient): self, method: str, endpoint: str, data: dict | None = None, stream: bool = False ) -> aiohttp.ClientResponse: """发起HTTP请求(每次都用 with aiohttp.ClientSession() as session)""" - url = f"{self.base_url}/{endpoint}?key={self.api_key}" - timeout = aiohttp.ClientTimeout(total=300) + api_key = self.api_provider.get_api_key() + url = f"{self.base_url}/{endpoint}?key={api_key}" try: async with aiohttp.ClientSession( - timeout=timeout, + timeout=aiohttp.ClientTimeout(total=300), headers={"Content-Type": "application/json", "User-Agent": "MMC-AioHTTP-Gemini-Client/1.0"}, ) as session: if method.upper() == "POST": @@ -500,7 +499,7 @@ class AiohttpGeminiClient(BaseClient): # 直接重抛项目定义的异常 raise except Exception as e: - logger.debug(e) + logger.debug(str(e)) # 其他异常转换为网络连接错误 raise NetworkConnectionError() from e diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 0ef79a89b..c8f8b96e0 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -376,11 +376,14 @@ def _default_normal_response_parser( class OpenaiClient(BaseClient): def __init__(self, api_provider: APIProvider): super().__init__(api_provider) - self.client: AsyncOpenAI = AsyncOpenAI( - base_url=api_provider.base_url, - api_key=api_provider.api_key, + + def _create_client(self) -> AsyncOpenAI: + """动态创建OpenAI客户端""" + return AsyncOpenAI( + base_url=self.api_provider.base_url, + api_key=self.api_provider.get_api_key(), max_retries=0, - timeout=api_provider.timeout, + timeout=self.api_provider.timeout, ) async def get_response( @@ -429,10 +432,11 @@ class OpenaiClient(BaseClient): # 将tool_options转换为OpenAI API所需的格式 tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore + client = self._create_client() try: if model_info.force_stream_mode: req_task = asyncio.create_task( - self.client.chat.completions.create( + client.chat.completions.create( model=model_info.model_identifier, messages=messages, tools=tools, @@ -455,7 +459,7 @@ class OpenaiClient(BaseClient): # 发送请求并获取响应 # start_time = time.time() req_task = asyncio.create_task( - self.client.chat.completions.create( + client.chat.completions.create( model=model_info.model_identifier, messages=messages, tools=tools, @@ -506,8 +510,9 @@ class OpenaiClient(BaseClient): :param embedding_input: 嵌入输入文本 :return: 嵌入响应 """ + client = self._create_client() try: - raw_response = await self.client.embeddings.create( + raw_response = await client.embeddings.create( model=model_info.model_identifier, input=embedding_input, extra_body=extra_params, @@ -564,8 +569,9 @@ class OpenaiClient(BaseClient): :extra_params: 附加的请求参数 :return: 音频转录响应 """ + client = self._create_client() try: - raw_response = await self.client.audio.transcriptions.create( + raw_response = await client.audio.transcriptions.create( model=model_info.model_identifier, file=("audio.wav", io.BytesIO(base64.b64decode(audio_base64))), extra_body=extra_params,