优化代码格式和异常处理
- 修复异常处理链,使用from语法保留原始异常 - 格式化代码以符合项目规范 - 优化导入模块的顺序 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -26,11 +26,11 @@ class LLM_request:
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
"o1-2024-12-17",
|
||||
"o1-preview-2024-09-12",
|
||||
"o1-preview-2024-09-12",
|
||||
"o3-mini-2025-01-31",
|
||||
"o1-mini-2024-09-12",
|
||||
]
|
||||
|
||||
|
||||
def __init__(self, model, **kwargs):
|
||||
# 将大写的配置键转换为小写并从config中获取实际值
|
||||
try:
|
||||
@@ -52,9 +52,6 @@ class LLM_request:
|
||||
# 从 kwargs 中提取 request_type,如果没有提供则默认为 "default"
|
||||
self.request_type = kwargs.pop("request_type", "default")
|
||||
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
"""初始化数据库集合"""
|
||||
@@ -180,7 +177,7 @@ class LLM_request:
|
||||
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
||||
# 判断是否为流式
|
||||
stream_mode = self.params.get("stream", False)
|
||||
logger_msg = "进入流式输出模式," if stream_mode else ""
|
||||
# logger_msg = "进入流式输出模式," if stream_mode else ""
|
||||
# logger.debug(f"{logger_msg}发送请求到URL: {api_url}")
|
||||
# logger.info(f"使用模型: {self.model_name}")
|
||||
|
||||
@@ -229,7 +226,8 @@ class LLM_request:
|
||||
error_message = error_obj.get("message")
|
||||
error_status = error_obj.get("status")
|
||||
logger.error(
|
||||
f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}"
|
||||
f"服务器错误详情: 代码={error_code}, 状态={error_status}, "
|
||||
f"消息={error_message}"
|
||||
)
|
||||
elif isinstance(error_json, dict) and "error" in error_json:
|
||||
# 处理单个错误对象的情况
|
||||
@@ -282,7 +280,7 @@ class LLM_request:
|
||||
flag_delta_content_finished = False
|
||||
accumulated_content = ""
|
||||
usage = None # 初始化usage变量,避免未定义错误
|
||||
|
||||
|
||||
async for line_bytes in response.content:
|
||||
line = line_bytes.decode("utf-8").strip()
|
||||
if not line:
|
||||
@@ -294,7 +292,7 @@ class LLM_request:
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
if flag_delta_content_finished:
|
||||
chunk_usage = chunk.get("usage",None)
|
||||
chunk_usage = chunk.get("usage", None)
|
||||
if chunk_usage:
|
||||
usage = chunk_usage # 获取token用量
|
||||
else:
|
||||
@@ -306,7 +304,7 @@ class LLM_request:
|
||||
# 检测流式输出文本是否结束
|
||||
finish_reason = chunk["choices"][0].get("finish_reason")
|
||||
if finish_reason == "stop":
|
||||
chunk_usage = chunk.get("usage",None)
|
||||
chunk_usage = chunk.get("usage", None)
|
||||
if chunk_usage:
|
||||
usage = chunk_usage
|
||||
break
|
||||
@@ -355,12 +353,16 @@ class LLM_request:
|
||||
if "error" in error_item and isinstance(error_item["error"], dict):
|
||||
error_obj = error_item["error"]
|
||||
logger.error(
|
||||
f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}"
|
||||
f"服务器错误详情: 代码={error_obj.get('code')}, "
|
||||
f"状态={error_obj.get('status')}, "
|
||||
f"消息={error_obj.get('message')}"
|
||||
)
|
||||
elif isinstance(error_json, dict) and "error" in error_json:
|
||||
error_obj = error_json.get("error", {})
|
||||
logger.error(
|
||||
f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}"
|
||||
f"服务器错误详情: 代码={error_obj.get('code')}, "
|
||||
f"状态={error_obj.get('status')}, "
|
||||
f"消息={error_obj.get('message')}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"服务器错误响应: {error_json}")
|
||||
@@ -373,15 +375,22 @@ class LLM_request:
|
||||
else:
|
||||
logger.critical(f"HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}")
|
||||
# 安全地检查和记录请求详情
|
||||
if image_base64 and payload and isinstance(payload, dict) and "messages" in payload and len(payload["messages"]) > 0:
|
||||
if (
|
||||
image_base64
|
||||
and payload
|
||||
and isinstance(payload, dict)
|
||||
and "messages" in payload
|
||||
and len(payload["messages"]) > 0
|
||||
):
|
||||
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
|
||||
content = payload["messages"][0]["content"]
|
||||
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
|
||||
payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
||||
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64[:10]}...{image_base64[-10:]}"
|
||||
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
|
||||
f"{image_base64[:10]}...{image_base64[-10:]}"
|
||||
)
|
||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
|
||||
raise RuntimeError(f"API请求失败: 状态码 {e.status}, {e.message}")
|
||||
raise RuntimeError(f"API请求失败: 状态码 {e.status}, {e.message}") from e
|
||||
except Exception as e:
|
||||
if retry < policy["max_retries"] - 1:
|
||||
wait_time = policy["base_wait"] * (2**retry)
|
||||
@@ -390,15 +399,22 @@ class LLM_request:
|
||||
else:
|
||||
logger.critical(f"请求失败: {str(e)}")
|
||||
# 安全地检查和记录请求详情
|
||||
if image_base64 and payload and isinstance(payload, dict) and "messages" in payload and len(payload["messages"]) > 0:
|
||||
if (
|
||||
image_base64
|
||||
and payload
|
||||
and isinstance(payload, dict)
|
||||
and "messages" in payload
|
||||
and len(payload["messages"]) > 0
|
||||
):
|
||||
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
|
||||
content = payload["messages"][0]["content"]
|
||||
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
|
||||
payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
||||
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64[:10]}...{image_base64[-10:]}"
|
||||
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
|
||||
f"{image_base64[:10]}...{image_base64[-10:]}"
|
||||
)
|
||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
|
||||
raise RuntimeError(f"API请求失败: {str(e)}")
|
||||
raise RuntimeError(f"API请求失败: {str(e)}") from e
|
||||
|
||||
logger.error("达到最大重试次数,请求仍然失败")
|
||||
raise RuntimeError("达到最大重试次数,API请求仍然失败")
|
||||
@@ -411,7 +427,7 @@ class LLM_request:
|
||||
"""
|
||||
# 复制一份参数,避免直接修改原始数据
|
||||
new_params = dict(params)
|
||||
|
||||
|
||||
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
|
||||
# 删除 'temperature' 参数(如果存在)
|
||||
new_params.pop("temperature", None)
|
||||
@@ -479,7 +495,7 @@ class LLM_request:
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
user_id=user_id,
|
||||
request_type = request_type if request_type is not None else self.request_type,
|
||||
request_type=request_type if request_type is not None else self.request_type,
|
||||
endpoint=endpoint,
|
||||
)
|
||||
|
||||
@@ -546,13 +562,14 @@ class LLM_request:
|
||||
list: embedding向量,如果失败则返回None
|
||||
"""
|
||||
|
||||
if(len(text) < 1):
|
||||
if len(text) < 1:
|
||||
logger.debug("该消息没有长度,不再发送获取embedding向量的请求")
|
||||
return None
|
||||
|
||||
def embedding_handler(result):
|
||||
"""处理响应"""
|
||||
if "data" in result and len(result["data"]) > 0:
|
||||
# 提取 token 使用信息
|
||||
# 提取 token 使用信息
|
||||
usage = result.get("usage", {})
|
||||
if usage:
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
@@ -565,7 +582,7 @@ class LLM_request:
|
||||
total_tokens=total_tokens,
|
||||
user_id="system", # 可以根据需要修改 user_id
|
||||
request_type="embedding", # 请求类型为 embedding
|
||||
endpoint="/embeddings" # API 端点
|
||||
endpoint="/embeddings", # API 端点
|
||||
)
|
||||
return result["data"][0].get("embedding", None)
|
||||
return result["data"][0].get("embedding", None)
|
||||
|
||||
Reference in New Issue
Block a user