拆分_execute_request

This commit is contained in:
UnCLAS-Prommer
2025-04-19 22:48:30 +08:00
parent 41c3ab7e3b
commit 46a5b01a13
2 changed files with 452 additions and 1236 deletions

View File

@@ -2,9 +2,11 @@ import asyncio
import json import json
import re import re
from datetime import datetime from datetime import datetime
from typing import Tuple, Union from typing import Tuple, Union, Dict, Any
import aiohttp import aiohttp
from aiohttp.client import ClientResponse
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
import base64 import base64
from PIL import Image from PIL import Image
@@ -16,6 +18,53 @@ from ...config.config import global_config
logger = get_module_logger("model_utils") logger = get_module_logger("model_utils")
class PayLoadTooLargeError(Exception):
"""自定义异常类,用于处理请求体过大错误"""
def __init__(self, message: str):
super().__init__(message)
self.message = message
def __str__(self):
return "请求体过大,请尝试压缩图片或减少输入内容。"
class RequestAbortException(Exception):
"""自定义异常类,用于处理请求中断异常"""
def __init__(self, message: str, response: ClientResponse):
super().__init__(message)
self.message = message
self.response = response
def __str__(self):
return self.message
class PermissionDeniedException(Exception):
"""自定义异常类,用于处理访问拒绝的异常"""
def __init__(self, message: str):
super().__init__(message)
self.message = message
def __str__(self):
return self.message
# 常见Error Code Mapping
error_code_mapping = {
400: "参数不正确",
401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~",
402: "账号余额不足",
403: "需要实名,或余额不足",
404: "Not Found",
429: "请求过于频繁,请稍后再试",
500: "服务器内部故障",
503: "服务器负载过高",
}
class LLMRequest: class LLMRequest:
# 定义需要转换的模型列表,作为类变量避免重复 # 定义需要转换的模型列表,作为类变量避免重复
MODELS_NEEDING_TRANSFORMATION = [ MODELS_NEEDING_TRANSFORMATION = [
@@ -28,7 +77,7 @@ class LLMRequest:
"o1-mini-2024-09-12", "o1-mini-2024-09-12",
] ]
def __init__(self, model, **kwargs): def __init__(self, model: dict, **kwargs):
# 将大写的配置键转换为小写并从config中获取实际值 # 将大写的配置键转换为小写并从config中获取实际值
try: try:
self.api_key = os.environ[model["key"]] self.api_key = os.environ[model["key"]]
@@ -37,7 +86,7 @@ class LLMRequest:
logger.error(f"原始 model dict 信息:{model}") logger.error(f"原始 model dict 信息:{model}")
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}") logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
self.model_name = model["name"] self.model_name: str = model["name"]
self.params = kwargs self.params = kwargs
self.stream = model.get("stream", False) self.stream = model.get("stream", False)
@@ -123,6 +172,7 @@ class LLMRequest:
output_cost = (completion_tokens / 1000000) * self.pri_out output_cost = (completion_tokens / 1000000) * self.pri_out
return round(input_cost + output_cost, 6) return round(input_cost + output_cost, 6)
'''
async def _execute_request( async def _execute_request(
self, self,
endpoint: str, endpoint: str,
@@ -509,6 +559,405 @@ class LLMRequest:
logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败") logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败")
raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数API请求仍然失败") raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数API请求仍然失败")
'''
async def _prepare_request(
self,
endpoint: str,
prompt: str = None,
image_base64: str = None,
image_format: str = None,
payload: dict = None,
retry_policy: dict = None,
) -> Dict[str, Any]:
"""配置请求参数
Args:
endpoint: API端点路径 (如 "chat/completions")
prompt: prompt文本
image_base64: 图片的base64编码
image_format: 图片格式
payload: 请求体数据
retry_policy: 自定义重试策略
request_type: 请求类型
"""
# 合并重试策略
default_retry = {
"max_retries": 3,
"base_wait": 10,
"retry_codes": [429, 413, 500, 503],
"abort_codes": [400, 401, 402, 403],
}
policy = {**default_retry, **(retry_policy or {})}
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
stream_mode = self.stream
# 构建请求体
if image_base64:
payload = await self._build_payload(prompt, image_base64, image_format)
elif payload is None:
payload = await self._build_payload(prompt)
if stream_mode:
payload["stream"] = stream_mode
return {
"policy": policy,
"payload": payload,
"api_url": api_url,
"stream_mode": stream_mode,
"image_base64": image_base64, # 保留必要的exception处理所需的原始数据
"image_format": image_format,
"prompt": prompt,
}
async def _execute_request(
self,
endpoint: str,
prompt: str = None,
image_base64: str = None,
image_format: str = None,
payload: dict = None,
retry_policy: dict = None,
response_handler: callable = None,
user_id: str = "system",
request_type: str = None,
):
"""统一请求执行入口
Args:
endpoint: API端点路径 (如 "chat/completions")
prompt: prompt文本
image_base64: 图片的base64编码
image_format: 图片格式
payload: 请求体数据
retry_policy: 自定义重试策略
response_handler: 自定义响应处理器
user_id: 用户ID
request_type: 请求类型
"""
# 获取请求配置
request_content = await self._prepare_request(
endpoint, prompt, image_base64, image_format, payload, retry_policy
)
if request_type is None:
request_type = self.request_type
for retry in range(request_content["policy"]["max_retries"]):
try:
# 使用上下文管理器处理会话
headers = await self._build_headers()
# 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
if request_content["stream_mode"]:
headers["Accept"] = "text/event-stream"
async with aiohttp.ClientSession() as session:
async with session.post(
request_content["api_url"], headers=headers, json=request_content["payload"]
) as response:
handled_result = await self._handle_response(
response, request_content, retry, response_handler, user_id, request_type, endpoint
)
return handled_result
except Exception as e:
handled_payload, count_delta = await self._handle_exception(e, retry, request_content)
retry += count_delta # 降级不计入重试次数
if handled_payload:
# 如果降级成功,重新构建请求体
request_content["payload"] = handled_payload
continue
logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败")
raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数API请求仍然失败")
async def _handle_response(
self,
response: ClientResponse,
request_content: Dict[str, Any],
retry_count: int,
response_handler: callable,
user_id,
request_type,
endpoint,
) -> Union[Dict[str, Any], None]:
policy = request_content["policy"]
stream_mode = request_content["stream_mode"]
if response.status in policy["retry_codes"] or response.status in policy["abort_codes"]:
await self._handle_error_response(response, retry_count, policy)
return
response.raise_for_status()
result = {}
if stream_mode:
# 将流式输出转化为非流式输出
result = await self._handle_stream_output(response)
else:
result = await response.json()
return (
response_handler(result)
if response_handler
else self._default_response_handler(result, user_id, request_type, endpoint)
)
async def _handle_stream_output(self, response: ClientResponse) -> Dict[str, Any]:
flag_delta_content_finished = False
accumulated_content = ""
usage = None # 初始化usage变量避免未定义错误
reasoning_content = ""
content = ""
async for line_bytes in response.content:
try:
line = line_bytes.decode("utf-8").strip()
if not line:
continue
if line.startswith("data:"):
data_str = line[5:].strip()
if data_str == "[DONE]":
break
try:
chunk = json.loads(data_str)
if flag_delta_content_finished:
chunk_usage = chunk.get("usage", None)
if chunk_usage:
usage = chunk_usage # 获取token用量
else:
delta = chunk["choices"][0]["delta"]
delta_content = delta.get("content")
if delta_content is None:
delta_content = ""
accumulated_content += delta_content
# 检测流式输出文本是否结束
finish_reason = chunk["choices"][0].get("finish_reason")
if delta.get("reasoning_content", None):
reasoning_content += delta["reasoning_content"]
if finish_reason == "stop":
chunk_usage = chunk.get("usage", None)
if chunk_usage:
usage = chunk_usage
break
# 部分平台在文本输出结束前不会返回token用量此时需要再获取一次chunk
flag_delta_content_finished = True
except Exception as e:
logger.exception(f"模型 {self.model_name} 解析流式输出错误: {str(e)}")
except Exception as e:
if isinstance(e, GeneratorExit):
log_content = f"模型 {self.model_name} 流式输出被中断,正在清理资源..."
else:
log_content = f"模型 {self.model_name} 处理流式输出时发生错误: {str(e)}"
logger.warning(log_content)
# 确保资源被正确清理
try:
await response.release()
except Exception as cleanup_error:
logger.error(f"清理资源时发生错误: {cleanup_error}")
# 返回已经累积的内容
content = accumulated_content
if not content:
content = accumulated_content
think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
if think_match:
reasoning_content = think_match.group(1).strip()
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
result = {
"choices": [
{
"message": {
"content": content,
"reasoning_content": reasoning_content,
# 流式输出可能没有工具调用此处不需要添加tool_calls字段
}
}
],
"usage": usage,
}
return result
async def _handle_error_response(
self, response: ClientResponse, retry_count: int, policy: Dict[str, Any]
) -> Union[Dict[str, any]]:
if response.status in policy["retry_codes"]:
wait_time = policy["base_wait"] * (2**retry_count)
logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试")
if response.status == 413:
logger.warning("请求体过大,尝试压缩...")
raise PayLoadTooLargeError("请求体过大")
elif response.status in [500, 503]:
logger.error(
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
)
raise RuntimeError("服务器负载过高模型恢复失败QAQ")
else:
logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
raise RuntimeError("请求限制(429)")
elif response.status in policy["abort_codes"]:
if response.status != 403:
raise RequestAbortException("请求出现错误,中断处理", response)
else:
raise PermissionDeniedException("模型禁止访问")
async def _handle_exception(
self, exception, retry_count: int, request_content: Dict[str, Any]
) -> Union[Tuple[Dict[str, Any], int], Tuple[None, int]]:
policy = request_content["policy"]
payload = request_content["payload"]
keep_request = False
if retry_count < policy["max_retries"] - 1:
wait_time = policy["base_wait"] * (2**retry_count)
keep_request = True
if isinstance(exception, RequestAbortException):
response = exception.response
logger.error(
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
)
# 尝试获取并记录服务器返回的详细错误信息
try:
error_json = await response.json()
if error_json and isinstance(error_json, list) and len(error_json) > 0:
# 处理多个错误的情况
for error_item in error_json:
if "error" in error_item and isinstance(error_item["error"], dict):
error_obj: dict = error_item["error"]
error_code = error_obj.get("code")
error_message = error_obj.get("message")
error_status = error_obj.get("status")
logger.error(
f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}"
)
elif isinstance(error_json, dict) and "error" in error_json:
# 处理单个错误对象的情况
error_obj = error_json.get("error", {})
error_code = error_obj.get("code")
error_message = error_obj.get("message")
error_status = error_obj.get("status")
logger.error(f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}")
else:
# 记录原始错误响应内容
logger.error(f"服务器错误响应: {error_json}")
except Exception as e:
logger.warning(f"无法解析服务器错误响应: {str(e)}")
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
elif isinstance(exception, PermissionDeniedException):
# 只针对硅基流动的V3和R1进行降级处理
if self.model_name.startswith("Pro/deepseek-ai") and self.base_url == "https://api.siliconflow.cn/v1/":
old_model_name = self.model_name
self.model_name = self.model_name[4:] # 移除"Pro/"前缀
logger.warning(f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}")
# 对全局配置进行更新
if global_config.llm_normal.get("name") == old_model_name:
global_config.llm_normal["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
if global_config.llm_reasoning.get("name") == old_model_name:
global_config.llm_reasoning["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
if payload and "model" in payload:
payload["model"] = self.model_name
await asyncio.sleep(wait_time)
return payload, -1
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(403)}")
elif isinstance(exception, PayLoadTooLargeError):
if keep_request:
image_base64 = request_content["image_base64"]
compressed_image_base64 = compress_base64_image_by_scale(image_base64)
new_payload = await self._build_payload(
request_content["prompt"], compressed_image_base64, request_content["image_format"]
)
return new_payload, 0
else:
return None, 0
elif isinstance(exception, aiohttp.ClientError) or isinstance(exception, asyncio.TimeoutError):
if keep_request:
logger.error(f"模型 {self.model_name} 网络错误,等待{wait_time}秒后重试... 错误: {str(exception)}")
await asyncio.sleep(wait_time)
return None, 0
else:
logger.critical(f"模型 {self.model_name} 网络错误达到最大重试次数: {str(exception)}")
raise RuntimeError(f"网络请求失败: {str(exception)}")
elif isinstance(exception, aiohttp.ClientResponseError):
# 处理aiohttp抛出的除了policy中的status的响应错误
if keep_request:
logger.error(
f"模型 {self.model_name} HTTP响应错误等待{wait_time}秒后重试... 状态码: {exception.status}, 错误: {exception.message}"
)
try:
error_text = await exception.response.text()
error_json = json.loads(error_text)
if isinstance(error_json, list) and len(error_json) > 0:
# 处理多个错误的情况
for error_item in error_json:
if "error" in error_item and isinstance(error_item["error"], dict):
error_obj = error_item["error"]
logger.error(
f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
f"状态={error_obj.get('status')}, "
f"消息={error_obj.get('message')}"
)
elif isinstance(error_json, dict) and "error" in error_json:
error_obj = error_json.get("error", {})
logger.error(
f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
f"状态={error_obj.get('status')}, "
f"消息={error_obj.get('message')}"
)
else:
logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}")
except (json.JSONDecodeError, TypeError) as json_err:
logger.warning(
f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}"
)
except Exception as parse_err:
logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}")
await asyncio.sleep(wait_time)
return None, 0
else:
logger.critical(
f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {exception.status}, 错误: {exception.message}"
)
# 安全地检查和记录请求详情
handled_payload = await self._safely_record(request_content, payload)
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload}")
raise RuntimeError(
f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
)
else:
if keep_request:
logger.error(f"模型 {self.model_name} 请求失败,等待{wait_time}秒后重试... 错误: {str(exception)}")
await asyncio.sleep(wait_time)
return None, 0
else:
logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}")
# 安全地检查和记录请求详情
handled_payload = await self._safely_record(request_content, payload)
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload}")
raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}")
async def _safely_record(self, request_content: Dict[str, Any], payload: Dict[str, Any]):
image_base64: str = request_content.get("image_base64")
image_format: str = request_content.get("image_format")
if (
image_base64
and payload
and isinstance(payload, dict)
and "messages" in payload
and len(payload["messages"]) > 0
):
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
content = payload["messages"][0]["content"]
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
payload["messages"][0]["content"][1]["image_url"]["url"] = (
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}"
)
# if isinstance(content, str) and len(content) > 100:
# payload["messages"][0]["content"] = content[:100]
return payload
async def _transform_parameters(self, params: dict) -> dict: async def _transform_parameters(self, params: dict) -> dict:
""" """

File diff suppressed because it is too large Load Diff