Revert "refactor(llm): 重构 LLM 请求处理,引入通用故障转移执行器"

This reverts commit 6ed9349933.
This commit is contained in:
minecraft1024a
2025-09-24 21:28:42 +08:00
parent 98212bb938
commit 4e3ab4003c

View File

@@ -159,7 +159,7 @@ class LLMRequest:
max_tokens: Optional[int] = None,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
"""
为图像生成响应(已集成故障转移)
为图像生成响应
Args:
prompt (str): 提示词
image_base64 (str): 图像的Base64编码字符串
@@ -167,78 +167,71 @@ class LLMRequest:
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
# 标准化图片格式以确保API兼容性
normalized_format = _normalize_image_format(image_format)
async def request_logic(
model_info: ModelInfo, api_provider: APIProvider, client: BaseClient
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
start_time = time.time()
message_builder = MessageBuilder()
message_builder.add_text_content(prompt)
message_builder.add_image_content(
image_base64=image_base64,
image_format=normalized_format,
support_formats=client.get_support_image_formats(),
)
messages = [message_builder.build()]
# 模型选择
start_time = time.time()
model_info, api_provider, client = self._select_model()
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
# 请求体构建
message_builder = MessageBuilder()
message_builder.add_text_content(prompt)
message_builder.add_image_content(
image_base64=image_base64,
image_format=normalized_format,
support_formats=client.get_support_image_formats(),
)
messages = [message_builder.build()]
# 请求并处理返回值
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=messages,
temperature=temperature,
max_tokens=max_tokens,
)
content = response.content or ""
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
# 从内容中提取<think>标签的推理内容(向后兼容)
if not reasoning_content and content:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
if usage := response.usage:
await llm_usage_recorder.record_usage_to_database(
model_info=model_info,
message_list=messages,
temperature=temperature,
max_tokens=max_tokens,
model_usage=usage,
user_id="system",
time_cost=time.time() - start_time,
request_type=self.request_type,
endpoint="/chat/completions",
)
content = response.content or ""
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
if not reasoning_content and content:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
if usage := response.usage:
await llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system",
time_cost=time.time() - start_time,
request_type=self.request_type,
endpoint="/chat/completions",
)
return content, (reasoning_content, model_info.name, tool_calls)
result = await self._execute_with_failover(request_callable=request_logic, raise_on_failure=True)
if result:
return result
# 这段代码理论上不可达,因为 raise_on_failure=True 会抛出异常
raise RuntimeError("图片响应生成失败,所有模型均尝试失败。")
return content, (reasoning_content, model_info.name, tool_calls)
async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]:
"""
为语音生成响应(已集成故障转移)
为语音生成响应
Args:
voice_base64 (str): 语音的Base64编码字符串
Returns:
(Optional[str]): 生成的文本描述或None
"""
# 模型选择
model_info, api_provider, client = self._select_model()
async def request_logic(model_info: ModelInfo, api_provider: APIProvider, client: BaseClient) -> Optional[str]:
"""定义单次请求的具体逻辑"""
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.AUDIO,
model_info=model_info,
audio_base64=voice_base64,
)
return response.content or None
# 对于语音识别如果所有模型都失败我们可能不希望程序崩溃而是返回None
result = await self._execute_with_failover(request_callable=request_logic, raise_on_failure=False)
return result
# 请求并处理返回值
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.AUDIO,
model_info=model_info,
audio_base64=voice_base64,
)
return response.content or None
async def generate_response_async(
self,
@@ -286,75 +279,6 @@ class LLMRequest:
raise e
return "所有并发请求都失败了", ("", "unknown", None)
async def _execute_with_failover(
self,
request_callable: Callable[[ModelInfo, APIProvider, BaseClient], Coroutine[Any, Any, Any]],
raise_on_failure: bool = True,
) -> Any:
"""
通用的故障转移执行器。
它会使用智能模型调度器按最优顺序尝试模型,直到请求成功或所有模型都失败。
Args:
request_callable: 一个接收 (model_info, api_provider, client) 并返回协程的函数,
用于执行实际的请求逻辑。
raise_on_failure: 如果所有模型都失败,是否抛出异常。
Returns:
请求成功时的返回结果。
Raises:
RuntimeError: 如果所有模型都失败且 raise_on_failure 为 True。
"""
failed_models = set()
last_exception: Optional[Exception] = None
# model_scheduler 现在会动态排序,所以我们只需要在循环中处理失败的模型
while True:
model_scheduler = self._model_scheduler(failed_models)
try:
model_info, api_provider, client = next(model_scheduler)
except StopIteration:
# 没有更多可用模型了
break
model_name = model_info.name
logger.debug(f"正在尝试使用模型: {model_name} (剩余可用: {len(self.model_for_task.model_list) - len(failed_models)})")
try:
# 执行传入的请求函数
result = await request_callable(model_info, api_provider, client)
logger.debug(f"模型 '{model_name}' 成功生成回复。")
return result
except RespNotOkException as e:
# 对于某些致命的HTTP错误如认证失败我们可能希望立即失败或标记该模型为永久失败
if e.status_code in [401, 403]:
logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将永久禁用此模型在此次请求中。")
else:
logger.warning(f"模型 '{model_name}' 请求失败HTTP状态码: {e.status_code},将尝试下一个模型。")
failed_models.add(model_name)
last_exception = e
continue
except Exception as e:
# 捕获其他所有异常(包括超时、解析错误、运行时错误等)
logger.error(f"使用模型 '{model_name}' 时发生异常: {e},将尝试下一个模型。")
failed_models.add(model_name)
last_exception = e
continue
# 所有模型都尝试失败
logger.error("所有可用模型都已尝试失败。")
if raise_on_failure:
if last_exception:
raise RuntimeError("所有模型都请求失败") from last_exception
raise RuntimeError("所有模型都请求失败,且没有具体的异常信息")
# 根据需要返回一个默认的错误结果
return None
async def _execute_single_request(
self,
prompt: str,
@@ -364,67 +288,83 @@ class LLMRequest:
raise_when_empty: bool = True,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
"""
使用通用的故障转移执行器来执行单次文本生成请求
执行单次请求,并在模型失败时按顺序切换到下一个可用模型
"""
failed_models = set()
last_exception: Optional[Exception] = None
async def request_logic(
model_info: ModelInfo, api_provider: APIProvider, client: BaseClient
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
"""定义单次请求的具体逻辑"""
model_scheduler = self._model_scheduler(failed_models)
for model_info, api_provider, client in model_scheduler:
start_time = time.time()
model_name = model_info.name
logger.debug(f"正在尝试使用模型: {model_name}") # 你不许刷屏
# 检查是否启用反截断
use_anti_truncation = getattr(model_info, "use_anti_truncation", False)
processed_prompt = prompt
if use_anti_truncation:
processed_prompt += self.anti_truncation_instruction
logger.info(f"模型 '{model_name}' (任务: '{self.task_name}') 已启用反截断功能。")
processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider)
message_builder = MessageBuilder()
message_builder.add_text_content(processed_prompt)
messages = [message_builder.build()]
tool_built = self._build_tool_options(tools)
# 针对当前模型的空回复/截断重试逻辑
empty_retry_count = 0
max_empty_retry = api_provider.max_retry
empty_retry_interval = api_provider.retry_interval
is_empty_reply = False
is_truncated = False
while empty_retry_count <= max_empty_retry:
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=messages,
tool_options=tool_built,
temperature=temperature,
max_tokens=max_tokens,
)
content = response.content or ""
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
if not reasoning_content and content:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
is_empty_reply = not tool_calls and (not content or content.strip() == "")
is_truncated = False
try:
# 检查是否启用反截断
# 检查是否为该模型启用反截断
use_anti_truncation = getattr(model_info, "use_anti_truncation", False)
processed_prompt = prompt
if use_anti_truncation:
if content.endswith(self.end_marker):
content = content[: -len(self.end_marker)].strip()
else:
is_truncated = True
processed_prompt += self.anti_truncation_instruction
logger.info(f"模型 '{model_name}' (任务: '{self.task_name}') 已启用反截断功能。")
processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider)
message_builder = MessageBuilder()
message_builder.add_text_content(processed_prompt)
messages = [message_builder.build()]
tool_built = self._build_tool_options(tools)
# 针对当前模型的空回复/截断重试逻辑
empty_retry_count = 0
max_empty_retry = api_provider.max_retry
empty_retry_interval = api_provider.retry_interval
while empty_retry_count <= max_empty_retry:
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=messages,
tool_options=tool_built,
temperature=temperature,
max_tokens=max_tokens,
)
content = response.content or ""
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
if not reasoning_content and content:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
is_empty_reply = not tool_calls and (not content or content.strip() == "")
is_truncated = False
if use_anti_truncation:
if content.endswith(self.end_marker):
content = content[: -len(self.end_marker)].strip()
else:
is_truncated = True
if is_empty_reply or is_truncated:
empty_retry_count += 1
if empty_retry_count <= max_empty_retry:
reason = "空回复" if is_empty_reply else "截断"
logger.warning(
f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成..."
)
if empty_retry_interval > 0:
await asyncio.sleep(empty_retry_interval)
continue # 继续使用当前模型重试
else:
# 当前模型重试次数用尽,跳出内层循环,触发外层循环切换模型
reason = "空回复" if is_empty_reply else "截断"
logger.error(f"模型 '{model_name}' 经过 {max_empty_retry} 次重试后仍然是{reason}的回复。")
raise RuntimeError(f"模型 '{model_name}' 达到最大空回复/截断重试次数")
if not is_empty_reply and not is_truncated:
# 成功获取响应
if usage := response.usage:
await llm_usage_recorder.record_usage_to_database(
@@ -441,115 +381,115 @@ class LLMRequest:
raise RuntimeError("生成空回复")
content = "生成的响应为空"
logger.debug(f"模型 '{model_name}' 成功生成回复。") # 你也不许刷屏
return content, (reasoning_content, model_name, tool_calls)
# 如果代码执行到这里,说明是空回复或截断,需要重试
empty_retry_count += 1
if empty_retry_count <= max_empty_retry:
reason = "空回复" if is_empty_reply else "截断"
logger.warning(
f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成..."
)
if empty_retry_interval > 0:
await asyncio.sleep(empty_retry_interval)
continue # 继续使用当前模型重试
except RespNotOkException as e:
if e.status_code in [401, 403]:
logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将尝试下一个模型。")
failed_models.add(model_name)
last_exception = e
continue # 切换到下一个模型
else:
logger.error(f"模型 '{model_name}' 请求失败HTTP状态码: {e.status_code}")
if raise_when_empty:
raise
# 对于其他HTTP错误直接抛出不再尝试其他模型
return f"请求失败: {e}", ("", model_name, None)
# 如果循环结束,说明重试次数已用尽
reason = "空回复" if is_empty_reply else "截断"
logger.error(f"模型 '{model_name}' 经过 {max_empty_retry}重试后仍然{reason}的回复")
raise RuntimeError(f"模型 '{model_name}' 达到最大空回复/截断重试次数")
except RuntimeError as e:
# 捕获所有重试失败(包括空回复和网络问题)
logger.error(f"模型 '{model_name}' 在所有重试后仍然失败: {e},将尝试下一个模型")
failed_models.add(model_name)
last_exception = e
continue # 切换到下一个模型
# 调用通用的故障转移执行器
result = await self._execute_with_failover(
request_callable=request_logic, raise_on_failure=raise_when_empty
)
except Exception as e:
logger.error(f"使用模型 '{model_name}' 时发生未知异常: {e}")
failed_models.add(model_name)
last_exception = e
continue # 切换到下一个模型
if result:
return result
# 所有模型都尝试失败
logger.error("所有可用模型都已尝试失败。")
if raise_when_empty:
if last_exception:
raise RuntimeError("所有模型都请求失败") from last_exception
raise RuntimeError("所有模型都请求失败,且没有具体的异常信息")
# 如果所有模型都失败了,并且不抛出异常,返回一个默认的错误信息
return "所有模型都请求失败", ("", "unknown", None)
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
"""获取嵌入向量(已集成故障转移)
"""获取嵌入向量
Args:
embedding_input (str): 获取嵌入的目标
Returns:
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
"""
# 无需构建消息体,直接使用输入文本
start_time = time.time()
model_info, api_provider, client = self._select_model()
async def request_logic(
model_info: ModelInfo, api_provider: APIProvider, client: BaseClient
) -> Tuple[List[float], str]:
"""定义单次请求的具体逻辑"""
start_time = time.time()
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.EMBEDDING,
model_info=model_info,
embedding_input=embedding_input,
)
embedding = response.embedding
if not embedding:
raise RuntimeError(f"模型 '{model_info.name}'未能返回 embedding。")
if usage := response.usage:
await llm_usage_recorder.record_usage_to_database(
model_info=model_info,
time_cost=time.time() - start_time,
model_usage=usage,
user_id="system",
request_type=self.request_type,
endpoint="/embeddings",
)
return embedding, model_info.name
result = await self._execute_with_failover(request_callable=request_logic, raise_on_failure=True)
if result:
return result
# 这段代码理论上不可达,因为 raise_on_failure=True 会抛出异常
raise RuntimeError("获取 embedding 失败,所有模型均尝试失败。")
def _model_scheduler(
self, failed_models: set | None = None
) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]:
"""
一个智能模型调度器,根据实时负载动态排序并提供模型,同时跳过已失败的模型。
"""
# sourcery skip: class-extract-method
if failed_models is None:
failed_models = set()
# 1. 筛选出所有未失败的可用模型
available_models = [name for name in self.model_for_task.model_list if name not in failed_models]
# 2. 根据负载均衡算法对可用模型进行排序
# key: total_tokens + penalty * 300 + usage_penalty * 1000
sorted_models = sorted(
available_models,
key=lambda name: self.model_usage[name][0]
+ self.model_usage[name][1] * 300
+ self.model_usage[name][2] * 1000,
# 请求并处理返回值
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.EMBEDDING,
model_info=model_info,
embedding_input=embedding_input,
)
if not sorted_models:
logger.warning("所有模型都已失败或不可用,调度器无法提供任何模型。")
return
embedding = response.embedding
logger.debug(f"模型调度顺序: {', '.join(sorted_models)}")
if usage := response.usage:
await llm_usage_recorder.record_usage_to_database(
model_info=model_info,
time_cost=time.time() - start_time,
model_usage=usage,
user_id="system",
request_type=self.request_type,
endpoint="/embeddings",
)
if not embedding:
raise RuntimeError("获取embedding失败")
return embedding, model_info.name
def _model_scheduler(self, failed_models: set) -> Generator[Tuple[ModelInfo, APIProvider, BaseClient], None, None]:
"""
一个模型调度器,按顺序提供模型,并跳过已失败的模型。
"""
for model_name in self.model_for_task.model_list:
if model_name in failed_models:
continue
# 3. 按最优顺序 yield 模型信息
for model_name in sorted_models:
model_info = model_config.get_model_info(model_name)
api_provider = model_config.get_provider(model_info.api_provider)
force_new_client = self.request_type == "embedding"
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
yield model_info, api_provider, client
def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
"""
根据总tokens和惩罚值选择的模型 (负载均衡)
"""
least_used_model_name = min(
self.model_usage,
key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000,
)
model_info = model_config.get_model_info(least_used_model_name)
api_provider = model_config.get_provider(model_info.api_provider)
# 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题
force_new_client = self.request_type == "embedding"
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
logger.debug(f"选择请求模型: {model_info.name}")
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用
return model_info, api_provider, client
async def _execute_request(
self,
api_provider: APIProvider,
@@ -573,73 +513,63 @@ class LLMRequest:
"""
retry_remain = api_provider.max_retry
compressed_messages: Optional[List[Message]] = None
# 增加使用惩罚值,标记该模型正在被尝试
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1)
try:
while retry_remain > 0:
try:
if request_type == RequestType.RESPONSE:
assert message_list is not None, "message_list cannot be None for response requests"
return await client.get_response(
model_info=model_info,
message_list=(compressed_messages or message_list),
tool_options=tool_options,
max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
temperature=self.model_for_task.temperature if temperature is None else temperature,
response_format=response_format,
stream_response_handler=stream_response_handler,
async_response_parser=async_response_parser,
extra_params=model_info.extra_params,
)
elif request_type == RequestType.EMBEDDING:
assert embedding_input, "embedding_input cannot be empty for embedding requests"
return await client.get_embedding(
model_info=model_info,
embedding_input=embedding_input,
extra_params=model_info.extra_params,
)
elif request_type == RequestType.AUDIO:
assert audio_base64 is not None, "audio_base64 cannot be None for audio requests"
return await client.get_audio_transcriptions(
model_info=model_info,
audio_base64=audio_base64,
extra_params=model_info.extra_params,
)
except Exception as e:
logger.debug(f"请求失败: {str(e)}")
# 处理异常
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty)
wait_interval, compressed_messages = self._default_exception_handler(
e,
self.task_name,
while retry_remain > 0:
try:
if request_type == RequestType.RESPONSE:
assert message_list is not None, "message_list cannot be None for response requests"
return await client.get_response(
model_info=model_info,
api_provider=api_provider,
remain_try=retry_remain,
retry_interval=api_provider.retry_interval,
messages=(message_list, compressed_messages is not None) if message_list else None,
message_list=(compressed_messages or message_list),
tool_options=tool_options,
max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
temperature=self.model_for_task.temperature if temperature is None else temperature,
response_format=response_format,
stream_response_handler=stream_response_handler,
async_response_parser=async_response_parser,
extra_params=model_info.extra_params,
)
elif request_type == RequestType.EMBEDDING:
assert embedding_input, "embedding_input cannot be empty for embedding requests"
return await client.get_embedding(
model_info=model_info,
embedding_input=embedding_input,
extra_params=model_info.extra_params,
)
elif request_type == RequestType.AUDIO:
assert audio_base64 is not None, "audio_base64 cannot be None for audio requests"
return await client.get_audio_transcriptions(
model_info=model_info,
audio_base64=audio_base64,
extra_params=model_info.extra_params,
)
except Exception as e:
logger.debug(f"请求失败: {str(e)}")
# 处理异常
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty)
if wait_interval == -1:
retry_remain = 0 # 不再重试
elif wait_interval > 0:
logger.info(f"等待 {wait_interval} 秒后重试...")
await asyncio.sleep(wait_interval)
finally:
# 放在finally防止死循环
retry_remain -= 1
wait_interval, compressed_messages = self._default_exception_handler(
e,
self.task_name,
model_info=model_info,
api_provider=api_provider,
remain_try=retry_remain,
retry_interval=api_provider.retry_interval,
messages=(message_list, compressed_messages is not None) if message_list else None,
)
# 当请求完全结束(无论是成功还是所有重试都失败),都将在此处处理
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry}")
raise RuntimeError("请求失败,已达到最大重试次数")
finally:
# 无论请求成功或失败,最终都将使用惩罚值减回去
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1)
if wait_interval == -1:
retry_remain = 0 # 不再重试
elif wait_interval > 0:
logger.info(f"等待 {wait_interval} 秒后重试...")
await asyncio.sleep(wait_interval)
finally:
# 放在finally防止死循环
retry_remain -= 1
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) # 使用结束,减少使用惩罚值
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry}")
raise RuntimeError("请求失败,已达到最大重试次数")
def _default_exception_handler(
self,