docs(llm): 为 LLM 工具模块添加全面的文档和注释
为 `utils_model.py` 模块及其核心类(`_ModelSelector`、`_PromptProcessor`、`_RequestExecutor`)添加了详细的文档字符串。 同时,增加了大量的行内注释,以阐明复杂的逻辑,例如: - 模型选择的负载均衡算法 - 针对不同错误的失败惩罚计算 - 对嵌入任务的特殊客户端处理 此举旨在提高 LLM 交互核心逻辑的可读性和可维护性。
This commit is contained in:
@@ -1,3 +1,29 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@software:
|
||||
@file: utils_model.py
|
||||
@time: 2024/7/28 上午1:09
|
||||
@author: Mai ū
|
||||
@contact: 2496955113@qq.com
|
||||
@desc: 该模块封装了与大语言模型(LLM)交互的所有核心逻辑。
|
||||
它被设计为一个高度容错和可扩展的系统,包含以下主要组件:
|
||||
|
||||
- **模型选择器 (_ModelSelector)**:
|
||||
实现了基于负载均衡和失败惩罚的动态模型选择策略,确保在高并发或部分模型失效时系统的稳定性。
|
||||
|
||||
- **提示处理器 (_PromptProcessor)**:
|
||||
负责对输入模型的提示词进行预处理(如内容混淆、反截断指令注入)和对模型输出进行后处理(如提取思考过程、检查截断)。
|
||||
|
||||
- **请求执行器 (_RequestExecutor)**:
|
||||
封装了底层的API请求逻辑,包括自动重试、异常分类处理和消息体压缩等功能。
|
||||
|
||||
- **请求策略 (_RequestStrategy)**:
|
||||
实现了高阶请求策略,如模型间的故障转移(Failover),确保单个模型的失败不会导致整个请求失败。
|
||||
|
||||
- **LLMRequest (主接口)**:
|
||||
作为模块的统一入口(Facade),为上层业务逻辑提供了简洁的接口来发起文本、图像、语音等不同类型的LLM请求。
|
||||
"""
|
||||
import re
|
||||
import asyncio
|
||||
import time
|
||||
@@ -102,10 +128,18 @@ class RequestType(Enum):
|
||||
class _ModelSelector:
|
||||
"""负责模型选择、负载均衡和动态故障切换的策略。"""
|
||||
|
||||
CRITICAL_PENALTY_MULTIPLIER = 5
|
||||
DEFAULT_PENALTY_INCREMENT = 1
|
||||
CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数
|
||||
DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量
|
||||
|
||||
def __init__(self, model_list: List[str], model_usage: Dict[str, Tuple[int, int, int]]):
|
||||
"""
|
||||
初始化模型选择器。
|
||||
|
||||
Args:
|
||||
model_list (List[str]): 可用模型名称列表。
|
||||
model_usage (Dict[str, Tuple[int, int, int]]): 模型的初始使用情况,
|
||||
格式为 {model_name: (total_tokens, penalty, usage_penalty)}。
|
||||
"""
|
||||
self.model_list = model_list
|
||||
self.model_usage = model_usage
|
||||
|
||||
@@ -132,8 +166,12 @@ class _ModelSelector:
|
||||
logger.warning("没有可用的模型供当前请求选择。")
|
||||
return None
|
||||
|
||||
# 根据公式查找分数最低的模型,该公式综合了总token数、模型失败惩罚值和使用频率惩罚值。
|
||||
# 核心负载均衡算法:选择一个综合得分最低的模型。
|
||||
# 公式: total_tokens + penalty * 300 + usage_penalty * 1000
|
||||
# 设计思路:
|
||||
# - `total_tokens`: 基础成本,优先使用累计token少的模型,实现长期均衡。
|
||||
# - `penalty * 300`: 失败惩罚项。每次失败会增加penalty,使其在短期内被选中的概率降低。权重300意味着一次失败大致相当于300个token的成本。
|
||||
# - `usage_penalty * 1000`: 短期使用惩罚项。每次被选中后会增加,完成后会减少。高权重确保在多个模型都健康的情况下,请求会均匀分布(轮询)。
|
||||
least_used_model_name = min(
|
||||
candidate_models_usage,
|
||||
key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000,
|
||||
@@ -141,7 +179,8 @@ class _ModelSelector:
|
||||
|
||||
model_info = model_config.get_model_info(least_used_model_name)
|
||||
api_provider = model_config.get_provider(model_info.api_provider)
|
||||
# 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题
|
||||
# 特殊处理:对于 embedding 任务,强制创建新的 aiohttp.ClientSession。
|
||||
# 这是为了避免在某些高并发场景下,共享的ClientSession可能引发的事件循环相关问题。
|
||||
force_new_client = request_type == "embedding"
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||
|
||||
@@ -165,16 +204,22 @@ class _ModelSelector:
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_name]
|
||||
penalty_increment = self.DEFAULT_PENALTY_INCREMENT
|
||||
|
||||
# 对严重错误施加更高的惩罚,以便快速将问题模型移出候选池
|
||||
if isinstance(e, (NetworkConnectionError, ReqAbortException)):
|
||||
# 网络连接错误或请求被中断,通常是基础设施问题,应重罚
|
||||
penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER
|
||||
logger.warning(f"模型 '{model_name}' 发生严重错误 ({type(e).__name__}),增加高额惩罚值: {penalty_increment}")
|
||||
elif isinstance(e, RespNotOkException):
|
||||
# 对于HTTP响应错误,重点关注服务器端错误
|
||||
if e.status_code >= 500:
|
||||
# 5xx 错误表明服务器端出现问题,应重罚
|
||||
penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER
|
||||
logger.warning(f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加高额惩罚值: {penalty_increment}")
|
||||
else:
|
||||
# 4xx 客户端错误通常不代表模型本身不可用,给予基础惩罚
|
||||
logger.warning(f"模型 '{model_name}' 发生客户端响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}")
|
||||
else:
|
||||
# 其他未知异常,给予基础惩罚
|
||||
logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}")
|
||||
|
||||
self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty)
|
||||
@@ -184,7 +229,11 @@ class _PromptProcessor:
|
||||
"""封装所有与提示词和响应内容的预处理和后处理逻辑。"""
|
||||
|
||||
def __init__(self):
|
||||
# 内容混淆过滤指令
|
||||
"""
|
||||
初始化提示处理器。
|
||||
这里定义了用于内容混淆和反截断的指令文本。
|
||||
"""
|
||||
# 内容混淆过滤指令: 用于在API传输中增加噪音,防止内容被审查或修改
|
||||
self.noise_instruction = """**【核心过滤规则】**
|
||||
在交互过程中,你会在用户的输入内容中看到一些无意义的、由字母和数字组成的随机字符串。这些字符串是API传输过程中的技术伪装信息,【不是】用户真实意图的一部分。
|
||||
|
||||
@@ -291,6 +340,13 @@ class _RequestExecutor:
|
||||
"""负责执行实际的API请求,包含重试逻辑和底层异常处理。"""
|
||||
|
||||
def __init__(self, model_selector: _ModelSelector, task_name: str):
|
||||
"""
|
||||
初始化请求执行器。
|
||||
|
||||
Args:
|
||||
model_selector (_ModelSelector): 模型选择器实例,用于在请求失败时更新惩罚。
|
||||
task_name (str): 当前任务的名称,用于日志记录。
|
||||
"""
|
||||
self.model_selector = model_selector
|
||||
self.task_name = task_name
|
||||
|
||||
@@ -302,19 +358,37 @@ class _RequestExecutor:
|
||||
model_info: ModelInfo,
|
||||
**kwargs,
|
||||
) -> APIResponse:
|
||||
"""实际执行请求的方法,包含了重试和异常处理逻辑。"""
|
||||
"""
|
||||
实际执行请求的方法,包含了重试和异常处理逻辑。
|
||||
|
||||
Args:
|
||||
api_provider (APIProvider): API提供商配置。
|
||||
client (BaseClient): 用于发送请求的客户端实例。
|
||||
request_type (RequestType): 请求的类型 (e.g., RESPONSE, EMBEDDING)。
|
||||
model_info (ModelInfo): 正在使用的模型的信息。
|
||||
**kwargs: 传递给客户端方法的具体参数。
|
||||
|
||||
Returns:
|
||||
APIResponse: 来自API的成功响应。
|
||||
|
||||
Raises:
|
||||
Exception: 如果重试后请求仍然失败,则抛出最终的异常。
|
||||
RuntimeError: 如果达到最大重试次数。
|
||||
"""
|
||||
retry_remain = api_provider.max_retry
|
||||
compressed_messages: Optional[List[Message]] = None
|
||||
|
||||
while retry_remain > 0:
|
||||
try:
|
||||
# 优先使用压缩后的消息列表
|
||||
message_list = kwargs.get("message_list")
|
||||
current_messages = compressed_messages or message_list
|
||||
|
||||
# 根据请求类型调用不同的客户端方法
|
||||
if request_type == RequestType.RESPONSE:
|
||||
assert current_messages is not None, "message_list cannot be None for response requests"
|
||||
|
||||
# 修复: 防止 'message_list' 在 kwargs 中重复
|
||||
# 修复: 防止 'message_list' 在 kwargs 中重复传递
|
||||
request_params = kwargs.copy()
|
||||
request_params.pop("message_list", None)
|
||||
|
||||
@@ -328,18 +402,20 @@ class _RequestExecutor:
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"请求失败: {str(e)}")
|
||||
# 记录失败并更新模型的惩罚值
|
||||
self.model_selector.update_failure_penalty(model_info.name, e)
|
||||
|
||||
# 处理异常,决定是否重试以及等待多久
|
||||
wait_interval, new_compressed_messages = self._handle_exception(
|
||||
e, model_info, api_provider, retry_remain, (kwargs.get("message_list"), compressed_messages is not None)
|
||||
)
|
||||
if new_compressed_messages:
|
||||
compressed_messages = new_compressed_messages
|
||||
compressed_messages = new_compressed_messages # 更新为压缩后的消息
|
||||
|
||||
if wait_interval == -1:
|
||||
raise e # 如果不再重试,则传播异常
|
||||
raise e # 如果决定不再重试,则传播异常
|
||||
elif wait_interval > 0:
|
||||
await asyncio.sleep(wait_interval)
|
||||
await asyncio.sleep(wait_interval) # 等待指定时间后重试
|
||||
finally:
|
||||
retry_remain -= 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user