From d65f90ee49fd761692df26784f0c7d60ebadc9a2 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sat, 9 Aug 2025 11:40:29 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E7=BC=93=E5=AD=98=E5=B1=82?= =?UTF-8?q?=E6=8F=90=E9=AB=98=E6=80=A7=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm_models/model_client/base_client.py | 20 +++++++++++++------- src/llm_models/utils_model.py | 3 +-- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index 8e8affba6..97c345466 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -140,6 +140,9 @@ class BaseClient(ABC): class ClientRegistry: def __init__(self) -> None: self.client_registry: dict[str, type[BaseClient]] = {} + """APIProvider.type -> BaseClient的映射表""" + self.client_instance_cache: dict[str, BaseClient] = {} + """APIProvider.name -> BaseClient的映射表""" def register_client_class(self, client_type: str): """ @@ -156,17 +159,20 @@ class ClientRegistry: return decorator - def get_client_class(self, client_type: str) -> type[BaseClient]: + def get_client_class_instance(self, api_provider: APIProvider) -> BaseClient: """ - 获取注册的API客户端类 + 获取注册的API客户端实例 Args: - client_type: 客户端类型 + api_provider: APIProvider实例 Returns: - type[BaseClient]: 注册的API客户端类 + BaseClient: 注册的API客户端实例 """ - if client_type not in self.client_registry: - raise KeyError(f"'{client_type}' 类型的 Client 未注册") - return self.client_registry[client_type] + if api_provider.name not in self.client_instance_cache: + if client_class := self.client_registry.get(api_provider.client_type): + self.client_instance_cache[api_provider.name] = client_class(api_provider) + else: + raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册") + return self.client_instance_cache[api_provider.name] client_registry = ClientRegistry() diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index b9986afce..8fd6ce7a9 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1,5 +1,4 @@ import re -import copy import asyncio import time @@ -249,7 +248,7 @@ class LLMRequest: ) model_info = model_config.get_model_info(least_used_model_name) api_provider = model_config.get_provider(model_info.api_provider) - client = client_registry.get_client_class(api_provider.client_type)(copy.deepcopy(api_provider)) + client = client_registry.get_client_class_instance(api_provider) logger.debug(f"选择请求模型: {model_info.name}") total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用