refactor(llm_models): 将 LLMRequest 重构为模块化的策略驱动架构

此次重构旨在分解原有的单体 `LLMRequest` 类,以提高代码的可维护性、可扩展性和健壮性。通过引入多个遵循单一职责原则的内部辅助类,请求生命周期的各个阶段被清晰地分离开来。

主要变更包括:

- **引入 `_ModelSelector`**: 专门负责模型的动态选择、负载均衡和失败惩罚策略。该策略现在能对网络错误和服务器错误等严重问题施加更高的惩罚。

- **引入 `_PromptProcessor`**: 封装所有与提示词相关的处理逻辑,包括内容混淆、反截断指令注入以及响应内容的后处理(如提取思考过程)。

- **引入 `_RequestExecutor`**: 负责执行底层的API请求,包含自动重试、异常分类和消息体压缩等功能。

- **引入 `_RequestStrategy`**: 实现高阶请求策略,如模型间的故障转移(Failover),确保单个模型的失败不会导致整个请求失败。

`LLMRequest` 类现在作为外观(Facade),协调这些内部组件,为上层调用提供了更简洁、稳定的接口。
This commit is contained in:
minecraft1024a
2025-09-26 21:17:34 +08:00
committed by Windpicker-owo
parent f2d02572fb
commit ac9321ff80

View File

@@ -1,5 +1,4 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
""" """
@desc: 该模块封装了与大语言模型LLM交互的所有核心逻辑。 @desc: 该模块封装了与大语言模型LLM交互的所有核心逻辑。
它被设计为一个高度容错和可扩展的系统,包含以下主要组件: 它被设计为一个高度容错和可扩展的系统,包含以下主要组件:
@@ -92,7 +91,6 @@ async def execute_concurrently(
""" """
logger.info(f"启用并发请求模式,并发数: {concurrency_count}") logger.info(f"启用并发请求模式,并发数: {concurrency_count}")
tasks = [coro_callable(*args, **kwargs) for _ in range(concurrency_count)] tasks = [coro_callable(*args, **kwargs) for _ in range(concurrency_count)]
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
successful_results = [res for res in results if not isinstance(res, Exception)] successful_results = [res for res in results if not isinstance(res, Exception)]
@@ -765,7 +763,8 @@ class LLMRequest:
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
""" """
为图像生成响应 为图像生成响应
Args: Args:
prompt (str): 提示词 prompt (str): 提示词
image_base64 (str): 图像的Base64编码字符串 image_base64 (str): 图像的Base64编码字符串
@@ -774,70 +773,47 @@ class LLMRequest:
Returns: Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
""" """
# 标准化图片格式以确保API兼容性
normalized_format = _normalize_image_format(image_format)
# 模型选择
start_time = time.time() start_time = time.time()
model_info, api_provider, client = self._select_model()
# 图像请求目前不使用复杂的故障转移策略,直接选择模型并执行
# 请求体构建 selection_result = self._model_selector.select_best_available_model(set(), "response")
message_builder = MessageBuilder() if not selection_result:
message_builder.add_text_content(prompt) raise RuntimeError("无法为图像响应选择可用模型。")
message_builder.add_image_content( model_info, api_provider, client = selection_result
normalized_format = _normalize_image_format(image_format)
message = MessageBuilder().add_text_content(prompt).add_image_content(
image_base64=image_base64, image_base64=image_base64,
image_format=normalized_format, image_format=normalized_format,
support_formats=client.get_support_image_formats(), support_formats=client.get_support_image_formats(),
) ).build()
messages = [message_builder.build()]
# 请求并处理返回值 response = await self._executor.execute_request(
response = await self._execute_request( api_provider, client, RequestType.RESPONSE, model_info,
api_provider=api_provider, message_list=[message],
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=messages,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
) )
content = response.content or ""
reasoning_content = response.reasoning_content or "" self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
tool_calls = response.tool_calls content, reasoning, _ = self._prompt_processor.process_response(response.content or "", False)
# 从内容中提取<think>标签的推理内容(向后兼容) reasoning = response.reasoning_content or reasoning
if not reasoning_content and content:
content, extracted_reasoning = self._extract_reasoning(content) return content, (reasoning, model_info.name, response.tool_calls)
reasoning_content = extracted_reasoning
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system",
time_cost=time.time() - start_time,
request_type=self.request_type,
endpoint="/chat/completions",
)
return content, (reasoning_content, model_info.name, tool_calls)
async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]: async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]:
""" """
为语音生成响应 为语音生成响应(语音转文字)。
使用故障转移策略来确保即使主模型失败也能获得结果。
Args: Args:
voice_base64 (str): 语音的Base64编码字符串。 voice_base64 (str): 语音的Base64编码字符串。
Returns: Returns:
(Optional[str]): 生成的文本描述或None Optional[str]: 语音转换后的文本内容,如果所有模型都失败则返回None
""" """
# 模型选择 response, _ = await self._strategy.execute_with_failover(
model_info, api_provider, client = self._select_model() RequestType.AUDIO, audio_base64=voice_base64
# 请求并处理返回值
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.AUDIO,
model_info=model_info,
audio_base64=voice_base64,
) )
return response.content or None return response.content or None
@@ -879,7 +855,7 @@ class LLMRequest:
raise e raise e
return "所有并发请求都失败了", ("", "unknown", None) return "所有并发请求都失败了", ("", "unknown", None)
async def _execute_single_request( async def _execute_single_text_request(
self, self,
prompt: str, prompt: str,
temperature: Optional[float] = None, temperature: Optional[float] = None,
@@ -888,323 +864,100 @@ class LLMRequest:
raise_when_empty: bool = True, raise_when_empty: bool = True,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
""" """
执行单次请求,并在模型失败时按顺序切换到下一个可用模型 执行单次文本生成请求的内部方法
这是 `generate_response_async` 的核心实现,处理单个请求的完整生命周期,
包括工具构建、故障转移执行和用量记录。
Args:
prompt (str): 用户的提示。
temperature (Optional[float]): 生成温度。
max_tokens (Optional[int]): 最大生成令牌数。
tools (Optional[List[Dict[str, Any]]]): 可用工具列表。
raise_when_empty (bool): 如果响应为空是否引发异常。
Returns:
Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
(响应内容, (推理过程, 模型名称, 工具调用))
""" """
failed_models = set() start_time = time.time()
last_exception: Optional[Exception] = None tool_options = self._build_tool_options(tools)
model_scheduler = self._model_scheduler(failed_models) response, model_info = await self._strategy.execute_with_failover(
RequestType.RESPONSE,
raise_when_empty=raise_when_empty,
prompt=prompt, # 传递原始prompt由strategy处理
tool_options=tool_options,
temperature=self.model_for_task.temperature if temperature is None else temperature,
max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
)
for model_info, api_provider, client in model_scheduler: self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
start_time = time.time()
model_name = model_info.name
logger.debug(f"正在尝试使用模型: {model_name}") # 你不许刷屏
try: if not response.content and not response.tool_calls:
# 检查是否启用反截断 if raise_when_empty:
# 检查是否为该模型启用反截断 raise RuntimeError("所选模型生成了空回复。")
use_anti_truncation = getattr(model_info, "use_anti_truncation", False) response.content = "生成的响应为空"
processed_prompt = prompt
if use_anti_truncation:
processed_prompt += self.anti_truncation_instruction
logger.info(f"模型 '{model_name}' (任务: '{self.task_name}') 已启用反截断功能。")
processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider) return response.content or "", (response.reasoning_content or "", model_info.name, response.tool_calls)
message_builder = MessageBuilder()
message_builder.add_text_content(processed_prompt)
messages = [message_builder.build()]
tool_built = self._build_tool_options(tools)
# 针对当前模型的空回复/截断重试逻辑
empty_retry_count = 0
max_empty_retry = api_provider.max_retry
empty_retry_interval = api_provider.retry_interval
while empty_retry_count <= max_empty_retry:
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=messages,
tool_options=tool_built,
temperature=temperature,
max_tokens=max_tokens,
)
content = response.content or ""
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
if not reasoning_content and content:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
is_empty_reply = not tool_calls and (not content or content.strip() == "")
is_truncated = False
if use_anti_truncation:
if content.endswith(self.end_marker):
content = content[: -len(self.end_marker)].strip()
else:
is_truncated = True
if is_empty_reply or is_truncated:
empty_retry_count += 1
if empty_retry_count <= max_empty_retry:
reason = "空回复" if is_empty_reply else "截断"
logger.warning(
f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成..."
)
if empty_retry_interval > 0:
await asyncio.sleep(empty_retry_interval)
continue # 继续使用当前模型重试
else:
# 当前模型重试次数用尽,跳出内层循环,触发外层循环切换模型
reason = "空回复" if is_empty_reply else "截断"
logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次重试后仍然是{reason}的回复。")
raise RuntimeError(f"模型 '{model_name}' 达到最大空回复/截断重试次数")
# 成功获取响应
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
time_cost=time.time() - start_time,
user_id="system",
request_type=self.request_type,
endpoint="/chat/completions",
)
if not content and not tool_calls:
if raise_when_empty:
raise RuntimeError("生成空回复")
content = "生成的响应为空"
logger.debug(f"模型 '{model_name}' 成功生成回复。") # 你也不许刷屏
return content, (reasoning_content, model_name, tool_calls)
except RespNotOkException as e:
if e.status_code in [401, 403]:
logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将尝试下一个模型。")
failed_models.add(model_name)
last_exception = e
continue # 切换到下一个模型
else:
logger.error(f"模型 '{model_name}' 请求失败HTTP状态码: {e.status_code}")
if raise_when_empty:
raise
# 对于其他HTTP错误直接抛出不再尝试其他模型
return f"请求失败: {e}", ("", model_name, None)
except RuntimeError as e:
# 捕获所有重试失败(包括空回复和网络问题)
logger.error(f"模型 '{model_name}' 在所有重试后仍然失败: {e},将尝试下一个模型。")
failed_models.add(model_name)
last_exception = e
continue # 切换到下一个模型
except Exception as e:
logger.error(f"使用模型 '{model_name}' 时发生未知异常: {e}")
failed_models.add(model_name)
last_exception = e
continue # 切换到下一个模型
# 所有模型都尝试失败
logger.error("所有可用模型都已尝试失败。")
if raise_when_empty:
if last_exception:
raise RuntimeError("所有模型都请求失败") from last_exception
raise RuntimeError("所有模型都请求失败,且没有具体的异常信息")
return "所有模型都请求失败", ("", "unknown", None)
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
"""获取嵌入向量 """
获取嵌入向量。
Args: Args:
embedding_input (str): 获取嵌入的目标 embedding_input (str): 获取嵌入的目标
Returns: Returns:
(Tuple[List[float], str]): (嵌入向量,使用的模型名称) (Tuple[List[float], str]): (嵌入向量,使用的模型名称)
""" """
# 无需构建消息体,直接使用输入文本
start_time = time.time() start_time = time.time()
model_info, api_provider, client = self._select_model() response, model_info = await self._strategy.execute_with_failover(
RequestType.EMBEDDING,
# 请求并处理返回值 embedding_input=embedding_input
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.EMBEDDING,
model_info=model_info,
embedding_input=embedding_input,
) )
embedding = response.embedding self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings")
if usage := response.usage: if not response.embedding:
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
time_cost=time.time() - start_time,
model_usage=usage,
user_id="system",
request_type=self.request_type,
endpoint="/embeddings",
)
if not embedding:
raise RuntimeError("获取embedding失败") raise RuntimeError("获取embedding失败")
return response.embedding, model_info.name
return embedding, model_info.name def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str):
def _model_scheduler(self, failed_models: set) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]:
""" """
一个模型调度器,按顺序提供模型,并跳过已失败的模型 记录模型使用情况
此方法首先在内存中更新模型的累计token使用量然后创建一个异步任务
将详细的用量数据包括模型信息、token数、耗时等写入数据库。
Args:
model_info (ModelInfo): 使用的模型信息。
usage (Optional[UsageRecord]): API返回的用量记录。
time_cost (float): 本次请求的总耗时。
endpoint (str): 请求的API端点 (e.g., "/chat/completions")。
""" """
for model_name in self.model_for_task.model_list: if usage:
if model_name in failed_models: # 步骤1: 更新内存中的token计数用于负载均衡
continue total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens + usage.total_tokens, penalty, usage_penalty)
model_info = model_config.get_model_info(model_name)
api_provider = model_config.get_provider(model_info.api_provider) # 步骤2: 创建一个后台任务,将用量数据异步写入数据库
force_new_client = self.request_type == "embedding" asyncio.create_task(llm_usage_recorder.record_usage_to_database(
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) model_info=model_info,
model_usage=usage,
yield model_info, api_provider, client user_id="system", # 此处可根据业务需求修改
time_cost=time_cost,
def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]: request_type=self.task_name,
""" endpoint=endpoint,
根据总tokens和惩罚值选择的模型 (负载均衡) ))
"""
least_used_model_name = min(
self.model_usage,
key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000,
)
model_info = model_config.get_model_info(least_used_model_name)
api_provider = model_config.get_provider(model_info.api_provider)
# 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题
force_new_client = self.request_type == "embedding"
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
logger.debug(f"选择请求模型: {model_info.name}")
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用
return model_info, api_provider, client
async def _execute_request(
self,
api_provider: APIProvider,
client: BaseClient,
request_type: RequestType,
model_info: ModelInfo,
message_list: List[Message] | None = None,
tool_options: list[ToolOption] | None = None,
response_format: RespFormat | None = None,
stream_response_handler: Optional[Callable] = None,
async_response_parser: Optional[Callable] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
embedding_input: str = "",
audio_base64: str = "",
) -> APIResponse:
"""
实际执行请求的方法
包含了重试和异常处理逻辑
"""
retry_remain = api_provider.max_retry
compressed_messages: Optional[List[Message]] = None
while retry_remain > 0:
try:
if request_type == RequestType.RESPONSE:
assert message_list is not None, "message_list cannot be None for response requests"
return await client.get_response(
model_info=model_info,
message_list=(compressed_messages or message_list),
tool_options=tool_options,
max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
temperature=self.model_for_task.temperature if temperature is None else temperature,
response_format=response_format,
stream_response_handler=stream_response_handler,
async_response_parser=async_response_parser,
extra_params=model_info.extra_params,
)
elif request_type == RequestType.EMBEDDING:
assert embedding_input, "embedding_input cannot be empty for embedding requests"
return await client.get_embedding(
model_info=model_info,
embedding_input=embedding_input,
extra_params=model_info.extra_params,
)
elif request_type == RequestType.AUDIO:
assert audio_base64 is not None, "audio_base64 cannot be None for audio requests"
return await client.get_audio_transcriptions(
model_info=model_info,
audio_base64=audio_base64,
extra_params=model_info.extra_params,
)
except Exception as e:
logger.debug(f"请求失败: {str(e)}")
# 处理异常
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty)
wait_interval, compressed_messages = self._default_exception_handler(
e,
self.task_name,
model_info=model_info,
api_provider=api_provider,
remain_try=retry_remain,
retry_interval=api_provider.retry_interval,
messages=(message_list, compressed_messages is not None) if message_list else None,
)
if wait_interval == -1:
retry_remain = 0 # 不再重试
elif wait_interval > 0:
logger.info(f"等待 {wait_interval} 秒后重试...")
await asyncio.sleep(wait_interval)
finally:
# 放在finally防止死循环
retry_remain -= 1
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) # 使用结束,减少使用惩罚值
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry}")
raise RuntimeError("请求失败,已达到最大重试次数")
@staticmethod @staticmethod
def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]: def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
""" """
根据输入的字典列表构建并验证 `ToolOption` 对象列表。 根据输入的字典列表构建并验证 `ToolOption` 对象列表。
if isinstance(e, NetworkConnectionError): # 网络连接错误 此方法将标准化的工具定义(字典格式)转换为内部使用的 `ToolOption` 对象,
return self._check_retry( 同时会验证参数格式的正确性。
remain_try,
retry_interval,
can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,将于{retry_interval}秒后重试",
cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常超过最大重试次数请检查网络连接状态或URL是否正确",
)
elif isinstance(e, ReqAbortException):
logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}")
return -1, None # 不再重试请求该模型
elif isinstance(e, RespNotOkException):
return self._handle_resp_not_ok(
e,
task_name,
model_info,
api_provider,
remain_try,
retry_interval,
messages,
)
elif isinstance(e, RespParseException):
# 响应解析错误
logger.error(f"任务-'{task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}")
logger.debug(f"附加内容: {str(e.ext_info)}")
return -1, None # 不再重试请求该模型
else:
logger.error(f"任务-'{task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}")
return -1, None # 不再重试请求该模型
Args: Args:
tools (Optional[List[Dict[str, Any]]]): 工具定义的列表。 tools (Optional[List[Dict[str, Any]]]): 工具定义的列表。
@@ -1214,72 +967,7 @@ class LLMRequest:
Returns: Returns:
Optional[List[ToolOption]]: 构建好的 `ToolOption` 对象列表,如果输入为空则返回 None。 Optional[List[ToolOption]]: 构建好的 `ToolOption` 对象列表,如果输入为空则返回 None。
""" """
# 响应错误 # 如果没有提供工具,直接返回 None
if e.status_code in [400, 401, 402, 403, 404]:
model_name = model_info.name
if (
e.status_code == 403
and model_name.startswith("Pro/deepseek-ai")
and api_provider.base_url == "https://api.siliconflow.cn/v1/"
):
old_model_name = model_name
new_model_name = model_name[4:]
model_info.name = new_model_name
logger.warning(f"检测到403错误模型从 {old_model_name} 降级为 {new_model_name}")
# 更新任务配置中的模型列表
for i, m_name in enumerate(self.model_for_task.model_list):
if m_name == old_model_name:
self.model_for_task.model_list[i] = new_model_name
logger.warning(
f"将任务 {self.task_name} 的模型列表中的 {old_model_name} 临时降级至 {new_model_name}"
)
break
return 0, None # 立即重试
# 客户端错误
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}"
)
return -1, None # 不再重试请求该模型
elif e.status_code == 413:
if messages and not messages[1]:
# 消息列表不为空且未压缩,尝试压缩消息
return self._check_retry(
remain_try,
0,
can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试",
cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,压缩消息后仍然过大,放弃请求",
can_retry_callable=compress_messages,
messages=messages[0],
)
# 没有消息可压缩
logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,无法压缩消息,放弃请求。")
return -1, None
elif e.status_code == 429:
# 请求过于频繁
return self._check_retry(
remain_try,
retry_interval,
can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试",
cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数,放弃请求",
)
elif e.status_code >= 500:
# 服务器错误
return self._check_retry(
remain_try,
retry_interval,
can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,将于{retry_interval}秒后重试",
cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,超过最大重试次数,请稍后再试",
)
else:
# 未知错误
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}': 未知错误,错误代码-{e.status_code},错误信息-{e.message}"
)
return -1, None
def _build_tool_options(self, tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
# sourcery skip: extract-method
"""构建工具选项列表"""
if not tools: if not tools:
return None return None