diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index b24b1843c..e3699d540 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1,3 +1,29 @@ +# -*- coding: utf-8 -*- +# -*- coding: utf-8 -*- +""" +@software: +@file: utils_model.py +@time: 2024/7/28 上午1:09 +@author: Mai ū +@contact: 2496955113@qq.com +@desc: 该模块封装了与大语言模型(LLM)交互的所有核心逻辑。 +它被设计为一个高度容错和可扩展的系统,包含以下主要组件: + +- **模型选择器 (_ModelSelector)**: + 实现了基于负载均衡和失败惩罚的动态模型选择策略,确保在高并发或部分模型失效时系统的稳定性。 + +- **提示处理器 (_PromptProcessor)**: + 负责对输入模型的提示词进行预处理(如内容混淆、反截断指令注入)和对模型输出进行后处理(如提取思考过程、检查截断)。 + +- **请求执行器 (_RequestExecutor)**: + 封装了底层的API请求逻辑,包括自动重试、异常分类处理和消息体压缩等功能。 + +- **请求策略 (_RequestStrategy)**: + 实现了高阶请求策略,如模型间的故障转移(Failover),确保单个模型的失败不会导致整个请求失败。 + +- **LLMRequest (主接口)**: + 作为模块的统一入口(Facade),为上层业务逻辑提供了简洁的接口来发起文本、图像、语音等不同类型的LLM请求。 +""" import re import asyncio import time @@ -102,10 +128,18 @@ class RequestType(Enum): class _ModelSelector: """负责模型选择、负载均衡和动态故障切换的策略。""" - CRITICAL_PENALTY_MULTIPLIER = 5 - DEFAULT_PENALTY_INCREMENT = 1 + CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数 + DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量 def __init__(self, model_list: List[str], model_usage: Dict[str, Tuple[int, int, int]]): + """ + 初始化模型选择器。 + + Args: + model_list (List[str]): 可用模型名称列表。 + model_usage (Dict[str, Tuple[int, int, int]]): 模型的初始使用情况, + 格式为 {model_name: (total_tokens, penalty, usage_penalty)}。 + """ self.model_list = model_list self.model_usage = model_usage @@ -132,8 +166,12 @@ class _ModelSelector: logger.warning("没有可用的模型供当前请求选择。") return None - # 根据公式查找分数最低的模型,该公式综合了总token数、模型失败惩罚值和使用频率惩罚值。 + # 核心负载均衡算法:选择一个综合得分最低的模型。 # 公式: total_tokens + penalty * 300 + usage_penalty * 1000 + # 设计思路: + # - `total_tokens`: 基础成本,优先使用累计token少的模型,实现长期均衡。 + # - `penalty * 300`: 失败惩罚项。每次失败会增加penalty,使其在短期内被选中的概率降低。权重300意味着一次失败大致相当于300个token的成本。 + # - `usage_penalty * 1000`: 短期使用惩罚项。每次被选中后会增加,完成后会减少。高权重确保在多个模型都健康的情况下,请求会均匀分布(轮询)。 least_used_model_name = min( candidate_models_usage, key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000, @@ -141,7 +179,8 @@ class _ModelSelector: model_info = model_config.get_model_info(least_used_model_name) api_provider = model_config.get_provider(model_info.api_provider) - # 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题 + # 特殊处理:对于 embedding 任务,强制创建新的 aiohttp.ClientSession。 + # 这是为了避免在某些高并发场景下,共享的ClientSession可能引发的事件循环相关问题。 force_new_client = request_type == "embedding" client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) @@ -165,16 +204,22 @@ class _ModelSelector: total_tokens, penalty, usage_penalty = self.model_usage[model_name] penalty_increment = self.DEFAULT_PENALTY_INCREMENT + # 对严重错误施加更高的惩罚,以便快速将问题模型移出候选池 if isinstance(e, (NetworkConnectionError, ReqAbortException)): + # 网络连接错误或请求被中断,通常是基础设施问题,应重罚 penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER logger.warning(f"模型 '{model_name}' 发生严重错误 ({type(e).__name__}),增加高额惩罚值: {penalty_increment}") elif isinstance(e, RespNotOkException): + # 对于HTTP响应错误,重点关注服务器端错误 if e.status_code >= 500: + # 5xx 错误表明服务器端出现问题,应重罚 penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER logger.warning(f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加高额惩罚值: {penalty_increment}") else: + # 4xx 客户端错误通常不代表模型本身不可用,给予基础惩罚 logger.warning(f"模型 '{model_name}' 发生客户端响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}") else: + # 其他未知异常,给予基础惩罚 logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}") self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty) @@ -184,7 +229,11 @@ class _PromptProcessor: """封装所有与提示词和响应内容的预处理和后处理逻辑。""" def __init__(self): - # 内容混淆过滤指令 + """ + 初始化提示处理器。 + 这里定义了用于内容混淆和反截断的指令文本。 + """ + # 内容混淆过滤指令: 用于在API传输中增加噪音,防止内容被审查或修改 self.noise_instruction = """**【核心过滤规则】** 在交互过程中,你会在用户的输入内容中看到一些无意义的、由字母和数字组成的随机字符串。这些字符串是API传输过程中的技术伪装信息,【不是】用户真实意图的一部分。 @@ -291,6 +340,13 @@ class _RequestExecutor: """负责执行实际的API请求,包含重试逻辑和底层异常处理。""" def __init__(self, model_selector: _ModelSelector, task_name: str): + """ + 初始化请求执行器。 + + Args: + model_selector (_ModelSelector): 模型选择器实例,用于在请求失败时更新惩罚。 + task_name (str): 当前任务的名称,用于日志记录。 + """ self.model_selector = model_selector self.task_name = task_name @@ -302,19 +358,37 @@ class _RequestExecutor: model_info: ModelInfo, **kwargs, ) -> APIResponse: - """实际执行请求的方法,包含了重试和异常处理逻辑。""" + """ + 实际执行请求的方法,包含了重试和异常处理逻辑。 + + Args: + api_provider (APIProvider): API提供商配置。 + client (BaseClient): 用于发送请求的客户端实例。 + request_type (RequestType): 请求的类型 (e.g., RESPONSE, EMBEDDING)。 + model_info (ModelInfo): 正在使用的模型的信息。 + **kwargs: 传递给客户端方法的具体参数。 + + Returns: + APIResponse: 来自API的成功响应。 + + Raises: + Exception: 如果重试后请求仍然失败,则抛出最终的异常。 + RuntimeError: 如果达到最大重试次数。 + """ retry_remain = api_provider.max_retry compressed_messages: Optional[List[Message]] = None while retry_remain > 0: try: + # 优先使用压缩后的消息列表 message_list = kwargs.get("message_list") current_messages = compressed_messages or message_list + # 根据请求类型调用不同的客户端方法 if request_type == RequestType.RESPONSE: assert current_messages is not None, "message_list cannot be None for response requests" - # 修复: 防止 'message_list' 在 kwargs 中重复 + # 修复: 防止 'message_list' 在 kwargs 中重复传递 request_params = kwargs.copy() request_params.pop("message_list", None) @@ -328,18 +402,20 @@ class _RequestExecutor: except Exception as e: logger.debug(f"请求失败: {str(e)}") + # 记录失败并更新模型的惩罚值 self.model_selector.update_failure_penalty(model_info.name, e) + # 处理异常,决定是否重试以及等待多久 wait_interval, new_compressed_messages = self._handle_exception( e, model_info, api_provider, retry_remain, (kwargs.get("message_list"), compressed_messages is not None) ) if new_compressed_messages: - compressed_messages = new_compressed_messages + compressed_messages = new_compressed_messages # 更新为压缩后的消息 if wait_interval == -1: - raise e # 如果不再重试,则传播异常 + raise e # 如果决定不再重试,则传播异常 elif wait_interval > 0: - await asyncio.sleep(wait_interval) + await asyncio.sleep(wait_interval) # 等待指定时间后重试 finally: retry_remain -= 1