diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py
index c39ab8af9..acb7130b6 100644
--- a/src/llm_models/utils_model.py
+++ b/src/llm_models/utils_model.py
@@ -2,6 +2,7 @@ import re
import asyncio
import time
import random
+import string
from enum import Enum
from rich.traceback import install
@@ -13,7 +14,7 @@ from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
from .payload_content.message import MessageBuilder, Message
from .payload_content.resp_format import RespFormat
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType
-from .model_client.base_client import BaseClient, APIResponse, client_registry
+from .model_client.base_client import BaseClient, APIResponse, client_registry, UsageRecord
from .utils import compress_messages, llm_usage_recorder
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
@@ -21,18 +22,9 @@ install(extra_lines=3)
logger = get_logger("model_utils")
-# 常见Error Code Mapping
-error_code_mapping = {
- 400: "参数不正确",
- 401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确",
- 402: "账号余额不足",
- 403: "需要实名,或余额不足",
- 404: "Not Found",
- 429: "请求过于频繁,请稍后再试",
- 500: "服务器内部故障",
- 503: "服务器负载过高",
-}
-
+# ==============================================================================
+# Standalone Utility Functions
+# ==============================================================================
def _normalize_image_format(image_format: str) -> str:
"""
@@ -45,35 +37,17 @@ def _normalize_image_format(image_format: str) -> str:
str: 标准化后的图片格式
"""
format_mapping = {
- "jpg": "jpeg",
- "JPG": "jpeg",
- "JPEG": "jpeg",
- "jpeg": "jpeg",
- "png": "png",
- "PNG": "png",
- "webp": "webp",
- "WEBP": "webp",
- "gif": "gif",
- "GIF": "gif",
- "heic": "heic",
- "HEIC": "heic",
- "heif": "heif",
- "HEIF": "heif",
+ "jpg": "jpeg", "JPG": "jpeg", "JPEG": "jpeg", "jpeg": "jpeg",
+ "png": "png", "PNG": "png",
+ "webp": "webp", "WEBP": "webp",
+ "gif": "gif", "GIF": "gif",
+ "heic": "heic", "HEIC": "heic",
+ "heif": "heif", "HEIF": "heif",
}
-
normalized = format_mapping.get(image_format, image_format.lower())
logger.debug(f"图片格式标准化: {image_format} -> {normalized}")
return normalized
-
-class RequestType(Enum):
- """请求类型枚举"""
-
- RESPONSE = "response"
- EMBEDDING = "embedding"
- AUDIO = "audio"
-
-
async def execute_concurrently(
coro_callable: Callable[..., Coroutine[Any, Any, Any]],
concurrency_count: int,
@@ -97,7 +71,6 @@ async def execute_concurrently(
"""
logger.info(f"启用并发请求模式,并发数: {concurrency_count}")
tasks = [coro_callable(*args, **kwargs) for _ in range(concurrency_count)]
-
results = await asyncio.gather(*tasks, return_exceptions=True)
successful_results = [res for res in results if not isinstance(res, Exception)]
@@ -110,41 +83,107 @@ async def execute_concurrently(
for i, res in enumerate(results):
if isinstance(res, Exception):
logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}")
-
+
first_exception = next((res for res in results if isinstance(res, Exception)), None)
if first_exception:
raise first_exception
-
raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息")
+class RequestType(Enum):
+ """请求类型枚举"""
+ RESPONSE = "response"
+ EMBEDDING = "embedding"
+ AUDIO = "audio"
-class LLMRequest:
-<<<<<<< HEAD
- """
- LLM请求协调器。
- 封装了模型选择、Prompt处理、请求执行和高级策略(如故障转移、并发)的完整流程。
- 为上层业务逻辑提供统一的、简化的接口来与大语言模型交互。
- """
-=======
- """LLM请求类"""
->>>>>>> parent of 253946f (refactor(llm): 将LLM请求逻辑解耦到专门的组件中)
+# ==============================================================================
+# Helper Classes for LLMRequest Refactoring
+# ==============================================================================
- def __init__(self, model_set: TaskConfig, request_type: str = "") -> None:
+class _ModelSelector:
+ """负责模型选择、负载均衡和动态故障切换的策略。"""
+
+ CRITICAL_PENALTY_MULTIPLIER = 5
+ DEFAULT_PENALTY_INCREMENT = 1
+
+ def __init__(self, model_list: List[str], model_usage: Dict[str, Tuple[int, int, int]]):
+ self.model_list = model_list
+ self.model_usage = model_usage
+
+ def select_best_available_model(
+ self, failed_models_in_this_request: set, request_type: str
+ ) -> Optional[Tuple[ModelInfo, APIProvider, BaseClient]]:
"""
- 初始化LLM请求协调器。
+ 从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。
Args:
- model_set (TaskConfig): 特定任务的模型配置集合。
- request_type (str, optional): 请求类型或任务名称,用于日志和用量记录。 Defaults to "".
- """
- self.task_name = request_type
- self.model_for_task = model_set
- self.request_type = request_type
- self.model_usage: Dict[str, Tuple[int, int, int]] = {
- model: (0, 0, 0) for model in self.model_for_task.model_list
- }
- """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整"""
+ failed_models_in_this_request (set): 当前请求中已失败的模型名称集合。
+ request_type (str): 请求类型,用于确定是否强制创建新客户端。
+ Returns:
+ Optional[Tuple[ModelInfo, APIProvider, BaseClient]]: 选定的模型详细信息,如果无可用模型则返回 None。
+ """
+ candidate_models_usage = {
+ model_name: usage_data
+ for model_name, usage_data in self.model_usage.items()
+ if model_name not in failed_models_in_this_request
+ }
+
+ if not candidate_models_usage:
+ logger.warning("没有可用的模型供当前请求选择。")
+ return None
+
+ # 根据公式查找分数最低的模型,该公式综合了总token数、模型失败惩罚值和使用频率惩罚值。
+ # 公式: total_tokens + penalty * 300 + usage_penalty * 1000
+ least_used_model_name = min(
+ candidate_models_usage,
+ key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000,
+ )
+
+ model_info = model_config.get_model_info(least_used_model_name)
+ api_provider = model_config.get_provider(model_info.api_provider)
+ # 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题
+ force_new_client = request_type == "embedding"
+ client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
+
+ logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}")
+ # 增加所选模型的请求使用惩罚值,以实现动态负载均衡。
+ self.update_usage_penalty(model_info.name, increase=True)
+ return model_info, api_provider, client
+
+ def update_usage_penalty(self, model_name: str, increase: bool):
+ """更新模型的使用惩罚值,用于负载均衡。"""
+ total_tokens, penalty, usage_penalty = self.model_usage[model_name]
+ adjustment = 1 if increase else -1
+ self.model_usage[model_name] = (total_tokens, penalty, usage_penalty + adjustment)
+
+ def update_failure_penalty(self, model_name: str, e: Exception):
+ """
+ 根据异常类型动态调整模型的失败惩罚值。
+ 关键错误(如网络连接、服务器错误)会获得更高的惩罚,
+ 促使负载均衡算法在下次选择时优先规避这些不可靠的模型。
+ """
+ total_tokens, penalty, usage_penalty = self.model_usage[model_name]
+ penalty_increment = self.DEFAULT_PENALTY_INCREMENT
+
+ if isinstance(e, (NetworkConnectionError, ReqAbortException)):
+ penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER
+ logger.warning(f"模型 '{model_name}' 发生严重错误 ({type(e).__name__}),增加高额惩罚值: {penalty_increment}")
+ elif isinstance(e, RespNotOkException):
+ if e.status_code >= 500:
+ penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER
+ logger.warning(f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加高额惩罚值: {penalty_increment}")
+ else:
+ logger.warning(f"模型 '{model_name}' 发生客户端响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}")
+ else:
+ logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}")
+
+ self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty)
+
+
+class _PromptProcessor:
+ """封装所有与提示词和响应内容的预处理和后处理逻辑。"""
+
+ def __init__(self):
# 内容混淆过滤指令
self.noise_instruction = """**【核心过滤规则】**
在交互过程中,你会在用户的输入内容中看到一些无意义的、由字母和数字组成的随机字符串。这些字符串是API传输过程中的技术伪装信息,【不是】用户真实意图的一部分。
@@ -152,7 +191,6 @@ class LLMRequest:
你的任务是【完全并彻底地忽略】这些随机字符串。
**【严禁】** 将这些字符串写入你的回复、对它们进行任何形式的解读、或将它们作为故事的一部分进行模拟。当看到这些字符串时,请将其视为空白或直接跳过,并根据用户其他有意义的输入(如有)或当前故事情节来继续你的叙事。此规则为最高优先级。
"""
-
# 反截断指令
self.end_marker = "###MAI_RESPONSE_END###"
self.anti_truncation_instruction = f"""
@@ -165,6 +203,372 @@ class LLMRequest:
这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。
"""
+ def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str:
+ """为请求准备最终的提示词,应用内容混淆和反截断指令。"""
+ processed_prompt = self._apply_content_obfuscation(prompt, api_provider)
+ if getattr(model_info, "use_anti_truncation", False):
+ processed_prompt += self.anti_truncation_instruction
+ logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。")
+ return processed_prompt
+
+ def process_response(self, content: str, use_anti_truncation: bool) -> Tuple[str, str, bool]:
+ """
+ 处理响应内容,提取思维链并检查截断。
+
+ Returns:
+ Tuple[str, str, bool]: (处理后的内容, 思维链内容, 是否被截断)
+ """
+ content, reasoning = self._extract_reasoning(content)
+ is_truncated = False
+ if use_anti_truncation:
+ if content.endswith(self.end_marker):
+ content = content[: -len(self.end_marker)].strip()
+ else:
+ is_truncated = True
+ return content, reasoning, is_truncated
+
+ def _apply_content_obfuscation(self, text: str, api_provider: APIProvider) -> str:
+ """根据API提供商配置对文本进行混淆处理。"""
+ if not getattr(api_provider, "enable_content_obfuscation", False):
+ return text
+
+ intensity = getattr(api_provider, "obfuscation_intensity", 1)
+ logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}")
+ processed_text = self.noise_instruction + "\n\n" + text
+ return self._inject_random_noise(processed_text, intensity)
+
+ @staticmethod
+ def _inject_random_noise(text: str, intensity: int) -> str:
+ """在文本中注入随机乱码。"""
+ params = {
+ 1: {"probability": 15, "length": (3, 6)},
+ 2: {"probability": 25, "length": (5, 10)},
+ 3: {"probability": 35, "length": (8, 15)},
+ }
+ config = params.get(intensity, params[1])
+ words = text.split()
+ result = []
+ for word in words:
+ result.append(word)
+ if random.randint(1, 100) <= config["probability"]:
+ noise_length = random.randint(*config["length"])
+ chars = string.ascii_letters + string.digits + "!@#$%^&*()_+-=[]{}|;:,.<>?"
+ noise = "".join(random.choice(chars) for _ in range(noise_length))
+ result.append(noise)
+ return " ".join(result)
+
+ @staticmethod
+ def _extract_reasoning(content: str) -> Tuple[str, str]:
+ """
+ 从模型返回的完整内容中提取被...标签包裹的思考过程,
+ 并返回清理后的内容和思考过程。
+
+ Args:
+ content (str): 模型返回的原始字符串。
+
+ Returns:
+ Tuple[str, str]:
+ - 清理后的内容(移除了标签及其内容)。
+ - 提取出的思考过程文本(如果没有则为空字符串)。
+ """
+ # 使用正则表达式精确查找 ... 标签及其内容
+ think_pattern = re.compile(r"(.*?)\s*", re.DOTALL)
+ match = think_pattern.search(content)
+
+ if match:
+ # 提取思考过程
+ reasoning = match.group(1).strip()
+ # 从原始内容中移除匹配到的整个部分(包括标签和后面的空白)
+ clean_content = think_pattern.sub("", content, count=1).strip()
+ else:
+ reasoning = ""
+ clean_content = content.strip()
+
+ return clean_content, reasoning
+
+
+class _RequestExecutor:
+ """负责执行实际的API请求,包含重试逻辑和底层异常处理。"""
+
+ def __init__(self, model_selector: _ModelSelector, task_name: str):
+ self.model_selector = model_selector
+ self.task_name = task_name
+
+ async def execute_request(
+ self,
+ api_provider: APIProvider,
+ client: BaseClient,
+ request_type: RequestType,
+ model_info: ModelInfo,
+ **kwargs,
+ ) -> APIResponse:
+ """实际执行请求的方法,包含了重试和异常处理逻辑。"""
+ retry_remain = api_provider.max_retry
+ compressed_messages: Optional[List[Message]] = None
+
+ while retry_remain > 0:
+ try:
+ message_list = kwargs.get("message_list")
+ current_messages = compressed_messages or message_list
+
+ if request_type == RequestType.RESPONSE:
+ assert current_messages is not None, "message_list cannot be None for response requests"
+ return await client.get_response(model_info=model_info, message_list=current_messages, **kwargs)
+ elif request_type == RequestType.EMBEDDING:
+ return await client.get_embedding(model_info=model_info, **kwargs)
+ elif request_type == RequestType.AUDIO:
+ return await client.get_audio_transcriptions(model_info=model_info, **kwargs)
+
+ except Exception as e:
+ logger.debug(f"请求失败: {str(e)}")
+ self.model_selector.update_failure_penalty(model_info.name, e)
+
+ wait_interval, new_compressed_messages = self._handle_exception(
+ e, model_info, api_provider, retry_remain, (kwargs.get("message_list"), compressed_messages is not None)
+ )
+ if new_compressed_messages:
+ compressed_messages = new_compressed_messages
+
+ if wait_interval == -1:
+ raise e # 如果不再重试,则传播异常
+ elif wait_interval > 0:
+ await asyncio.sleep(wait_interval)
+ finally:
+ retry_remain -= 1
+
+ logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次")
+ raise RuntimeError("请求失败,已达到最大重试次数")
+
+ def _handle_exception(
+ self, e: Exception, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
+ ) -> Tuple[int, Optional[List[Message]]]:
+ """
+ 默认异常处理函数,决定是否重试。
+
+ Returns:
+ (等待间隔(-1表示不再重试), 新的消息列表(适用于压缩消息))
+ """
+ model_name = model_info.name
+ retry_interval = api_provider.retry_interval
+
+ if isinstance(e, (NetworkConnectionError, ReqAbortException)):
+ return self._check_retry(remain_try, retry_interval, "连接异常", model_name)
+ elif isinstance(e, RespNotOkException):
+ return self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info)
+ elif isinstance(e, RespParseException):
+ logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误 - {e.message}")
+ return -1, None
+ else:
+ logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {str(e)}")
+ return -1, None
+
+ def _handle_resp_not_ok(
+ self, e: RespNotOkException, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
+ ) -> Tuple[int, Optional[List[Message]]]:
+ """处理非200的HTTP响应异常。"""
+ model_name = model_info.name
+ if e.status_code in [400, 401, 402, 403, 404]:
+ logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 客户端错误 {e.status_code} - {e.message},不再重试。")
+ return -1, None
+ elif e.status_code == 413:
+ messages, is_compressed = messages_info
+ if messages and not is_compressed:
+ logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试。")
+ return 0, compress_messages(messages)
+ logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大且无法压缩,放弃请求。")
+ return -1, None
+ elif e.status_code == 429 or e.status_code >= 500:
+ reason = "请求过于频繁" if e.status_code == 429 else "服务器错误"
+ return self._check_retry(remain_try, api_provider.retry_interval, reason, model_name)
+ else:
+ logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知响应错误 {e.status_code} - {e.message}")
+ return -1, None
+
+ def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> Tuple[int, None]:
+ """辅助函数:检查是否可以重试。"""
+ if remain_try > 1: # 剩余次数大于1才重试
+ logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},将于{interval}秒后重试 ({remain_try - 1}次剩余)。")
+ return interval, None
+ logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},已达最大重试次数,放弃。")
+ return -1, None
+
+
+class _RequestStrategy:
+ """
+ 封装高级请求策略,如故障转移。
+ 此类协调模型选择、提示处理和请求执行,以实现健壮的请求处理,
+ 即使在单个模型或API端点失败的情况下也能正常工作。
+ """
+
+ def __init__(self, model_selector: _ModelSelector, prompt_processor: _PromptProcessor, executor: _RequestExecutor, model_list: List[str], task_name: str):
+ """
+ 初始化请求策略。
+
+ Args:
+ model_selector (_ModelSelector): 模型选择器实例。
+ prompt_processor (_PromptProcessor): 提示处理器实例。
+ executor (_RequestExecutor): 请求执行器实例。
+ model_list (List[str]): 可用模型列表。
+ task_name (str): 当前任务的名称。
+ """
+ self.model_selector = model_selector
+ self.prompt_processor = prompt_processor
+ self.executor = executor
+ self.model_list = model_list
+ self.task_name = task_name
+
+ async def execute_with_failover(
+ self,
+ request_type: RequestType,
+ raise_when_empty: bool = True,
+ **kwargs,
+ ) -> Tuple[APIResponse, ModelInfo]:
+ """
+ 执行请求,动态选择最佳可用模型,并在模型失败时进行故障转移。
+ """
+ failed_models_in_this_request = set()
+ max_attempts = len(self.model_list)
+ last_exception: Optional[Exception] = None
+
+ for attempt in range(max_attempts):
+ selection_result = self.model_selector.select_best_available_model(failed_models_in_this_request, str(request_type.value))
+ if selection_result is None:
+ logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。")
+ break
+
+ model_info, api_provider, client = selection_result
+ logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_info.name}'...")
+
+ try:
+ # 准备请求参数
+ request_kwargs = kwargs.copy()
+ if request_type == RequestType.RESPONSE and "prompt" in request_kwargs:
+ prompt = request_kwargs.pop("prompt")
+ processed_prompt = self.prompt_processor.prepare_prompt(
+ prompt, model_info, api_provider, self.task_name
+ )
+ message = MessageBuilder().add_text_content(processed_prompt).build()
+ request_kwargs["message_list"] = [message]
+
+ # 合并模型特定的额外参数
+ if model_info.extra_params:
+ request_kwargs["extra_params"] = {**model_info.extra_params, **request_kwargs.get("extra_params", {})}
+
+ response = await self._try_model_request(model_info, api_provider, client, request_type, **request_kwargs)
+
+ # 成功,立即返回
+ logger.debug(f"模型 '{model_info.name}' 成功生成了回复。")
+ self.model_selector.update_usage_penalty(model_info.name, increase=False)
+ return response, model_info
+
+ except Exception as e:
+ logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。")
+ failed_models_in_this_request.add(model_info.name)
+ last_exception = e
+ # 使用惩罚值已在 select 时增加,失败后不减少,以降低其后续被选中的概率
+
+ logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。")
+ if raise_when_empty:
+ if last_exception:
+ raise RuntimeError("所有模型均未能生成响应。") from last_exception
+ raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。")
+
+ # 如果不抛出异常,返回一个备用响应
+ fallback_model_info = model_config.get_model_info(self.model_list[0])
+ return APIResponse(content="所有模型都请求失败"), fallback_model_info
+
+
+ async def _try_model_request(
+ self, model_info: ModelInfo, api_provider: APIProvider, client: BaseClient, request_type: RequestType, **kwargs
+ ) -> APIResponse:
+ """
+ 为单个模型尝试请求,包含空回复/截断的内部重试逻辑。
+ 如果模型返回空回复或响应被截断,此方法将自动重试请求,直到达到最大重试次数。
+
+ Args:
+ model_info (ModelInfo): 要使用的模型信息。
+ api_provider (APIProvider): API提供商信息。
+ client (BaseClient): API客户端实例。
+ request_type (RequestType): 请求类型。
+ **kwargs: 传递给执行器的请求参数。
+
+ Returns:
+ APIResponse: 成功的API响应。
+
+ Raises:
+ RuntimeError: 如果在达到最大重试次数后仍然收到空回复或截断的响应。
+ """
+ max_empty_retry = api_provider.max_retry
+
+ for i in range(max_empty_retry + 1):
+ response = await self.executor.execute_request(
+ api_provider, client, request_type, model_info, **kwargs
+ )
+
+ if request_type != RequestType.RESPONSE:
+ return response # 对于非响应类型,直接返回
+
+ # --- 响应内容处理和空回复/截断检查 ---
+ content = response.content or ""
+ use_anti_truncation = getattr(model_info, "use_anti_truncation", False)
+ processed_content, reasoning, is_truncated = self.prompt_processor.process_response(content, use_anti_truncation)
+
+ # 更新响应对象
+ response.content = processed_content
+ response.reasoning_content = response.reasoning_content or reasoning
+
+ is_empty_reply = not response.tool_calls and not (response.content and response.content.strip())
+
+ if not is_empty_reply and not is_truncated:
+ return response # 成功获取有效响应
+
+ if i < max_empty_retry:
+ reason = "空回复" if is_empty_reply else "截断"
+ logger.warning(f"模型 '{model_info.name}' 检测到{reason},正在进行内部重试 ({i + 1}/{max_empty_retry})...")
+ if api_provider.retry_interval > 0:
+ await asyncio.sleep(api_provider.retry_interval)
+ else:
+ reason = "空回复" if is_empty_reply else "截断"
+ logger.error(f"模型 '{model_info.name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。")
+ raise RuntimeError(f"模型 '{model_info.name}' 已达到空回复/截断的最大内部重试次数。")
+
+ raise RuntimeError("内部重试逻辑错误") # 理论上不应到达这里
+
+
+# ==============================================================================
+# Main Facade Class
+# ==============================================================================
+
+class LLMRequest:
+ """
+ LLM请求协调器。
+ 封装了模型选择、Prompt处理、请求执行和高级策略(如故障转移、并发)的完整流程。
+ 为上层业务逻辑提供统一的、简化的接口来与大语言模型交互。
+ """
+
+ def __init__(self, model_set: TaskConfig, request_type: str = ""):
+ """
+ 初始化LLM请求协调器。
+
+ Args:
+ model_set (TaskConfig): 特定任务的模型配置集合。
+ request_type (str, optional): 请求类型或任务名称,用于日志和用量记录。 Defaults to "".
+ """
+ self.task_name = request_type
+ self.model_for_task = model_set
+ self.model_usage: Dict[str, Tuple[int, int, int]] = {
+ model: (0, 0, 0) for model in self.model_for_task.model_list
+ }
+ """模型使用量记录,(total_tokens, penalty, usage_penalty)"""
+
+ # 初始化辅助类
+ self._model_selector = _ModelSelector(self.model_for_task.model_list, self.model_usage)
+ self._prompt_processor = _PromptProcessor()
+ self._executor = _RequestExecutor(self._model_selector, self.task_name)
+ self._strategy = _RequestStrategy(
+ self._model_selector, self._prompt_processor, self._executor, self.model_for_task.model_list, self.task_name
+ )
+
async def generate_response_for_image(
self,
prompt: str,
@@ -174,77 +578,57 @@ class LLMRequest:
max_tokens: Optional[int] = None,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
"""
- 为图像生成响应
+ 为图像生成响应。
+
Args:
prompt (str): 提示词
image_base64 (str): 图像的Base64编码字符串
image_format (str): 图像格式(如 'png', 'jpeg' 等)
+
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
- # 标准化图片格式以确保API兼容性
- normalized_format = _normalize_image_format(image_format)
-
- # 模型选择
start_time = time.time()
- model_info, api_provider, client = self._select_model()
-
- # 请求体构建
- message_builder = MessageBuilder()
- message_builder.add_text_content(prompt)
- message_builder.add_image_content(
+
+ # 图像请求目前不使用复杂的故障转移策略,直接选择模型并执行
+ selection_result = self._model_selector.select_best_available_model(set(), "response")
+ if not selection_result:
+ raise RuntimeError("无法为图像响应选择可用模型。")
+ model_info, api_provider, client = selection_result
+
+ normalized_format = _normalize_image_format(image_format)
+ message = MessageBuilder().add_text_content(prompt).add_image_content(
image_base64=image_base64,
image_format=normalized_format,
support_formats=client.get_support_image_formats(),
- )
- messages = [message_builder.build()]
+ ).build()
- # 请求并处理返回值
- response = await self._execute_request(
- api_provider=api_provider,
- client=client,
- request_type=RequestType.RESPONSE,
- model_info=model_info,
- message_list=messages,
+ response = await self._executor.execute_request(
+ api_provider, client, RequestType.RESPONSE, model_info,
+ message_list=[message],
temperature=temperature,
max_tokens=max_tokens,
)
- content = response.content or ""
- reasoning_content = response.reasoning_content or ""
- tool_calls = response.tool_calls
- # 从内容中提取标签的推理内容(向后兼容)
- if not reasoning_content and content:
- content, extracted_reasoning = self._extract_reasoning(content)
- reasoning_content = extracted_reasoning
- if usage := response.usage:
- await llm_usage_recorder.record_usage_to_database(
- model_info=model_info,
- model_usage=usage,
- user_id="system",
- time_cost=time.time() - start_time,
- request_type=self.request_type,
- endpoint="/chat/completions",
- )
- return content, (reasoning_content, model_info.name, tool_calls)
+
+ self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
+ content, reasoning, _ = self._prompt_processor.process_response(response.content or "", False)
+ reasoning = response.reasoning_content or reasoning
+
+ return content, (reasoning, model_info.name, response.tool_calls)
async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]:
"""
- 为语音生成响应
- Args:
- voice_base64 (str): 语音的Base64编码字符串
- Returns:
- (Optional[str]): 生成的文本描述或None
- """
- # 模型选择
- model_info, api_provider, client = self._select_model()
+ 为语音生成响应(语音转文字)。
+ 使用故障转移策略来确保即使主模型失败也能获得结果。
- # 请求并处理返回值
- response = await self._execute_request(
- api_provider=api_provider,
- client=client,
- request_type=RequestType.AUDIO,
- model_info=model_info,
- audio_base64=voice_base64,
+ Args:
+ voice_base64 (str): 语音的Base64编码字符串。
+
+ Returns:
+ Optional[str]: 语音转换后的文本内容,如果所有模型都失败则返回None。
+ """
+ response, _ = await self._strategy.execute_with_failover(
+ RequestType.AUDIO, audio_base64=voice_base64
)
return response.content or None
@@ -257,44 +641,36 @@ class LLMRequest:
raise_when_empty: bool = True,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
"""
- 异步生成响应,支持并发请求
+ 异步生成响应,支持并发请求。
+
Args:
prompt (str): 提示词
temperature (float, optional): 温度参数
max_tokens (int, optional): 最大token数
tools: 工具配置
- raise_when_empty: 是否在空回复时抛出异常
+ raise_when_empty (bool): 是否在空回复时抛出异常
+
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
- # 检查是否需要并发请求
concurrency_count = getattr(self.model_for_task, "concurrency_count", 1)
if concurrency_count <= 1:
- # 单次请求
- return await self._execute_single_request(prompt, temperature, max_tokens, tools, raise_when_empty)
-
- # 并发请求
+ return await self._execute_single_text_request(prompt, temperature, max_tokens, tools, raise_when_empty)
+
try:
- # 为 _execute_single_request 传递参数时,将 raise_when_empty 设为 False,
- # 这样单个请求失败时不会立即抛出异常,而是由 gather 统一处理
- content, (reasoning_content, model_name, tool_calls) = await execute_concurrently(
- self._execute_single_request,
+ return await execute_concurrently(
+ self._execute_single_text_request,
concurrency_count,
- prompt,
- temperature,
- max_tokens,
- tools,
- raise_when_empty=False,
+ prompt, temperature, max_tokens, tools, raise_when_empty=False
)
- return content, (reasoning_content, model_name, tool_calls)
except Exception as e:
logger.error(f"所有 {concurrency_count} 个并发请求都失败了: {e}")
if raise_when_empty:
raise e
return "所有并发请求都失败了", ("", "unknown", None)
- async def _execute_single_request(
+ async def _execute_single_text_request(
self,
prompt: str,
temperature: Optional[float] = None,
@@ -303,633 +679,101 @@ class LLMRequest:
raise_when_empty: bool = True,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
"""
- 执行单次请求,动态选择最佳可用模型,并在模型失败时进行故障转移。
+ 执行单次文本生成请求的内部方法。
+ 这是 `generate_response_async` 的核心实现,处理单个请求的完整生命周期,
+ 包括工具构建、故障转移执行和用量记录。
+
+ Args:
+ prompt (str): 用户的提示。
+ temperature (Optional[float]): 生成温度。
+ max_tokens (Optional[int]): 最大生成令牌数。
+ tools (Optional[List[Dict[str, Any]]]): 可用工具列表。
+ raise_when_empty (bool): 如果响应为空是否引发异常。
+
+ Returns:
+ Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
+ (响应内容, (推理过程, 模型名称, 工具调用))
"""
- failed_models_in_this_request = set()
- # 迭代次数等于模型总数,以确保每个模型在当前请求中最多只尝试一次
- max_attempts = len(self.model_for_task.model_list)
- last_exception: Optional[Exception] = None
+ start_time = time.time()
+ tool_options = self._build_tool_options(tools)
- for attempt in range(max_attempts):
- # 根据负载均衡和当前故障选择最佳可用模型
- model_selection_result = self._select_best_available_model(failed_models_in_this_request)
+ response, model_info = await self._strategy.execute_with_failover(
+ RequestType.RESPONSE,
+ raise_when_empty=raise_when_empty,
+ prompt=prompt, # 传递原始prompt,由strategy处理
+ tool_options=tool_options,
+ temperature=self.model_for_task.temperature if temperature is None else temperature,
+ max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
+ )
- if model_selection_result is None:
- logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。")
- break # 没有更多模型可供尝试
+ self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
- model_info, api_provider, client = model_selection_result
- model_name = model_info.name
- logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_name}'...")
+ if not response.content and not response.tool_calls:
+ if raise_when_empty:
+ raise RuntimeError("所选模型生成了空回复。")
+ response.content = "生成的响应为空"
- start_time = time.time()
-
- try:
- # --- 为当前模型尝试进行设置 ---
- # 检查是否为该模型启用反截断
- use_anti_truncation = getattr(model_info, "use_anti_truncation", False)
- processed_prompt = prompt
- if use_anti_truncation:
- processed_prompt += self.anti_truncation_instruction
- logger.info(f"模型 '{model_name}' (任务: '{self.task_name}') 已启用反截断功能。")
-
- processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider)
-
- message_builder = MessageBuilder()
- message_builder.add_text_content(processed_prompt)
- messages = [message_builder.build()]
- tool_built = self._build_tool_options(tools)
-
- # --- 当前选定模型内的空回复/截断重试逻辑 ---
- empty_retry_count = 0
- max_empty_retry = api_provider.max_retry
- empty_retry_interval = api_provider.retry_interval
-
- while empty_retry_count <= max_empty_retry:
- response = await self._execute_request(
- api_provider=api_provider,
- client=client,
- request_type=RequestType.RESPONSE,
- model_info=model_info,
- message_list=messages,
- tool_options=tool_built,
- temperature=temperature,
- max_tokens=max_tokens,
- )
-
- content = response.content or ""
- reasoning_content = response.reasoning_content or ""
- tool_calls = response.tool_calls
-
- # 向后兼容 标签(如果 reasoning_content 为空)
- if not reasoning_content and content:
- content, extracted_reasoning = self._extract_reasoning(content)
- reasoning_content = extracted_reasoning
-
- is_empty_reply = not tool_calls and (not content or content.strip() == "")
- is_truncated = False
- if use_anti_truncation:
- if content.endswith(self.end_marker):
- content = content[: -len(self.end_marker)].strip()
- else:
- is_truncated = True
-
- if is_empty_reply or is_truncated:
- empty_retry_count += 1
- if empty_retry_count <= max_empty_retry:
- reason = "空回复" if is_empty_reply else "截断"
- logger.warning(
- f"模型 '{model_name}' 检测到{reason},正在进行内部重试 ({empty_retry_count}/{max_empty_retry})..."
- )
- if empty_retry_interval > 0:
- await asyncio.sleep(empty_retry_interval)
- continue # 使用当前模型重试
- else:
- reason = "空回复" if is_empty_reply else "截断"
- logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。将此模型标记为当前请求失败。")
- raise RuntimeError(f"模型 '{model_name}' 已达到空回复/截断的最大内部重试次数。")
-
- # --- 从当前模型获取成功响应 ---
- if usage := response.usage:
- await llm_usage_recorder.record_usage_to_database(
- model_info=model_info,
- model_usage=usage,
- time_cost=time.time() - start_time,
- user_id="system",
- request_type=self.request_type,
- endpoint="/chat/completions",
- )
-
- # 处理成功执行后响应仍然为空的情况
- if not content and not tool_calls:
- if raise_when_empty:
- raise RuntimeError("所选模型生成了空回复。")
- content = "生成的响应为空" # Fallback message
-
- logger.debug(f"模型 '{model_name}' 成功生成了回复。")
- return content, (reasoning_content, model_name, tool_calls) # 成功,立即返回
-
- # --- 当前模型尝试过程中的异常处理 ---
- except Exception as e: # 捕获当前模型尝试过程中的所有异常
- # 修复 NameError: model_name 在异常处理块中未定义,应使用 model_info.name
- logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。")
- failed_models_in_this_request.add(model_info.name)
- last_exception = e # 存储异常以供最终报告
- # 继续循环以尝试下一个可用模型
-
- # 如果循环结束未能返回,则表示当前请求的所有模型都已失败
- logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。")
- if raise_when_empty:
- if last_exception:
- raise RuntimeError("所有模型均未能生成响应。") from last_exception
- raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。")
- return "所有模型都请求失败", ("", "unknown", None)
+ return response.content or "", (response.reasoning_content or "", model_info.name, response.tool_calls)
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
- """获取嵌入向量
+ """
+ 获取嵌入向量。
+
Args:
embedding_input (str): 获取嵌入的目标
+
Returns:
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
"""
- # 无需构建消息体,直接使用输入文本
start_time = time.time()
- model_info, api_provider, client = self._select_model()
-
- # 请求并处理返回值
- response = await self._execute_request(
- api_provider=api_provider,
- client=client,
- request_type=RequestType.EMBEDDING,
- model_info=model_info,
- embedding_input=embedding_input,
+ response, model_info = await self._strategy.execute_with_failover(
+ RequestType.EMBEDDING,
+ embedding_input=embedding_input
)
+
+ self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings")
+
+ if not response.embedding:
+ raise RuntimeError("获取embedding失败")
+
+ return response.embedding, model_info.name
- embedding = response.embedding
-
- if usage := response.usage:
- await llm_usage_recorder.record_usage_to_database(
+ def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str):
+ """异步记录用量到数据库。"""
+ if usage:
+ # 更新内存中的token计数
+ total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
+ self.model_usage[model_info.name] = (total_tokens + usage.total_tokens, penalty, usage_penalty)
+
+ asyncio.create_task(llm_usage_recorder.record_usage_to_database(
model_info=model_info,
- time_cost=time.time() - start_time,
model_usage=usage,
user_id="system",
- request_type=self.request_type,
- endpoint="/embeddings",
- )
-
- if not embedding:
- raise RuntimeError("获取embedding失败")
-
- return embedding, model_info.name
-
- def _select_best_available_model(self, failed_models_in_this_request: set) -> Tuple[ModelInfo, APIProvider, BaseClient] | None:
- """
- 从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。
-
- 参数:
- failed_models_in_this_request (set): 当前请求中已失败的模型名称集合。
-
- 返回:
- Tuple[ModelInfo, APIProvider, BaseClient] | None: 选定的模型详细信息,如果无可用模型则返回 None。
- """
- candidate_models_usage = {}
- # 过滤掉当前请求中已失败的模型
- for model_name, usage_data in self.model_usage.items():
- if model_name not in failed_models_in_this_request:
- candidate_models_usage[model_name] = usage_data
-
- if not candidate_models_usage:
- logger.warning("没有可用的模型供当前请求选择。")
- return None
-
- # 根据现有公式查找分数最低的模型,该公式综合了总token数、模型惩罚值和使用频率惩罚值。
- # 公式: total_tokens + penalty * 300 + usage_penalty * 1000
- # 较高的 usage_penalty (由于被选中的模型会被增加) 和 penalty (由于模型失败) 会使模型得分更高,从而降低被选中的几率。
- least_used_model_name = min(
- candidate_models_usage,
- key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000,
- )
-
- # --- 动态故障转移的核心逻辑 ---
- # _execute_single_request 中的循环会多次调用此函数。
- # 如果当前选定的模型因异常而失败,下次循环会重新调用此函数,
- # 此时由于失败模型已被标记,且其惩罚值可能已在 _execute_request 中增加,
- # _select_best_available_model 会自动选择一个得分更低(即更可用)的模型。
- # 这种机制实现了动态的、基于当前系统状态的故障转移。
-
- model_info = model_config.get_model_info(least_used_model_name)
- api_provider = model_config.get_provider(model_info.api_provider)
-
- # 对于嵌入任务,如果需要,强制创建新的客户端实例(从原始 _select_model 复制)
- force_new_client = self.request_type == "embedding"
- client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
-
- logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}")
-
- # 增加所选模型的请求使用惩罚值,以反映其当前使用情况/选择。
- # 这有助于在同一请求的后续选择或未来请求中实现动态负载均衡。
- total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
- self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1)
-
- return model_info, api_provider, client
-
- def _model_scheduler(self, failed_models: set) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]:
- """
- 一个模型调度器,按顺序提供模型,并跳过已失败的模型。
- """
- for model_name in self.model_for_task.model_list:
- if model_name in failed_models:
- continue
-
- model_info = model_config.get_model_info(model_name)
- api_provider = model_config.get_provider(model_info.api_provider)
- force_new_client = self.request_type == "embedding"
- client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
-
- yield model_info, api_provider, client
-
- def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
- """
- 根据总tokens和惩罚值选择的模型 (负载均衡)
- """
- least_used_model_name = min(
- self.model_usage,
- key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000,
- )
- model_info = model_config.get_model_info(least_used_model_name)
- api_provider = model_config.get_provider(model_info.api_provider)
-
- # 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题
- force_new_client = self.request_type == "embedding"
- client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
- logger.debug(f"选择请求模型: {model_info.name}")
- total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
- self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用
- return model_info, api_provider, client
-
- async def _execute_request(
- self,
- api_provider: APIProvider,
- client: BaseClient,
- request_type: RequestType,
- model_info: ModelInfo,
- message_list: List[Message] | None = None,
- tool_options: list[ToolOption] | None = None,
- response_format: RespFormat | None = None,
- stream_response_handler: Optional[Callable] = None,
- async_response_parser: Optional[Callable] = None,
- temperature: Optional[float] = None,
- max_tokens: Optional[int] = None,
- embedding_input: str = "",
- audio_base64: str = "",
- ) -> APIResponse:
- """
- 实际执行请求的方法
-
- 包含了重试和异常处理逻辑
- """
- retry_remain = api_provider.max_retry
- compressed_messages: Optional[List[Message]] = None
- while retry_remain > 0:
- try:
- if request_type == RequestType.RESPONSE:
- assert message_list is not None, "message_list cannot be None for response requests"
- return await client.get_response(
- model_info=model_info,
- message_list=(compressed_messages or message_list),
- tool_options=tool_options,
- max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
- temperature=self.model_for_task.temperature if temperature is None else temperature,
- response_format=response_format,
- stream_response_handler=stream_response_handler,
- async_response_parser=async_response_parser,
- extra_params=model_info.extra_params,
- )
- elif request_type == RequestType.EMBEDDING:
- assert embedding_input, "embedding_input cannot be empty for embedding requests"
- return await client.get_embedding(
- model_info=model_info,
- embedding_input=embedding_input,
- extra_params=model_info.extra_params,
- )
- elif request_type == RequestType.AUDIO:
- assert audio_base64 is not None, "audio_base64 cannot be None for audio requests"
- return await client.get_audio_transcriptions(
- model_info=model_info,
- audio_base64=audio_base64,
- extra_params=model_info.extra_params,
- )
- except Exception as e:
- logger.debug(f"请求失败: {str(e)}")
- # 处理异常
- total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
-
- # --- 增强动态故障转移的智能性 ---
- # 根据异常类型和严重程度,动态调整模型的惩罚值。
- # 关键错误(如网络连接、服务器错误)会获得更高的惩罚,
- # 促使负载均衡算法在下次选择时优先规避这些不可靠的模型。
- CRITICAL_PENALTY_MULTIPLIER = 5 # 关键错误时的惩罚系数
- default_penalty_increment = 1 # 普通错误时的基础惩罚
-
- penalty_increment = default_penalty_increment
-
- if isinstance(e, NetworkConnectionError):
- # 网络连接问题表明模型服务器不稳定,增加较高惩罚
- penalty_increment = CRITICAL_PENALTY_MULTIPLIER
- # 修复 NameError: model_name 在此处未定义,应使用 model_info.name
- logger.warning(f"模型 '{model_info.name}' 发生网络连接错误,增加惩罚值: {penalty_increment}")
- elif isinstance(e, ReqAbortException):
- # 请求被中止,可能是服务器端原因或服务不稳定,增加较高惩罚
- penalty_increment = CRITICAL_PENALTY_MULTIPLIER
- # 修复 NameError: model_name 在此处未定义,应使用 model_info.name
- logger.warning(f"模型 '{model_info.name}' 请求被中止,增加惩罚值: {penalty_increment}")
- elif isinstance(e, RespNotOkException):
- if e.status_code >= 500:
- # 服务器错误 (5xx) 表明服务器端问题,应显著增加惩罚
- penalty_increment = CRITICAL_PENALTY_MULTIPLIER
- logger.warning(f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加惩罚值: {penalty_increment}")
- elif e.status_code == 429:
- # 请求过于频繁,是暂时性问题,但仍需惩罚,此处使用默认基础值
- # penalty_increment = 2 # 可以选择一个中间值,例如2,表示比普通错误重,但比关键错误轻
- logger.warning(f"模型 '{model_name}' 请求过于频繁 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}")
- else:
- # 其他客户端错误 (4xx)。通常不重试,_handle_resp_not_ok 会处理。
- # 如果 _handle_resp_not_ok 返回 retry_interval, 则进入这里的 exception 块。
- logger.warning(f"模型 '{model_name}' 发生非致命的响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}")
- else:
- # 其他未捕获的异常,增加基础惩罚
- logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}")
-
- self.model_usage[model_info.name] = (total_tokens, penalty + penalty_increment, usage_penalty)
- # --- 结束增强 ---
- # 移除冗余的、错误的惩罚值更新行,保留上面正确的动态惩罚更新
- # self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty)
-
- wait_interval, compressed_messages = self._default_exception_handler(
- e,
- self.task_name,
- model_info=model_info,
- api_provider=api_provider,
- remain_try=retry_remain,
- retry_interval=api_provider.retry_interval,
- messages=(message_list, compressed_messages is not None) if message_list else None,
- )
-
- if wait_interval == -1:
- retry_remain = 0 # 不再重试
- elif wait_interval > 0:
- logger.info(f"等待 {wait_interval} 秒后重试...")
- await asyncio.sleep(wait_interval)
- finally:
- # 放在finally防止死循环
- retry_remain -= 1
- total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
- self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) # 使用结束,减少使用惩罚值
- logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次")
- raise RuntimeError("请求失败,已达到最大重试次数")
-
- def _default_exception_handler(
- self,
- e: Exception,
- task_name: str,
- model_info: ModelInfo,
- api_provider: APIProvider,
- remain_try: int,
- retry_interval: int = 10,
- messages: Tuple[List[Message], bool] | None = None,
- ) -> Tuple[int, List[Message] | None]:
- """
- 默认异常处理函数
- Args:
- e (Exception): 异常对象
- task_name (str): 任务名称
- model_info (ModelInfo): 模型信息
- api_provider (APIProvider): API提供商
- remain_try (int): 剩余尝试次数
- retry_interval (int): 重试间隔
- messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过)
- Returns:
- (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息))
- """
- model_name = model_info.name if model_info else "unknown"
-
- if isinstance(e, NetworkConnectionError): # 网络连接错误
- return self._check_retry(
- remain_try,
- retry_interval,
- can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,将于{retry_interval}秒后重试",
- cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,超过最大重试次数,请检查网络连接状态或URL是否正确",
- )
- elif isinstance(e, ReqAbortException):
- logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}")
- return -1, None # 不再重试请求该模型
- elif isinstance(e, RespNotOkException):
- return self._handle_resp_not_ok(
- e,
- task_name,
- model_info,
- api_provider,
- remain_try,
- retry_interval,
- messages,
- )
- elif isinstance(e, RespParseException):
- # 响应解析错误
- logger.error(f"任务-'{task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}")
- logger.debug(f"附加内容: {str(e.ext_info)}")
- return -1, None # 不再重试请求该模型
- else:
- logger.error(f"任务-'{task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}")
- return -1, None # 不再重试请求该模型
-
- @staticmethod
- def _check_retry(
- remain_try: int,
- retry_interval: int,
- can_retry_msg: str,
- cannot_retry_msg: str,
- can_retry_callable: Callable | None = None,
- **kwargs,
- ) -> Tuple[int, List[Message] | None]:
- """辅助函数:检查是否可以重试
- Args:
- remain_try (int): 剩余尝试次数
- retry_interval (int): 重试间隔
- can_retry_msg (str): 可以重试时的提示信息
- cannot_retry_msg (str): 不可以重试时的提示信息
- can_retry_callable (Callable | None): 可以重试时调用的函数(如果有)
- **kwargs: 其他参数
-
- Returns:
- (Tuple[int, List[Message] | None]): (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息))
- """
- if remain_try > 0:
- # 还有重试机会
- logger.warning(f"{can_retry_msg}")
- if can_retry_callable is not None:
- return retry_interval, can_retry_callable(**kwargs)
- else:
- return retry_interval, None
- else:
- # 达到最大重试次数
- logger.warning(f"{cannot_retry_msg}")
- return -1, None # 不再重试请求该模型
-
- def _handle_resp_not_ok(
- self,
- e: RespNotOkException,
- task_name: str,
- model_info: ModelInfo,
- api_provider: APIProvider,
- remain_try: int,
- retry_interval: int = 10,
- messages: tuple[list[Message], bool] | None = None,
- ):
- model_name = model_info.name
- """
- 处理响应错误异常
- Args:
- e (RespNotOkException): 响应错误异常对象
- task_name (str): 任务名称
- model_info (ModelInfo): 模型信息
- api_provider (APIProvider): API提供商
- remain_try (int): 剩余尝试次数
- retry_interval (int): 重试间隔
- messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过)
- Returns:
- (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息))
- """
- # 响应错误
- if e.status_code in [400, 401, 402, 403, 404]:
- model_name = model_info.name
- # 客户端错误
- logger.warning(
- f"任务-'{task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}"
- )
- return -1, None # 不再重试请求该模型
- elif e.status_code == 413:
- if messages and not messages[1]:
- # 消息列表不为空且未压缩,尝试压缩消息
- return self._check_retry(
- remain_try,
- 0,
- can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试",
- cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,压缩消息后仍然过大,放弃请求",
- can_retry_callable=compress_messages,
- messages=messages[0],
- )
- # 没有消息可压缩
- logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,无法压缩消息,放弃请求。")
- return -1, None
- elif e.status_code == 429:
- # 请求过于频繁
- return self._check_retry(
- remain_try,
- retry_interval,
- can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试",
- cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数,放弃请求",
- )
- elif e.status_code >= 500:
- # 服务器错误
- return self._check_retry(
- remain_try,
- retry_interval,
- can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,将于{retry_interval}秒后重试",
- cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,超过最大重试次数,请稍后再试",
- )
- else:
- # 未知错误
- logger.warning(
- f"任务-'{task_name}' 模型-'{model_name}': 未知错误,错误代码-{e.status_code},错误信息-{e.message}"
- )
- return -1, None
+ time_cost=time_cost,
+ request_type=self.task_name,
+ endpoint=endpoint,
+ ))
@staticmethod
def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
- # sourcery skip: extract-method
- """构建工具选项列表"""
+ """构建工具选项列表。"""
if not tools:
return None
tool_options: List[ToolOption] = []
for tool in tools:
- tool_legal = True
- tool_options_builder = ToolOptionBuilder()
- tool_options_builder.set_name(tool.get("name", ""))
- tool_options_builder.set_description(tool.get("description", ""))
- parameters: List[Tuple[str, str, str, bool, List[str] | None]] = tool.get("parameters", [])
- for param in parameters:
- try:
+ try:
+ builder = ToolOptionBuilder().set_name(tool["name"]).set_description(tool.get("description", ""))
+ for param in tool.get("parameters", []):
+ # 参数格式验证
assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组"
- assert isinstance(param[0], str), "参数名称必须是字符串"
- assert isinstance(param[1], ToolParamType), "参数类型必须是ToolParamType枚举"
- assert isinstance(param[2], str), "参数描述必须是字符串"
- assert isinstance(param[3], bool), "参数是否必填必须是布尔值"
- assert isinstance(param[4], list) or param[4] is None, "参数枚举值必须是列表或None"
- tool_options_builder.add_param(
+ builder.add_param(
name=param[0],
param_type=param[1],
description=param[2],
required=param[3],
enum_values=param[4],
)
- except AssertionError as ae:
- tool_legal = False
- logger.error(f"{param[0]} 参数定义错误: {str(ae)}")
- except Exception as e:
- tool_legal = False
- logger.error(f"构建工具参数失败: {str(e)}")
- if tool_legal:
- tool_options.append(tool_options_builder.build())
+ tool_options.append(builder.build())
+ except (KeyError, IndexError, TypeError, AssertionError) as e:
+ logger.error(f"构建工具 '{tool.get('name', 'N/A')}' 失败: {e}")
return tool_options or None
-
- @staticmethod
- def _extract_reasoning(content: str) -> Tuple[str, str]:
- """CoT思维链提取,向后兼容"""
- match = re.search(r"(?:)?(.*?)", content, re.DOTALL)
- content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip()
- reasoning = match[1].strip() if match else ""
- return content, reasoning
-
- def _apply_content_obfuscation(self, text: str, api_provider) -> str:
- """根据API提供商配置对文本进行混淆处理"""
- if not hasattr(api_provider, "enable_content_obfuscation") or not api_provider.enable_content_obfuscation:
- logger.debug(f"API提供商 '{api_provider.name}' 未启用内容混淆")
- return text
-
- intensity = getattr(api_provider, "obfuscation_intensity", 1)
- logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}")
-
- # 在开头加入过滤规则指令
- processed_text = self.noise_instruction + "\n\n" + text
- logger.debug(f"已添加过滤规则指令,文本长度: {len(text)} -> {len(processed_text)}")
-
- # 添加随机乱码
- final_text = self._inject_random_noise(processed_text, intensity)
- logger.debug(f"乱码注入完成,最终文本长度: {len(final_text)}")
-
- return final_text
-
- @staticmethod
- def _inject_random_noise(text: str, intensity: int) -> str:
- """在文本中注入随机乱码"""
- import random
- import string
-
- def generate_noise(length: int) -> str:
- """生成指定长度的随机乱码字符"""
- chars = (
- string.ascii_letters # a-z, A-Z
- + string.digits # 0-9
- + "!@#$%^&*()_+-=[]{}|;:,.<>?" # 特殊符号
- + "一二三四五六七八九零壹贰叁" # 中文字符
- + "αβγδεζηθικλμνξοπρστυφχψω" # 希腊字母
- + "∀∃∈∉∪∩⊂⊃∧∨¬→↔∴∵" # 数学符号
- )
- return "".join(random.choice(chars) for _ in range(length))
-
- # 强度参数映射
- params = {
- 1: {"probability": 15, "length": (3, 6)}, # 低强度:15%概率,3-6个字符
- 2: {"probability": 25, "length": (5, 10)}, # 中强度:25%概率,5-10个字符
- 3: {"probability": 35, "length": (8, 15)}, # 高强度:35%概率,8-15个字符
- }
-
- config = params.get(intensity, params[1])
- logger.debug(f"乱码注入参数: 概率={config['probability']}%, 长度范围={config['length']}")
-
- # 按词分割处理
- words = text.split()
- result = []
- noise_count = 0
-
- for word in words:
- result.append(word)
- # 根据概率插入乱码
- if random.randint(1, 100) <= config["probability"]:
- noise_length = random.randint(*config["length"])
- noise = generate_noise(noise_length)
- result.append(noise)
- noise_count += 1
-
- logger.debug(f"共注入 {noise_count} 个乱码片段,原词数: {len(words)}")
- return " ".join(result)