feat(llm): 在负载均衡中引入延迟作为考量因素

为了更智能地选择模型,负载均衡算法现在会考虑模型的平均响应延迟。延迟较高的模型将受到惩罚,从而优先选择响应更快的模型。

- 使用 `namedtuple` (`ModelUsageStats`) 替代了原有的元组来存储模型使用统计信息,提高了代码的可读性和可维护性。
- 在模型选择的评分公式中增加了 `avg_latency` 权重,使算法能够动态适应模型的性能变化。
- 更新了 `LLMRequest` 类,以在每次成功请求后计算并更新模型的平均延迟。
This commit is contained in:
minecraft1024a
2025-10-07 20:29:09 +08:00
parent 4f9b31d188
commit f9c02520d0

View File

@@ -23,6 +23,7 @@ import random
import re
import string
import time
from collections import namedtuple
from collections.abc import Callable, Coroutine
from enum import Enum
from typing import Any
@@ -133,21 +134,26 @@ class RequestType(Enum):
# Helper Classes for LLMRequest Refactoring
# ==============================================================================
# 定义用于跟踪模型使用情况的具名元组
ModelUsageStats = namedtuple( # noqa: PYI024
"ModelUsageStats", ["total_tokens", "penalty", "usage_penalty", "avg_latency", "request_count"]
)
class _ModelSelector:
"""负责模型选择、负载均衡和动态故障切换的策略。"""
CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数
DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量
LATENCY_WEIGHT = 200 # 延迟权重
def __init__(self, model_list: list[str], model_usage: dict[str, tuple[int, int, int]]):
def __init__(self, model_list: list[str], model_usage: dict[str, ModelUsageStats]):
"""
初始化模型选择器。
Args:
model_list (List[str]): 可用模型名称列表。
model_usage (Dict[str, Tuple[int, int, int]]): 模型的初始使用情况
格式为 {model_name: (total_tokens, penalty, usage_penalty)}。
model_usage (Dict[str, ModelUsageStats]): 模型的初始使用情况
"""
self.model_list = model_list
self.model_usage = model_usage
@@ -176,16 +182,18 @@ class _ModelSelector:
return None
# 核心负载均衡算法:选择一个综合得分最低的模型。
# 公式: total_tokens + penalty * 300 + usage_penalty * 1000
# 公式: total_tokens + penalty * 300 + usage_penalty * 1000 + avg_latency * 200
# 设计思路:
# - `total_tokens`: 基础成本优先使用累计token少的模型实现长期均衡。
# - `penalty * 300`: 失败惩罚项。每次失败会增加penalty使其在短期内被选中的概率降低。权重300意味着一次失败大致相当于300个token的成本。
# - `usage_penalty * 1000`: 短期使用惩罚项。每次被选中后会增加,完成后会减少。高权重确保在多个模型都健康的情况下,请求会均匀分布(轮询)。
# - `avg_latency * 200`: 延迟惩罚项。优先选择平均响应时间更快的模型。权重200意味着1秒的延迟约等于200个token的成本。
least_used_model_name = min(
candidate_models_usage,
key=lambda k: candidate_models_usage[k][0]
+ candidate_models_usage[k][1] * 300
+ candidate_models_usage[k][2] * 1000,
key=lambda k: candidate_models_usage[k].total_tokens
+ candidate_models_usage[k].penalty * 300
+ candidate_models_usage[k].usage_penalty * 1000
+ candidate_models_usage[k].avg_latency * self.LATENCY_WEIGHT,
)
model_info = model_config.get_model_info(least_used_model_name)
@@ -211,11 +219,11 @@ class _ModelSelector:
increase (bool): True表示增加惩罚值False表示减少。
"""
# 获取当前模型的统计数据
total_tokens, penalty, usage_penalty = self.model_usage[model_name]
stats = self.model_usage[model_name]
# 根据操作是增加还是减少来确定调整量
adjustment = 1 if increase else -1
# 更新模型的惩罚值
self.model_usage[model_name] = (total_tokens, penalty, usage_penalty + adjustment)
self.model_usage[model_name] = stats._replace(usage_penalty=stats.usage_penalty + adjustment)
async def update_failure_penalty(self, model_name: str, e: Exception):
"""
@@ -223,7 +231,7 @@ class _ModelSelector:
关键错误(如网络连接、服务器错误)会获得更高的惩罚,
促使负载均衡算法在下次选择时优先规避这些不可靠的模型。
"""
total_tokens, penalty, usage_penalty = self.model_usage[model_name]
stats = self.model_usage[model_name]
penalty_increment = self.DEFAULT_PENALTY_INCREMENT
# 对严重错误施加更高的惩罚,以便快速将问题模型移出候选池
@@ -250,7 +258,7 @@ class _ModelSelector:
# 其他未知异常,给予基础惩罚
logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}")
self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty)
self.model_usage[model_name] = stats._replace(penalty=stats.penalty + penalty_increment)
class _PromptProcessor:
@@ -789,8 +797,11 @@ class LLMRequest:
"""
self.task_name = request_type
self.model_for_task = model_set
self.model_usage: dict[str, tuple[int, int, int]] = dict.fromkeys(self.model_for_task.model_list, (0, 0, 0))
"""模型使用量记录,(total_tokens, penalty, usage_penalty)"""
self.model_usage: dict[str, ModelUsageStats] = {
model: ModelUsageStats(total_tokens=0, penalty=0, usage_penalty=0, avg_latency=0.0, request_count=0)
for model in self.model_for_task.model_list
}
"""模型使用量记录"""
# 初始化辅助类
self._model_selector = _ModelSelector(self.model_for_task.model_list, self.model_usage)
@@ -992,12 +1003,21 @@ class LLMRequest:
endpoint (str): 请求的API端点 (e.g., "/chat/completions")。
"""
if usage:
# 步骤1: 更新内存中的token计数,用于负载均衡
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens + usage.total_tokens, penalty, usage_penalty)
# 步骤1: 更新内存中的计数,用于负载均衡
stats = self.model_usage[model_info.name]
# 计算新的平均延迟
new_request_count = stats.request_count + 1
new_avg_latency = (stats.avg_latency * stats.request_count + time_cost) / new_request_count
self.model_usage[model_info.name] = stats._replace(
total_tokens=stats.total_tokens + usage.total_tokens,
avg_latency=new_avg_latency,
request_count=new_request_count,
)
# 步骤2: 创建一个后台任务,将用量数据异步写入数据库
asyncio.create_task(
asyncio.create_task( # noqa: RUF006
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,