feat(llm): 在负载均衡中引入延迟作为考量因素

为了更智能地选择模型,负载均衡算法现在会考虑模型的平均响应延迟。延迟较高的模型将受到惩罚,从而优先选择响应更快的模型。

- 使用 `namedtuple` (`ModelUsageStats`) 替代了原有的元组来存储模型使用统计信息,提高了代码的可读性和可维护性。
- 在模型选择的评分公式中增加了 `avg_latency` 权重,使算法能够动态适应模型的性能变化。
- 更新了 `LLMRequest` 类,以在每次成功请求后计算并更新模型的平均延迟。
This commit is contained in:
minecraft1024a
2025-10-07 20:29:09 +08:00
parent 4f9b31d188
commit f9c02520d0

View File

@@ -23,6 +23,7 @@ import random
import re import re
import string import string
import time import time
from collections import namedtuple
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from enum import Enum from enum import Enum
from typing import Any from typing import Any
@@ -133,21 +134,26 @@ class RequestType(Enum):
# Helper Classes for LLMRequest Refactoring # Helper Classes for LLMRequest Refactoring
# ============================================================================== # ==============================================================================
# 定义用于跟踪模型使用情况的具名元组
ModelUsageStats = namedtuple( # noqa: PYI024
"ModelUsageStats", ["total_tokens", "penalty", "usage_penalty", "avg_latency", "request_count"]
)
class _ModelSelector: class _ModelSelector:
"""负责模型选择、负载均衡和动态故障切换的策略。""" """负责模型选择、负载均衡和动态故障切换的策略。"""
CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数 CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数
DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量 DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量
LATENCY_WEIGHT = 200 # 延迟权重
def __init__(self, model_list: list[str], model_usage: dict[str, tuple[int, int, int]]): def __init__(self, model_list: list[str], model_usage: dict[str, ModelUsageStats]):
""" """
初始化模型选择器。 初始化模型选择器。
Args: Args:
model_list (List[str]): 可用模型名称列表。 model_list (List[str]): 可用模型名称列表。
model_usage (Dict[str, Tuple[int, int, int]]): 模型的初始使用情况 model_usage (Dict[str, ModelUsageStats]): 模型的初始使用情况
格式为 {model_name: (total_tokens, penalty, usage_penalty)}。
""" """
self.model_list = model_list self.model_list = model_list
self.model_usage = model_usage self.model_usage = model_usage
@@ -176,16 +182,18 @@ class _ModelSelector:
return None return None
# 核心负载均衡算法:选择一个综合得分最低的模型。 # 核心负载均衡算法:选择一个综合得分最低的模型。
# 公式: total_tokens + penalty * 300 + usage_penalty * 1000 # 公式: total_tokens + penalty * 300 + usage_penalty * 1000 + avg_latency * 200
# 设计思路: # 设计思路:
# - `total_tokens`: 基础成本优先使用累计token少的模型实现长期均衡。 # - `total_tokens`: 基础成本优先使用累计token少的模型实现长期均衡。
# - `penalty * 300`: 失败惩罚项。每次失败会增加penalty使其在短期内被选中的概率降低。权重300意味着一次失败大致相当于300个token的成本。 # - `penalty * 300`: 失败惩罚项。每次失败会增加penalty使其在短期内被选中的概率降低。权重300意味着一次失败大致相当于300个token的成本。
# - `usage_penalty * 1000`: 短期使用惩罚项。每次被选中后会增加,完成后会减少。高权重确保在多个模型都健康的情况下,请求会均匀分布(轮询)。 # - `usage_penalty * 1000`: 短期使用惩罚项。每次被选中后会增加,完成后会减少。高权重确保在多个模型都健康的情况下,请求会均匀分布(轮询)。
# - `avg_latency * 200`: 延迟惩罚项。优先选择平均响应时间更快的模型。权重200意味着1秒的延迟约等于200个token的成本。
least_used_model_name = min( least_used_model_name = min(
candidate_models_usage, candidate_models_usage,
key=lambda k: candidate_models_usage[k][0] key=lambda k: candidate_models_usage[k].total_tokens
+ candidate_models_usage[k][1] * 300 + candidate_models_usage[k].penalty * 300
+ candidate_models_usage[k][2] * 1000, + candidate_models_usage[k].usage_penalty * 1000
+ candidate_models_usage[k].avg_latency * self.LATENCY_WEIGHT,
) )
model_info = model_config.get_model_info(least_used_model_name) model_info = model_config.get_model_info(least_used_model_name)
@@ -211,11 +219,11 @@ class _ModelSelector:
increase (bool): True表示增加惩罚值False表示减少。 increase (bool): True表示增加惩罚值False表示减少。
""" """
# 获取当前模型的统计数据 # 获取当前模型的统计数据
total_tokens, penalty, usage_penalty = self.model_usage[model_name] stats = self.model_usage[model_name]
# 根据操作是增加还是减少来确定调整量 # 根据操作是增加还是减少来确定调整量
adjustment = 1 if increase else -1 adjustment = 1 if increase else -1
# 更新模型的惩罚值 # 更新模型的惩罚值
self.model_usage[model_name] = (total_tokens, penalty, usage_penalty + adjustment) self.model_usage[model_name] = stats._replace(usage_penalty=stats.usage_penalty + adjustment)
async def update_failure_penalty(self, model_name: str, e: Exception): async def update_failure_penalty(self, model_name: str, e: Exception):
""" """
@@ -223,7 +231,7 @@ class _ModelSelector:
关键错误(如网络连接、服务器错误)会获得更高的惩罚, 关键错误(如网络连接、服务器错误)会获得更高的惩罚,
促使负载均衡算法在下次选择时优先规避这些不可靠的模型。 促使负载均衡算法在下次选择时优先规避这些不可靠的模型。
""" """
total_tokens, penalty, usage_penalty = self.model_usage[model_name] stats = self.model_usage[model_name]
penalty_increment = self.DEFAULT_PENALTY_INCREMENT penalty_increment = self.DEFAULT_PENALTY_INCREMENT
# 对严重错误施加更高的惩罚,以便快速将问题模型移出候选池 # 对严重错误施加更高的惩罚,以便快速将问题模型移出候选池
@@ -250,7 +258,7 @@ class _ModelSelector:
# 其他未知异常,给予基础惩罚 # 其他未知异常,给予基础惩罚
logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}") logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}")
self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty) self.model_usage[model_name] = stats._replace(penalty=stats.penalty + penalty_increment)
class _PromptProcessor: class _PromptProcessor:
@@ -789,8 +797,11 @@ class LLMRequest:
""" """
self.task_name = request_type self.task_name = request_type
self.model_for_task = model_set self.model_for_task = model_set
self.model_usage: dict[str, tuple[int, int, int]] = dict.fromkeys(self.model_for_task.model_list, (0, 0, 0)) self.model_usage: dict[str, ModelUsageStats] = {
"""模型使用量记录,(total_tokens, penalty, usage_penalty)""" model: ModelUsageStats(total_tokens=0, penalty=0, usage_penalty=0, avg_latency=0.0, request_count=0)
for model in self.model_for_task.model_list
}
"""模型使用量记录"""
# 初始化辅助类 # 初始化辅助类
self._model_selector = _ModelSelector(self.model_for_task.model_list, self.model_usage) self._model_selector = _ModelSelector(self.model_for_task.model_list, self.model_usage)
@@ -992,12 +1003,21 @@ class LLMRequest:
endpoint (str): 请求的API端点 (e.g., "/chat/completions")。 endpoint (str): 请求的API端点 (e.g., "/chat/completions")。
""" """
if usage: if usage:
# 步骤1: 更新内存中的token计数,用于负载均衡 # 步骤1: 更新内存中的计数,用于负载均衡
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] stats = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens + usage.total_tokens, penalty, usage_penalty)
# 计算新的平均延迟
new_request_count = stats.request_count + 1
new_avg_latency = (stats.avg_latency * stats.request_count + time_cost) / new_request_count
self.model_usage[model_info.name] = stats._replace(
total_tokens=stats.total_tokens + usage.total_tokens,
avg_latency=new_avg_latency,
request_count=new_request_count,
)
# 步骤2: 创建一个后台任务,将用量数据异步写入数据库 # 步骤2: 创建一个后台任务,将用量数据异步写入数据库
asyncio.create_task( asyncio.create_task( # noqa: RUF006
llm_usage_recorder.record_usage_to_database( llm_usage_recorder.record_usage_to_database(
model_info=model_info, model_info=model_info,
model_usage=usage, model_usage=usage,