From 49f376dc1ce12522f133664d625f632eb86a3778 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sat, 1 Nov 2025 00:23:25 +0800 Subject: [PATCH] =?UTF-8?q?refactor(client):=20=E4=BC=98=E5=8C=96OpenaiCli?= =?UTF-8?q?ent=E7=9A=84=E5=85=A8=E5=B1=80=E7=BC=93=E5=AD=98=EF=BC=8C?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E4=BA=8B=E4=BB=B6=E5=BE=AA=E7=8E=AF=E6=A3=80?= =?UTF-8?q?=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm_models/model_client/openai_client.py | 49 +++++++++++++++----- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index ad5acb0c0..a6f92fd17 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -376,8 +376,8 @@ def _default_normal_response_parser( @client_registry.register_client_class("openai") class OpenaiClient(BaseClient): # 类级别的全局缓存:所有 OpenaiClient 实例共享 - _global_client_cache: ClassVar[dict[int, AsyncOpenAI] ] = {} - """全局 AsyncOpenAI 客户端缓存:config_hash -> AsyncOpenAI 实例""" + _global_client_cache: ClassVar[dict[tuple[int, int | None], AsyncOpenAI]] = {} + """全局 AsyncOpenAI 客户端缓存:(config_hash, loop_id) -> AsyncOpenAI 实例""" def __init__(self, api_provider: APIProvider): super().__init__(api_provider) @@ -393,20 +393,44 @@ class OpenaiClient(BaseClient): ) return hash(config_tuple) + @staticmethod + def _get_current_loop_id() -> int | None: + """获取当前事件循环的ID""" + try: + loop = asyncio.get_running_loop() + return id(loop) + except RuntimeError: + # 没有运行中的事件循环 + return None + def _create_client(self) -> AsyncOpenAI: """ - 获取或创建 OpenAI 客户端实例(全局缓存) + 获取或创建 OpenAI 客户端实例(全局缓存,支持事件循环检测) - 多个 OpenaiClient 实例如果配置相同(base_url + api_key + timeout), + 多个 OpenaiClient 实例如果配置相同(base_url + api_key + timeout)且在同一事件循环中, 将共享同一个 AsyncOpenAI 客户端实例,最大化连接池复用。 + 当事件循环变化时,会自动创建新的客户端实例。 """ - # 检查全局缓存 - if self._config_hash in self._global_client_cache: - return self._global_client_cache[self._config_hash] + # 获取当前事件循环ID + current_loop_id = self._get_current_loop_id() + cache_key = (self._config_hash, current_loop_id) + + # 清理其他事件循环的过期缓存 + keys_to_remove = [ + key for key in self._global_client_cache.keys() + if key[0] == self._config_hash and key[1] != current_loop_id + ] + for key in keys_to_remove: + logger.debug(f"清理过期的 AsyncOpenAI 客户端缓存 (loop_id={key[1]})") + del self._global_client_cache[key] + + # 检查当前事件循环的缓存 + if cache_key in self._global_client_cache: + return self._global_client_cache[cache_key] # 创建新的 AsyncOpenAI 实例 logger.debug( - f"创建新的 AsyncOpenAI 客户端实例 (base_url={self.api_provider.base_url}, config_hash={self._config_hash})" + f"创建新的 AsyncOpenAI 客户端实例 (base_url={self.api_provider.base_url}, config_hash={self._config_hash}, loop_id={current_loop_id})" ) client = AsyncOpenAI( @@ -416,8 +440,8 @@ class OpenaiClient(BaseClient): timeout=self.api_provider.timeout, ) - # 存入全局缓存 - self._global_client_cache[self._config_hash] = client + # 存入全局缓存(带事件循环ID) + self._global_client_cache[cache_key] = client return client @@ -426,7 +450,10 @@ class OpenaiClient(BaseClient): """获取全局缓存统计信息""" return { "cached_openai_clients": len(cls._global_client_cache), - "config_hashes": list(cls._global_client_cache.keys()), + "cache_keys": [ + {"config_hash": k[0], "loop_id": k[1]} + for k in cls._global_client_cache.keys() + ], } async def get_response(