feat(model): 优化客户端缓存和事件循环检测机制

- 在 ClientRegistry 中添加事件循环变化检测,自动处理缓存失效
- 为 OpenaiClient 实现全局 AsyncOpenAI 客户端缓存,提升连接池复用效率
- 将 utils_model 中的同步方法改为异步,确保与事件循环兼容
- 移除 embedding 请求的特殊处理,现在所有请求都能享受缓存优势
- 添加缓存统计功能,便于监控和调试
This commit is contained in:
Windpicker-owo
2025-10-06 01:05:50 +08:00
parent a72012bf78
commit f59a31865c
4 changed files with 156 additions and 44 deletions

View File

@@ -4,12 +4,15 @@ from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from src.common.logger import get_logger
from src.config.api_ada_configs import APIProvider, ModelInfo
from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolCall, ToolOption
logger = get_logger("model_client.base_client")
@dataclass
class UsageRecord:
@@ -144,6 +147,10 @@ class ClientRegistry:
"""APIProvider.type -> BaseClient的映射表"""
self.client_instance_cache: dict[str, BaseClient] = {}
"""APIProvider.name -> BaseClient的映射表"""
self._event_loop_cache: dict[str, int | None] = {}
"""APIProvider.name -> event loop id的映射表用于检测事件循环变化"""
self._loop_change_count: int = 0
"""事件循环变化导致缓存失效的次数"""
def register_client_class(self, client_type: str):
"""
@@ -160,29 +167,91 @@ class ClientRegistry:
return decorator
def _get_current_loop_id(self) -> int | None:
"""
获取当前事件循环的ID
Returns:
int | None: 事件循环ID如果没有运行中的循环则返回None
"""
try:
loop = asyncio.get_running_loop()
return id(loop)
except RuntimeError:
# 没有运行中的事件循环
return None
def _is_event_loop_changed(self, provider_name: str) -> bool:
"""
检查事件循环是否发生变化
Args:
provider_name: Provider名称
Returns:
bool: 事件循环是否变化
"""
current_loop_id = self._get_current_loop_id()
# 如果没有缓存的循环ID说明是首次创建
if provider_name not in self._event_loop_cache:
return False
# 比较当前循环ID与缓存的循环ID
cached_loop_id = self._event_loop_cache[provider_name]
return current_loop_id != cached_loop_id
def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient:
"""
获取注册的API客户端实例
获取注册的API客户端实例(带事件循环检测)
Args:
api_provider: APIProvider实例
force_new: 是否强制创建新实例(用于解决事件循环问题
force_new: 是否强制创建新实例(通常不需要,会自动检测事件循环变化
Returns:
BaseClient: 注册的API客户端实例
"""
provider_name = api_provider.name
# 如果强制创建新实例,直接创建不使用缓存
if force_new:
if client_class := self.client_registry.get(api_provider.client_type):
return client_class(api_provider)
new_instance = client_class(api_provider)
# 更新事件循环缓存
self._event_loop_cache[provider_name] = self._get_current_loop_id()
return new_instance
else:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
# 检查事件循环是否变化
if self._is_event_loop_changed(provider_name):
# 事件循环已变化,需要重新创建实例
logger.debug(f"检测到事件循环变化,为 {provider_name} 重新创建客户端实例")
self._loop_change_count += 1
# 移除旧实例
if provider_name in self.client_instance_cache:
del self.client_instance_cache[provider_name]
# 正常的缓存逻辑
if api_provider.name not in self.client_instance_cache:
if provider_name not in self.client_instance_cache:
if client_class := self.client_registry.get(api_provider.client_type):
self.client_instance_cache[api_provider.name] = client_class(api_provider)
self.client_instance_cache[provider_name] = client_class(api_provider)
# 缓存当前事件循环ID
self._event_loop_cache[provider_name] = self._get_current_loop_id()
else:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
return self.client_instance_cache[api_provider.name]
return self.client_instance_cache[provider_name]
def get_cache_stats(self) -> dict:
"""
获取缓存统计信息
Returns:
dict: 包含缓存统计的字典
"""
return {
"cached_instances": len(self.client_instance_cache),
"tracked_loops": len(self._event_loop_cache),
"loop_change_count": self._loop_change_count,
"cached_providers": list(self.client_instance_cache.keys()),
}
client_registry = ClientRegistry()

View File

@@ -383,17 +383,61 @@ def _default_normal_response_parser(
@client_registry.register_client_class("openai")
class OpenaiClient(BaseClient):
# 类级别的全局缓存:所有 OpenaiClient 实例共享
_global_client_cache: dict[int, AsyncOpenAI] = {}
"""全局 AsyncOpenAI 客户端缓存config_hash -> AsyncOpenAI 实例"""
def __init__(self, api_provider: APIProvider):
super().__init__(api_provider)
self._config_hash = self._calculate_config_hash()
"""当前 provider 的配置哈希值"""
def _calculate_config_hash(self) -> int:
"""计算当前配置的哈希值"""
config_tuple = (
self.api_provider.base_url,
self.api_provider.get_api_key(),
self.api_provider.timeout,
)
return hash(config_tuple)
def _create_client(self) -> AsyncOpenAI:
"""动态创建OpenAI客户端"""
return AsyncOpenAI(
"""
获取或创建 OpenAI 客户端实例(全局缓存)
多个 OpenaiClient 实例如果配置相同base_url + api_key + timeout
将共享同一个 AsyncOpenAI 客户端实例,最大化连接池复用。
"""
# 检查全局缓存
if self._config_hash in self._global_client_cache:
return self._global_client_cache[self._config_hash]
# 创建新的 AsyncOpenAI 实例
logger.debug(
f"创建新的 AsyncOpenAI 客户端实例 "
f"(base_url={self.api_provider.base_url}, "
f"config_hash={self._config_hash})"
)
client = AsyncOpenAI(
base_url=self.api_provider.base_url,
api_key=self.api_provider.get_api_key(),
max_retries=0,
timeout=self.api_provider.timeout,
)
# 存入全局缓存
self._global_client_cache[self._config_hash] = client
return client
@classmethod
def get_cache_stats(cls) -> dict:
"""获取全局缓存统计信息"""
return {
"cached_openai_clients": len(cls._global_client_cache),
"config_hashes": list(cls._global_client_cache.keys()),
}
async def get_response(
self,