diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index b67640641..48ef0c082 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -48,8 +48,10 @@ class LLMRequest: self.task_name = request_type self.model_for_task = model_set self.request_type = request_type - self.model_usage: Dict[str, Tuple[int, int]] = {model: (0, 0) for model in self.model_for_task.model_list} - """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty),惩罚值是为了能在某个模型请求不给力的时候进行调整""" + self.model_usage: Dict[str, Tuple[int, int, int]] = { + model: (0, 0, 0) for model in self.model_for_task.model_list + } + """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整""" self.pri_in = 0 self.pri_out = 0 @@ -226,12 +228,15 @@ class LLMRequest: 根据总tokens和惩罚值选择的模型 """ least_used_model_name = min( - self.model_usage, key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage, + key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000, ) model_info = model_config.get_model_info(least_used_model_name) api_provider = model_config.get_provider(model_info.api_provider) client = client_registry.get_client_class(api_provider.client_type)(copy.deepcopy(api_provider)) logger.debug(f"选择请求模型: {model_info.name}") + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用 return model_info, api_provider, client async def _execute_request( @@ -289,8 +294,8 @@ class LLMRequest: except Exception as e: logger.debug(f"请求失败: {str(e)}") # 处理异常 - total_tokens, penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty + 1) + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty) wait_interval, compressed_messages = self._default_exception_handler( e, @@ -309,6 +314,8 @@ class LLMRequest: finally: # 放在finally防止死循环 retry_remain -= 1 + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) # 使用结束,减少使用惩罚值 logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") raise RuntimeError("请求失败,已达到最大重试次数")