refactor(llm): 将LLM请求逻辑解耦到专门的组件中
庞大的`LLMRequest`类已被重构为一个协调器,它将任务委托给多个专门的组件。此更改旨在遵循单一职责原则,从而提高代码的结构、可维护性和可扩展性。 核心逻辑被提取到以下新类中: - `ModelSelector`: 封装了基于负载和可用性选择最佳模型的逻辑。 - `PromptProcessor`: 负责处理所有提示词修改和响应内容的解析。 - `RequestStrategy`: 管理请求的执行流程,包括故障转移和并发请求策略。 这种新的架构使系统更加模块化,更易于测试,并且未来可以更轻松地扩展新的请求策略。
This commit is contained in:
65
src/llm_models/llm_utils.py
Normal file
65
src/llm_models/llm_utils.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
@File : llm_utils.py
|
||||||
|
@Time : 2024/05/24 17:00:00
|
||||||
|
@Author : 墨墨
|
||||||
|
@Version : 1.0
|
||||||
|
@Desc : LLM相关通用工具函数
|
||||||
|
"""
|
||||||
|
from typing import List, Dict, Any, Tuple
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from .payload_content.tool_option import ToolOption, ToolOptionBuilder, ToolParamType
|
||||||
|
|
||||||
|
logger = get_logger("llm_utils")
|
||||||
|
|
||||||
|
def normalize_image_format(image_format: str) -> str:
|
||||||
|
"""
|
||||||
|
标准化图片格式名称,确保与各种API的兼容性
|
||||||
|
"""
|
||||||
|
format_mapping = {
|
||||||
|
"jpg": "jpeg", "JPG": "jpeg", "JPEG": "jpeg", "jpeg": "jpeg",
|
||||||
|
"png": "png", "PNG": "png",
|
||||||
|
"webp": "webp", "WEBP": "webp",
|
||||||
|
"gif": "gif", "GIF": "gif",
|
||||||
|
"heic": "heic", "HEIC": "heic",
|
||||||
|
"heif": "heif", "HEIF": "heif",
|
||||||
|
}
|
||||||
|
normalized = format_mapping.get(image_format, image_format.lower())
|
||||||
|
logger.debug(f"图片格式标准化: {image_format} -> {normalized}")
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
def build_tool_options(tools: List[Dict[str, Any]] | None) -> List[ToolOption] | None:
|
||||||
|
"""构建工具选项列表"""
|
||||||
|
if not tools:
|
||||||
|
return None
|
||||||
|
tool_options: List[ToolOption] = []
|
||||||
|
for tool in tools:
|
||||||
|
try:
|
||||||
|
tool_options_builder = ToolOptionBuilder()
|
||||||
|
tool_options_builder.set_name(tool.get("name", ""))
|
||||||
|
tool_options_builder.set_description(tool.get("description", ""))
|
||||||
|
parameters: List[Tuple[str, str, str, bool, List[str] | None]] = tool.get("parameters", [])
|
||||||
|
for param in parameters:
|
||||||
|
# 参数校验
|
||||||
|
assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组"
|
||||||
|
assert isinstance(param[0], str), "参数名称必须是字符串"
|
||||||
|
assert isinstance(param[1], ToolParamType), "参数类型必须是ToolParamType枚举"
|
||||||
|
assert isinstance(param[2], str), "参数描述必须是字符串"
|
||||||
|
assert isinstance(param[3], bool), "参数是否必填必须是布尔值"
|
||||||
|
assert isinstance(param[4], list) or param[4] is None, "参数枚举值必须是列表或None"
|
||||||
|
|
||||||
|
tool_options_builder.add_param(
|
||||||
|
name=param[0],
|
||||||
|
param_type=param[1],
|
||||||
|
description=param[2],
|
||||||
|
required=param[3],
|
||||||
|
enum_values=param[4],
|
||||||
|
)
|
||||||
|
tool_options.append(tool_options_builder.build())
|
||||||
|
except AssertionError as ae:
|
||||||
|
logger.error(f"工具 '{tool.get('name', 'unknown')}' 的参数定义错误: {str(ae)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"构建工具 '{tool.get('name', 'unknown')}' 失败: {str(e)}")
|
||||||
|
|
||||||
|
return tool_options or None
|
||||||
@@ -122,7 +122,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]:
|
|||||||
|
|
||||||
def _convert_tool_param(param: ToolParam) -> dict:
|
def _convert_tool_param(param: ToolParam) -> dict:
|
||||||
"""转换工具参数"""
|
"""转换工具参数"""
|
||||||
result = {
|
result: dict[str, Any] = {
|
||||||
"type": param.param_type.value,
|
"type": param.param_type.value,
|
||||||
"description": param.description,
|
"description": param.description,
|
||||||
}
|
}
|
||||||
@@ -132,7 +132,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]:
|
|||||||
|
|
||||||
def _convert_tool_option_item(tool_option: ToolOption) -> dict:
|
def _convert_tool_option_item(tool_option: ToolOption) -> dict:
|
||||||
"""转换单个工具选项"""
|
"""转换单个工具选项"""
|
||||||
function_declaration = {
|
function_declaration: dict[str, Any] = {
|
||||||
"name": tool_option.name,
|
"name": tool_option.name,
|
||||||
"description": tool_option.description,
|
"description": tool_option.description,
|
||||||
}
|
}
|
||||||
@@ -500,7 +500,7 @@ class AiohttpGeminiClient(BaseClient):
|
|||||||
# 直接重抛项目定义的异常
|
# 直接重抛项目定义的异常
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(e)
|
logger.debug(f"请求处理中发生未知异常: {e}")
|
||||||
# 其他异常转换为网络连接错误
|
# 其他异常转换为网络连接错误
|
||||||
raise NetworkConnectionError() from e
|
raise NetworkConnectionError() from e
|
||||||
|
|
||||||
|
|||||||
130
src/llm_models/model_selector.py
Normal file
130
src/llm_models/model_selector.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
@File : model_selector.py
|
||||||
|
@Time : 2024/05/24 16:00:00
|
||||||
|
@Author : 墨墨
|
||||||
|
@Version : 1.0
|
||||||
|
@Desc : 模型选择与负载均衡器
|
||||||
|
"""
|
||||||
|
from typing import Dict, Tuple, Set, Optional
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import model_config
|
||||||
|
from src.config.api_ada_configs import ModelInfo, APIProvider, TaskConfig
|
||||||
|
from .model_client.base_client import BaseClient, client_registry
|
||||||
|
|
||||||
|
logger = get_logger("model_selector")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSelector:
|
||||||
|
"""模型选择与负载均衡器"""
|
||||||
|
|
||||||
|
def __init__(self, model_set: TaskConfig, request_type: str = ""):
|
||||||
|
"""
|
||||||
|
初始化模型选择器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_set (TaskConfig): 任务配置中定义的模型集合
|
||||||
|
request_type (str, optional): 请求类型 (例如 "embedding"). Defaults to "".
|
||||||
|
"""
|
||||||
|
self.model_for_task = model_set
|
||||||
|
self.request_type = request_type
|
||||||
|
self.model_usage: Dict[str, Tuple[int, int, int]] = {
|
||||||
|
model: (0, 0, 0) for model in self.model_for_task.model_list
|
||||||
|
}
|
||||||
|
"""模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整"""
|
||||||
|
|
||||||
|
def select_best_available_model(
|
||||||
|
self, failed_models_in_this_request: Set[str]
|
||||||
|
) -> Optional[Tuple[ModelInfo, APIProvider, BaseClient]]:
|
||||||
|
"""
|
||||||
|
从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
failed_models_in_this_request (Set[str]): 当前请求中已失败的模型名称集合。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Tuple[ModelInfo, APIProvider, BaseClient]]: 选定的模型详细信息,如果无可用模型则返回 None。
|
||||||
|
"""
|
||||||
|
candidate_models_usage = {
|
||||||
|
model_name: usage_data
|
||||||
|
for model_name, usage_data in self.model_usage.items()
|
||||||
|
if model_name not in failed_models_in_this_request
|
||||||
|
}
|
||||||
|
|
||||||
|
if not candidate_models_usage:
|
||||||
|
logger.warning("没有可用的模型供当前请求选择。")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 根据现有公式查找分数最低的模型
|
||||||
|
# 公式: total_tokens + penalty * 300 + usage_penalty * 1000
|
||||||
|
# 较高的 usage_penalty (由于被选中的模型会被增加) 和 penalty (由于模型失败) 会使模型得分更高,从而降低被选中的几率。
|
||||||
|
least_used_model_name = min(
|
||||||
|
candidate_models_usage,
|
||||||
|
key=lambda k: candidate_models_usage[k][0]
|
||||||
|
+ candidate_models_usage[k][1] * 300
|
||||||
|
+ candidate_models_usage[k][2] * 1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- 动态故障转移的核心逻辑 ---
|
||||||
|
# RequestStrategy 中的循环会多次调用此函数。
|
||||||
|
# 如果当前选定的模型因异常而失败,下次循环会重新调用此函数,
|
||||||
|
# 此时由于失败模型已被标记,且其惩罚值可能已在 RequestExecutor 中增加,
|
||||||
|
# 此函数会自动选择一个得分更低(即更可用)的模型。
|
||||||
|
# 这种机制实现了动态的、基于当前系统状态的故障转移。
|
||||||
|
model_info = model_config.get_model_info(least_used_model_name)
|
||||||
|
api_provider = model_config.get_provider(model_info.api_provider)
|
||||||
|
|
||||||
|
force_new_client = self.request_type == "embedding"
|
||||||
|
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||||
|
|
||||||
|
logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}")
|
||||||
|
|
||||||
|
# 增加所选模型的请求使用惩罚值,以反映其当前使用情况/选择。
|
||||||
|
# 这有助于在同一请求的后续选择或未来请求中实现动态负载均衡。
|
||||||
|
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||||
|
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1)
|
||||||
|
|
||||||
|
return model_info, api_provider, client
|
||||||
|
|
||||||
|
def select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
|
||||||
|
"""
|
||||||
|
根据总tokens和惩罚值选择的模型 (负载均衡)
|
||||||
|
"""
|
||||||
|
least_used_model_name = min(
|
||||||
|
self.model_usage,
|
||||||
|
key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000,
|
||||||
|
)
|
||||||
|
model_info = model_config.get_model_info(least_used_model_name)
|
||||||
|
api_provider = model_config.get_provider(model_info.api_provider)
|
||||||
|
|
||||||
|
force_new_client = self.request_type == "embedding"
|
||||||
|
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||||
|
logger.debug(f"选择请求模型: {model_info.name}")
|
||||||
|
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||||
|
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1)
|
||||||
|
return model_info, api_provider, client
|
||||||
|
|
||||||
|
def update_model_penalty(self, model_name: str, penalty_increment: int):
|
||||||
|
"""
|
||||||
|
更新指定模型的惩罚值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): 模型名称
|
||||||
|
penalty_increment (int): 惩罚增量
|
||||||
|
"""
|
||||||
|
if model_name in self.model_usage:
|
||||||
|
total_tokens, penalty, usage_penalty = self.model_usage[model_name]
|
||||||
|
self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty)
|
||||||
|
logger.debug(f"模型 '{model_name}' 的惩罚值增加了 {penalty_increment}")
|
||||||
|
|
||||||
|
def decrease_usage_penalty(self, model_name: str):
|
||||||
|
"""
|
||||||
|
请求结束后,减少使用惩罚值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): 模型名称
|
||||||
|
"""
|
||||||
|
if model_name in self.model_usage:
|
||||||
|
total_tokens, penalty, usage_penalty = self.model_usage[model_name]
|
||||||
|
self.model_usage[model_name] = (total_tokens, penalty, usage_penalty - 1)
|
||||||
113
src/llm_models/prompt_processor.py
Normal file
113
src/llm_models/prompt_processor.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
@File : prompt_processor.py
|
||||||
|
@Time : 2024/05/24 16:45:00
|
||||||
|
@Author : 墨墨
|
||||||
|
@Version : 1.0
|
||||||
|
@Desc : 提示词预处理(反截断、内容混淆等)
|
||||||
|
"""
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
import re
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.api_ada_configs import APIProvider, ModelInfo
|
||||||
|
|
||||||
|
logger = get_logger("prompt_processor")
|
||||||
|
|
||||||
|
|
||||||
|
class PromptProcessor:
|
||||||
|
"""提示词处理器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.noise_instruction = """**【核心过滤规则】**
|
||||||
|
在交互过程中,你会在用户的输入内容中看到一些无意义的、由字母和数字组成的随机字符串。这些字符串是API传输过程中的技术伪装信息,【不是】用户真实意图的一部分。
|
||||||
|
|
||||||
|
你的任务是【完全并彻底地忽略】这些随机字符串。
|
||||||
|
**【严禁】** 将这些字符串写入你的回复、对它们进行任何形式的解读、或将它们作为故事的一部分进行模拟。当看到这些字符串时,请将其视为空白或直接跳过,并根据用户其他有意义的输入(如有)或当前故事情节来继续你的叙事。此规则为最高优先级。
|
||||||
|
"""
|
||||||
|
self.end_marker = "###MAI_RESPONSE_END###"
|
||||||
|
self.anti_truncation_instruction = f"""
|
||||||
|
**【输出完成信令】**
|
||||||
|
这是一个非常重要的指令,请务-务必遵守。在你的回复内容完全结束后,请务必在最后另起一行,只写 `{self.end_marker}` 作为结束标志。
|
||||||
|
例如:
|
||||||
|
<你的回复内容>
|
||||||
|
{self.end_marker}
|
||||||
|
|
||||||
|
这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def process_prompt(
|
||||||
|
self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
根据模型和API提供商的配置处理提示词
|
||||||
|
"""
|
||||||
|
processed_prompt = prompt
|
||||||
|
|
||||||
|
# 1. 添加反截断指令
|
||||||
|
use_anti_truncation = getattr(model_info, "use_anti_truncation", False)
|
||||||
|
if use_anti_truncation:
|
||||||
|
processed_prompt += self.anti_truncation_instruction
|
||||||
|
logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。")
|
||||||
|
|
||||||
|
# 2. 应用内容混淆
|
||||||
|
if getattr(api_provider, "enable_content_obfuscation", False):
|
||||||
|
intensity = getattr(api_provider, "obfuscation_intensity", 1)
|
||||||
|
logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}")
|
||||||
|
processed_prompt = self._apply_content_obfuscation(processed_prompt, intensity)
|
||||||
|
|
||||||
|
return processed_prompt
|
||||||
|
|
||||||
|
def _apply_content_obfuscation(self, text: str, intensity: int) -> str:
|
||||||
|
"""对文本进行混淆处理"""
|
||||||
|
# 在开头加入过滤规则指令
|
||||||
|
processed_text = self.noise_instruction + "\n\n" + text
|
||||||
|
logger.debug(f"已添加过滤规则指令,文本长度: {len(text)} -> {len(processed_text)}")
|
||||||
|
|
||||||
|
# 添加随机乱码
|
||||||
|
final_text = self._inject_random_noise(processed_text, intensity)
|
||||||
|
logger.debug(f"乱码注入完成,最终文本长度: {len(final_text)}")
|
||||||
|
|
||||||
|
return final_text
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _inject_random_noise(text: str, intensity: int) -> str:
|
||||||
|
"""在文本中注入随机乱码"""
|
||||||
|
def generate_noise(length: int) -> str:
|
||||||
|
chars = (
|
||||||
|
string.ascii_letters + string.digits + "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||||
|
+ "一二三四五六七八九零壹贰叁" + "αβγδεζηθικλμνξοπρστυφχψω" + "∀∃∈∉∪∩⊂⊃∧∨¬→↔∴∵"
|
||||||
|
)
|
||||||
|
return "".join(random.choice(chars) for _ in range(length))
|
||||||
|
|
||||||
|
params = {
|
||||||
|
1: {"probability": 15, "length": (3, 6)},
|
||||||
|
2: {"probability": 25, "length": (5, 10)},
|
||||||
|
3: {"probability": 35, "length": (8, 15)},
|
||||||
|
}
|
||||||
|
config = params.get(intensity, params[1])
|
||||||
|
logger.debug(f"乱码注入参数: 概率={config['probability']}%, 长度范围={config['length']}")
|
||||||
|
|
||||||
|
words = text.split()
|
||||||
|
result = []
|
||||||
|
noise_count = 0
|
||||||
|
for word in words:
|
||||||
|
result.append(word)
|
||||||
|
if random.randint(1, 100) <= config["probability"]:
|
||||||
|
noise_length = random.randint(*config["length"])
|
||||||
|
noise = generate_noise(noise_length)
|
||||||
|
result.append(noise)
|
||||||
|
noise_count += 1
|
||||||
|
|
||||||
|
logger.debug(f"共注入 {noise_count} 个乱码片段,原词数: {len(words)}")
|
||||||
|
return " ".join(result)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_reasoning(content: str) -> Tuple[str, str]:
|
||||||
|
"""CoT思维链提取,向后兼容"""
|
||||||
|
match = re.search(r"(?:<think>)?(.*?)</think>", content, re.DOTALL)
|
||||||
|
clean_content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
|
||||||
|
reasoning = match.group(1).strip() if match else ""
|
||||||
|
return clean_content, reasoning
|
||||||
226
src/llm_models/request_executor.py
Normal file
226
src/llm_models/request_executor.py
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
@File : request_executor.py
|
||||||
|
@Time : 2024/05/24 16:15:00
|
||||||
|
@Author : 墨墨
|
||||||
|
@Version : 1.0
|
||||||
|
@Desc : 负责执行LLM请求、处理重试及异常
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
from typing import List, Callable, Optional, Tuple
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
|
||||||
|
from .exceptions import (
|
||||||
|
NetworkConnectionError,
|
||||||
|
ReqAbortException,
|
||||||
|
RespNotOkException,
|
||||||
|
RespParseException,
|
||||||
|
)
|
||||||
|
from .model_client.base_client import APIResponse, BaseClient
|
||||||
|
from .model_selector import ModelSelector
|
||||||
|
from .payload_content.message import Message
|
||||||
|
from .payload_content.resp_format import RespFormat
|
||||||
|
from .payload_content.tool_option import ToolOption
|
||||||
|
from .utils import compress_messages
|
||||||
|
|
||||||
|
logger = get_logger("request_executor")
|
||||||
|
|
||||||
|
|
||||||
|
class RequestExecutor:
|
||||||
|
"""请求执行器"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task_name: str,
|
||||||
|
model_set: TaskConfig,
|
||||||
|
api_provider: APIProvider,
|
||||||
|
client: BaseClient,
|
||||||
|
model_info: ModelInfo,
|
||||||
|
model_selector: ModelSelector,
|
||||||
|
):
|
||||||
|
self.task_name = task_name
|
||||||
|
self.model_set = model_set
|
||||||
|
self.api_provider = api_provider
|
||||||
|
self.client = client
|
||||||
|
self.model_info = model_info
|
||||||
|
self.model_selector = model_selector
|
||||||
|
|
||||||
|
async def execute_request(
|
||||||
|
self,
|
||||||
|
request_type: str,
|
||||||
|
message_list: List[Message] | None = None,
|
||||||
|
tool_options: list[ToolOption] | None = None,
|
||||||
|
response_format: RespFormat | None = None,
|
||||||
|
stream_response_handler: Optional[Callable] = None,
|
||||||
|
async_response_parser: Optional[Callable] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
embedding_input: str = "",
|
||||||
|
audio_base64: str = "",
|
||||||
|
) -> APIResponse:
|
||||||
|
"""
|
||||||
|
实际执行请求的方法, 包含了重试和异常处理逻辑
|
||||||
|
"""
|
||||||
|
retry_remain = self.api_provider.max_retry
|
||||||
|
compressed_messages: Optional[List[Message]] = None
|
||||||
|
while retry_remain > 0:
|
||||||
|
try:
|
||||||
|
if request_type == "response":
|
||||||
|
assert message_list is not None, "message_list cannot be None for response requests"
|
||||||
|
return await self.client.get_response(
|
||||||
|
model_info=self.model_info,
|
||||||
|
message_list=(compressed_messages or message_list),
|
||||||
|
tool_options=tool_options,
|
||||||
|
max_tokens=self.model_set.max_tokens if max_tokens is None else max_tokens,
|
||||||
|
temperature=self.model_set.temperature if temperature is None else temperature,
|
||||||
|
response_format=response_format,
|
||||||
|
stream_response_handler=stream_response_handler,
|
||||||
|
async_response_parser=async_response_parser,
|
||||||
|
extra_params=self.model_info.extra_params,
|
||||||
|
)
|
||||||
|
elif request_type == "embedding":
|
||||||
|
assert embedding_input, "embedding_input cannot be empty for embedding requests"
|
||||||
|
return await self.client.get_embedding(
|
||||||
|
model_info=self.model_info,
|
||||||
|
embedding_input=embedding_input,
|
||||||
|
extra_params=self.model_info.extra_params,
|
||||||
|
)
|
||||||
|
elif request_type == "audio":
|
||||||
|
assert audio_base64 is not None, "audio_base64 cannot be None for audio requests"
|
||||||
|
return await self.client.get_audio_transcriptions(
|
||||||
|
model_info=self.model_info,
|
||||||
|
audio_base64=audio_base64,
|
||||||
|
extra_params=self.model_info.extra_params,
|
||||||
|
)
|
||||||
|
raise ValueError(f"未知的请求类型: {request_type}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"请求失败: {str(e)}")
|
||||||
|
self._apply_penalty_on_failure(e)
|
||||||
|
|
||||||
|
wait_interval, compressed_messages = self._default_exception_handler(
|
||||||
|
e,
|
||||||
|
remain_try=retry_remain,
|
||||||
|
retry_interval=self.api_provider.retry_interval,
|
||||||
|
messages=(message_list, compressed_messages is not None) if message_list else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if wait_interval == -1:
|
||||||
|
retry_remain = 0 # 不再重试
|
||||||
|
elif wait_interval > 0:
|
||||||
|
logger.info(f"等待 {wait_interval} 秒后重试...")
|
||||||
|
await asyncio.sleep(wait_interval)
|
||||||
|
finally:
|
||||||
|
retry_remain -= 1
|
||||||
|
|
||||||
|
self.model_selector.decrease_usage_penalty(self.model_info.name)
|
||||||
|
logger.error(f"模型 '{self.model_info.name}' 请求失败,达到最大重试次数 {self.api_provider.max_retry} 次")
|
||||||
|
raise RuntimeError("请求失败,已达到最大重试次数")
|
||||||
|
|
||||||
|
def _apply_penalty_on_failure(self, e: Exception):
|
||||||
|
"""根据异常类型,动态调整模型的惩罚值"""
|
||||||
|
CRITICAL_PENALTY_MULTIPLIER = 5
|
||||||
|
default_penalty_increment = 1
|
||||||
|
penalty_increment = default_penalty_increment
|
||||||
|
|
||||||
|
if isinstance(e, (NetworkConnectionError, ReqAbortException)):
|
||||||
|
penalty_increment = CRITICAL_PENALTY_MULTIPLIER
|
||||||
|
elif isinstance(e, RespNotOkException):
|
||||||
|
if e.status_code >= 500:
|
||||||
|
penalty_increment = CRITICAL_PENALTY_MULTIPLIER
|
||||||
|
|
||||||
|
log_message = f"发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}"
|
||||||
|
if isinstance(e, (NetworkConnectionError, ReqAbortException)):
|
||||||
|
log_message = f"发生关键错误 ({type(e).__name__}),增加惩罚值: {penalty_increment}"
|
||||||
|
elif isinstance(e, RespNotOkException):
|
||||||
|
log_message = f"发生响应错误 (状态码: {e.status_code}),增加惩罚值: {penalty_increment}"
|
||||||
|
logger.warning(f"模型 '{self.model_info.name}' {log_message}")
|
||||||
|
|
||||||
|
self.model_selector.update_model_penalty(self.model_info.name, penalty_increment)
|
||||||
|
|
||||||
|
def _default_exception_handler(
|
||||||
|
self,
|
||||||
|
e: Exception,
|
||||||
|
remain_try: int,
|
||||||
|
retry_interval: int = 10,
|
||||||
|
messages: Tuple[List[Message], bool] | None = None,
|
||||||
|
) -> Tuple[int, List[Message] | None]:
|
||||||
|
"""默认异常处理函数"""
|
||||||
|
model_name = self.model_info.name
|
||||||
|
|
||||||
|
if isinstance(e, NetworkConnectionError):
|
||||||
|
return self._check_retry(
|
||||||
|
remain_try,
|
||||||
|
retry_interval,
|
||||||
|
can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 连接异常,将于{retry_interval}秒后重试",
|
||||||
|
cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 连接异常,超过最大重试次数",
|
||||||
|
)
|
||||||
|
elif isinstance(e, ReqAbortException):
|
||||||
|
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}")
|
||||||
|
return -1, None
|
||||||
|
elif isinstance(e, RespNotOkException):
|
||||||
|
return self._handle_resp_not_ok(e, remain_try, retry_interval, messages)
|
||||||
|
elif isinstance(e, RespParseException):
|
||||||
|
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}")
|
||||||
|
logger.debug(f"附加内容: {str(e.ext_info)}")
|
||||||
|
return -1, None
|
||||||
|
else:
|
||||||
|
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}")
|
||||||
|
return -1, None
|
||||||
|
|
||||||
|
def _handle_resp_not_ok(
|
||||||
|
self,
|
||||||
|
e: RespNotOkException,
|
||||||
|
remain_try: int,
|
||||||
|
retry_interval: int = 10,
|
||||||
|
messages: tuple[list[Message], bool] | None = None,
|
||||||
|
):
|
||||||
|
"""处理响应错误异常"""
|
||||||
|
model_name = self.model_info.name
|
||||||
|
if e.status_code in [400, 401, 402, 403, 404]:
|
||||||
|
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}")
|
||||||
|
return -1, None
|
||||||
|
elif e.status_code == 413:
|
||||||
|
if messages and not messages[1]:
|
||||||
|
return self._check_retry(
|
||||||
|
remain_try, 0,
|
||||||
|
can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试",
|
||||||
|
cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,压缩后仍失败",
|
||||||
|
can_retry_callable=compress_messages, messages=messages[0],
|
||||||
|
)
|
||||||
|
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,无法压缩,放弃请求。")
|
||||||
|
return -1, None
|
||||||
|
elif e.status_code == 429:
|
||||||
|
return self._check_retry(
|
||||||
|
remain_try, retry_interval,
|
||||||
|
can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试",
|
||||||
|
cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数",
|
||||||
|
)
|
||||||
|
elif e.status_code >= 500:
|
||||||
|
return self._check_retry(
|
||||||
|
remain_try, retry_interval,
|
||||||
|
can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 服务器错误,将于{retry_interval}秒后重试",
|
||||||
|
cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 服务器错误,超过最大重试次数",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知错误,错误代码-{e.status_code},错误信息-{e.message}")
|
||||||
|
return -1, None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_retry(
|
||||||
|
remain_try: int,
|
||||||
|
retry_interval: int,
|
||||||
|
can_retry_msg: str,
|
||||||
|
cannot_retry_msg: str,
|
||||||
|
can_retry_callable: Callable | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[int, List[Message] | None]:
|
||||||
|
"""辅助函数:检查是否可以重试"""
|
||||||
|
if remain_try > 0:
|
||||||
|
logger.warning(f"{can_retry_msg}")
|
||||||
|
if can_retry_callable is not None:
|
||||||
|
return retry_interval, can_retry_callable(**kwargs)
|
||||||
|
return retry_interval, None
|
||||||
|
else:
|
||||||
|
logger.warning(f"{cannot_retry_msg}")
|
||||||
|
return -1, None
|
||||||
206
src/llm_models/request_strategy.py
Normal file
206
src/llm_models/request_strategy.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
@File : request_strategy.py
|
||||||
|
@Time : 2024/05/24 16:30:00
|
||||||
|
@Author : 墨墨
|
||||||
|
@Version : 1.0
|
||||||
|
@Desc : 高级请求策略(并发、故障转移)
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
from typing import List, Tuple, Optional, Dict, Any, Callable, Coroutine
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.api_ada_configs import TaskConfig
|
||||||
|
from .model_client.base_client import APIResponse
|
||||||
|
from .model_selector import ModelSelector
|
||||||
|
from .payload_content.message import MessageBuilder
|
||||||
|
from .payload_content.tool_option import ToolCall
|
||||||
|
from .prompt_processor import PromptProcessor
|
||||||
|
from .request_executor import RequestExecutor
|
||||||
|
|
||||||
|
logger = get_logger("request_strategy")
|
||||||
|
|
||||||
|
|
||||||
|
class RequestStrategy:
|
||||||
|
"""高级请求策略"""
|
||||||
|
|
||||||
|
def __init__(self, model_set: TaskConfig, model_selector: ModelSelector, task_name: str):
|
||||||
|
self.model_set = model_set
|
||||||
|
self.model_selector = model_selector
|
||||||
|
self.task_name = task_name
|
||||||
|
|
||||||
|
async def execute_with_fallback(
|
||||||
|
self,
|
||||||
|
base_payload: Dict[str, Any],
|
||||||
|
raise_when_empty: bool = True,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
执行单次请求,动态选择最佳可用模型,并在模型失败时进行故障转移。
|
||||||
|
"""
|
||||||
|
failed_models_in_this_request = set()
|
||||||
|
max_attempts = len(self.model_set.model_list)
|
||||||
|
last_exception: Optional[Exception] = None
|
||||||
|
|
||||||
|
for attempt in range(max_attempts):
|
||||||
|
model_selection_result = self.model_selector.select_best_available_model(failed_models_in_this_request)
|
||||||
|
|
||||||
|
if model_selection_result is None:
|
||||||
|
logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。")
|
||||||
|
break
|
||||||
|
|
||||||
|
model_info, api_provider, client = model_selection_result
|
||||||
|
model_name = model_info.name
|
||||||
|
logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_name}'...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Process Prompt
|
||||||
|
prompt_processor: PromptProcessor = base_payload["prompt_processor"]
|
||||||
|
raw_prompt = base_payload["prompt"]
|
||||||
|
processed_prompt = prompt_processor.process_prompt(
|
||||||
|
raw_prompt, model_info, api_provider, self.task_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Build Message
|
||||||
|
message_builder = MessageBuilder().add_text_content(processed_prompt)
|
||||||
|
messages = [message_builder.build()]
|
||||||
|
|
||||||
|
# 3. Create payload for executor
|
||||||
|
executor_payload = {
|
||||||
|
"request_type": "response", # Strategy only handles response type
|
||||||
|
"message_list": messages,
|
||||||
|
"tool_options": base_payload["tool_options"],
|
||||||
|
"temperature": base_payload["temperature"],
|
||||||
|
"max_tokens": base_payload["max_tokens"],
|
||||||
|
}
|
||||||
|
|
||||||
|
executor = RequestExecutor(
|
||||||
|
task_name=self.task_name,
|
||||||
|
model_set=self.model_set,
|
||||||
|
api_provider=api_provider,
|
||||||
|
client=client,
|
||||||
|
model_info=model_info,
|
||||||
|
model_selector=self.model_selector,
|
||||||
|
)
|
||||||
|
response = await self._execute_and_handle_empty_retry(executor, executor_payload, prompt_processor)
|
||||||
|
|
||||||
|
# 4. Post-process response
|
||||||
|
# The reasoning content is now extracted here, after a successful, de-truncated response is received.
|
||||||
|
final_content, reasoning_content = prompt_processor.extract_reasoning(response.content or "")
|
||||||
|
response.content = final_content # Update response with cleaned content
|
||||||
|
|
||||||
|
tool_calls = response.tool_calls
|
||||||
|
|
||||||
|
if not final_content and not tool_calls:
|
||||||
|
if raise_when_empty:
|
||||||
|
raise RuntimeError("所选模型生成了空回复。")
|
||||||
|
content = "生成的响应为空" # Fallback message
|
||||||
|
|
||||||
|
logger.debug(f"模型 '{model_name}' 成功生成了回复。")
|
||||||
|
return {
|
||||||
|
"content": response.content,
|
||||||
|
"reasoning_content": reasoning_content,
|
||||||
|
"model_name": model_name,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
"model_info": model_info,
|
||||||
|
"usage": response.usage,
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。")
|
||||||
|
failed_models_in_this_request.add(model_info.name)
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。")
|
||||||
|
if raise_when_empty:
|
||||||
|
if last_exception:
|
||||||
|
raise RuntimeError("所有模型均未能生成响应。") from last_exception
|
||||||
|
raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。")
|
||||||
|
return {
|
||||||
|
"content": "所有模型都请求失败",
|
||||||
|
"reasoning_content": "",
|
||||||
|
"model_name": "unknown",
|
||||||
|
"tool_calls": None,
|
||||||
|
"model_info": None,
|
||||||
|
"usage": None,
|
||||||
|
"success": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute_concurrently(
|
||||||
|
self,
|
||||||
|
coro_callable: Callable[..., Coroutine[Any, Any, Any]],
|
||||||
|
concurrency_count: int,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
执行并发请求并从成功的结果中随机选择一个。
|
||||||
|
"""
|
||||||
|
logger.info(f"启用并发请求模式,并发数: {concurrency_count}")
|
||||||
|
tasks = [coro_callable(*args, **kwargs) for _ in range(concurrency_count)]
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
successful_results = [res for res in results if not isinstance(res, Exception)]
|
||||||
|
|
||||||
|
if successful_results:
|
||||||
|
selected = random.choice(successful_results)
|
||||||
|
logger.info(f"并发请求完成,从{len(successful_results)}个成功结果中选择了一个")
|
||||||
|
return selected
|
||||||
|
|
||||||
|
for i, res in enumerate(results):
|
||||||
|
if isinstance(res, Exception):
|
||||||
|
logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}")
|
||||||
|
|
||||||
|
first_exception = next((res for res in results if isinstance(res, Exception)), None)
|
||||||
|
if first_exception:
|
||||||
|
raise first_exception
|
||||||
|
|
||||||
|
raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息")
|
||||||
|
|
||||||
|
async def _execute_and_handle_empty_retry(
|
||||||
|
self, executor: RequestExecutor, payload: Dict[str, Any], prompt_processor: PromptProcessor
|
||||||
|
) -> APIResponse:
|
||||||
|
"""
|
||||||
|
在单个模型内部处理空回复/截断的重试逻辑
|
||||||
|
"""
|
||||||
|
empty_retry_count = 0
|
||||||
|
max_empty_retry = executor.api_provider.max_retry
|
||||||
|
empty_retry_interval = executor.api_provider.retry_interval
|
||||||
|
use_anti_truncation = getattr(executor.model_info, "use_anti_truncation", False)
|
||||||
|
end_marker = prompt_processor.end_marker
|
||||||
|
|
||||||
|
while empty_retry_count <= max_empty_retry:
|
||||||
|
response = await executor.execute_request(**payload)
|
||||||
|
|
||||||
|
content = response.content or ""
|
||||||
|
tool_calls = response.tool_calls
|
||||||
|
|
||||||
|
is_empty_reply = not tool_calls and (not content or content.strip() == "")
|
||||||
|
is_truncated = False
|
||||||
|
if use_anti_truncation and end_marker:
|
||||||
|
if content.endswith(end_marker):
|
||||||
|
# 移除结束标记
|
||||||
|
response.content = content[: -len(end_marker)].strip()
|
||||||
|
else:
|
||||||
|
is_truncated = True
|
||||||
|
|
||||||
|
if is_empty_reply or is_truncated:
|
||||||
|
empty_retry_count += 1
|
||||||
|
if empty_retry_count <= max_empty_retry:
|
||||||
|
reason = "空回复" if is_empty_reply else "截断"
|
||||||
|
logger.warning(
|
||||||
|
f"模型 '{executor.model_info.name}' 检测到{reason},正在进行内部重试 ({empty_retry_count}/{max_empty_retry})..."
|
||||||
|
)
|
||||||
|
if empty_retry_interval > 0:
|
||||||
|
await asyncio.sleep(empty_retry_interval)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
reason = "空回复" if is_empty_reply else "截断"
|
||||||
|
raise RuntimeError(f"模型 '{executor.model_info.name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。")
|
||||||
|
|
||||||
|
# 成功获取响应
|
||||||
|
return response
|
||||||
|
|
||||||
|
# 此处理论上不会到达,因为循环要么返回要么抛异常
|
||||||
|
raise RuntimeError("空回复/截断重Test逻辑出现未知错误")
|
||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user