feat(model): 优化客户端缓存和事件循环检测机制
- 在 ClientRegistry 中添加事件循环变化检测,自动处理缓存失效 - 为 OpenaiClient 实现全局 AsyncOpenAI 客户端缓存,提升连接池复用效率 - 将 utils_model 中的同步方法改为异步,确保与事件循环兼容 - 移除 embedding 请求的特殊处理,现在所有请求都能享受缓存优势 - 添加缓存统计功能,便于监控和调试
This commit is contained in:
@@ -4,12 +4,15 @@ from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo
|
||||
|
||||
from ..payload_content.message import Message
|
||||
from ..payload_content.resp_format import RespFormat
|
||||
from ..payload_content.tool_option import ToolCall, ToolOption
|
||||
|
||||
logger = get_logger("model_client.base_client")
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageRecord:
|
||||
@@ -144,6 +147,10 @@ class ClientRegistry:
|
||||
"""APIProvider.type -> BaseClient的映射表"""
|
||||
self.client_instance_cache: dict[str, BaseClient] = {}
|
||||
"""APIProvider.name -> BaseClient的映射表"""
|
||||
self._event_loop_cache: dict[str, int | None] = {}
|
||||
"""APIProvider.name -> event loop id的映射表,用于检测事件循环变化"""
|
||||
self._loop_change_count: int = 0
|
||||
"""事件循环变化导致缓存失效的次数"""
|
||||
|
||||
def register_client_class(self, client_type: str):
|
||||
"""
|
||||
@@ -160,29 +167,91 @@ class ClientRegistry:
|
||||
|
||||
return decorator
|
||||
|
||||
def _get_current_loop_id(self) -> int | None:
|
||||
"""
|
||||
获取当前事件循环的ID
|
||||
Returns:
|
||||
int | None: 事件循环ID,如果没有运行中的循环则返回None
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
return id(loop)
|
||||
except RuntimeError:
|
||||
# 没有运行中的事件循环
|
||||
return None
|
||||
|
||||
def _is_event_loop_changed(self, provider_name: str) -> bool:
|
||||
"""
|
||||
检查事件循环是否发生变化
|
||||
Args:
|
||||
provider_name: Provider名称
|
||||
Returns:
|
||||
bool: 事件循环是否变化
|
||||
"""
|
||||
current_loop_id = self._get_current_loop_id()
|
||||
|
||||
# 如果没有缓存的循环ID,说明是首次创建
|
||||
if provider_name not in self._event_loop_cache:
|
||||
return False
|
||||
|
||||
# 比较当前循环ID与缓存的循环ID
|
||||
cached_loop_id = self._event_loop_cache[provider_name]
|
||||
return current_loop_id != cached_loop_id
|
||||
|
||||
def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient:
|
||||
"""
|
||||
获取注册的API客户端实例
|
||||
获取注册的API客户端实例(带事件循环检测)
|
||||
Args:
|
||||
api_provider: APIProvider实例
|
||||
force_new: 是否强制创建新实例(用于解决事件循环问题)
|
||||
force_new: 是否强制创建新实例(通常不需要,会自动检测事件循环变化)
|
||||
Returns:
|
||||
BaseClient: 注册的API客户端实例
|
||||
"""
|
||||
provider_name = api_provider.name
|
||||
|
||||
# 如果强制创建新实例,直接创建不使用缓存
|
||||
if force_new:
|
||||
if client_class := self.client_registry.get(api_provider.client_type):
|
||||
return client_class(api_provider)
|
||||
new_instance = client_class(api_provider)
|
||||
# 更新事件循环缓存
|
||||
self._event_loop_cache[provider_name] = self._get_current_loop_id()
|
||||
return new_instance
|
||||
else:
|
||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||
|
||||
# 检查事件循环是否变化
|
||||
if self._is_event_loop_changed(provider_name):
|
||||
# 事件循环已变化,需要重新创建实例
|
||||
logger.debug(f"检测到事件循环变化,为 {provider_name} 重新创建客户端实例")
|
||||
self._loop_change_count += 1
|
||||
|
||||
# 移除旧实例
|
||||
if provider_name in self.client_instance_cache:
|
||||
del self.client_instance_cache[provider_name]
|
||||
|
||||
# 正常的缓存逻辑
|
||||
if api_provider.name not in self.client_instance_cache:
|
||||
if provider_name not in self.client_instance_cache:
|
||||
if client_class := self.client_registry.get(api_provider.client_type):
|
||||
self.client_instance_cache[api_provider.name] = client_class(api_provider)
|
||||
self.client_instance_cache[provider_name] = client_class(api_provider)
|
||||
# 缓存当前事件循环ID
|
||||
self._event_loop_cache[provider_name] = self._get_current_loop_id()
|
||||
else:
|
||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||
return self.client_instance_cache[api_provider.name]
|
||||
|
||||
return self.client_instance_cache[provider_name]
|
||||
|
||||
def get_cache_stats(self) -> dict:
|
||||
"""
|
||||
获取缓存统计信息
|
||||
Returns:
|
||||
dict: 包含缓存统计的字典
|
||||
"""
|
||||
return {
|
||||
"cached_instances": len(self.client_instance_cache),
|
||||
"tracked_loops": len(self._event_loop_cache),
|
||||
"loop_change_count": self._loop_change_count,
|
||||
"cached_providers": list(self.client_instance_cache.keys()),
|
||||
}
|
||||
|
||||
|
||||
client_registry = ClientRegistry()
|
||||
|
||||
@@ -383,17 +383,61 @@ def _default_normal_response_parser(
|
||||
|
||||
@client_registry.register_client_class("openai")
|
||||
class OpenaiClient(BaseClient):
|
||||
# 类级别的全局缓存:所有 OpenaiClient 实例共享
|
||||
_global_client_cache: dict[int, AsyncOpenAI] = {}
|
||||
"""全局 AsyncOpenAI 客户端缓存:config_hash -> AsyncOpenAI 实例"""
|
||||
|
||||
def __init__(self, api_provider: APIProvider):
|
||||
super().__init__(api_provider)
|
||||
self._config_hash = self._calculate_config_hash()
|
||||
"""当前 provider 的配置哈希值"""
|
||||
|
||||
def _calculate_config_hash(self) -> int:
|
||||
"""计算当前配置的哈希值"""
|
||||
config_tuple = (
|
||||
self.api_provider.base_url,
|
||||
self.api_provider.get_api_key(),
|
||||
self.api_provider.timeout,
|
||||
)
|
||||
return hash(config_tuple)
|
||||
|
||||
def _create_client(self) -> AsyncOpenAI:
|
||||
"""动态创建OpenAI客户端"""
|
||||
return AsyncOpenAI(
|
||||
"""
|
||||
获取或创建 OpenAI 客户端实例(全局缓存)
|
||||
|
||||
多个 OpenaiClient 实例如果配置相同(base_url + api_key + timeout),
|
||||
将共享同一个 AsyncOpenAI 客户端实例,最大化连接池复用。
|
||||
"""
|
||||
# 检查全局缓存
|
||||
if self._config_hash in self._global_client_cache:
|
||||
return self._global_client_cache[self._config_hash]
|
||||
|
||||
# 创建新的 AsyncOpenAI 实例
|
||||
logger.debug(
|
||||
f"创建新的 AsyncOpenAI 客户端实例 "
|
||||
f"(base_url={self.api_provider.base_url}, "
|
||||
f"config_hash={self._config_hash})"
|
||||
)
|
||||
|
||||
client = AsyncOpenAI(
|
||||
base_url=self.api_provider.base_url,
|
||||
api_key=self.api_provider.get_api_key(),
|
||||
max_retries=0,
|
||||
timeout=self.api_provider.timeout,
|
||||
)
|
||||
|
||||
# 存入全局缓存
|
||||
self._global_client_cache[self._config_hash] = client
|
||||
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
def get_cache_stats(cls) -> dict:
|
||||
"""获取全局缓存统计信息"""
|
||||
return {
|
||||
"cached_openai_clients": len(cls._global_client_cache),
|
||||
"config_hashes": list(cls._global_client_cache.keys()),
|
||||
}
|
||||
|
||||
async def get_response(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user