93 lines
3.4 KiB
Python
93 lines
3.4 KiB
Python
import importlib
|
||
from typing import Dict
|
||
|
||
from src.config.config import model_config
|
||
from src.config.api_ada_configs import ModuleConfig, ModelUsageArgConfig
|
||
from src.common.logger import get_logger
|
||
|
||
from .model_client import ModelRequestHandler, BaseClient
|
||
|
||
logger = get_logger("模型管理器")
|
||
|
||
class ModelManager:
|
||
# TODO: 添加读写锁,防止异步刷新配置时发生数据竞争
|
||
|
||
def __init__(
|
||
self,
|
||
config: ModuleConfig,
|
||
):
|
||
self.config: ModuleConfig = config
|
||
"""配置信息"""
|
||
|
||
self.api_client_map: Dict[str, BaseClient] = {}
|
||
"""API客户端映射表"""
|
||
|
||
self._request_handler_cache: Dict[str, ModelRequestHandler] = {}
|
||
"""ModelRequestHandler缓存,避免重复创建"""
|
||
|
||
for provider_name, api_provider in self.config.api_providers.items():
|
||
# 初始化API客户端
|
||
try:
|
||
# 根据配置动态加载实现
|
||
client_module = importlib.import_module(
|
||
f".model_client.{api_provider.client_type}_client", __package__
|
||
)
|
||
client_class = getattr(
|
||
client_module, f"{api_provider.client_type.capitalize()}Client"
|
||
)
|
||
if not issubclass(client_class, BaseClient):
|
||
raise TypeError(
|
||
f"'{client_class.__name__}' is not a subclass of 'BaseClient'"
|
||
)
|
||
self.api_client_map[api_provider.name] = client_class(
|
||
api_provider
|
||
) # 实例化,放入api_client_map
|
||
except ImportError as e:
|
||
logger.error(f"Failed to import client module: {e}")
|
||
raise ImportError(
|
||
f"Failed to import client module for '{provider_name}': {e}"
|
||
) from e
|
||
|
||
def __getitem__(self, task_name: str) -> ModelRequestHandler:
|
||
"""
|
||
获取任务所需的模型客户端(封装)
|
||
使用缓存机制避免重复创建ModelRequestHandler
|
||
:param task_name: 任务名称
|
||
:return: 模型客户端
|
||
"""
|
||
if task_name not in self.config.task_model_arg_map:
|
||
raise KeyError(f"'{task_name}' not registered in ModelManager")
|
||
|
||
# 检查缓存中是否已存在
|
||
if task_name in self._request_handler_cache:
|
||
logger.debug(f"🚀 [性能优化] 从缓存获取ModelRequestHandler: {task_name}")
|
||
return self._request_handler_cache[task_name]
|
||
|
||
# 创建新的ModelRequestHandler并缓存
|
||
logger.debug(f"🔧 [性能优化] 创建并缓存ModelRequestHandler: {task_name}")
|
||
handler = ModelRequestHandler(
|
||
task_name=task_name,
|
||
config=self.config,
|
||
api_client_map=self.api_client_map,
|
||
)
|
||
self._request_handler_cache[task_name] = handler
|
||
return handler
|
||
|
||
def __setitem__(self, task_name: str, value: ModelUsageArgConfig):
|
||
"""
|
||
注册任务的模型使用配置
|
||
:param task_name: 任务名称
|
||
:param value: 模型使用配置
|
||
"""
|
||
self.config.task_model_arg_map[task_name] = value
|
||
|
||
def __contains__(self, task_name: str):
|
||
"""
|
||
判断任务是否已注册
|
||
:param task_name: 任务名称
|
||
:return: 是否在模型列表中
|
||
"""
|
||
return task_name in self.config.task_model_arg_map
|
||
|
||
|