diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index dc4374c06..b2b4ba928 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -23,6 +23,7 @@ import random import re import string import time +from collections import namedtuple from collections.abc import Callable, Coroutine from enum import Enum from typing import Any @@ -133,21 +134,26 @@ class RequestType(Enum): # Helper Classes for LLMRequest Refactoring # ============================================================================== +# 定义用于跟踪模型使用情况的具名元组 +ModelUsageStats = namedtuple( # noqa: PYI024 + "ModelUsageStats", ["total_tokens", "penalty", "usage_penalty", "avg_latency", "request_count"] +) + class _ModelSelector: """负责模型选择、负载均衡和动态故障切换的策略。""" CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数 DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量 + LATENCY_WEIGHT = 200 # 延迟权重 - def __init__(self, model_list: list[str], model_usage: dict[str, tuple[int, int, int]]): + def __init__(self, model_list: list[str], model_usage: dict[str, ModelUsageStats]): """ 初始化模型选择器。 Args: model_list (List[str]): 可用模型名称列表。 - model_usage (Dict[str, Tuple[int, int, int]]): 模型的初始使用情况, - 格式为 {model_name: (total_tokens, penalty, usage_penalty)}。 + model_usage (Dict[str, ModelUsageStats]): 模型的初始使用情况。 """ self.model_list = model_list self.model_usage = model_usage @@ -176,16 +182,18 @@ class _ModelSelector: return None # 核心负载均衡算法:选择一个综合得分最低的模型。 - # 公式: total_tokens + penalty * 300 + usage_penalty * 1000 + # 公式: total_tokens + penalty * 300 + usage_penalty * 1000 + avg_latency * 200 # 设计思路: # - `total_tokens`: 基础成本,优先使用累计token少的模型,实现长期均衡。 # - `penalty * 300`: 失败惩罚项。每次失败会增加penalty,使其在短期内被选中的概率降低。权重300意味着一次失败大致相当于300个token的成本。 # - `usage_penalty * 1000`: 短期使用惩罚项。每次被选中后会增加,完成后会减少。高权重确保在多个模型都健康的情况下,请求会均匀分布(轮询)。 + # - `avg_latency * 200`: 延迟惩罚项。优先选择平均响应时间更快的模型。权重200意味着1秒的延迟约等于200个token的成本。 least_used_model_name = min( candidate_models_usage, - key=lambda k: candidate_models_usage[k][0] - + candidate_models_usage[k][1] * 300 - + candidate_models_usage[k][2] * 1000, + key=lambda k: candidate_models_usage[k].total_tokens + + candidate_models_usage[k].penalty * 300 + + candidate_models_usage[k].usage_penalty * 1000 + + candidate_models_usage[k].avg_latency * self.LATENCY_WEIGHT, ) model_info = model_config.get_model_info(least_used_model_name) @@ -211,11 +219,11 @@ class _ModelSelector: increase (bool): True表示增加惩罚值,False表示减少。 """ # 获取当前模型的统计数据 - total_tokens, penalty, usage_penalty = self.model_usage[model_name] + stats = self.model_usage[model_name] # 根据操作是增加还是减少来确定调整量 adjustment = 1 if increase else -1 # 更新模型的惩罚值 - self.model_usage[model_name] = (total_tokens, penalty, usage_penalty + adjustment) + self.model_usage[model_name] = stats._replace(usage_penalty=stats.usage_penalty + adjustment) async def update_failure_penalty(self, model_name: str, e: Exception): """ @@ -223,7 +231,7 @@ class _ModelSelector: 关键错误(如网络连接、服务器错误)会获得更高的惩罚, 促使负载均衡算法在下次选择时优先规避这些不可靠的模型。 """ - total_tokens, penalty, usage_penalty = self.model_usage[model_name] + stats = self.model_usage[model_name] penalty_increment = self.DEFAULT_PENALTY_INCREMENT # 对严重错误施加更高的惩罚,以便快速将问题模型移出候选池 @@ -250,7 +258,7 @@ class _ModelSelector: # 其他未知异常,给予基础惩罚 logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}") - self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty) + self.model_usage[model_name] = stats._replace(penalty=stats.penalty + penalty_increment) class _PromptProcessor: @@ -789,8 +797,11 @@ class LLMRequest: """ self.task_name = request_type self.model_for_task = model_set - self.model_usage: dict[str, tuple[int, int, int]] = dict.fromkeys(self.model_for_task.model_list, (0, 0, 0)) - """模型使用量记录,(total_tokens, penalty, usage_penalty)""" + self.model_usage: dict[str, ModelUsageStats] = { + model: ModelUsageStats(total_tokens=0, penalty=0, usage_penalty=0, avg_latency=0.0, request_count=0) + for model in self.model_for_task.model_list + } + """模型使用量记录""" # 初始化辅助类 self._model_selector = _ModelSelector(self.model_for_task.model_list, self.model_usage) @@ -992,12 +1003,21 @@ class LLMRequest: endpoint (str): 请求的API端点 (e.g., "/chat/completions")。 """ if usage: - # 步骤1: 更新内存中的token计数,用于负载均衡 - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens + usage.total_tokens, penalty, usage_penalty) + # 步骤1: 更新内存中的统计数据,用于负载均衡 + stats = self.model_usage[model_info.name] + + # 计算新的平均延迟 + new_request_count = stats.request_count + 1 + new_avg_latency = (stats.avg_latency * stats.request_count + time_cost) / new_request_count + + self.model_usage[model_info.name] = stats._replace( + total_tokens=stats.total_tokens + usage.total_tokens, + avg_latency=new_avg_latency, + request_count=new_request_count, + ) # 步骤2: 创建一个后台任务,将用量数据异步写入数据库 - asyncio.create_task( + asyncio.create_task( # noqa: RUF006 llm_usage_recorder.record_usage_to_database( model_info=model_info, model_usage=usage,