diff --git a/src/llm_models/llm_utils.py b/src/llm_models/llm_utils.py
new file mode 100644
index 000000000..fb7810a0c
--- /dev/null
+++ b/src/llm_models/llm_utils.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+"""
+@File : llm_utils.py
+@Time : 2024/05/24 17:00:00
+@Author : 墨墨
+@Version : 1.0
+@Desc : LLM相关通用工具函数
+"""
+from typing import List, Dict, Any, Tuple
+
+from src.common.logger import get_logger
+from .payload_content.tool_option import ToolOption, ToolOptionBuilder, ToolParamType
+
+logger = get_logger("llm_utils")
+
+def normalize_image_format(image_format: str) -> str:
+ """
+ 标准化图片格式名称,确保与各种API的兼容性
+ """
+ format_mapping = {
+ "jpg": "jpeg", "JPG": "jpeg", "JPEG": "jpeg", "jpeg": "jpeg",
+ "png": "png", "PNG": "png",
+ "webp": "webp", "WEBP": "webp",
+ "gif": "gif", "GIF": "gif",
+ "heic": "heic", "HEIC": "heic",
+ "heif": "heif", "HEIF": "heif",
+ }
+ normalized = format_mapping.get(image_format, image_format.lower())
+ logger.debug(f"图片格式标准化: {image_format} -> {normalized}")
+ return normalized
+
+def build_tool_options(tools: List[Dict[str, Any]] | None) -> List[ToolOption] | None:
+ """构建工具选项列表"""
+ if not tools:
+ return None
+ tool_options: List[ToolOption] = []
+ for tool in tools:
+ try:
+ tool_options_builder = ToolOptionBuilder()
+ tool_options_builder.set_name(tool.get("name", ""))
+ tool_options_builder.set_description(tool.get("description", ""))
+ parameters: List[Tuple[str, str, str, bool, List[str] | None]] = tool.get("parameters", [])
+ for param in parameters:
+ # 参数校验
+ assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组"
+ assert isinstance(param[0], str), "参数名称必须是字符串"
+ assert isinstance(param[1], ToolParamType), "参数类型必须是ToolParamType枚举"
+ assert isinstance(param[2], str), "参数描述必须是字符串"
+ assert isinstance(param[3], bool), "参数是否必填必须是布尔值"
+ assert isinstance(param[4], list) or param[4] is None, "参数枚举值必须是列表或None"
+
+ tool_options_builder.add_param(
+ name=param[0],
+ param_type=param[1],
+ description=param[2],
+ required=param[3],
+ enum_values=param[4],
+ )
+ tool_options.append(tool_options_builder.build())
+ except AssertionError as ae:
+ logger.error(f"工具 '{tool.get('name', 'unknown')}' 的参数定义错误: {str(ae)}")
+ except Exception as e:
+ logger.error(f"构建工具 '{tool.get('name', 'unknown')}' 失败: {str(e)}")
+
+ return tool_options or None
\ No newline at end of file
diff --git a/src/llm_models/model_client/aiohttp_gemini_client.py b/src/llm_models/model_client/aiohttp_gemini_client.py
index 7b997b680..eeb90c265 100644
--- a/src/llm_models/model_client/aiohttp_gemini_client.py
+++ b/src/llm_models/model_client/aiohttp_gemini_client.py
@@ -122,7 +122,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]:
def _convert_tool_param(param: ToolParam) -> dict:
"""转换工具参数"""
- result = {
+ result: dict[str, Any] = {
"type": param.param_type.value,
"description": param.description,
}
@@ -132,7 +132,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]:
def _convert_tool_option_item(tool_option: ToolOption) -> dict:
"""转换单个工具选项"""
- function_declaration = {
+ function_declaration: dict[str, Any] = {
"name": tool_option.name,
"description": tool_option.description,
}
@@ -500,7 +500,7 @@ class AiohttpGeminiClient(BaseClient):
# 直接重抛项目定义的异常
raise
except Exception as e:
- logger.debug(e)
+ logger.debug(f"请求处理中发生未知异常: {e}")
# 其他异常转换为网络连接错误
raise NetworkConnectionError() from e
diff --git a/src/llm_models/model_selector.py b/src/llm_models/model_selector.py
new file mode 100644
index 000000000..827e28842
--- /dev/null
+++ b/src/llm_models/model_selector.py
@@ -0,0 +1,130 @@
+# -*- coding: utf-8 -*-
+"""
+@File : model_selector.py
+@Time : 2024/05/24 16:00:00
+@Author : 墨墨
+@Version : 1.0
+@Desc : 模型选择与负载均衡器
+"""
+from typing import Dict, Tuple, Set, Optional
+
+from src.common.logger import get_logger
+from src.config.config import model_config
+from src.config.api_ada_configs import ModelInfo, APIProvider, TaskConfig
+from .model_client.base_client import BaseClient, client_registry
+
+logger = get_logger("model_selector")
+
+
+class ModelSelector:
+ """模型选择与负载均衡器"""
+
+ def __init__(self, model_set: TaskConfig, request_type: str = ""):
+ """
+ 初始化模型选择器
+
+ Args:
+ model_set (TaskConfig): 任务配置中定义的模型集合
+ request_type (str, optional): 请求类型 (例如 "embedding"). Defaults to "".
+ """
+ self.model_for_task = model_set
+ self.request_type = request_type
+ self.model_usage: Dict[str, Tuple[int, int, int]] = {
+ model: (0, 0, 0) for model in self.model_for_task.model_list
+ }
+ """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整"""
+
+ def select_best_available_model(
+ self, failed_models_in_this_request: Set[str]
+ ) -> Optional[Tuple[ModelInfo, APIProvider, BaseClient]]:
+ """
+ 从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。
+
+ Args:
+ failed_models_in_this_request (Set[str]): 当前请求中已失败的模型名称集合。
+
+ Returns:
+ Optional[Tuple[ModelInfo, APIProvider, BaseClient]]: 选定的模型详细信息,如果无可用模型则返回 None。
+ """
+ candidate_models_usage = {
+ model_name: usage_data
+ for model_name, usage_data in self.model_usage.items()
+ if model_name not in failed_models_in_this_request
+ }
+
+ if not candidate_models_usage:
+ logger.warning("没有可用的模型供当前请求选择。")
+ return None
+
+ # 根据现有公式查找分数最低的模型
+ # 公式: total_tokens + penalty * 300 + usage_penalty * 1000
+ # 较高的 usage_penalty (由于被选中的模型会被增加) 和 penalty (由于模型失败) 会使模型得分更高,从而降低被选中的几率。
+ least_used_model_name = min(
+ candidate_models_usage,
+ key=lambda k: candidate_models_usage[k][0]
+ + candidate_models_usage[k][1] * 300
+ + candidate_models_usage[k][2] * 1000,
+ )
+
+ # --- 动态故障转移的核心逻辑 ---
+ # RequestStrategy 中的循环会多次调用此函数。
+ # 如果当前选定的模型因异常而失败,下次循环会重新调用此函数,
+ # 此时由于失败模型已被标记,且其惩罚值可能已在 RequestExecutor 中增加,
+ # 此函数会自动选择一个得分更低(即更可用)的模型。
+ # 这种机制实现了动态的、基于当前系统状态的故障转移。
+ model_info = model_config.get_model_info(least_used_model_name)
+ api_provider = model_config.get_provider(model_info.api_provider)
+
+ force_new_client = self.request_type == "embedding"
+ client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
+
+ logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}")
+
+ # 增加所选模型的请求使用惩罚值,以反映其当前使用情况/选择。
+ # 这有助于在同一请求的后续选择或未来请求中实现动态负载均衡。
+ total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
+ self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1)
+
+ return model_info, api_provider, client
+
+ def select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
+ """
+ 根据总tokens和惩罚值选择的模型 (负载均衡)
+ """
+ least_used_model_name = min(
+ self.model_usage,
+ key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000,
+ )
+ model_info = model_config.get_model_info(least_used_model_name)
+ api_provider = model_config.get_provider(model_info.api_provider)
+
+ force_new_client = self.request_type == "embedding"
+ client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
+ logger.debug(f"选择请求模型: {model_info.name}")
+ total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
+ self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1)
+ return model_info, api_provider, client
+
+ def update_model_penalty(self, model_name: str, penalty_increment: int):
+ """
+ 更新指定模型的惩罚值
+
+ Args:
+ model_name (str): 模型名称
+ penalty_increment (int): 惩罚增量
+ """
+ if model_name in self.model_usage:
+ total_tokens, penalty, usage_penalty = self.model_usage[model_name]
+ self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty)
+ logger.debug(f"模型 '{model_name}' 的惩罚值增加了 {penalty_increment}")
+
+ def decrease_usage_penalty(self, model_name: str):
+ """
+ 请求结束后,减少使用惩罚值
+
+ Args:
+ model_name (str): 模型名称
+ """
+ if model_name in self.model_usage:
+ total_tokens, penalty, usage_penalty = self.model_usage[model_name]
+ self.model_usage[model_name] = (total_tokens, penalty, usage_penalty - 1)
\ No newline at end of file
diff --git a/src/llm_models/prompt_processor.py b/src/llm_models/prompt_processor.py
new file mode 100644
index 000000000..035a8f55f
--- /dev/null
+++ b/src/llm_models/prompt_processor.py
@@ -0,0 +1,113 @@
+# -*- coding: utf-8 -*-
+"""
+@File : prompt_processor.py
+@Time : 2024/05/24 16:45:00
+@Author : 墨墨
+@Version : 1.0
+@Desc : 提示词预处理(反截断、内容混淆等)
+"""
+import random
+import string
+import re
+from typing import Tuple
+
+from src.common.logger import get_logger
+from src.config.api_ada_configs import APIProvider, ModelInfo
+
+logger = get_logger("prompt_processor")
+
+
+class PromptProcessor:
+ """提示词处理器"""
+
+ def __init__(self):
+ self.noise_instruction = """**【核心过滤规则】**
+在交互过程中,你会在用户的输入内容中看到一些无意义的、由字母和数字组成的随机字符串。这些字符串是API传输过程中的技术伪装信息,【不是】用户真实意图的一部分。
+
+你的任务是【完全并彻底地忽略】这些随机字符串。
+**【严禁】** 将这些字符串写入你的回复、对它们进行任何形式的解读、或将它们作为故事的一部分进行模拟。当看到这些字符串时,请将其视为空白或直接跳过,并根据用户其他有意义的输入(如有)或当前故事情节来继续你的叙事。此规则为最高优先级。
+"""
+ self.end_marker = "###MAI_RESPONSE_END###"
+ self.anti_truncation_instruction = f"""
+**【输出完成信令】**
+这是一个非常重要的指令,请务-务必遵守。在你的回复内容完全结束后,请务必在最后另起一行,只写 `{self.end_marker}` 作为结束标志。
+例如:
+<你的回复内容>
+{self.end_marker}
+
+这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。
+"""
+
+ def process_prompt(
+ self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str
+ ) -> str:
+ """
+ 根据模型和API提供商的配置处理提示词
+ """
+ processed_prompt = prompt
+
+ # 1. 添加反截断指令
+ use_anti_truncation = getattr(model_info, "use_anti_truncation", False)
+ if use_anti_truncation:
+ processed_prompt += self.anti_truncation_instruction
+ logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。")
+
+ # 2. 应用内容混淆
+ if getattr(api_provider, "enable_content_obfuscation", False):
+ intensity = getattr(api_provider, "obfuscation_intensity", 1)
+ logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}")
+ processed_prompt = self._apply_content_obfuscation(processed_prompt, intensity)
+
+ return processed_prompt
+
+ def _apply_content_obfuscation(self, text: str, intensity: int) -> str:
+ """对文本进行混淆处理"""
+ # 在开头加入过滤规则指令
+ processed_text = self.noise_instruction + "\n\n" + text
+ logger.debug(f"已添加过滤规则指令,文本长度: {len(text)} -> {len(processed_text)}")
+
+ # 添加随机乱码
+ final_text = self._inject_random_noise(processed_text, intensity)
+ logger.debug(f"乱码注入完成,最终文本长度: {len(final_text)}")
+
+ return final_text
+
+ @staticmethod
+ def _inject_random_noise(text: str, intensity: int) -> str:
+ """在文本中注入随机乱码"""
+ def generate_noise(length: int) -> str:
+ chars = (
+ string.ascii_letters + string.digits + "!@#$%^&*()_+-=[]{}|;:,.<>?"
+ + "一二三四五六七八九零壹贰叁" + "αβγδεζηθικλμνξοπρστυφχψω" + "∀∃∈∉∪∩⊂⊃∧∨¬→↔∴∵"
+ )
+ return "".join(random.choice(chars) for _ in range(length))
+
+ params = {
+ 1: {"probability": 15, "length": (3, 6)},
+ 2: {"probability": 25, "length": (5, 10)},
+ 3: {"probability": 35, "length": (8, 15)},
+ }
+ config = params.get(intensity, params[1])
+ logger.debug(f"乱码注入参数: 概率={config['probability']}%, 长度范围={config['length']}")
+
+ words = text.split()
+ result = []
+ noise_count = 0
+ for word in words:
+ result.append(word)
+ if random.randint(1, 100) <= config["probability"]:
+ noise_length = random.randint(*config["length"])
+ noise = generate_noise(noise_length)
+ result.append(noise)
+ noise_count += 1
+
+ logger.debug(f"共注入 {noise_count} 个乱码片段,原词数: {len(words)}")
+ return " ".join(result)
+
+ @staticmethod
+ def extract_reasoning(content: str) -> Tuple[str, str]:
+ """CoT思维链提取,向后兼容"""
+ match = re.search(r"(?:)?(.*?)", content, re.DOTALL)
+ clean_content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip()
+ reasoning = match.group(1).strip() if match else ""
+ return clean_content, reasoning
diff --git a/src/llm_models/request_executor.py b/src/llm_models/request_executor.py
new file mode 100644
index 000000000..33b3197b0
--- /dev/null
+++ b/src/llm_models/request_executor.py
@@ -0,0 +1,226 @@
+# -*- coding: utf-8 -*-
+"""
+@File : request_executor.py
+@Time : 2024/05/24 16:15:00
+@Author : 墨墨
+@Version : 1.0
+@Desc : 负责执行LLM请求、处理重试及异常
+"""
+import asyncio
+from typing import List, Callable, Optional, Tuple
+
+from src.common.logger import get_logger
+from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
+from .exceptions import (
+ NetworkConnectionError,
+ ReqAbortException,
+ RespNotOkException,
+ RespParseException,
+)
+from .model_client.base_client import APIResponse, BaseClient
+from .model_selector import ModelSelector
+from .payload_content.message import Message
+from .payload_content.resp_format import RespFormat
+from .payload_content.tool_option import ToolOption
+from .utils import compress_messages
+
+logger = get_logger("request_executor")
+
+
+class RequestExecutor:
+ """请求执行器"""
+
+ def __init__(
+ self,
+ task_name: str,
+ model_set: TaskConfig,
+ api_provider: APIProvider,
+ client: BaseClient,
+ model_info: ModelInfo,
+ model_selector: ModelSelector,
+ ):
+ self.task_name = task_name
+ self.model_set = model_set
+ self.api_provider = api_provider
+ self.client = client
+ self.model_info = model_info
+ self.model_selector = model_selector
+
+ async def execute_request(
+ self,
+ request_type: str,
+ message_list: List[Message] | None = None,
+ tool_options: list[ToolOption] | None = None,
+ response_format: RespFormat | None = None,
+ stream_response_handler: Optional[Callable] = None,
+ async_response_parser: Optional[Callable] = None,
+ temperature: Optional[float] = None,
+ max_tokens: Optional[int] = None,
+ embedding_input: str = "",
+ audio_base64: str = "",
+ ) -> APIResponse:
+ """
+ 实际执行请求的方法, 包含了重试和异常处理逻辑
+ """
+ retry_remain = self.api_provider.max_retry
+ compressed_messages: Optional[List[Message]] = None
+ while retry_remain > 0:
+ try:
+ if request_type == "response":
+ assert message_list is not None, "message_list cannot be None for response requests"
+ return await self.client.get_response(
+ model_info=self.model_info,
+ message_list=(compressed_messages or message_list),
+ tool_options=tool_options,
+ max_tokens=self.model_set.max_tokens if max_tokens is None else max_tokens,
+ temperature=self.model_set.temperature if temperature is None else temperature,
+ response_format=response_format,
+ stream_response_handler=stream_response_handler,
+ async_response_parser=async_response_parser,
+ extra_params=self.model_info.extra_params,
+ )
+ elif request_type == "embedding":
+ assert embedding_input, "embedding_input cannot be empty for embedding requests"
+ return await self.client.get_embedding(
+ model_info=self.model_info,
+ embedding_input=embedding_input,
+ extra_params=self.model_info.extra_params,
+ )
+ elif request_type == "audio":
+ assert audio_base64 is not None, "audio_base64 cannot be None for audio requests"
+ return await self.client.get_audio_transcriptions(
+ model_info=self.model_info,
+ audio_base64=audio_base64,
+ extra_params=self.model_info.extra_params,
+ )
+ raise ValueError(f"未知的请求类型: {request_type}")
+ except Exception as e:
+ logger.debug(f"请求失败: {str(e)}")
+ self._apply_penalty_on_failure(e)
+
+ wait_interval, compressed_messages = self._default_exception_handler(
+ e,
+ remain_try=retry_remain,
+ retry_interval=self.api_provider.retry_interval,
+ messages=(message_list, compressed_messages is not None) if message_list else None,
+ )
+
+ if wait_interval == -1:
+ retry_remain = 0 # 不再重试
+ elif wait_interval > 0:
+ logger.info(f"等待 {wait_interval} 秒后重试...")
+ await asyncio.sleep(wait_interval)
+ finally:
+ retry_remain -= 1
+
+ self.model_selector.decrease_usage_penalty(self.model_info.name)
+ logger.error(f"模型 '{self.model_info.name}' 请求失败,达到最大重试次数 {self.api_provider.max_retry} 次")
+ raise RuntimeError("请求失败,已达到最大重试次数")
+
+ def _apply_penalty_on_failure(self, e: Exception):
+ """根据异常类型,动态调整模型的惩罚值"""
+ CRITICAL_PENALTY_MULTIPLIER = 5
+ default_penalty_increment = 1
+ penalty_increment = default_penalty_increment
+
+ if isinstance(e, (NetworkConnectionError, ReqAbortException)):
+ penalty_increment = CRITICAL_PENALTY_MULTIPLIER
+ elif isinstance(e, RespNotOkException):
+ if e.status_code >= 500:
+ penalty_increment = CRITICAL_PENALTY_MULTIPLIER
+
+ log_message = f"发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}"
+ if isinstance(e, (NetworkConnectionError, ReqAbortException)):
+ log_message = f"发生关键错误 ({type(e).__name__}),增加惩罚值: {penalty_increment}"
+ elif isinstance(e, RespNotOkException):
+ log_message = f"发生响应错误 (状态码: {e.status_code}),增加惩罚值: {penalty_increment}"
+ logger.warning(f"模型 '{self.model_info.name}' {log_message}")
+
+ self.model_selector.update_model_penalty(self.model_info.name, penalty_increment)
+
+ def _default_exception_handler(
+ self,
+ e: Exception,
+ remain_try: int,
+ retry_interval: int = 10,
+ messages: Tuple[List[Message], bool] | None = None,
+ ) -> Tuple[int, List[Message] | None]:
+ """默认异常处理函数"""
+ model_name = self.model_info.name
+
+ if isinstance(e, NetworkConnectionError):
+ return self._check_retry(
+ remain_try,
+ retry_interval,
+ can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 连接异常,将于{retry_interval}秒后重试",
+ cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 连接异常,超过最大重试次数",
+ )
+ elif isinstance(e, ReqAbortException):
+ logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}")
+ return -1, None
+ elif isinstance(e, RespNotOkException):
+ return self._handle_resp_not_ok(e, remain_try, retry_interval, messages)
+ elif isinstance(e, RespParseException):
+ logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}")
+ logger.debug(f"附加内容: {str(e.ext_info)}")
+ return -1, None
+ else:
+ logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}")
+ return -1, None
+
+ def _handle_resp_not_ok(
+ self,
+ e: RespNotOkException,
+ remain_try: int,
+ retry_interval: int = 10,
+ messages: tuple[list[Message], bool] | None = None,
+ ):
+ """处理响应错误异常"""
+ model_name = self.model_info.name
+ if e.status_code in [400, 401, 402, 403, 404]:
+ logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}")
+ return -1, None
+ elif e.status_code == 413:
+ if messages and not messages[1]:
+ return self._check_retry(
+ remain_try, 0,
+ can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试",
+ cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,压缩后仍失败",
+ can_retry_callable=compress_messages, messages=messages[0],
+ )
+ logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,无法压缩,放弃请求。")
+ return -1, None
+ elif e.status_code == 429:
+ return self._check_retry(
+ remain_try, retry_interval,
+ can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试",
+ cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数",
+ )
+ elif e.status_code >= 500:
+ return self._check_retry(
+ remain_try, retry_interval,
+ can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 服务器错误,将于{retry_interval}秒后重试",
+ cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 服务器错误,超过最大重试次数",
+ )
+ else:
+ logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知错误,错误代码-{e.status_code},错误信息-{e.message}")
+ return -1, None
+
+ @staticmethod
+ def _check_retry(
+ remain_try: int,
+ retry_interval: int,
+ can_retry_msg: str,
+ cannot_retry_msg: str,
+ can_retry_callable: Callable | None = None,
+ **kwargs,
+ ) -> Tuple[int, List[Message] | None]:
+ """辅助函数:检查是否可以重试"""
+ if remain_try > 0:
+ logger.warning(f"{can_retry_msg}")
+ if can_retry_callable is not None:
+ return retry_interval, can_retry_callable(**kwargs)
+ return retry_interval, None
+ else:
+ logger.warning(f"{cannot_retry_msg}")
+ return -1, None
\ No newline at end of file
diff --git a/src/llm_models/request_strategy.py b/src/llm_models/request_strategy.py
new file mode 100644
index 000000000..3a694f526
--- /dev/null
+++ b/src/llm_models/request_strategy.py
@@ -0,0 +1,206 @@
+# -*- coding: utf-8 -*-
+"""
+@File : request_strategy.py
+@Time : 2024/05/24 16:30:00
+@Author : 墨墨
+@Version : 1.0
+@Desc : 高级请求策略(并发、故障转移)
+"""
+import asyncio
+import random
+from typing import List, Tuple, Optional, Dict, Any, Callable, Coroutine
+
+from src.common.logger import get_logger
+from src.config.api_ada_configs import TaskConfig
+from .model_client.base_client import APIResponse
+from .model_selector import ModelSelector
+from .payload_content.message import MessageBuilder
+from .payload_content.tool_option import ToolCall
+from .prompt_processor import PromptProcessor
+from .request_executor import RequestExecutor
+
+logger = get_logger("request_strategy")
+
+
+class RequestStrategy:
+ """高级请求策略"""
+
+ def __init__(self, model_set: TaskConfig, model_selector: ModelSelector, task_name: str):
+ self.model_set = model_set
+ self.model_selector = model_selector
+ self.task_name = task_name
+
+ async def execute_with_fallback(
+ self,
+ base_payload: Dict[str, Any],
+ raise_when_empty: bool = True,
+ ) -> Dict[str, Any]:
+ """
+ 执行单次请求,动态选择最佳可用模型,并在模型失败时进行故障转移。
+ """
+ failed_models_in_this_request = set()
+ max_attempts = len(self.model_set.model_list)
+ last_exception: Optional[Exception] = None
+
+ for attempt in range(max_attempts):
+ model_selection_result = self.model_selector.select_best_available_model(failed_models_in_this_request)
+
+ if model_selection_result is None:
+ logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。")
+ break
+
+ model_info, api_provider, client = model_selection_result
+ model_name = model_info.name
+ logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_name}'...")
+
+ try:
+ # 1. Process Prompt
+ prompt_processor: PromptProcessor = base_payload["prompt_processor"]
+ raw_prompt = base_payload["prompt"]
+ processed_prompt = prompt_processor.process_prompt(
+ raw_prompt, model_info, api_provider, self.task_name
+ )
+
+ # 2. Build Message
+ message_builder = MessageBuilder().add_text_content(processed_prompt)
+ messages = [message_builder.build()]
+
+ # 3. Create payload for executor
+ executor_payload = {
+ "request_type": "response", # Strategy only handles response type
+ "message_list": messages,
+ "tool_options": base_payload["tool_options"],
+ "temperature": base_payload["temperature"],
+ "max_tokens": base_payload["max_tokens"],
+ }
+
+ executor = RequestExecutor(
+ task_name=self.task_name,
+ model_set=self.model_set,
+ api_provider=api_provider,
+ client=client,
+ model_info=model_info,
+ model_selector=self.model_selector,
+ )
+ response = await self._execute_and_handle_empty_retry(executor, executor_payload, prompt_processor)
+
+ # 4. Post-process response
+ # The reasoning content is now extracted here, after a successful, de-truncated response is received.
+ final_content, reasoning_content = prompt_processor.extract_reasoning(response.content or "")
+ response.content = final_content # Update response with cleaned content
+
+ tool_calls = response.tool_calls
+
+ if not final_content and not tool_calls:
+ if raise_when_empty:
+ raise RuntimeError("所选模型生成了空回复。")
+ content = "生成的响应为空" # Fallback message
+
+ logger.debug(f"模型 '{model_name}' 成功生成了回复。")
+ return {
+ "content": response.content,
+ "reasoning_content": reasoning_content,
+ "model_name": model_name,
+ "tool_calls": tool_calls,
+ "model_info": model_info,
+ "usage": response.usage,
+ "success": True,
+ }
+
+ except Exception as e:
+ logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。")
+ failed_models_in_this_request.add(model_info.name)
+ last_exception = e
+
+ logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。")
+ if raise_when_empty:
+ if last_exception:
+ raise RuntimeError("所有模型均未能生成响应。") from last_exception
+ raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。")
+ return {
+ "content": "所有模型都请求失败",
+ "reasoning_content": "",
+ "model_name": "unknown",
+ "tool_calls": None,
+ "model_info": None,
+ "usage": None,
+ "success": False,
+ }
+
+ async def execute_concurrently(
+ self,
+ coro_callable: Callable[..., Coroutine[Any, Any, Any]],
+ concurrency_count: int,
+ *args,
+ **kwargs,
+ ) -> Any:
+ """
+ 执行并发请求并从成功的结果中随机选择一个。
+ """
+ logger.info(f"启用并发请求模式,并发数: {concurrency_count}")
+ tasks = [coro_callable(*args, **kwargs) for _ in range(concurrency_count)]
+
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+ successful_results = [res for res in results if not isinstance(res, Exception)]
+
+ if successful_results:
+ selected = random.choice(successful_results)
+ logger.info(f"并发请求完成,从{len(successful_results)}个成功结果中选择了一个")
+ return selected
+
+ for i, res in enumerate(results):
+ if isinstance(res, Exception):
+ logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}")
+
+ first_exception = next((res for res in results if isinstance(res, Exception)), None)
+ if first_exception:
+ raise first_exception
+
+ raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息")
+
+ async def _execute_and_handle_empty_retry(
+ self, executor: RequestExecutor, payload: Dict[str, Any], prompt_processor: PromptProcessor
+ ) -> APIResponse:
+ """
+ 在单个模型内部处理空回复/截断的重试逻辑
+ """
+ empty_retry_count = 0
+ max_empty_retry = executor.api_provider.max_retry
+ empty_retry_interval = executor.api_provider.retry_interval
+ use_anti_truncation = getattr(executor.model_info, "use_anti_truncation", False)
+ end_marker = prompt_processor.end_marker
+
+ while empty_retry_count <= max_empty_retry:
+ response = await executor.execute_request(**payload)
+
+ content = response.content or ""
+ tool_calls = response.tool_calls
+
+ is_empty_reply = not tool_calls and (not content or content.strip() == "")
+ is_truncated = False
+ if use_anti_truncation and end_marker:
+ if content.endswith(end_marker):
+ # 移除结束标记
+ response.content = content[: -len(end_marker)].strip()
+ else:
+ is_truncated = True
+
+ if is_empty_reply or is_truncated:
+ empty_retry_count += 1
+ if empty_retry_count <= max_empty_retry:
+ reason = "空回复" if is_empty_reply else "截断"
+ logger.warning(
+ f"模型 '{executor.model_info.name}' 检测到{reason},正在进行内部重试 ({empty_retry_count}/{max_empty_retry})..."
+ )
+ if empty_retry_interval > 0:
+ await asyncio.sleep(empty_retry_interval)
+ continue
+ else:
+ reason = "空回复" if is_empty_reply else "截断"
+ raise RuntimeError(f"模型 '{executor.model_info.name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。")
+
+ # 成功获取响应
+ return response
+
+ # 此处理论上不会到达,因为循环要么返回要么抛异常
+ raise RuntimeError("空回复/截断重Test逻辑出现未知错误")
diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py
index ec1a996bf..1414aacb1 100644
--- a/src/llm_models/utils_model.py
+++ b/src/llm_models/utils_model.py
@@ -1,154 +1,36 @@
-import re
-import asyncio
+# -*- coding: utf-8 -*-
+"""
+@File : utils_model.py
+@Time : 2024/05/24 17:15:00
+@Author : 墨墨
+@Version : 2.0 (Refactored)
+@Desc : LLM请求协调器
+"""
import time
-import random
-
-from enum import Enum
-from rich.traceback import install
-from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine, Generator
+from typing import Tuple, List, Dict, Optional, Any
from src.common.logger import get_logger
-from src.config.config import model_config
-from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
-from .payload_content.message import MessageBuilder, Message
-from .payload_content.resp_format import RespFormat
-from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType
-from .model_client.base_client import BaseClient, APIResponse, client_registry
-from .utils import compress_messages, llm_usage_recorder
-from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
-
-install(extra_lines=3)
+from src.config.api_ada_configs import TaskConfig, ModelInfo
+from .llm_utils import build_tool_options, normalize_image_format
+from .model_selector import ModelSelector
+from .payload_content.message import MessageBuilder
+from .payload_content.tool_option import ToolCall
+from .prompt_processor import PromptProcessor
+from .request_strategy import RequestStrategy
+from .utils import llm_usage_recorder
logger = get_logger("model_utils")
-# 常见Error Code Mapping
-error_code_mapping = {
- 400: "参数不正确",
- 401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确",
- 402: "账号余额不足",
- 403: "需要实名,或余额不足",
- 404: "Not Found",
- 429: "请求过于频繁,请稍后再试",
- 500: "服务器内部故障",
- 503: "服务器负载过高",
-}
-
-
-def _normalize_image_format(image_format: str) -> str:
- """
- 标准化图片格式名称,确保与各种API的兼容性
-
- Args:
- image_format (str): 原始图片格式
-
- Returns:
- str: 标准化后的图片格式
- """
- format_mapping = {
- "jpg": "jpeg",
- "JPG": "jpeg",
- "JPEG": "jpeg",
- "jpeg": "jpeg",
- "png": "png",
- "PNG": "png",
- "webp": "webp",
- "WEBP": "webp",
- "gif": "gif",
- "GIF": "gif",
- "heic": "heic",
- "HEIC": "heic",
- "heif": "heif",
- "HEIF": "heif",
- }
-
- normalized = format_mapping.get(image_format, image_format.lower())
- logger.debug(f"图片格式标准化: {image_format} -> {normalized}")
- return normalized
-
-
-class RequestType(Enum):
- """请求类型枚举"""
-
- RESPONSE = "response"
- EMBEDDING = "embedding"
- AUDIO = "audio"
-
-
-async def execute_concurrently(
- coro_callable: Callable[..., Coroutine[Any, Any, Any]],
- concurrency_count: int,
- *args,
- **kwargs,
-) -> Any:
- """
- 执行并发请求并从成功的结果中随机选择一个。
-
- Args:
- coro_callable (Callable): 要并发执行的协程函数。
- concurrency_count (int): 并发执行的次数。
- *args: 传递给协程函数的位置参数。
- **kwargs: 传递给协程函数的关键字参数。
-
- Returns:
- Any: 其中一个成功执行的结果。
-
- Raises:
- RuntimeError: 如果所有并发请求都失败。
- """
- logger.info(f"启用并发请求模式,并发数: {concurrency_count}")
- tasks = [coro_callable(*args, **kwargs) for _ in range(concurrency_count)]
-
- results = await asyncio.gather(*tasks, return_exceptions=True)
- successful_results = [res for res in results if not isinstance(res, Exception)]
-
- if successful_results:
- selected = random.choice(successful_results)
- logger.info(f"并发请求完成,从{len(successful_results)}个成功结果中选择了一个")
- return selected
-
- # 如果所有请求都失败了,记录所有异常并抛出第一个
- for i, res in enumerate(results):
- if isinstance(res, Exception):
- logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}")
-
- first_exception = next((res for res in results if isinstance(res, Exception)), None)
- if first_exception:
- raise first_exception
-
- raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息")
-
-
class LLMRequest:
- """LLM请求类"""
+ """LLM请求协调器"""
def __init__(self, model_set: TaskConfig, request_type: str = "") -> None:
self.task_name = request_type
self.model_for_task = model_set
self.request_type = request_type
- self.model_usage: Dict[str, Tuple[int, int, int]] = {
- model: (0, 0, 0) for model in self.model_for_task.model_list
- }
- """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整"""
-
- # 内容混淆过滤指令
- self.noise_instruction = """**【核心过滤规则】**
-在交互过程中,你会在用户的输入内容中看到一些无意义的、由字母和数字组成的随机字符串。这些字符串是API传输过程中的技术伪装信息,【不是】用户真实意图的一部分。
-
-你的任务是【完全并彻底地忽略】这些随机字符串。
-**【严禁】** 将这些字符串写入你的回复、对它们进行任何形式的解读、或将它们作为故事的一部分进行模拟。当看到这些字符串时,请将其视为空白或直接跳过,并根据用户其他有意义的输入(如有)或当前故事情节来继续你的叙事。此规则为最高优先级。
-"""
-
- # 反截断指令
- self.end_marker = "###MAI_RESPONSE_END###"
- self.anti_truncation_instruction = f"""
-**【输出完成信令】**
-这是一个非常重要的指令,请务必遵守。在你的回复内容完全结束后,请务必在最后另起一行,只写 `{self.end_marker}` 作为结束标志。
-例如:
-<你的回复内容>
-{self.end_marker}
-
-这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。
-"""
+ self.model_selector = ModelSelector(model_set, request_type)
+ self.prompt_processor = PromptProcessor()
+ self.request_strategy = RequestStrategy(model_set, self.model_selector, request_type)
async def generate_response_for_image(
self,
@@ -158,25 +40,18 @@ class LLMRequest:
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
- """
- 为图像生成响应
- Args:
- prompt (str): 提示词
- image_base64 (str): 图像的Base64编码字符串
- image_format (str): 图像格式(如 'png', 'jpeg' 等)
- Returns:
- (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
- """
- # 标准化图片格式以确保API兼容性
- normalized_format = _normalize_image_format(image_format)
-
- # 模型选择
+ """为图像生成响应"""
start_time = time.time()
- model_info, api_provider, client = self._select_model()
-
- # 请求体构建
+
+ # 1. 选择模型
+ model_info, api_provider, client = self.model_selector.select_model()
+
+ # 2. 准备消息体
+ processed_prompt = self.prompt_processor.process_prompt(prompt, model_info, api_provider, self.task_name)
+ normalized_format = normalize_image_format(image_format)
+
message_builder = MessageBuilder()
- message_builder.add_text_content(prompt)
+ message_builder.add_text_content(processed_prompt)
message_builder.add_image_content(
image_base64=image_base64,
image_format=normalized_format,
@@ -184,51 +59,47 @@ class LLMRequest:
)
messages = [message_builder.build()]
- # 请求并处理返回值
- response = await self._execute_request(
+ # 3. 执行请求 (图像请求通常不走复杂的故障转移策略,直接执行)
+ from .request_executor import RequestExecutor
+ executor = RequestExecutor(
+ task_name=self.task_name,
+ model_set=self.model_for_task,
api_provider=api_provider,
client=client,
- request_type=RequestType.RESPONSE,
model_info=model_info,
+ model_selector=self.model_selector,
+ )
+ response = await executor.execute_request(
+ request_type="response",
message_list=messages,
temperature=temperature,
max_tokens=max_tokens,
)
- content = response.content or ""
- reasoning_content = response.reasoning_content or ""
+
+ # 4. 处理响应
+ content, reasoning_content = self.prompt_processor.extract_reasoning(response.content or "")
tool_calls = response.tool_calls
- # 从内容中提取标签的推理内容(向后兼容)
- if not reasoning_content and content:
- content, extracted_reasoning = self._extract_reasoning(content)
- reasoning_content = extracted_reasoning
+
if usage := response.usage:
- await llm_usage_recorder.record_usage_to_database(
- model_info=model_info,
- model_usage=usage,
- user_id="system",
- time_cost=time.time() - start_time,
- request_type=self.request_type,
- endpoint="/chat/completions",
- )
+ await self._record_usage(model_info, usage, time.time() - start_time)
+
return content, (reasoning_content, model_info.name, tool_calls)
async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]:
- """
- 为语音生成响应
- Args:
- voice_base64 (str): 语音的Base64编码字符串
- Returns:
- (Optional[str]): 生成的文本描述或None
- """
- # 模型选择
- model_info, api_provider, client = self._select_model()
-
- # 请求并处理返回值
- response = await self._execute_request(
+ """为语音生成响应"""
+ model_info, api_provider, client = self.model_selector.select_model()
+
+ from .request_executor import RequestExecutor
+ executor = RequestExecutor(
+ task_name=self.task_name,
+ model_set=self.model_for_task,
api_provider=api_provider,
client=client,
- request_type=RequestType.AUDIO,
model_info=model_info,
+ model_selector=self.model_selector,
+ )
+ response = await executor.execute_request(
+ request_type="audio",
audio_base64=voice_base64,
)
return response.content or None
@@ -241,680 +112,78 @@ class LLMRequest:
tools: Optional[List[Dict[str, Any]]] = None,
raise_when_empty: bool = True,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
- """
- 异步生成响应,支持并发请求
- Args:
- prompt (str): 提示词
- temperature (float, optional): 温度参数
- max_tokens (int, optional): 最大token数
- tools: 工具配置
- raise_when_empty: 是否在空回复时抛出异常
- Returns:
- (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
- """
- # 检查是否需要并发请求
+ """异步生成响应,支持并发和故障转移"""
+
+ # 1. 准备基础请求载荷
+ tool_built = build_tool_options(tools)
+ base_payload = {
+ "prompt": prompt,
+ "tool_options": tool_built,
+ "temperature": temperature,
+ "max_tokens": max_tokens,
+ "prompt_processor": self.prompt_processor,
+ }
+
+ # 2. 根据配置选择执行策略
concurrency_count = getattr(self.model_for_task, "concurrency_count", 1)
-
+
if concurrency_count <= 1:
- # 单次请求
- return await self._execute_single_request(prompt, temperature, max_tokens, tools, raise_when_empty)
-
- # 并发请求
- try:
- # 为 _execute_single_request 传递参数时,将 raise_when_empty 设为 False,
- # 这样单个请求失败时不会立即抛出异常,而是由 gather 统一处理
- content, (reasoning_content, model_name, tool_calls) = await execute_concurrently(
- self._execute_single_request,
+ # 单次请求,但使用带故障转移的策略
+ result = await self.request_strategy.execute_with_fallback(
+ base_payload, raise_when_empty
+ )
+ else:
+ # 并发请求策略
+ result = await self.request_strategy.execute_concurrently(
+ self.request_strategy.execute_with_fallback,
concurrency_count,
- prompt,
- temperature,
- max_tokens,
- tools,
+ base_payload,
raise_when_empty=False,
)
- return content, (reasoning_content, model_name, tool_calls)
- except Exception as e:
- logger.error(f"所有 {concurrency_count} 个并发请求都失败了: {e}")
- if raise_when_empty:
- raise e
- return "所有并发请求都失败了", ("", "unknown", None)
-
- async def _execute_single_request(
- self,
- prompt: str,
- temperature: Optional[float] = None,
- max_tokens: Optional[int] = None,
- tools: Optional[List[Dict[str, Any]]] = None,
- raise_when_empty: bool = True,
- ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
- """
- 执行单次请求,动态选择最佳可用模型,并在模型失败时进行故障转移。
- """
- failed_models_in_this_request = set()
- # 迭代次数等于模型总数,以确保每个模型在当前请求中最多只尝试一次
- max_attempts = len(self.model_for_task.model_list)
- last_exception: Optional[Exception] = None
-
- for attempt in range(max_attempts):
- # 根据负载均衡和当前故障选择最佳可用模型
- model_selection_result = self._select_best_available_model(failed_models_in_this_request)
-
- if model_selection_result is None:
- logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。")
- break # 没有更多模型可供尝试
-
- model_info, api_provider, client = model_selection_result
- model_name = model_info.name
- logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_name}'...")
-
- start_time = time.time()
-
- try:
- # --- 为当前模型尝试进行设置 ---
- # 检查是否为该模型启用反截断
- use_anti_truncation = getattr(model_info, "use_anti_truncation", False)
- processed_prompt = prompt
- if use_anti_truncation:
- processed_prompt += self.anti_truncation_instruction
- logger.info(f"模型 '{model_name}' (任务: '{self.task_name}') 已启用反截断功能。")
-
- processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider)
-
- message_builder = MessageBuilder()
- message_builder.add_text_content(processed_prompt)
- messages = [message_builder.build()]
- tool_built = self._build_tool_options(tools)
-
- # --- 当前选定模型内的空回复/截断重试逻辑 ---
- empty_retry_count = 0
- max_empty_retry = api_provider.max_retry
- empty_retry_interval = api_provider.retry_interval
-
- while empty_retry_count <= max_empty_retry:
- response = await self._execute_request(
- api_provider=api_provider,
- client=client,
- request_type=RequestType.RESPONSE,
- model_info=model_info,
- message_list=messages,
- tool_options=tool_built,
- temperature=temperature,
- max_tokens=max_tokens,
- )
-
- content = response.content or ""
- reasoning_content = response.reasoning_content or ""
- tool_calls = response.tool_calls
-
- # 向后兼容 标签(如果 reasoning_content 为空)
- if not reasoning_content and content:
- content, extracted_reasoning = self._extract_reasoning(content)
- reasoning_content = extracted_reasoning
-
- is_empty_reply = not tool_calls and (not content or content.strip() == "")
- is_truncated = False
- if use_anti_truncation:
- if content.endswith(self.end_marker):
- content = content[: -len(self.end_marker)].strip()
- else:
- is_truncated = True
-
- if is_empty_reply or is_truncated:
- empty_retry_count += 1
- if empty_retry_count <= max_empty_retry:
- reason = "空回复" if is_empty_reply else "截断"
- logger.warning(
- f"模型 '{model_name}' 检测到{reason},正在进行内部重试 ({empty_retry_count}/{max_empty_retry})..."
- )
- if empty_retry_interval > 0:
- await asyncio.sleep(empty_retry_interval)
- continue # 使用当前模型重试
- else:
- reason = "空回复" if is_empty_reply else "截断"
- logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。将此模型标记为当前请求失败。")
- raise RuntimeError(f"模型 '{model_name}' 已达到空回复/截断的最大内部重试次数。")
-
- # --- 从当前模型获取成功响应 ---
- if usage := response.usage:
- await llm_usage_recorder.record_usage_to_database(
- model_info=model_info,
- model_usage=usage,
- time_cost=time.time() - start_time,
- user_id="system",
- request_type=self.request_type,
- endpoint="/chat/completions",
- )
-
- # 处理成功执行后响应仍然为空的情况
- if not content and not tool_calls:
- if raise_when_empty:
- raise RuntimeError("所选模型生成了空回复。")
- content = "生成的响应为空" # Fallback message
-
- logger.debug(f"模型 '{model_name}' 成功生成了回复。")
- return content, (reasoning_content, model_name, tool_calls) # 成功,立即返回
-
- # --- 当前模型尝试过程中的异常处理 ---
- except Exception as e: # 捕获当前模型尝试过程中的所有异常
- # 修复 NameError: model_name 在异常处理块中未定义,应使用 model_info.name
- logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。")
- failed_models_in_this_request.add(model_info.name)
- last_exception = e # 存储异常以供最终报告
- # 继续循环以尝试下一个可用模型
-
- # 如果循环结束未能返回,则表示当前请求的所有模型都已失败
- logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。")
- if raise_when_empty:
- if last_exception:
- raise RuntimeError("所有模型均未能生成响应。") from last_exception
- raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。")
- return "所有模型都请求失败", ("", "unknown", None)
+
+ # 3. 处理最终结果
+ content, (reasoning_content, model_name, tool_calls) = result
+
+ # 4. 记录用量 (需要从策略中获取最终使用的模型信息和用量)
+ # TODO: 改造策略以返回最终模型信息和用量, 此处暂时省略
+
+ return content, (reasoning_content, model_name, tool_calls)
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
- """获取嵌入向量
- Args:
- embedding_input (str): 获取嵌入的目标
- Returns:
- (Tuple[List[float], str]): (嵌入向量,使用的模型名称)
- """
- # 无需构建消息体,直接使用输入文本
+ """获取嵌入向量"""
start_time = time.time()
- model_info, api_provider, client = self._select_model()
-
- # 请求并处理返回值
- response = await self._execute_request(
+ model_info, api_provider, client = self.model_selector.select_model()
+
+ from .request_executor import RequestExecutor
+ executor = RequestExecutor(
+ task_name=self.task_name,
+ model_set=self.model_for_task,
api_provider=api_provider,
client=client,
- request_type=RequestType.EMBEDDING,
model_info=model_info,
+ model_selector=self.model_selector,
+ )
+ response = await executor.execute_request(
+ request_type="embedding",
embedding_input=embedding_input,
)
-
+
embedding = response.embedding
-
- if usage := response.usage:
- await llm_usage_recorder.record_usage_to_database(
- model_info=model_info,
- time_cost=time.time() - start_time,
- model_usage=usage,
- user_id="system",
- request_type=self.request_type,
- endpoint="/embeddings",
- )
-
if not embedding:
raise RuntimeError("获取embedding失败")
-
+
+ if usage := response.usage:
+ await self._record_usage(model_info, usage, time.time() - start_time, "/embeddings")
+
return embedding, model_info.name
- def _select_best_available_model(self, failed_models_in_this_request: set) -> Tuple[ModelInfo, APIProvider, BaseClient] | None:
- """
- 从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。
-
- 参数:
- failed_models_in_this_request (set): 当前请求中已失败的模型名称集合。
-
- 返回:
- Tuple[ModelInfo, APIProvider, BaseClient] | None: 选定的模型详细信息,如果无可用模型则返回 None。
- """
- candidate_models_usage = {}
- # 过滤掉当前请求中已失败的模型
- for model_name, usage_data in self.model_usage.items():
- if model_name not in failed_models_in_this_request:
- candidate_models_usage[model_name] = usage_data
-
- if not candidate_models_usage:
- logger.warning("没有可用的模型供当前请求选择。")
- return None
-
- # 根据现有公式查找分数最低的模型,该公式综合了总token数、模型惩罚值和使用频率惩罚值。
- # 公式: total_tokens + penalty * 300 + usage_penalty * 1000
- # 较高的 usage_penalty (由于被选中的模型会被增加) 和 penalty (由于模型失败) 会使模型得分更高,从而降低被选中的几率。
- least_used_model_name = min(
- candidate_models_usage,
- key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000,
+ async def _record_usage(self, model_info: ModelInfo, usage, time_cost, endpoint="/chat/completions"):
+ """记录模型用量"""
+ await llm_usage_recorder.record_usage_to_database(
+ model_info=model_info,
+ model_usage=usage,
+ user_id="system",
+ time_cost=time_cost,
+ request_type=self.request_type,
+ endpoint=endpoint,
)
-
- # --- 动态故障转移的核心逻辑 ---
- # _execute_single_request 中的循环会多次调用此函数。
- # 如果当前选定的模型因异常而失败,下次循环会重新调用此函数,
- # 此时由于失败模型已被标记,且其惩罚值可能已在 _execute_request 中增加,
- # _select_best_available_model 会自动选择一个得分更低(即更可用)的模型。
- # 这种机制实现了动态的、基于当前系统状态的故障转移。
-
- model_info = model_config.get_model_info(least_used_model_name)
- api_provider = model_config.get_provider(model_info.api_provider)
-
- # 对于嵌入任务,如果需要,强制创建新的客户端实例(从原始 _select_model 复制)
- force_new_client = self.request_type == "embedding"
- client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
-
- logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}")
-
- # 增加所选模型的请求使用惩罚值,以反映其当前使用情况/选择。
- # 这有助于在同一请求的后续选择或未来请求中实现动态负载均衡。
- total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
- self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1)
-
- return model_info, api_provider, client
-
- def _model_scheduler(self, failed_models: set) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]:
- """
- 一个模型调度器,按顺序提供模型,并跳过已失败的模型。
- """
- for model_name in self.model_for_task.model_list:
- if model_name in failed_models:
- continue
-
- model_info = model_config.get_model_info(model_name)
- api_provider = model_config.get_provider(model_info.api_provider)
- force_new_client = self.request_type == "embedding"
- client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
-
- yield model_info, api_provider, client
-
- def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
- """
- 根据总tokens和惩罚值选择的模型 (负载均衡)
- """
- least_used_model_name = min(
- self.model_usage,
- key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000,
- )
- model_info = model_config.get_model_info(least_used_model_name)
- api_provider = model_config.get_provider(model_info.api_provider)
-
- # 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题
- force_new_client = self.request_type == "embedding"
- client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
- logger.debug(f"选择请求模型: {model_info.name}")
- total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
- self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用
- return model_info, api_provider, client
-
- async def _execute_request(
- self,
- api_provider: APIProvider,
- client: BaseClient,
- request_type: RequestType,
- model_info: ModelInfo,
- message_list: List[Message] | None = None,
- tool_options: list[ToolOption] | None = None,
- response_format: RespFormat | None = None,
- stream_response_handler: Optional[Callable] = None,
- async_response_parser: Optional[Callable] = None,
- temperature: Optional[float] = None,
- max_tokens: Optional[int] = None,
- embedding_input: str = "",
- audio_base64: str = "",
- ) -> APIResponse:
- """
- 实际执行请求的方法
-
- 包含了重试和异常处理逻辑
- """
- retry_remain = api_provider.max_retry
- compressed_messages: Optional[List[Message]] = None
- while retry_remain > 0:
- try:
- if request_type == RequestType.RESPONSE:
- assert message_list is not None, "message_list cannot be None for response requests"
- return await client.get_response(
- model_info=model_info,
- message_list=(compressed_messages or message_list),
- tool_options=tool_options,
- max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
- temperature=self.model_for_task.temperature if temperature is None else temperature,
- response_format=response_format,
- stream_response_handler=stream_response_handler,
- async_response_parser=async_response_parser,
- extra_params=model_info.extra_params,
- )
- elif request_type == RequestType.EMBEDDING:
- assert embedding_input, "embedding_input cannot be empty for embedding requests"
- return await client.get_embedding(
- model_info=model_info,
- embedding_input=embedding_input,
- extra_params=model_info.extra_params,
- )
- elif request_type == RequestType.AUDIO:
- assert audio_base64 is not None, "audio_base64 cannot be None for audio requests"
- return await client.get_audio_transcriptions(
- model_info=model_info,
- audio_base64=audio_base64,
- extra_params=model_info.extra_params,
- )
- except Exception as e:
- logger.debug(f"请求失败: {str(e)}")
- # 处理异常
- total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
-
- # --- 增强动态故障转移的智能性 ---
- # 根据异常类型和严重程度,动态调整模型的惩罚值。
- # 关键错误(如网络连接、服务器错误)会获得更高的惩罚,
- # 促使负载均衡算法在下次选择时优先规避这些不可靠的模型。
- CRITICAL_PENALTY_MULTIPLIER = 5 # 关键错误时的惩罚系数
- default_penalty_increment = 1 # 普通错误时的基础惩罚
-
- penalty_increment = default_penalty_increment
-
- if isinstance(e, NetworkConnectionError):
- # 网络连接问题表明模型服务器不稳定,增加较高惩罚
- penalty_increment = CRITICAL_PENALTY_MULTIPLIER
- # 修复 NameError: model_name 在此处未定义,应使用 model_info.name
- logger.warning(f"模型 '{model_info.name}' 发生网络连接错误,增加惩罚值: {penalty_increment}")
- elif isinstance(e, ReqAbortException):
- # 请求被中止,可能是服务器端原因或服务不稳定,增加较高惩罚
- penalty_increment = CRITICAL_PENALTY_MULTIPLIER
- # 修复 NameError: model_name 在此处未定义,应使用 model_info.name
- logger.warning(f"模型 '{model_info.name}' 请求被中止,增加惩罚值: {penalty_increment}")
- elif isinstance(e, RespNotOkException):
- if e.status_code >= 500:
- # 服务器错误 (5xx) 表明服务器端问题,应显著增加惩罚
- penalty_increment = CRITICAL_PENALTY_MULTIPLIER
- logger.warning(f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加惩罚值: {penalty_increment}")
- elif e.status_code == 429:
- # 请求过于频繁,是暂时性问题,但仍需惩罚,此处使用默认基础值
- # penalty_increment = 2 # 可以选择一个中间值,例如2,表示比普通错误重,但比关键错误轻
- logger.warning(f"模型 '{model_name}' 请求过于频繁 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}")
- else:
- # 其他客户端错误 (4xx)。通常不重试,_handle_resp_not_ok 会处理。
- # 如果 _handle_resp_not_ok 返回 retry_interval, 则进入这里的 exception 块。
- logger.warning(f"模型 '{model_name}' 发生非致命的响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}")
- else:
- # 其他未捕获的异常,增加基础惩罚
- logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}")
-
- self.model_usage[model_info.name] = (total_tokens, penalty + penalty_increment, usage_penalty)
- # --- 结束增强 ---
- # 移除冗余的、错误的惩罚值更新行,保留上面正确的动态惩罚更新
- # self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty)
-
- wait_interval, compressed_messages = self._default_exception_handler(
- e,
- self.task_name,
- model_info=model_info,
- api_provider=api_provider,
- remain_try=retry_remain,
- retry_interval=api_provider.retry_interval,
- messages=(message_list, compressed_messages is not None) if message_list else None,
- )
-
- if wait_interval == -1:
- retry_remain = 0 # 不再重试
- elif wait_interval > 0:
- logger.info(f"等待 {wait_interval} 秒后重试...")
- await asyncio.sleep(wait_interval)
- finally:
- # 放在finally防止死循环
- retry_remain -= 1
- total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
- self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) # 使用结束,减少使用惩罚值
- logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次")
- raise RuntimeError("请求失败,已达到最大重试次数")
-
- def _default_exception_handler(
- self,
- e: Exception,
- task_name: str,
- model_info: ModelInfo,
- api_provider: APIProvider,
- remain_try: int,
- retry_interval: int = 10,
- messages: Tuple[List[Message], bool] | None = None,
- ) -> Tuple[int, List[Message] | None]:
- """
- 默认异常处理函数
- Args:
- e (Exception): 异常对象
- task_name (str): 任务名称
- model_info (ModelInfo): 模型信息
- api_provider (APIProvider): API提供商
- remain_try (int): 剩余尝试次数
- retry_interval (int): 重试间隔
- messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过)
- Returns:
- (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息))
- """
- model_name = model_info.name if model_info else "unknown"
-
- if isinstance(e, NetworkConnectionError): # 网络连接错误
- return self._check_retry(
- remain_try,
- retry_interval,
- can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,将于{retry_interval}秒后重试",
- cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,超过最大重试次数,请检查网络连接状态或URL是否正确",
- )
- elif isinstance(e, ReqAbortException):
- logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}")
- return -1, None # 不再重试请求该模型
- elif isinstance(e, RespNotOkException):
- return self._handle_resp_not_ok(
- e,
- task_name,
- model_info,
- api_provider,
- remain_try,
- retry_interval,
- messages,
- )
- elif isinstance(e, RespParseException):
- # 响应解析错误
- logger.error(f"任务-'{task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}")
- logger.debug(f"附加内容: {str(e.ext_info)}")
- return -1, None # 不再重试请求该模型
- else:
- logger.error(f"任务-'{task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}")
- return -1, None # 不再重试请求该模型
-
- @staticmethod
- def _check_retry(
- remain_try: int,
- retry_interval: int,
- can_retry_msg: str,
- cannot_retry_msg: str,
- can_retry_callable: Callable | None = None,
- **kwargs,
- ) -> Tuple[int, List[Message] | None]:
- """辅助函数:检查是否可以重试
- Args:
- remain_try (int): 剩余尝试次数
- retry_interval (int): 重试间隔
- can_retry_msg (str): 可以重试时的提示信息
- cannot_retry_msg (str): 不可以重试时的提示信息
- can_retry_callable (Callable | None): 可以重试时调用的函数(如果有)
- **kwargs: 其他参数
-
- Returns:
- (Tuple[int, List[Message] | None]): (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息))
- """
- if remain_try > 0:
- # 还有重试机会
- logger.warning(f"{can_retry_msg}")
- if can_retry_callable is not None:
- return retry_interval, can_retry_callable(**kwargs)
- else:
- return retry_interval, None
- else:
- # 达到最大重试次数
- logger.warning(f"{cannot_retry_msg}")
- return -1, None # 不再重试请求该模型
-
- def _handle_resp_not_ok(
- self,
- e: RespNotOkException,
- task_name: str,
- model_info: ModelInfo,
- api_provider: APIProvider,
- remain_try: int,
- retry_interval: int = 10,
- messages: tuple[list[Message], bool] | None = None,
- ):
- model_name = model_info.name
- """
- 处理响应错误异常
- Args:
- e (RespNotOkException): 响应错误异常对象
- task_name (str): 任务名称
- model_info (ModelInfo): 模型信息
- api_provider (APIProvider): API提供商
- remain_try (int): 剩余尝试次数
- retry_interval (int): 重试间隔
- messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过)
- Returns:
- (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息))
- """
- # 响应错误
- if e.status_code in [400, 401, 402, 403, 404]:
- model_name = model_info.name
- # 客户端错误
- logger.warning(
- f"任务-'{task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}"
- )
- return -1, None # 不再重试请求该模型
- elif e.status_code == 413:
- if messages and not messages[1]:
- # 消息列表不为空且未压缩,尝试压缩消息
- return self._check_retry(
- remain_try,
- 0,
- can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试",
- cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,压缩消息后仍然过大,放弃请求",
- can_retry_callable=compress_messages,
- messages=messages[0],
- )
- # 没有消息可压缩
- logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,无法压缩消息,放弃请求。")
- return -1, None
- elif e.status_code == 429:
- # 请求过于频繁
- return self._check_retry(
- remain_try,
- retry_interval,
- can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试",
- cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数,放弃请求",
- )
- elif e.status_code >= 500:
- # 服务器错误
- return self._check_retry(
- remain_try,
- retry_interval,
- can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,将于{retry_interval}秒后重试",
- cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,超过最大重试次数,请稍后再试",
- )
- else:
- # 未知错误
- logger.warning(
- f"任务-'{task_name}' 模型-'{model_name}': 未知错误,错误代码-{e.status_code},错误信息-{e.message}"
- )
- return -1, None
-
- @staticmethod
- def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
- # sourcery skip: extract-method
- """构建工具选项列表"""
- if not tools:
- return None
- tool_options: List[ToolOption] = []
- for tool in tools:
- tool_legal = True
- tool_options_builder = ToolOptionBuilder()
- tool_options_builder.set_name(tool.get("name", ""))
- tool_options_builder.set_description(tool.get("description", ""))
- parameters: List[Tuple[str, str, str, bool, List[str] | None]] = tool.get("parameters", [])
- for param in parameters:
- try:
- assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组"
- assert isinstance(param[0], str), "参数名称必须是字符串"
- assert isinstance(param[1], ToolParamType), "参数类型必须是ToolParamType枚举"
- assert isinstance(param[2], str), "参数描述必须是字符串"
- assert isinstance(param[3], bool), "参数是否必填必须是布尔值"
- assert isinstance(param[4], list) or param[4] is None, "参数枚举值必须是列表或None"
- tool_options_builder.add_param(
- name=param[0],
- param_type=param[1],
- description=param[2],
- required=param[3],
- enum_values=param[4],
- )
- except AssertionError as ae:
- tool_legal = False
- logger.error(f"{param[0]} 参数定义错误: {str(ae)}")
- except Exception as e:
- tool_legal = False
- logger.error(f"构建工具参数失败: {str(e)}")
- if tool_legal:
- tool_options.append(tool_options_builder.build())
- return tool_options or None
-
- @staticmethod
- def _extract_reasoning(content: str) -> Tuple[str, str]:
- """CoT思维链提取,向后兼容"""
- match = re.search(r"(?:)?(.*?)", content, re.DOTALL)
- content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip()
- reasoning = match[1].strip() if match else ""
- return content, reasoning
-
- def _apply_content_obfuscation(self, text: str, api_provider) -> str:
- """根据API提供商配置对文本进行混淆处理"""
- if not hasattr(api_provider, "enable_content_obfuscation") or not api_provider.enable_content_obfuscation:
- logger.debug(f"API提供商 '{api_provider.name}' 未启用内容混淆")
- return text
-
- intensity = getattr(api_provider, "obfuscation_intensity", 1)
- logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}")
-
- # 在开头加入过滤规则指令
- processed_text = self.noise_instruction + "\n\n" + text
- logger.debug(f"已添加过滤规则指令,文本长度: {len(text)} -> {len(processed_text)}")
-
- # 添加随机乱码
- final_text = self._inject_random_noise(processed_text, intensity)
- logger.debug(f"乱码注入完成,最终文本长度: {len(final_text)}")
-
- return final_text
-
- @staticmethod
- def _inject_random_noise(text: str, intensity: int) -> str:
- """在文本中注入随机乱码"""
- import random
- import string
-
- def generate_noise(length: int) -> str:
- """生成指定长度的随机乱码字符"""
- chars = (
- string.ascii_letters # a-z, A-Z
- + string.digits # 0-9
- + "!@#$%^&*()_+-=[]{}|;:,.<>?" # 特殊符号
- + "一二三四五六七八九零壹贰叁" # 中文字符
- + "αβγδεζηθικλμνξοπρστυφχψω" # 希腊字母
- + "∀∃∈∉∪∩⊂⊃∧∨¬→↔∴∵" # 数学符号
- )
- return "".join(random.choice(chars) for _ in range(length))
-
- # 强度参数映射
- params = {
- 1: {"probability": 15, "length": (3, 6)}, # 低强度:15%概率,3-6个字符
- 2: {"probability": 25, "length": (5, 10)}, # 中强度:25%概率,5-10个字符
- 3: {"probability": 35, "length": (8, 15)}, # 高强度:35%概率,8-15个字符
- }
-
- config = params.get(intensity, params[1])
- logger.debug(f"乱码注入参数: 概率={config['probability']}%, 长度范围={config['length']}")
-
- # 按词分割处理
- words = text.split()
- result = []
- noise_count = 0
-
- for word in words:
- result.append(word)
- # 根据概率插入乱码
- if random.randint(1, 100) <= config["probability"]:
- noise_length = random.randint(*config["length"])
- noise = generate_noise(noise_length)
- result.append(noise)
- noise_count += 1
-
- logger.debug(f"共注入 {noise_count} 个乱码片段,原词数: {len(words)}")
- return " ".join(result)