feat(llm): 在负载均衡中引入延迟作为考量因素
为了更智能地选择模型,负载均衡算法现在会考虑模型的平均响应延迟。延迟较高的模型将受到惩罚,从而优先选择响应更快的模型。 - 使用 `namedtuple` (`ModelUsageStats`) 替代了原有的元组来存储模型使用统计信息,提高了代码的可读性和可维护性。 - 在模型选择的评分公式中增加了 `avg_latency` 权重,使算法能够动态适应模型的性能变化。 - 更新了 `LLMRequest` 类,以在每次成功请求后计算并更新模型的平均延迟。
This commit is contained in:
@@ -23,6 +23,7 @@ import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from collections.abc import Callable, Coroutine
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
@@ -133,21 +134,26 @@ class RequestType(Enum):
|
||||
# Helper Classes for LLMRequest Refactoring
|
||||
# ==============================================================================
|
||||
|
||||
# 定义用于跟踪模型使用情况的具名元组
|
||||
ModelUsageStats = namedtuple( # noqa: PYI024
|
||||
"ModelUsageStats", ["total_tokens", "penalty", "usage_penalty", "avg_latency", "request_count"]
|
||||
)
|
||||
|
||||
|
||||
class _ModelSelector:
|
||||
"""负责模型选择、负载均衡和动态故障切换的策略。"""
|
||||
|
||||
CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数
|
||||
DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量
|
||||
LATENCY_WEIGHT = 200 # 延迟权重
|
||||
|
||||
def __init__(self, model_list: list[str], model_usage: dict[str, tuple[int, int, int]]):
|
||||
def __init__(self, model_list: list[str], model_usage: dict[str, ModelUsageStats]):
|
||||
"""
|
||||
初始化模型选择器。
|
||||
|
||||
Args:
|
||||
model_list (List[str]): 可用模型名称列表。
|
||||
model_usage (Dict[str, Tuple[int, int, int]]): 模型的初始使用情况,
|
||||
格式为 {model_name: (total_tokens, penalty, usage_penalty)}。
|
||||
model_usage (Dict[str, ModelUsageStats]): 模型的初始使用情况。
|
||||
"""
|
||||
self.model_list = model_list
|
||||
self.model_usage = model_usage
|
||||
@@ -176,16 +182,18 @@ class _ModelSelector:
|
||||
return None
|
||||
|
||||
# 核心负载均衡算法:选择一个综合得分最低的模型。
|
||||
# 公式: total_tokens + penalty * 300 + usage_penalty * 1000
|
||||
# 公式: total_tokens + penalty * 300 + usage_penalty * 1000 + avg_latency * 200
|
||||
# 设计思路:
|
||||
# - `total_tokens`: 基础成本,优先使用累计token少的模型,实现长期均衡。
|
||||
# - `penalty * 300`: 失败惩罚项。每次失败会增加penalty,使其在短期内被选中的概率降低。权重300意味着一次失败大致相当于300个token的成本。
|
||||
# - `usage_penalty * 1000`: 短期使用惩罚项。每次被选中后会增加,完成后会减少。高权重确保在多个模型都健康的情况下,请求会均匀分布(轮询)。
|
||||
# - `avg_latency * 200`: 延迟惩罚项。优先选择平均响应时间更快的模型。权重200意味着1秒的延迟约等于200个token的成本。
|
||||
least_used_model_name = min(
|
||||
candidate_models_usage,
|
||||
key=lambda k: candidate_models_usage[k][0]
|
||||
+ candidate_models_usage[k][1] * 300
|
||||
+ candidate_models_usage[k][2] * 1000,
|
||||
key=lambda k: candidate_models_usage[k].total_tokens
|
||||
+ candidate_models_usage[k].penalty * 300
|
||||
+ candidate_models_usage[k].usage_penalty * 1000
|
||||
+ candidate_models_usage[k].avg_latency * self.LATENCY_WEIGHT,
|
||||
)
|
||||
|
||||
model_info = model_config.get_model_info(least_used_model_name)
|
||||
@@ -211,11 +219,11 @@ class _ModelSelector:
|
||||
increase (bool): True表示增加惩罚值,False表示减少。
|
||||
"""
|
||||
# 获取当前模型的统计数据
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_name]
|
||||
stats = self.model_usage[model_name]
|
||||
# 根据操作是增加还是减少来确定调整量
|
||||
adjustment = 1 if increase else -1
|
||||
# 更新模型的惩罚值
|
||||
self.model_usage[model_name] = (total_tokens, penalty, usage_penalty + adjustment)
|
||||
self.model_usage[model_name] = stats._replace(usage_penalty=stats.usage_penalty + adjustment)
|
||||
|
||||
async def update_failure_penalty(self, model_name: str, e: Exception):
|
||||
"""
|
||||
@@ -223,7 +231,7 @@ class _ModelSelector:
|
||||
关键错误(如网络连接、服务器错误)会获得更高的惩罚,
|
||||
促使负载均衡算法在下次选择时优先规避这些不可靠的模型。
|
||||
"""
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_name]
|
||||
stats = self.model_usage[model_name]
|
||||
penalty_increment = self.DEFAULT_PENALTY_INCREMENT
|
||||
|
||||
# 对严重错误施加更高的惩罚,以便快速将问题模型移出候选池
|
||||
@@ -250,7 +258,7 @@ class _ModelSelector:
|
||||
# 其他未知异常,给予基础惩罚
|
||||
logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}")
|
||||
|
||||
self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty)
|
||||
self.model_usage[model_name] = stats._replace(penalty=stats.penalty + penalty_increment)
|
||||
|
||||
|
||||
class _PromptProcessor:
|
||||
@@ -789,8 +797,11 @@ class LLMRequest:
|
||||
"""
|
||||
self.task_name = request_type
|
||||
self.model_for_task = model_set
|
||||
self.model_usage: dict[str, tuple[int, int, int]] = dict.fromkeys(self.model_for_task.model_list, (0, 0, 0))
|
||||
"""模型使用量记录,(total_tokens, penalty, usage_penalty)"""
|
||||
self.model_usage: dict[str, ModelUsageStats] = {
|
||||
model: ModelUsageStats(total_tokens=0, penalty=0, usage_penalty=0, avg_latency=0.0, request_count=0)
|
||||
for model in self.model_for_task.model_list
|
||||
}
|
||||
"""模型使用量记录"""
|
||||
|
||||
# 初始化辅助类
|
||||
self._model_selector = _ModelSelector(self.model_for_task.model_list, self.model_usage)
|
||||
@@ -992,12 +1003,21 @@ class LLMRequest:
|
||||
endpoint (str): 请求的API端点 (e.g., "/chat/completions")。
|
||||
"""
|
||||
if usage:
|
||||
# 步骤1: 更新内存中的token计数,用于负载均衡
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||
self.model_usage[model_info.name] = (total_tokens + usage.total_tokens, penalty, usage_penalty)
|
||||
# 步骤1: 更新内存中的统计数据,用于负载均衡
|
||||
stats = self.model_usage[model_info.name]
|
||||
|
||||
# 计算新的平均延迟
|
||||
new_request_count = stats.request_count + 1
|
||||
new_avg_latency = (stats.avg_latency * stats.request_count + time_cost) / new_request_count
|
||||
|
||||
self.model_usage[model_info.name] = stats._replace(
|
||||
total_tokens=stats.total_tokens + usage.total_tokens,
|
||||
avg_latency=new_avg_latency,
|
||||
request_count=new_request_count,
|
||||
)
|
||||
|
||||
# 步骤2: 创建一个后台任务,将用量数据异步写入数据库
|
||||
asyncio.create_task(
|
||||
asyncio.create_task( # noqa: RUF006
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
|
||||
Reference in New Issue
Block a user