refactor(client): 优化OpenaiClient的全局缓存,支持事件循环检测
This commit is contained in:
@@ -384,8 +384,8 @@ def _default_normal_response_parser(
|
|||||||
@client_registry.register_client_class("openai")
|
@client_registry.register_client_class("openai")
|
||||||
class OpenaiClient(BaseClient):
|
class OpenaiClient(BaseClient):
|
||||||
# 类级别的全局缓存:所有 OpenaiClient 实例共享
|
# 类级别的全局缓存:所有 OpenaiClient 实例共享
|
||||||
_global_client_cache: ClassVar[dict[int, AsyncOpenAI] ] = {}
|
_global_client_cache: ClassVar[dict[tuple[int, int | None], AsyncOpenAI]] = {}
|
||||||
"""全局 AsyncOpenAI 客户端缓存:config_hash -> AsyncOpenAI 实例"""
|
"""全局 AsyncOpenAI 客户端缓存:(config_hash, loop_id) -> AsyncOpenAI 实例"""
|
||||||
|
|
||||||
def __init__(self, api_provider: APIProvider):
|
def __init__(self, api_provider: APIProvider):
|
||||||
super().__init__(api_provider)
|
super().__init__(api_provider)
|
||||||
@@ -401,20 +401,44 @@ class OpenaiClient(BaseClient):
|
|||||||
)
|
)
|
||||||
return hash(config_tuple)
|
return hash(config_tuple)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_current_loop_id() -> int | None:
|
||||||
|
"""获取当前事件循环的ID"""
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
return id(loop)
|
||||||
|
except RuntimeError:
|
||||||
|
# 没有运行中的事件循环
|
||||||
|
return None
|
||||||
|
|
||||||
def _create_client(self) -> AsyncOpenAI:
|
def _create_client(self) -> AsyncOpenAI:
|
||||||
"""
|
"""
|
||||||
获取或创建 OpenAI 客户端实例(全局缓存)
|
获取或创建 OpenAI 客户端实例(全局缓存,支持事件循环检测)
|
||||||
|
|
||||||
多个 OpenaiClient 实例如果配置相同(base_url + api_key + timeout),
|
多个 OpenaiClient 实例如果配置相同(base_url + api_key + timeout)且在同一事件循环中,
|
||||||
将共享同一个 AsyncOpenAI 客户端实例,最大化连接池复用。
|
将共享同一个 AsyncOpenAI 客户端实例,最大化连接池复用。
|
||||||
|
当事件循环变化时,会自动创建新的客户端实例。
|
||||||
"""
|
"""
|
||||||
# 检查全局缓存
|
# 获取当前事件循环ID
|
||||||
if self._config_hash in self._global_client_cache:
|
current_loop_id = self._get_current_loop_id()
|
||||||
return self._global_client_cache[self._config_hash]
|
cache_key = (self._config_hash, current_loop_id)
|
||||||
|
|
||||||
|
# 清理其他事件循环的过期缓存
|
||||||
|
keys_to_remove = [
|
||||||
|
key for key in self._global_client_cache.keys()
|
||||||
|
if key[0] == self._config_hash and key[1] != current_loop_id
|
||||||
|
]
|
||||||
|
for key in keys_to_remove:
|
||||||
|
logger.debug(f"清理过期的 AsyncOpenAI 客户端缓存 (loop_id={key[1]})")
|
||||||
|
del self._global_client_cache[key]
|
||||||
|
|
||||||
|
# 检查当前事件循环的缓存
|
||||||
|
if cache_key in self._global_client_cache:
|
||||||
|
return self._global_client_cache[cache_key]
|
||||||
|
|
||||||
# 创建新的 AsyncOpenAI 实例
|
# 创建新的 AsyncOpenAI 实例
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"创建新的 AsyncOpenAI 客户端实例 (base_url={self.api_provider.base_url}, config_hash={self._config_hash})"
|
f"创建新的 AsyncOpenAI 客户端实例 (base_url={self.api_provider.base_url}, config_hash={self._config_hash}, loop_id={current_loop_id})"
|
||||||
)
|
)
|
||||||
|
|
||||||
client = AsyncOpenAI(
|
client = AsyncOpenAI(
|
||||||
@@ -424,8 +448,8 @@ class OpenaiClient(BaseClient):
|
|||||||
timeout=self.api_provider.timeout,
|
timeout=self.api_provider.timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 存入全局缓存
|
# 存入全局缓存(带事件循环ID)
|
||||||
self._global_client_cache[self._config_hash] = client
|
self._global_client_cache[cache_key] = client
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@@ -434,7 +458,10 @@ class OpenaiClient(BaseClient):
|
|||||||
"""获取全局缓存统计信息"""
|
"""获取全局缓存统计信息"""
|
||||||
return {
|
return {
|
||||||
"cached_openai_clients": len(cls._global_client_cache),
|
"cached_openai_clients": len(cls._global_client_cache),
|
||||||
"config_hashes": list(cls._global_client_cache.keys()),
|
"cache_keys": [
|
||||||
|
{"config_hash": k[0], "loop_id": k[1]}
|
||||||
|
for k in cls._global_client_cache.keys()
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_response(
|
async def get_response(
|
||||||
|
|||||||
Reference in New Issue
Block a user