refactor(llm): improve module clarity with docstrings and unified logging
This commit introduces a comprehensive refactoring of the `llm_models` module to enhance code clarity, maintainability, and robustness. Key changes include: - **Comprehensive Documentation**: Added detailed docstrings and inline comments to `PromptProcessor`, `RequestExecutor`, `RequestStrategy`, and `LLMRequest`. This clarifies the purpose and logic of each component, including prompt manipulation, request execution with retries, fallback strategies, and concurrency. - **Unified Logging**: Standardized all loggers within the module to use a single, consistent name (`model_utils`), simplifying log filtering and analysis. - **Improved Result Handling**: Refined the result processing in `LLMRequest` to correctly extract and record usage data returned from the `RequestStrategy`, fixing a previously incomplete implementation.
This commit is contained in:
@@ -11,7 +11,7 @@ from typing import List, Dict, Any, Tuple
|
||||
from src.common.logger import get_logger
|
||||
from .payload_content.tool_option import ToolOption, ToolOptionBuilder, ToolParamType
|
||||
|
||||
logger = get_logger("llm_utils")
|
||||
logger = get_logger("model_utils")
|
||||
|
||||
def normalize_image_format(image_format: str) -> str:
|
||||
"""
|
||||
|
||||
@@ -13,7 +13,7 @@ from src.config.config import model_config
|
||||
from src.config.api_ada_configs import ModelInfo, APIProvider, TaskConfig
|
||||
from .model_client.base_client import BaseClient, client_registry
|
||||
|
||||
logger = get_logger("model_selector")
|
||||
logger = get_logger("model_utils")
|
||||
|
||||
|
||||
class ModelSelector:
|
||||
|
||||
@@ -18,16 +18,28 @@ logger = get_logger("prompt_processor")
|
||||
|
||||
|
||||
class PromptProcessor:
|
||||
"""提示词处理器"""
|
||||
"""
|
||||
提示词处理器。
|
||||
负责对发送给模型的原始prompt进行预处理,以增强模型性能或实现特定功能。
|
||||
主要功能包括:
|
||||
1. **反截断**:在prompt末尾添加一个特殊的结束标记指令,帮助判断模型输出是否被截断。
|
||||
2. **内容混淆**:向prompt中注入随机的“噪音”字符串,并附带指令让模型忽略它们,
|
||||
可能用于绕过某些平台的审查或内容策略。
|
||||
3. **思维链提取**:从模型的响应中分离出思考过程(被<think>标签包裹)和最终答案。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化Prompt处理器,定义所需的指令文本。"""
|
||||
# 指导模型忽略噪音字符串的指令
|
||||
self.noise_instruction = """**【核心过滤规则】**
|
||||
在交互过程中,你会在用户的输入内容中看到一些无意义的、由字母和数字组成的随机字符串。这些字符串是API传输过程中的技术伪装信息,【不是】用户真实意图的一部分。
|
||||
|
||||
你的任务是【完全并彻底地忽略】这些随机字符串。
|
||||
**【严禁】** 将这些字符串写入你的回复、对它们进行任何形式的解读、或将它们作为故事的一部分进行模拟。当看到这些字符串时,请将其视为空白或直接跳过,并根据用户其他有意义的输入(如有)或当前故事情节来继续你的叙事。此规则为最高优先级。
|
||||
"""
|
||||
# 定义一个独特的结束标记,用于反截断检查
|
||||
self.end_marker = "###MAI_RESPONSE_END###"
|
||||
# 指导模型在回复末尾添加结束标记的指令
|
||||
self.anti_truncation_instruction = f"""
|
||||
**【输出完成信令】**
|
||||
这是一个非常重要的指令,请务-务必遵守。在你的回复内容完全结束后,请务必在最后另起一行,只写 `{self.end_marker}` 作为结束标志。
|
||||
@@ -42,17 +54,26 @@ class PromptProcessor:
|
||||
self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str
|
||||
) -> str:
|
||||
"""
|
||||
根据模型和API提供商的配置处理提示词
|
||||
根据模型和API提供商的配置,对输入的prompt进行预处理。
|
||||
|
||||
Args:
|
||||
prompt (str): 原始的用户输入prompt。
|
||||
model_info (ModelInfo): 当前使用的模型信息。
|
||||
api_provider (APIProvider): 当前API提供商的配置。
|
||||
task_name (str): 当前任务的名称,用于日志记录。
|
||||
|
||||
Returns:
|
||||
str: 经过处理后的、最终将发送给模型的prompt。
|
||||
"""
|
||||
processed_prompt = prompt
|
||||
|
||||
# 1. 添加反截断指令
|
||||
# 步骤 1: 根据模型配置添加反截断指令
|
||||
use_anti_truncation = getattr(model_info, "use_anti_truncation", False)
|
||||
if use_anti_truncation:
|
||||
processed_prompt += self.anti_truncation_instruction
|
||||
logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。")
|
||||
|
||||
# 2. 应用内容混淆
|
||||
# 步骤 2: 根据API提供商配置应用内容混淆
|
||||
if getattr(api_provider, "enable_content_obfuscation", False):
|
||||
intensity = getattr(api_provider, "obfuscation_intensity", 1)
|
||||
logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}")
|
||||
@@ -61,12 +82,15 @@ class PromptProcessor:
|
||||
return processed_prompt
|
||||
|
||||
def _apply_content_obfuscation(self, text: str, intensity: int) -> str:
|
||||
"""对文本进行混淆处理"""
|
||||
# 在开头加入过滤规则指令
|
||||
"""
|
||||
对文本应用内容混淆处理。
|
||||
首先添加过滤规则指令,然后注入随机噪音。
|
||||
"""
|
||||
# 在文本开头加入指导模型忽略噪音的指令
|
||||
processed_text = self.noise_instruction + "\n\n" + text
|
||||
logger.debug(f"已添加过滤规则指令,文本长度: {len(text)} -> {len(processed_text)}")
|
||||
|
||||
# 添加随机乱码
|
||||
# 在文本中注入随机乱码
|
||||
final_text = self._inject_random_noise(processed_text, intensity)
|
||||
logger.debug(f"乱码注入完成,最终文本长度: {len(final_text)}")
|
||||
|
||||
@@ -74,20 +98,31 @@ class PromptProcessor:
|
||||
|
||||
@staticmethod
|
||||
def _inject_random_noise(text: str, intensity: int) -> str:
|
||||
"""在文本中注入随机乱码"""
|
||||
"""
|
||||
根据指定的强度,在文本的词语之间随机注入噪音字符串。
|
||||
|
||||
Args:
|
||||
text (str): 待注入噪音的文本。
|
||||
intensity (int): 混淆强度 (1, 2, or 3),决定噪音的注入概率和长度。
|
||||
|
||||
Returns:
|
||||
str: 注入噪音后的文本。
|
||||
"""
|
||||
def generate_noise(length: int) -> str:
|
||||
"""生成指定长度的随机噪音字符串。"""
|
||||
chars = (
|
||||
string.ascii_letters + string.digits + "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
+ "一二三四五六七八九零壹贰叁" + "αβγδεζηθικλμνξοπρστυφχψω" + "∀∃∈∉∪∩⊂⊃∧∨¬→↔∴∵"
|
||||
)
|
||||
return "".join(random.choice(chars) for _ in range(length))
|
||||
|
||||
# 根据强度级别定义注入参数
|
||||
params = {
|
||||
1: {"probability": 15, "length": (3, 6)},
|
||||
2: {"probability": 25, "length": (5, 10)},
|
||||
3: {"probability": 35, "length": (8, 15)},
|
||||
1: {"probability": 15, "length": (3, 6)}, # 低强度
|
||||
2: {"probability": 25, "length": (5, 10)}, # 中强度
|
||||
3: {"probability": 35, "length": (8, 15)}, # 高强度
|
||||
}
|
||||
config = params.get(intensity, params[1])
|
||||
config = params.get(intensity, params[1]) # 默认为低强度
|
||||
logger.debug(f"乱码注入参数: 概率={config['probability']}%, 长度范围={config['length']}")
|
||||
|
||||
words = text.split()
|
||||
@@ -95,6 +130,7 @@ class PromptProcessor:
|
||||
noise_count = 0
|
||||
for word in words:
|
||||
result.append(word)
|
||||
# 按概率决定是否注入噪音
|
||||
if random.randint(1, 100) <= config["probability"]:
|
||||
noise_length = random.randint(*config["length"])
|
||||
noise = generate_noise(noise_length)
|
||||
@@ -106,8 +142,22 @@ class PromptProcessor:
|
||||
|
||||
@staticmethod
|
||||
def extract_reasoning(content: str) -> Tuple[str, str]:
|
||||
"""CoT思维链提取,向后兼容"""
|
||||
"""
|
||||
从模型返回的完整内容中提取被<think>...</think>标签包裹的思考过程,
|
||||
并返回清理后的内容和思考过程。
|
||||
|
||||
Args:
|
||||
content (str): 模型返回的原始字符串。
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]:
|
||||
- 清理后的内容(移除了<think>标签及其内容)。
|
||||
- 提取出的思考过程文本(如果没有则为空字符串)。
|
||||
"""
|
||||
# 使用正则表达式查找<think>标签
|
||||
match = re.search(r"(?:<think>)?(.*?)</think>", content, re.DOTALL)
|
||||
# 从内容中移除<think>标签及其包裹的所有内容
|
||||
clean_content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
|
||||
# 如果找到匹配项,则提取思考过程
|
||||
reasoning = match.group(1).strip() if match else ""
|
||||
return clean_content, reasoning
|
||||
|
||||
@@ -24,11 +24,15 @@ from .payload_content.resp_format import RespFormat
|
||||
from .payload_content.tool_option import ToolOption
|
||||
from .utils import compress_messages
|
||||
|
||||
logger = get_logger("request_executor")
|
||||
logger = get_logger("model_utils")
|
||||
|
||||
|
||||
class RequestExecutor:
|
||||
"""请求执行器"""
|
||||
"""
|
||||
请求执行器。
|
||||
负责直接与模型客户端交互,执行API请求。
|
||||
它包含了核心的请求重试、异常分类处理、模型惩罚机制和消息压缩等底层逻辑。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -39,6 +43,17 @@ class RequestExecutor:
|
||||
model_info: ModelInfo,
|
||||
model_selector: ModelSelector,
|
||||
):
|
||||
"""
|
||||
初始化请求执行器。
|
||||
|
||||
Args:
|
||||
task_name (str): 当前任务的名称。
|
||||
model_set (TaskConfig): 任务相关的模型配置。
|
||||
api_provider (APIProvider): API提供商配置。
|
||||
client (BaseClient): 用于发送请求的客户端实例。
|
||||
model_info (ModelInfo): 当前请求要使用的模型信息。
|
||||
model_selector (ModelSelector): 模型选择器实例,用于更新模型状态(如惩罚值)。
|
||||
"""
|
||||
self.task_name = task_name
|
||||
self.model_set = model_set
|
||||
self.api_provider = api_provider
|
||||
@@ -60,12 +75,34 @@ class RequestExecutor:
|
||||
audio_base64: str = "",
|
||||
) -> APIResponse:
|
||||
"""
|
||||
实际执行请求的方法, 包含了重试和异常处理逻辑
|
||||
实际执行API请求,并包含完整的重试和异常处理逻辑。
|
||||
|
||||
Args:
|
||||
request_type (str): 请求类型 ('response', 'embedding', 'audio')。
|
||||
message_list (List[Message] | None, optional): 消息列表。 Defaults to None.
|
||||
tool_options (list[ToolOption] | None, optional): 工具选项。 Defaults to None.
|
||||
response_format (RespFormat | None, optional): 响应格式要求。 Defaults to None.
|
||||
stream_response_handler (Optional[Callable], optional): 流式响应处理器。 Defaults to None.
|
||||
async_response_parser (Optional[Callable], optional): 异步响应解析器。 Defaults to None.
|
||||
temperature (Optional[float], optional): 温度参数。 Defaults to None.
|
||||
max_tokens (Optional[int], optional): 最大token数。 Defaults to None.
|
||||
embedding_input (str, optional): embedding输入文本。 Defaults to "".
|
||||
audio_base64 (str, optional): 音频base64数据。 Defaults to "".
|
||||
|
||||
Returns:
|
||||
APIResponse: 从模型客户端返回的API响应对象。
|
||||
|
||||
Raises:
|
||||
ValueError: 如果请求类型未知。
|
||||
RuntimeError: 如果所有重试都失败。
|
||||
"""
|
||||
retry_remain = self.api_provider.max_retry
|
||||
compressed_messages: Optional[List[Message]] = None
|
||||
|
||||
# 循环进行重试
|
||||
while retry_remain > 0:
|
||||
try:
|
||||
# 根据请求类型调用不同的客户端方法
|
||||
if request_type == "response":
|
||||
assert message_list is not None, "message_list cannot be None for response requests"
|
||||
return await self.client.get_response(
|
||||
@@ -96,8 +133,10 @@ class RequestExecutor:
|
||||
raise ValueError(f"未知的请求类型: {request_type}")
|
||||
except Exception as e:
|
||||
logger.debug(f"请求失败: {str(e)}")
|
||||
# 对失败的模型应用惩罚
|
||||
self._apply_penalty_on_failure(e)
|
||||
|
||||
# 使用默认异常处理器来决定下一步操作(等待、重试、压缩或终止)
|
||||
wait_interval, compressed_messages = self._default_exception_handler(
|
||||
e,
|
||||
remain_try=retry_remain,
|
||||
@@ -106,29 +145,35 @@ class RequestExecutor:
|
||||
)
|
||||
|
||||
if wait_interval == -1:
|
||||
retry_remain = 0 # 不再重试
|
||||
retry_remain = 0 # 处理器决定不再重试
|
||||
elif wait_interval > 0:
|
||||
logger.info(f"等待 {wait_interval} 秒后重试...")
|
||||
await asyncio.sleep(wait_interval)
|
||||
finally:
|
||||
retry_remain -= 1
|
||||
|
||||
self.model_selector.decrease_usage_penalty(self.model_info.name)
|
||||
# 所有重试次数用尽后
|
||||
self.model_selector.decrease_usage_penalty(self.model_info.name) # 减少因使用而增加的基础惩罚
|
||||
logger.error(f"模型 '{self.model_info.name}' 请求失败,达到最大重试次数 {self.api_provider.max_retry} 次")
|
||||
raise RuntimeError("请求失败,已达到最大重试次数")
|
||||
|
||||
def _apply_penalty_on_failure(self, e: Exception):
|
||||
"""根据异常类型,动态调整模型的惩罚值"""
|
||||
"""
|
||||
根据异常类型,动态调整失败模型的惩罚值。
|
||||
关键错误(如网络问题、服务器5xx错误)会施加更重的惩罚。
|
||||
"""
|
||||
CRITICAL_PENALTY_MULTIPLIER = 5
|
||||
default_penalty_increment = 1
|
||||
penalty_increment = default_penalty_increment
|
||||
|
||||
# 对严重错误施加更高的惩罚
|
||||
if isinstance(e, (NetworkConnectionError, ReqAbortException)):
|
||||
penalty_increment = CRITICAL_PENALTY_MULTIPLIER
|
||||
elif isinstance(e, RespNotOkException):
|
||||
if e.status_code >= 500:
|
||||
if e.status_code >= 500: # 服务器内部错误
|
||||
penalty_increment = CRITICAL_PENALTY_MULTIPLIER
|
||||
|
||||
# 记录日志
|
||||
log_message = f"发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}"
|
||||
if isinstance(e, (NetworkConnectionError, ReqAbortException)):
|
||||
log_message = f"发生关键错误 ({type(e).__name__}),增加惩罚值: {penalty_increment}"
|
||||
@@ -136,6 +181,7 @@ class RequestExecutor:
|
||||
log_message = f"发生响应错误 (状态码: {e.status_code}),增加惩罚值: {penalty_increment}"
|
||||
logger.warning(f"模型 '{self.model_info.name}' {log_message}")
|
||||
|
||||
# 更新模型的惩罚值
|
||||
self.model_selector.update_model_penalty(self.model_info.name, penalty_increment)
|
||||
|
||||
def _default_exception_handler(
|
||||
@@ -145,7 +191,15 @@ class RequestExecutor:
|
||||
retry_interval: int = 10,
|
||||
messages: Tuple[List[Message], bool] | None = None,
|
||||
) -> Tuple[int, List[Message] | None]:
|
||||
"""默认异常处理函数"""
|
||||
"""
|
||||
默认的异常分类处理器。
|
||||
根据异常类型决定是否重试、等待多久以及是否需要压缩消息。
|
||||
|
||||
Returns:
|
||||
Tuple[int, List[Message] | None]:
|
||||
- 等待时间(秒)。-1表示不重试。
|
||||
- 压缩后的消息列表(如果有)。
|
||||
"""
|
||||
model_name = self.model_info.name
|
||||
|
||||
if isinstance(e, NetworkConnectionError):
|
||||
@@ -157,16 +211,16 @@ class RequestExecutor:
|
||||
)
|
||||
elif isinstance(e, ReqAbortException):
|
||||
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}")
|
||||
return -1, None
|
||||
return -1, None # 请求被中断,不重试
|
||||
elif isinstance(e, RespNotOkException):
|
||||
return self._handle_resp_not_ok(e, remain_try, retry_interval, messages)
|
||||
elif isinstance(e, RespParseException):
|
||||
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}")
|
||||
logger.debug(f"附加内容: {str(e.ext_info)}")
|
||||
return -1, None
|
||||
return -1, None # 解析错误通常不可重试
|
||||
else:
|
||||
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}")
|
||||
return -1, None
|
||||
return -1, None # 未知异常,不重试
|
||||
|
||||
def _handle_resp_not_ok(
|
||||
self,
|
||||
@@ -174,28 +228,33 @@ class RequestExecutor:
|
||||
remain_try: int,
|
||||
retry_interval: int = 10,
|
||||
messages: tuple[list[Message], bool] | None = None,
|
||||
):
|
||||
"""处理响应错误异常"""
|
||||
) -> Tuple[int, Optional[List[Message]]]:
|
||||
"""处理HTTP状态码非200的异常。"""
|
||||
model_name = self.model_info.name
|
||||
# 客户端错误 (4xx),通常不可重试
|
||||
if e.status_code in [400, 401, 402, 403, 404]:
|
||||
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}")
|
||||
return -1, None
|
||||
# 请求体过大 (413)
|
||||
elif e.status_code == 413:
|
||||
if messages and not messages[1]:
|
||||
# 如果消息存在且尚未被压缩,尝试压缩后重试一次
|
||||
if messages and not messages[1]: # messages[1] is a flag indicating if it's already compressed
|
||||
return self._check_retry(
|
||||
remain_try, 0,
|
||||
remain_try, 0, # 立即重试
|
||||
can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试",
|
||||
cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,压缩后仍失败",
|
||||
can_retry_callable=compress_messages, messages=messages[0],
|
||||
)
|
||||
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 请求体过大,无法压缩,放弃请求。")
|
||||
return -1, None
|
||||
# 请求过于频繁 (429)
|
||||
elif e.status_code == 429:
|
||||
return self._check_retry(
|
||||
remain_try, retry_interval,
|
||||
can_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试",
|
||||
cannot_retry_msg=f"任务-'{self.task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数",
|
||||
)
|
||||
# 服务器错误 (5xx),可以重试
|
||||
elif e.status_code >= 500:
|
||||
return self._check_retry(
|
||||
remain_try, retry_interval,
|
||||
@@ -215,9 +274,12 @@ class RequestExecutor:
|
||||
can_retry_callable: Callable | None = None,
|
||||
**kwargs,
|
||||
) -> Tuple[int, List[Message] | None]:
|
||||
"""辅助函数:检查是否可以重试"""
|
||||
"""
|
||||
辅助函数:检查是否可以重试,并执行可选的回调函数(如消息压缩)。
|
||||
"""
|
||||
if remain_try > 0:
|
||||
logger.warning(f"{can_retry_msg}")
|
||||
# 如果有可执行的回调(例如压缩函数),执行它并返回结果
|
||||
if can_retry_callable is not None:
|
||||
return retry_interval, can_retry_callable(**kwargs)
|
||||
return retry_interval, None
|
||||
|
||||
@@ -19,13 +19,24 @@ from .payload_content.tool_option import ToolCall
|
||||
from .prompt_processor import PromptProcessor
|
||||
from .request_executor import RequestExecutor
|
||||
|
||||
logger = get_logger("request_strategy")
|
||||
logger = get_logger("model_utils")
|
||||
|
||||
|
||||
class RequestStrategy:
|
||||
"""高级请求策略"""
|
||||
"""
|
||||
高级请求策略模块。
|
||||
负责实现复杂的请求逻辑,如模型的故障转移(fallback)和并发请求。
|
||||
"""
|
||||
|
||||
def __init__(self, model_set: TaskConfig, model_selector: ModelSelector, task_name: str):
|
||||
"""
|
||||
初始化请求策略。
|
||||
|
||||
Args:
|
||||
model_set (TaskConfig): 特定任务的模型配置。
|
||||
model_selector (ModelSelector): 模型选择器实例。
|
||||
task_name (str): 当前任务的名称。
|
||||
"""
|
||||
self.model_set = model_set
|
||||
self.model_selector = model_selector
|
||||
self.task_name = task_name
|
||||
@@ -37,43 +48,56 @@ class RequestStrategy:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行单次请求,动态选择最佳可用模型,并在模型失败时进行故障转移。
|
||||
|
||||
该方法会按顺序尝试任务配置中的所有可用模型,直到一个模型成功返回响应。
|
||||
如果所有模型都失败,将根据 `raise_when_empty` 参数决定是抛出异常还是返回一个失败结果。
|
||||
|
||||
Args:
|
||||
base_payload (Dict[str, Any]): 基础请求载荷,包含prompt、工具选项等。
|
||||
raise_when_empty (bool, optional): 如果所有模型都失败或返回空内容,是否抛出异常。 Defaults to True.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 一个包含响应结果的字典,包括内容、模型信息、用量和成功状态。
|
||||
"""
|
||||
# 记录在本次请求中已经失败的模型,避免重复尝试
|
||||
failed_models_in_this_request = set()
|
||||
max_attempts = len(self.model_set.model_list)
|
||||
last_exception: Optional[Exception] = None
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
# 选择一个当前最佳且未失败的模型
|
||||
model_selection_result = self.model_selector.select_best_available_model(failed_models_in_this_request)
|
||||
|
||||
if model_selection_result is None:
|
||||
logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。")
|
||||
break
|
||||
break # 没有更多可用模型,跳出循环
|
||||
|
||||
model_info, api_provider, client = model_selection_result
|
||||
model_name = model_info.name
|
||||
logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_name}'...")
|
||||
|
||||
try:
|
||||
# 1. Process Prompt
|
||||
# 步骤 1: 预处理Prompt
|
||||
prompt_processor: PromptProcessor = base_payload["prompt_processor"]
|
||||
raw_prompt = base_payload["prompt"]
|
||||
processed_prompt = prompt_processor.process_prompt(
|
||||
raw_prompt, model_info, api_provider, self.task_name
|
||||
)
|
||||
|
||||
# 2. Build Message
|
||||
# 步骤 2: 构建消息体
|
||||
message_builder = MessageBuilder().add_text_content(processed_prompt)
|
||||
messages = [message_builder.build()]
|
||||
|
||||
# 3. Create payload for executor
|
||||
# 步骤 3: 为执行器创建载荷
|
||||
executor_payload = {
|
||||
"request_type": "response", # Strategy only handles response type
|
||||
"request_type": "response", # 策略模式目前只处理'response'类型请求
|
||||
"message_list": messages,
|
||||
"tool_options": base_payload["tool_options"],
|
||||
"temperature": base_payload["temperature"],
|
||||
"max_tokens": base_payload["max_tokens"],
|
||||
}
|
||||
|
||||
# 创建请求执行器实例
|
||||
executor = RequestExecutor(
|
||||
task_name=self.task_name,
|
||||
model_set=self.model_set,
|
||||
@@ -82,21 +106,24 @@ class RequestStrategy:
|
||||
model_info=model_info,
|
||||
model_selector=self.model_selector,
|
||||
)
|
||||
# 执行请求,并处理内部的空回复/截断重试
|
||||
response = await self._execute_and_handle_empty_retry(executor, executor_payload, prompt_processor)
|
||||
|
||||
# 4. Post-process response
|
||||
# The reasoning content is now extracted here, after a successful, de-truncated response is received.
|
||||
# 步骤 4: 后处理响应
|
||||
# 在获取到成功的、完整的响应后,提取思考过程内容
|
||||
final_content, reasoning_content = prompt_processor.extract_reasoning(response.content or "")
|
||||
response.content = final_content # Update response with cleaned content
|
||||
response.content = final_content # 使用清理后的内容更新响应对象
|
||||
|
||||
tool_calls = response.tool_calls
|
||||
|
||||
# 检查最终内容是否为空
|
||||
if not final_content and not tool_calls:
|
||||
if raise_when_empty:
|
||||
raise RuntimeError("所选模型生成了空回复。")
|
||||
content = "生成的响应为空" # Fallback message
|
||||
logger.warning(f"模型 '{model_name}' 生成了空回复,返回默认信息。")
|
||||
|
||||
logger.debug(f"模型 '{model_name}' 成功生成了回复。")
|
||||
# 返回成功结果,包含用量和模型信息,供上层记录
|
||||
return {
|
||||
"content": response.content,
|
||||
"reasoning_content": reasoning_content,
|
||||
@@ -108,15 +135,19 @@ class RequestStrategy:
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 捕获请求过程中的任何异常
|
||||
logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。")
|
||||
failed_models_in_this_request.add(model_info.name)
|
||||
last_exception = e
|
||||
|
||||
# 如果循环结束仍未成功
|
||||
logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。")
|
||||
if raise_when_empty:
|
||||
if last_exception:
|
||||
raise RuntimeError("所有模型均未能生成响应。") from last_exception
|
||||
raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。")
|
||||
|
||||
# 返回失败结果
|
||||
return {
|
||||
"content": "所有模型都请求失败",
|
||||
"reasoning_content": "",
|
||||
@@ -135,23 +166,43 @@ class RequestStrategy:
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
执行并发请求并从成功的结果中随机选择一个。
|
||||
以指定的并发数执行多个协程,并从所有成功的结果中随机选择一个返回。
|
||||
|
||||
Args:
|
||||
coro_callable (Callable): 要并发执行的协程函数。
|
||||
concurrency_count (int): 并发数量。
|
||||
*args: 传递给协程函数的位置参数。
|
||||
**kwargs: 传递给协程函数的关键字参数。
|
||||
|
||||
Returns:
|
||||
Any: 从成功的结果中随机选择的一个。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 如果所有并发任务都失败了。
|
||||
"""
|
||||
logger.info(f"启用并发请求模式,并发数: {concurrency_count}")
|
||||
# 创建并发任务列表
|
||||
tasks = [coro_callable(*args, **kwargs) for _ in range(concurrency_count)]
|
||||
|
||||
# 等待所有任务完成
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
successful_results = [res for res in results if not isinstance(res, Exception)]
|
||||
# 筛选出成功的结果
|
||||
successful_results = [
|
||||
res for res in results if isinstance(res, dict) and res.get("success")
|
||||
]
|
||||
|
||||
if successful_results:
|
||||
# 从成功结果中随机选择一个
|
||||
selected = random.choice(successful_results)
|
||||
logger.info(f"并发请求完成,从{len(successful_results)}个成功结果中选择了一个")
|
||||
return selected
|
||||
|
||||
# 如果没有成功的结果,记录所有异常
|
||||
for i, res in enumerate(results):
|
||||
if isinstance(res, Exception):
|
||||
logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}")
|
||||
|
||||
# 抛出第一个遇到的异常
|
||||
first_exception = next((res for res in results if isinstance(res, Exception)), None)
|
||||
if first_exception:
|
||||
raise first_exception
|
||||
@@ -162,11 +213,23 @@ class RequestStrategy:
|
||||
self, executor: RequestExecutor, payload: Dict[str, Any], prompt_processor: PromptProcessor
|
||||
) -> APIResponse:
|
||||
"""
|
||||
在单个模型内部处理空回复/截断的重试逻辑
|
||||
在单个模型内部处理因回复为空或被截断而触发的重试逻辑。
|
||||
|
||||
Args:
|
||||
executor (RequestExecutor): 请求执行器实例。
|
||||
payload (Dict[str, Any]): 传递给 `execute_request` 的载荷。
|
||||
prompt_processor (PromptProcessor): 提示词处理器,用于获取反截断标记。
|
||||
|
||||
Returns:
|
||||
APIResponse: 一个有效的、非空且完整的API响应。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 如果在达到最大重试次数后仍然收到空回复或截断的回复。
|
||||
"""
|
||||
empty_retry_count = 0
|
||||
max_empty_retry = executor.api_provider.max_retry
|
||||
empty_retry_interval = executor.api_provider.retry_interval
|
||||
# 检查模型是否启用了反截断功能
|
||||
use_anti_truncation = getattr(executor.model_info, "use_anti_truncation", False)
|
||||
end_marker = prompt_processor.end_marker
|
||||
|
||||
@@ -176,15 +239,20 @@ class RequestStrategy:
|
||||
content = response.content or ""
|
||||
tool_calls = response.tool_calls
|
||||
|
||||
# 判断是否为空回复
|
||||
is_empty_reply = not tool_calls and (not content or content.strip() == "")
|
||||
is_truncated = False
|
||||
|
||||
# 如果启用了反截断,检查回复是否被截断
|
||||
if use_anti_truncation and end_marker:
|
||||
if content.endswith(end_marker):
|
||||
# 移除结束标记
|
||||
# 如果包含结束标记,说明回复完整,移除标记
|
||||
response.content = content[: -len(end_marker)].strip()
|
||||
else:
|
||||
# 否则,认为回复被截断
|
||||
is_truncated = True
|
||||
|
||||
# 如果是空回复或截断,则进行重试
|
||||
if is_empty_reply or is_truncated:
|
||||
empty_retry_count += 1
|
||||
if empty_retry_count <= max_empty_retry:
|
||||
@@ -194,13 +262,14 @@ class RequestStrategy:
|
||||
)
|
||||
if empty_retry_interval > 0:
|
||||
await asyncio.sleep(empty_retry_interval)
|
||||
continue
|
||||
continue # 继续下一次循环重试
|
||||
else:
|
||||
# 达到最大重试次数,抛出异常
|
||||
reason = "空回复" if is_empty_reply else "截断"
|
||||
raise RuntimeError(f"模型 '{executor.model_info.name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。")
|
||||
|
||||
# 成功获取响应
|
||||
# 成功获取到有效响应,返回结果
|
||||
return response
|
||||
|
||||
# 此处理论上不会到达,因为循环要么返回要么抛异常
|
||||
raise RuntimeError("空回复/截断重Test逻辑出现未知错误")
|
||||
# 此处理论上不会到达,因为循环要么返回要么抛出异常
|
||||
raise RuntimeError("空回复/截断重试逻辑出现未知错误")
|
||||
|
||||
@@ -10,7 +10,7 @@ import time
|
||||
from typing import Tuple, List, Dict, Optional, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import TaskConfig, ModelInfo
|
||||
from src.config.api_ada_configs import TaskConfig, ModelInfo, UsageRecord
|
||||
from .llm_utils import build_tool_options, normalize_image_format
|
||||
from .model_selector import ModelSelector
|
||||
from .payload_content.message import MessageBuilder
|
||||
@@ -22,9 +22,20 @@ from .utils import llm_usage_recorder
|
||||
logger = get_logger("model_utils")
|
||||
|
||||
class LLMRequest:
|
||||
"""LLM请求协调器"""
|
||||
"""
|
||||
LLM请求协调器。
|
||||
封装了模型选择、Prompt处理、请求执行和高级策略(如故障转移、并发)的完整流程。
|
||||
为上层业务逻辑提供统一的、简化的接口来与大语言模型交互。
|
||||
"""
|
||||
|
||||
def __init__(self, model_set: TaskConfig, request_type: str = "") -> None:
|
||||
"""
|
||||
初始化LLM请求协调器。
|
||||
|
||||
Args:
|
||||
model_set (TaskConfig): 特定任务的模型配置集合。
|
||||
request_type (str, optional): 请求类型或任务名称,用于日志和用量记录。 Defaults to "".
|
||||
"""
|
||||
self.task_name = request_type
|
||||
self.model_for_task = model_set
|
||||
self.request_type = request_type
|
||||
@@ -40,16 +51,33 @@ class LLMRequest:
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||
"""为图像生成响应"""
|
||||
"""
|
||||
为包含图像的多模态输入生成文本响应。
|
||||
|
||||
Args:
|
||||
prompt (str): 文本提示。
|
||||
image_base64 (str): Base64编码的图像数据。
|
||||
image_format (str): 图像格式 (例如, "png", "jpeg")。
|
||||
temperature (Optional[float], optional): 控制生成文本的随机性。 Defaults to None.
|
||||
max_tokens (Optional[int], optional): 生成响应的最大长度。 Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||
- 清理后的响应内容。
|
||||
- 一个元组,包含思考过程、模型名称和工具调用列表。
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 选择模型
|
||||
# 步骤 1: 选择一个支持图像处理的模型
|
||||
model_info, api_provider, client = self.model_selector.select_model()
|
||||
|
||||
# 2. 准备消息体
|
||||
# 步骤 2: 准备消息体
|
||||
# 预处理文本提示
|
||||
processed_prompt = self.prompt_processor.process_prompt(prompt, model_info, api_provider, self.task_name)
|
||||
# 规范化图像格式
|
||||
normalized_format = normalize_image_format(image_format)
|
||||
|
||||
# 使用MessageBuilder构建多模态消息
|
||||
message_builder = MessageBuilder()
|
||||
message_builder.add_text_content(processed_prompt)
|
||||
message_builder.add_image_content(
|
||||
@@ -59,7 +87,7 @@ class LLMRequest:
|
||||
)
|
||||
messages = [message_builder.build()]
|
||||
|
||||
# 3. 执行请求 (图像请求通常不走复杂的故障转移策略,直接执行)
|
||||
# 步骤 3: 执行请求 (图像请求通常不走复杂的故障转移策略,直接执行)
|
||||
from .request_executor import RequestExecutor
|
||||
executor = RequestExecutor(
|
||||
task_name=self.task_name,
|
||||
@@ -76,20 +104,31 @@ class LLMRequest:
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# 4. 处理响应
|
||||
# 步骤 4: 处理响应
|
||||
content, reasoning_content = self.prompt_processor.extract_reasoning(response.content or "")
|
||||
tool_calls = response.tool_calls
|
||||
|
||||
# 记录用量
|
||||
if usage := response.usage:
|
||||
await self._record_usage(model_info, usage, time.time() - start_time)
|
||||
|
||||
return content, (reasoning_content, model_info.name, tool_calls)
|
||||
|
||||
async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]:
|
||||
"""为语音生成响应"""
|
||||
"""
|
||||
将语音数据转换为文本(语音识别)。
|
||||
|
||||
Args:
|
||||
voice_base64 (str): Base64编码的语音数据。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 识别出的文本内容,如果失败则返回None。
|
||||
"""
|
||||
# 选择一个支持语音识别的模型
|
||||
model_info, api_provider, client = self.model_selector.select_model()
|
||||
|
||||
from .request_executor import RequestExecutor
|
||||
# 创建请求执行器
|
||||
executor = RequestExecutor(
|
||||
task_name=self.task_name,
|
||||
model_set=self.model_for_task,
|
||||
@@ -98,6 +137,7 @@ class LLMRequest:
|
||||
model_info=model_info,
|
||||
model_selector=self.model_selector,
|
||||
)
|
||||
# 执行语音转文本请求
|
||||
response = await executor.execute_request(
|
||||
request_type="audio",
|
||||
audio_base64=voice_base64,
|
||||
@@ -112,9 +152,24 @@ class LLMRequest:
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
raise_when_empty: bool = True,
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||
"""异步生成响应,支持并发和故障转移"""
|
||||
"""
|
||||
异步生成文本响应,支持并发和故障转移等高级策略。
|
||||
|
||||
Args:
|
||||
prompt (str): 用户输入的提示。
|
||||
temperature (Optional[float], optional): 控制生成文本的随机性。 Defaults to None.
|
||||
max_tokens (Optional[int], optional): 生成响应的最大长度。 Defaults to None.
|
||||
tools (Optional[List[Dict[str, Any]]], optional): 可供模型调用的工具列表。 Defaults to None.
|
||||
raise_when_empty (bool, optional): 如果最终响应为空,是否抛出异常。 Defaults to True.
|
||||
|
||||
Returns:
|
||||
Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||
- 清理后的响应内容。
|
||||
- 一个元组,包含思考过程、最终使用的模型名称和工具调用列表。
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 准备基础请求载荷
|
||||
# 步骤 1: 准备基础请求载荷
|
||||
tool_built = build_tool_options(tools)
|
||||
base_payload = {
|
||||
"prompt": prompt,
|
||||
@@ -124,7 +179,7 @@ class LLMRequest:
|
||||
"prompt_processor": self.prompt_processor,
|
||||
}
|
||||
|
||||
# 2. 根据配置选择执行策略
|
||||
# 步骤 2: 根据配置选择执行策略 (并发或单次带故障转移)
|
||||
concurrency_count = getattr(self.model_for_task, "concurrency_count", 1)
|
||||
|
||||
if concurrency_count <= 1:
|
||||
@@ -138,23 +193,43 @@ class LLMRequest:
|
||||
self.request_strategy.execute_with_fallback,
|
||||
concurrency_count,
|
||||
base_payload,
|
||||
raise_when_empty=False,
|
||||
raise_when_empty=False, # 在并发模式下,单个任务失败不应立即抛出异常
|
||||
)
|
||||
|
||||
# 3. 处理最终结果
|
||||
content, (reasoning_content, model_name, tool_calls) = result
|
||||
# 步骤 3: 处理最终结果
|
||||
content = result.get("content", "")
|
||||
reasoning_content = result.get("reasoning_content", "")
|
||||
model_name = result.get("model_name", "unknown")
|
||||
tool_calls = result.get("tool_calls")
|
||||
|
||||
# 4. 记录用量 (需要从策略中获取最终使用的模型信息和用量)
|
||||
# TODO: 改造策略以返回最终模型信息和用量, 此处暂时省略
|
||||
# 步骤 4: 记录用量 (从策略返回的结果中获取最终使用的模型信息和用量)
|
||||
final_model_info = result.get("model_info")
|
||||
usage = result.get("usage")
|
||||
|
||||
if final_model_info and usage:
|
||||
await self._record_usage(final_model_info, usage, time.time() - start_time)
|
||||
|
||||
return content, (reasoning_content, model_name, tool_calls)
|
||||
|
||||
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
||||
"""获取嵌入向量"""
|
||||
"""
|
||||
获取给定文本的嵌入向量 (Embedding)。
|
||||
|
||||
Args:
|
||||
embedding_input (str): 需要进行嵌入的文本。
|
||||
|
||||
Returns:
|
||||
Tuple[List[float], str]: 嵌入向量列表和所使用的模型名称。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 如果获取embedding失败。
|
||||
"""
|
||||
start_time = time.time()
|
||||
# 选择一个支持embedding的模型
|
||||
model_info, api_provider, client = self.model_selector.select_model()
|
||||
|
||||
from .request_executor import RequestExecutor
|
||||
# 创建请求执行器
|
||||
executor = RequestExecutor(
|
||||
task_name=self.task_name,
|
||||
model_set=self.model_for_task,
|
||||
@@ -163,6 +238,7 @@ class LLMRequest:
|
||||
model_info=model_info,
|
||||
model_selector=self.model_selector,
|
||||
)
|
||||
# 执行embedding请求
|
||||
response = await executor.execute_request(
|
||||
request_type="embedding",
|
||||
embedding_input=embedding_input,
|
||||
@@ -172,17 +248,26 @@ class LLMRequest:
|
||||
if not embedding:
|
||||
raise RuntimeError("获取embedding失败")
|
||||
|
||||
# 记录用量
|
||||
if usage := response.usage:
|
||||
await self._record_usage(model_info, usage, time.time() - start_time, "/embeddings")
|
||||
|
||||
return embedding, model_info.name
|
||||
|
||||
async def _record_usage(self, model_info: ModelInfo, usage, time_cost, endpoint="/chat/completions"):
|
||||
"""记录模型用量"""
|
||||
async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord, time_cost: float, endpoint: str = "/chat/completions"):
|
||||
"""
|
||||
记录模型API的调用用量到数据库。
|
||||
|
||||
Args:
|
||||
model_info (ModelInfo): 使用的模型信息。
|
||||
usage (UsageRecord): 包含token用量信息的对象。
|
||||
time_cost (float): 本次请求的总耗时(秒)。
|
||||
endpoint (str, optional): 请求的API端点。 Defaults to "/chat/completions".
|
||||
"""
|
||||
await llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
user_id="system",
|
||||
user_id="system", # 当前所有请求都以系统用户身份记录
|
||||
time_cost=time_cost,
|
||||
request_type=self.request_type,
|
||||
endpoint=endpoint,
|
||||
|
||||
Reference in New Issue
Block a user