From 6ed9349933df2bfa6fc51e0d1680dc78a37c3e9f Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Wed, 24 Sep 2025 15:00:39 +0800 Subject: [PATCH] =?UTF-8?q?refactor(llm):=20=E9=87=8D=E6=9E=84=20LLM=20?= =?UTF-8?q?=E8=AF=B7=E6=B1=82=E5=A4=84=E7=90=86=EF=BC=8C=E5=BC=95=E5=85=A5?= =?UTF-8?q?=E9=80=9A=E7=94=A8=E6=95=85=E9=9A=9C=E8=BD=AC=E7=A7=BB=E6=89=A7?= =?UTF-8?q?=E8=A1=8C=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 之前的代码里,处理文本、图片、语音的请求方法都各自为战,写了一大堆重复的故障转移逻辑,简直乱糟糟的,看得我头疼。 为了解决这个问题,我进行了一次大扫除: - 引入了一个通用的 `_execute_with_failover` 执行器,把所有“模型失败就换下一个”的脏活累活都统一管理起来了。 - 重构了所有相关的请求方法(文本、图片、语音、嵌入),让它们变得更清爽,只专注于自己的核心任务。 - 升级了 `_model_scheduler`,现在它会智能地根据实时负载给模型排队,谁最闲谁先上。那个笨笨的 `_select_model` 就被我光荣地裁掉了。 这次重构之后,代码的可维护性和健壮性都好多了,再加新功能也方便啦。哼哼,快夸我! --- src/llm_models/utils_model.py | 588 +++++++++++++++++++--------------- 1 file changed, 329 insertions(+), 259 deletions(-) diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 146e5eb46..10312f27d 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -159,7 +159,7 @@ class LLMRequest: max_tokens: Optional[int] = None, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ - 为图像生成响应 + 为图像生成响应(已集成故障转移) Args: prompt (str): 提示词 image_base64 (str): 图像的Base64编码字符串 @@ -167,71 +167,78 @@ class LLMRequest: Returns: (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ - # 标准化图片格式以确保API兼容性 normalized_format = _normalize_image_format(image_format) - # 模型选择 - start_time = time.time() - model_info, api_provider, client = self._select_model() - - # 请求体构建 - message_builder = MessageBuilder() - message_builder.add_text_content(prompt) - message_builder.add_image_content( - image_base64=image_base64, - image_format=normalized_format, - support_formats=client.get_support_image_formats(), - ) - messages = [message_builder.build()] - - # 请求并处理返回值 - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.RESPONSE, - model_info=model_info, - message_list=messages, - temperature=temperature, - max_tokens=max_tokens, - ) - content = response.content or "" - reasoning_content = response.reasoning_content or "" - tool_calls = response.tool_calls - # 从内容中提取标签的推理内容(向后兼容) - if not reasoning_content and content: - content, extracted_reasoning = self._extract_reasoning(content) - reasoning_content = extracted_reasoning - if usage := response.usage: - await llm_usage_recorder.record_usage_to_database( - model_info=model_info, - model_usage=usage, - user_id="system", - time_cost=time.time() - start_time, - request_type=self.request_type, - endpoint="/chat/completions", + async def request_logic( + model_info: ModelInfo, api_provider: APIProvider, client: BaseClient + ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: + start_time = time.time() + message_builder = MessageBuilder() + message_builder.add_text_content(prompt) + message_builder.add_image_content( + image_base64=image_base64, + image_format=normalized_format, + support_formats=client.get_support_image_formats(), ) - return content, (reasoning_content, model_info.name, tool_calls) + messages = [message_builder.build()] + + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, + model_info=model_info, + message_list=messages, + temperature=temperature, + max_tokens=max_tokens, + ) + + content = response.content or "" + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning + if usage := response.usage: + await llm_usage_recorder.record_usage_to_database( + model_info=model_info, + model_usage=usage, + user_id="system", + time_cost=time.time() - start_time, + request_type=self.request_type, + endpoint="/chat/completions", + ) + return content, (reasoning_content, model_info.name, tool_calls) + + result = await self._execute_with_failover(request_callable=request_logic, raise_on_failure=True) + if result: + return result + + # 这段代码理论上不可达,因为 raise_on_failure=True 会抛出异常 + raise RuntimeError("图片响应生成失败,所有模型均尝试失败。") async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]: """ - 为语音生成响应 + 为语音生成响应(已集成故障转移) Args: voice_base64 (str): 语音的Base64编码字符串 Returns: (Optional[str]): 生成的文本描述或None """ - # 模型选择 - model_info, api_provider, client = self._select_model() - # 请求并处理返回值 - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.AUDIO, - model_info=model_info, - audio_base64=voice_base64, - ) - return response.content or None + async def request_logic(model_info: ModelInfo, api_provider: APIProvider, client: BaseClient) -> Optional[str]: + """定义单次请求的具体逻辑""" + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.AUDIO, + model_info=model_info, + audio_base64=voice_base64, + ) + return response.content or None + + # 对于语音识别,如果所有模型都失败,我们可能不希望程序崩溃,而是返回None + result = await self._execute_with_failover(request_callable=request_logic, raise_on_failure=False) + return result async def generate_response_async( self, @@ -279,6 +286,75 @@ class LLMRequest: raise e return "所有并发请求都失败了", ("", "unknown", None) + async def _execute_with_failover( + self, + request_callable: Callable[[ModelInfo, APIProvider, BaseClient], Coroutine[Any, Any, Any]], + raise_on_failure: bool = True, + ) -> Any: + """ + 通用的故障转移执行器。 + + 它会使用智能模型调度器按最优顺序尝试模型,直到请求成功或所有模型都失败。 + + Args: + request_callable: 一个接收 (model_info, api_provider, client) 并返回协程的函数, + 用于执行实际的请求逻辑。 + raise_on_failure: 如果所有模型都失败,是否抛出异常。 + + Returns: + 请求成功时的返回结果。 + + Raises: + RuntimeError: 如果所有模型都失败且 raise_on_failure 为 True。 + """ + failed_models = set() + last_exception: Optional[Exception] = None + + # model_scheduler 现在会动态排序,所以我们只需要在循环中处理失败的模型 + while True: + model_scheduler = self._model_scheduler(failed_models) + try: + model_info, api_provider, client = next(model_scheduler) + except StopIteration: + # 没有更多可用模型了 + break + + model_name = model_info.name + logger.debug(f"正在尝试使用模型: {model_name} (剩余可用: {len(self.model_for_task.model_list) - len(failed_models)})") + + try: + # 执行传入的请求函数 + result = await request_callable(model_info, api_provider, client) + logger.debug(f"模型 '{model_name}' 成功生成回复。") + return result + + except RespNotOkException as e: + # 对于某些致命的HTTP错误(如认证失败),我们可能希望立即失败或标记该模型为永久失败 + if e.status_code in [401, 403]: + logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将永久禁用此模型在此次请求中。") + else: + logger.warning(f"模型 '{model_name}' 请求失败,HTTP状态码: {e.status_code},将尝试下一个模型。") + failed_models.add(model_name) + last_exception = e + continue + + except Exception as e: + # 捕获其他所有异常(包括超时、解析错误、运行时错误等) + logger.error(f"使用模型 '{model_name}' 时发生异常: {e},将尝试下一个模型。") + failed_models.add(model_name) + last_exception = e + continue + + # 所有模型都尝试失败 + logger.error("所有可用模型都已尝试失败。") + if raise_on_failure: + if last_exception: + raise RuntimeError("所有模型都请求失败") from last_exception + raise RuntimeError("所有模型都请求失败,且没有具体的异常信息") + + # 根据需要返回一个默认的错误结果 + return None + async def _execute_single_request( self, prompt: str, @@ -288,83 +364,67 @@ class LLMRequest: raise_when_empty: bool = True, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ - 执行单次请求,并在模型失败时按顺序切换到下一个可用模型。 + 使用通用的故障转移执行器来执行单次文本生成请求。 """ - failed_models = set() - last_exception: Optional[Exception] = None - model_scheduler = self._model_scheduler(failed_models) - - for model_info, api_provider, client in model_scheduler: + async def request_logic( + model_info: ModelInfo, api_provider: APIProvider, client: BaseClient + ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: + """定义单次请求的具体逻辑""" start_time = time.time() model_name = model_info.name - logger.debug(f"正在尝试使用模型: {model_name}") # 你不许刷屏 - try: - # 检查是否启用反截断 - # 检查是否为该模型启用反截断 - use_anti_truncation = getattr(model_info, "use_anti_truncation", False) - processed_prompt = prompt + # 检查是否启用反截断 + use_anti_truncation = getattr(model_info, "use_anti_truncation", False) + processed_prompt = prompt + if use_anti_truncation: + processed_prompt += self.anti_truncation_instruction + logger.info(f"模型 '{model_name}' (任务: '{self.task_name}') 已启用反截断功能。") + + processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider) + + message_builder = MessageBuilder() + message_builder.add_text_content(processed_prompt) + messages = [message_builder.build()] + tool_built = self._build_tool_options(tools) + + # 针对当前模型的空回复/截断重试逻辑 + empty_retry_count = 0 + max_empty_retry = api_provider.max_retry + empty_retry_interval = api_provider.retry_interval + + is_empty_reply = False + is_truncated = False + + while empty_retry_count <= max_empty_retry: + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, + model_info=model_info, + message_list=messages, + tool_options=tool_built, + temperature=temperature, + max_tokens=max_tokens, + ) + + content = response.content or "" + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning + + is_empty_reply = not tool_calls and (not content or content.strip() == "") + is_truncated = False if use_anti_truncation: - processed_prompt += self.anti_truncation_instruction - logger.info(f"模型 '{model_name}' (任务: '{self.task_name}') 已启用反截断功能。") - - processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider) - - message_builder = MessageBuilder() - message_builder.add_text_content(processed_prompt) - messages = [message_builder.build()] - tool_built = self._build_tool_options(tools) - - # 针对当前模型的空回复/截断重试逻辑 - empty_retry_count = 0 - max_empty_retry = api_provider.max_retry - empty_retry_interval = api_provider.retry_interval - - while empty_retry_count <= max_empty_retry: - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.RESPONSE, - model_info=model_info, - message_list=messages, - tool_options=tool_built, - temperature=temperature, - max_tokens=max_tokens, - ) - - content = response.content or "" - reasoning_content = response.reasoning_content or "" - tool_calls = response.tool_calls - - if not reasoning_content and content: - content, extracted_reasoning = self._extract_reasoning(content) - reasoning_content = extracted_reasoning - - is_empty_reply = not tool_calls and (not content or content.strip() == "") - is_truncated = False - if use_anti_truncation: - if content.endswith(self.end_marker): - content = content[: -len(self.end_marker)].strip() - else: - is_truncated = True - - if is_empty_reply or is_truncated: - empty_retry_count += 1 - if empty_retry_count <= max_empty_retry: - reason = "空回复" if is_empty_reply else "截断" - logger.warning( - f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成..." - ) - if empty_retry_interval > 0: - await asyncio.sleep(empty_retry_interval) - continue # 继续使用当前模型重试 - else: - # 当前模型重试次数用尽,跳出内层循环,触发外层循环切换模型 - reason = "空回复" if is_empty_reply else "截断" - logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次重试后仍然是{reason}的回复。") - raise RuntimeError(f"模型 '{model_name}' 达到最大空回复/截断重试次数") + if content.endswith(self.end_marker): + content = content[: -len(self.end_marker)].strip() + else: + is_truncated = True + if not is_empty_reply and not is_truncated: # 成功获取响应 if usage := response.usage: await llm_usage_recorder.record_usage_to_database( @@ -381,115 +441,115 @@ class LLMRequest: raise RuntimeError("生成空回复") content = "生成的响应为空" - logger.debug(f"模型 '{model_name}' 成功生成回复。") # 你也不许刷屏 return content, (reasoning_content, model_name, tool_calls) - except RespNotOkException as e: - if e.status_code in [401, 403]: - logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将尝试下一个模型。") - failed_models.add(model_name) - last_exception = e - continue # 切换到下一个模型 - else: - logger.error(f"模型 '{model_name}' 请求失败,HTTP状态码: {e.status_code}") - if raise_when_empty: - raise - # 对于其他HTTP错误,直接抛出,不再尝试其他模型 - return f"请求失败: {e}", ("", model_name, None) + # 如果代码执行到这里,说明是空回复或截断,需要重试 + empty_retry_count += 1 + if empty_retry_count <= max_empty_retry: + reason = "空回复" if is_empty_reply else "截断" + logger.warning( + f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成..." + ) + if empty_retry_interval > 0: + await asyncio.sleep(empty_retry_interval) + continue # 继续使用当前模型重试 - except RuntimeError as e: - # 捕获所有重试失败(包括空回复和网络问题) - logger.error(f"模型 '{model_name}' 在所有重试后仍然失败: {e},将尝试下一个模型。") - failed_models.add(model_name) - last_exception = e - continue # 切换到下一个模型 + # 如果循环结束,说明重试次数已用尽 + reason = "空回复" if is_empty_reply else "截断" + logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次重试后仍然是{reason}的回复。") + raise RuntimeError(f"模型 '{model_name}' 达到最大空回复/截断重试次数") - except Exception as e: - logger.error(f"使用模型 '{model_name}' 时发生未知异常: {e}") - failed_models.add(model_name) - last_exception = e - continue # 切换到下一个模型 + # 调用通用的故障转移执行器 + result = await self._execute_with_failover( + request_callable=request_logic, raise_on_failure=raise_when_empty + ) - # 所有模型都尝试失败 - logger.error("所有可用模型都已尝试失败。") - if raise_when_empty: - if last_exception: - raise RuntimeError("所有模型都请求失败") from last_exception - raise RuntimeError("所有模型都请求失败,且没有具体的异常信息") + if result: + return result + # 如果所有模型都失败了,并且不抛出异常,返回一个默认的错误信息 return "所有模型都请求失败", ("", "unknown", None) async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: - """获取嵌入向量 + """获取嵌入向量(已集成故障转移) Args: embedding_input (str): 获取嵌入的目标 Returns: (Tuple[List[float], str]): (嵌入向量,使用的模型名称) """ - # 无需构建消息体,直接使用输入文本 - start_time = time.time() - model_info, api_provider, client = self._select_model() - # 请求并处理返回值 - response = await self._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.EMBEDDING, - model_info=model_info, - embedding_input=embedding_input, - ) - - embedding = response.embedding - - if usage := response.usage: - await llm_usage_recorder.record_usage_to_database( + async def request_logic( + model_info: ModelInfo, api_provider: APIProvider, client: BaseClient + ) -> Tuple[List[float], str]: + """定义单次请求的具体逻辑""" + start_time = time.time() + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.EMBEDDING, model_info=model_info, - time_cost=time.time() - start_time, - model_usage=usage, - user_id="system", - request_type=self.request_type, - endpoint="/embeddings", + embedding_input=embedding_input, ) - if not embedding: - raise RuntimeError("获取embedding失败") + embedding = response.embedding + if not embedding: + raise RuntimeError(f"模型 '{model_info.name}'未能返回 embedding。") - return embedding, model_info.name + if usage := response.usage: + await llm_usage_recorder.record_usage_to_database( + model_info=model_info, + time_cost=time.time() - start_time, + model_usage=usage, + user_id="system", + request_type=self.request_type, + endpoint="/embeddings", + ) - def _model_scheduler(self, failed_models: set) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]: + return embedding, model_info.name + + result = await self._execute_with_failover(request_callable=request_logic, raise_on_failure=True) + if result: + return result + + # 这段代码理论上不可达,因为 raise_on_failure=True 会抛出异常 + raise RuntimeError("获取 embedding 失败,所有模型均尝试失败。") + + def _model_scheduler( + self, failed_models: set | None = None + ) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]: """ - 一个模型调度器,按顺序提供模型,并跳过已失败的模型。 + 一个智能模型调度器,根据实时负载动态排序并提供模型,同时跳过已失败的模型。 """ - for model_name in self.model_for_task.model_list: - if model_name in failed_models: - continue + # sourcery skip: class-extract-method + if failed_models is None: + failed_models = set() + # 1. 筛选出所有未失败的可用模型 + available_models = [name for name in self.model_for_task.model_list if name not in failed_models] + + # 2. 根据负载均衡算法对可用模型进行排序 + # key: total_tokens + penalty * 300 + usage_penalty * 1000 + sorted_models = sorted( + available_models, + key=lambda name: self.model_usage[name][0] + + self.model_usage[name][1] * 300 + + self.model_usage[name][2] * 1000, + ) + + if not sorted_models: + logger.warning("所有模型都已失败或不可用,调度器无法提供任何模型。") + return + + logger.debug(f"模型调度顺序: {', '.join(sorted_models)}") + + # 3. 按最优顺序 yield 模型信息 + for model_name in sorted_models: model_info = model_config.get_model_info(model_name) api_provider = model_config.get_provider(model_info.api_provider) force_new_client = self.request_type == "embedding" client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) - yield model_info, api_provider, client - def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]: - """ - 根据总tokens和惩罚值选择的模型 (负载均衡) - """ - least_used_model_name = min( - self.model_usage, - key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000, - ) - model_info = model_config.get_model_info(least_used_model_name) - api_provider = model_config.get_provider(model_info.api_provider) - - # 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题 - force_new_client = self.request_type == "embedding" - client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) - logger.debug(f"选择请求模型: {model_info.name}") - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用 - return model_info, api_provider, client - async def _execute_request( self, api_provider: APIProvider, @@ -513,63 +573,73 @@ class LLMRequest: """ retry_remain = api_provider.max_retry compressed_messages: Optional[List[Message]] = None - while retry_remain > 0: - try: - if request_type == RequestType.RESPONSE: - assert message_list is not None, "message_list cannot be None for response requests" - return await client.get_response( - model_info=model_info, - message_list=(compressed_messages or message_list), - tool_options=tool_options, - max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens, - temperature=self.model_for_task.temperature if temperature is None else temperature, - response_format=response_format, - stream_response_handler=stream_response_handler, - async_response_parser=async_response_parser, - extra_params=model_info.extra_params, - ) - elif request_type == RequestType.EMBEDDING: - assert embedding_input, "embedding_input cannot be empty for embedding requests" - return await client.get_embedding( - model_info=model_info, - embedding_input=embedding_input, - extra_params=model_info.extra_params, - ) - elif request_type == RequestType.AUDIO: - assert audio_base64 is not None, "audio_base64 cannot be None for audio requests" - return await client.get_audio_transcriptions( - model_info=model_info, - audio_base64=audio_base64, - extra_params=model_info.extra_params, - ) - except Exception as e: - logger.debug(f"请求失败: {str(e)}") - # 处理异常 - total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty) - wait_interval, compressed_messages = self._default_exception_handler( - e, - self.task_name, - model_info=model_info, - api_provider=api_provider, - remain_try=retry_remain, - retry_interval=api_provider.retry_interval, - messages=(message_list, compressed_messages is not None) if message_list else None, - ) - - if wait_interval == -1: - retry_remain = 0 # 不再重试 - elif wait_interval > 0: - logger.info(f"等待 {wait_interval} 秒后重试...") - await asyncio.sleep(wait_interval) - finally: - # 放在finally防止死循环 - retry_remain -= 1 + # 增加使用惩罚值,标记该模型正在被尝试 total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) # 使用结束,减少使用惩罚值 - logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") - raise RuntimeError("请求失败,已达到最大重试次数") + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) + + try: + while retry_remain > 0: + try: + if request_type == RequestType.RESPONSE: + assert message_list is not None, "message_list cannot be None for response requests" + return await client.get_response( + model_info=model_info, + message_list=(compressed_messages or message_list), + tool_options=tool_options, + max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens, + temperature=self.model_for_task.temperature if temperature is None else temperature, + response_format=response_format, + stream_response_handler=stream_response_handler, + async_response_parser=async_response_parser, + extra_params=model_info.extra_params, + ) + elif request_type == RequestType.EMBEDDING: + assert embedding_input, "embedding_input cannot be empty for embedding requests" + return await client.get_embedding( + model_info=model_info, + embedding_input=embedding_input, + extra_params=model_info.extra_params, + ) + elif request_type == RequestType.AUDIO: + assert audio_base64 is not None, "audio_base64 cannot be None for audio requests" + return await client.get_audio_transcriptions( + model_info=model_info, + audio_base64=audio_base64, + extra_params=model_info.extra_params, + ) + except Exception as e: + logger.debug(f"请求失败: {str(e)}") + # 处理异常 + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty) + + wait_interval, compressed_messages = self._default_exception_handler( + e, + self.task_name, + model_info=model_info, + api_provider=api_provider, + remain_try=retry_remain, + retry_interval=api_provider.retry_interval, + messages=(message_list, compressed_messages is not None) if message_list else None, + ) + + if wait_interval == -1: + retry_remain = 0 # 不再重试 + elif wait_interval > 0: + logger.info(f"等待 {wait_interval} 秒后重试...") + await asyncio.sleep(wait_interval) + finally: + # 放在finally防止死循环 + retry_remain -= 1 + + # 当请求完全结束(无论是成功还是所有重试都失败),都将在此处处理 + logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") + raise RuntimeError("请求失败,已达到最大重试次数") + finally: + # 无论请求成功或失败,最终都将使用惩罚值减回去 + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) def _default_exception_handler( self,