优化代码格式和异常处理

- 修复异常处理链,使用from语法保留原始异常
- 格式化代码以符合项目规范
- 优化导入模块的顺序

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
春河晴
2025-03-19 20:27:34 +09:00
parent a829dfdb77
commit fdc098d0db
52 changed files with 3156 additions and 2778 deletions

View File

@@ -26,11 +26,11 @@ class LLM_request:
"o1-mini",
"o1-preview",
"o1-2024-12-17",
"o1-preview-2024-09-12",
"o1-preview-2024-09-12",
"o3-mini-2025-01-31",
"o1-mini-2024-09-12",
]
def __init__(self, model, **kwargs):
# 将大写的配置键转换为小写并从config中获取实际值
try:
@@ -52,9 +52,6 @@ class LLM_request:
# 从 kwargs 中提取 request_type如果没有提供则默认为 "default"
self.request_type = kwargs.pop("request_type", "default")
@staticmethod
def _init_database():
"""初始化数据库集合"""
@@ -180,7 +177,7 @@ class LLM_request:
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
# 判断是否为流式
stream_mode = self.params.get("stream", False)
logger_msg = "进入流式输出模式," if stream_mode else ""
# logger_msg = "进入流式输出模式," if stream_mode else ""
# logger.debug(f"{logger_msg}发送请求到URL: {api_url}")
# logger.info(f"使用模型: {self.model_name}")
@@ -229,7 +226,8 @@ class LLM_request:
error_message = error_obj.get("message")
error_status = error_obj.get("status")
logger.error(
f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}"
f"服务器错误详情: 代码={error_code}, 状态={error_status}, "
f"消息={error_message}"
)
elif isinstance(error_json, dict) and "error" in error_json:
# 处理单个错误对象的情况
@@ -282,7 +280,7 @@ class LLM_request:
flag_delta_content_finished = False
accumulated_content = ""
usage = None # 初始化usage变量避免未定义错误
async for line_bytes in response.content:
line = line_bytes.decode("utf-8").strip()
if not line:
@@ -294,7 +292,7 @@ class LLM_request:
try:
chunk = json.loads(data_str)
if flag_delta_content_finished:
chunk_usage = chunk.get("usage",None)
chunk_usage = chunk.get("usage", None)
if chunk_usage:
usage = chunk_usage # 获取token用量
else:
@@ -306,7 +304,7 @@ class LLM_request:
# 检测流式输出文本是否结束
finish_reason = chunk["choices"][0].get("finish_reason")
if finish_reason == "stop":
chunk_usage = chunk.get("usage",None)
chunk_usage = chunk.get("usage", None)
if chunk_usage:
usage = chunk_usage
break
@@ -355,12 +353,16 @@ class LLM_request:
if "error" in error_item and isinstance(error_item["error"], dict):
error_obj = error_item["error"]
logger.error(
f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}"
f"服务器错误详情: 代码={error_obj.get('code')}, "
f"状态={error_obj.get('status')}, "
f"消息={error_obj.get('message')}"
)
elif isinstance(error_json, dict) and "error" in error_json:
error_obj = error_json.get("error", {})
logger.error(
f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}"
f"服务器错误详情: 代码={error_obj.get('code')}, "
f"状态={error_obj.get('status')}, "
f"消息={error_obj.get('message')}"
)
else:
logger.error(f"服务器错误响应: {error_json}")
@@ -373,15 +375,22 @@ class LLM_request:
else:
logger.critical(f"HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}")
# 安全地检查和记录请求详情
if image_base64 and payload and isinstance(payload, dict) and "messages" in payload and len(payload["messages"]) > 0:
if (
image_base64
and payload
and isinstance(payload, dict)
and "messages" in payload
and len(payload["messages"]) > 0
):
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
content = payload["messages"][0]["content"]
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
payload["messages"][0]["content"][1]["image_url"]["url"] = (
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64[:10]}...{image_base64[-10:]}"
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}"
)
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
raise RuntimeError(f"API请求失败: 状态码 {e.status}, {e.message}")
raise RuntimeError(f"API请求失败: 状态码 {e.status}, {e.message}") from e
except Exception as e:
if retry < policy["max_retries"] - 1:
wait_time = policy["base_wait"] * (2**retry)
@@ -390,15 +399,22 @@ class LLM_request:
else:
logger.critical(f"请求失败: {str(e)}")
# 安全地检查和记录请求详情
if image_base64 and payload and isinstance(payload, dict) and "messages" in payload and len(payload["messages"]) > 0:
if (
image_base64
and payload
and isinstance(payload, dict)
and "messages" in payload
and len(payload["messages"]) > 0
):
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
content = payload["messages"][0]["content"]
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
payload["messages"][0]["content"][1]["image_url"]["url"] = (
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64[:10]}...{image_base64[-10:]}"
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}"
)
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
raise RuntimeError(f"API请求失败: {str(e)}")
raise RuntimeError(f"API请求失败: {str(e)}") from e
logger.error("达到最大重试次数,请求仍然失败")
raise RuntimeError("达到最大重试次数API请求仍然失败")
@@ -411,7 +427,7 @@ class LLM_request:
"""
# 复制一份参数,避免直接修改原始数据
new_params = dict(params)
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
# 删除 'temperature' 参数(如果存在)
new_params.pop("temperature", None)
@@ -479,7 +495,7 @@ class LLM_request:
completion_tokens=completion_tokens,
total_tokens=total_tokens,
user_id=user_id,
request_type = request_type if request_type is not None else self.request_type,
request_type=request_type if request_type is not None else self.request_type,
endpoint=endpoint,
)
@@ -546,13 +562,14 @@ class LLM_request:
list: embedding向量如果失败则返回None
"""
if(len(text) < 1):
if len(text) < 1:
logger.debug("该消息没有长度不再发送获取embedding向量的请求")
return None
def embedding_handler(result):
"""处理响应"""
if "data" in result and len(result["data"]) > 0:
# 提取 token 使用信息
# 提取 token 使用信息
usage = result.get("usage", {})
if usage:
prompt_tokens = usage.get("prompt_tokens", 0)
@@ -565,7 +582,7 @@ class LLM_request:
total_tokens=total_tokens,
user_id="system", # 可以根据需要修改 user_id
request_type="embedding", # 请求类型为 embedding
endpoint="/embeddings" # API 端点
endpoint="/embeddings", # API 端点
)
return result["data"][0].get("embedding", None)
return result["data"][0].get("embedding", None)