增加缓存层提高性能
This commit is contained in:
@@ -140,6 +140,9 @@ class BaseClient(ABC):
|
|||||||
class ClientRegistry:
|
class ClientRegistry:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.client_registry: dict[str, type[BaseClient]] = {}
|
self.client_registry: dict[str, type[BaseClient]] = {}
|
||||||
|
"""APIProvider.type -> BaseClient的映射表"""
|
||||||
|
self.client_instance_cache: dict[str, BaseClient] = {}
|
||||||
|
"""APIProvider.name -> BaseClient的映射表"""
|
||||||
|
|
||||||
def register_client_class(self, client_type: str):
|
def register_client_class(self, client_type: str):
|
||||||
"""
|
"""
|
||||||
@@ -156,17 +159,20 @@ class ClientRegistry:
|
|||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def get_client_class(self, client_type: str) -> type[BaseClient]:
|
def get_client_class_instance(self, api_provider: APIProvider) -> BaseClient:
|
||||||
"""
|
"""
|
||||||
获取注册的API客户端类
|
获取注册的API客户端实例
|
||||||
Args:
|
Args:
|
||||||
client_type: 客户端类型
|
api_provider: APIProvider实例
|
||||||
Returns:
|
Returns:
|
||||||
type[BaseClient]: 注册的API客户端类
|
BaseClient: 注册的API客户端实例
|
||||||
"""
|
"""
|
||||||
if client_type not in self.client_registry:
|
if api_provider.name not in self.client_instance_cache:
|
||||||
raise KeyError(f"'{client_type}' 类型的 Client 未注册")
|
if client_class := self.client_registry.get(api_provider.client_type):
|
||||||
return self.client_registry[client_type]
|
self.client_instance_cache[api_provider.name] = client_class(api_provider)
|
||||||
|
else:
|
||||||
|
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||||
|
return self.client_instance_cache[api_provider.name]
|
||||||
|
|
||||||
|
|
||||||
client_registry = ClientRegistry()
|
client_registry = ClientRegistry()
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import re
|
import re
|
||||||
import copy
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -249,7 +248,7 @@ class LLMRequest:
|
|||||||
)
|
)
|
||||||
model_info = model_config.get_model_info(least_used_model_name)
|
model_info = model_config.get_model_info(least_used_model_name)
|
||||||
api_provider = model_config.get_provider(model_info.api_provider)
|
api_provider = model_config.get_provider(model_info.api_provider)
|
||||||
client = client_registry.get_client_class(api_provider.client_type)(copy.deepcopy(api_provider))
|
client = client_registry.get_client_class_instance(api_provider)
|
||||||
logger.debug(f"选择请求模型: {model_info.name}")
|
logger.debug(f"选择请求模型: {model_info.name}")
|
||||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||||
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用
|
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用
|
||||||
|
|||||||
Reference in New Issue
Block a user