增加缓存层提高性能

This commit is contained in:
UnCLAS-Prommer
2025-08-09 11:40:29 +08:00
parent 2ea4c75e9c
commit d65f90ee49
2 changed files with 14 additions and 9 deletions

View File

@@ -140,6 +140,9 @@ class BaseClient(ABC):
class ClientRegistry: class ClientRegistry:
def __init__(self) -> None: def __init__(self) -> None:
self.client_registry: dict[str, type[BaseClient]] = {} self.client_registry: dict[str, type[BaseClient]] = {}
"""APIProvider.type -> BaseClient的映射表"""
self.client_instance_cache: dict[str, BaseClient] = {}
"""APIProvider.name -> BaseClient的映射表"""
def register_client_class(self, client_type: str): def register_client_class(self, client_type: str):
""" """
@@ -156,17 +159,20 @@ class ClientRegistry:
return decorator return decorator
def get_client_class(self, client_type: str) -> type[BaseClient]: def get_client_class_instance(self, api_provider: APIProvider) -> BaseClient:
""" """
获取注册的API客户端 获取注册的API客户端实例
Args: Args:
client_type: 客户端类型 api_provider: APIProvider实例
Returns: Returns:
type[BaseClient]: 注册的API客户端 BaseClient: 注册的API客户端实例
""" """
if client_type not in self.client_registry: if api_provider.name not in self.client_instance_cache:
raise KeyError(f"'{client_type}' 类型的 Client 未注册") if client_class := self.client_registry.get(api_provider.client_type):
return self.client_registry[client_type] self.client_instance_cache[api_provider.name] = client_class(api_provider)
else:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
return self.client_instance_cache[api_provider.name]
client_registry = ClientRegistry() client_registry = ClientRegistry()

View File

@@ -1,5 +1,4 @@
import re import re
import copy
import asyncio import asyncio
import time import time
@@ -249,7 +248,7 @@ class LLMRequest:
) )
model_info = model_config.get_model_info(least_used_model_name) model_info = model_config.get_model_info(least_used_model_name)
api_provider = model_config.get_provider(model_info.api_provider) api_provider = model_config.get_provider(model_info.api_provider)
client = client_registry.get_client_class(api_provider.client_type)(copy.deepcopy(api_provider)) client = client_registry.get_client_class_instance(api_provider)
logger.debug(f"选择请求模型: {model_info.name}") logger.debug(f"选择请求模型: {model_info.name}")
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用 self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用