fix: 修复utils_model.py潜在问题

- 将重复的模型列表提取为类变量
- 修复流式处理中变量未初始化的错误
- 改进错误响应处理的安全性
- 修复类型标注
- 优化重复的条件判断
- 将_init_database改为静态方法

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
春河晴
2025-03-14 14:28:56 +09:00
parent be7997e1b7
commit e17f3276a4

View File

@@ -18,6 +18,17 @@ config = driver.config
class LLM_request: class LLM_request:
# 定义需要转换的模型列表,作为类变量避免重复
MODELS_NEEDING_TRANSFORMATION = [
"o3-mini",
"o1-mini",
"o1-preview",
"o1-2024-12-17",
"o1-preview-2024-09-12",
"o3-mini-2025-01-31",
"o1-mini-2024-09-12",
]
def __init__(self, model, **kwargs): def __init__(self, model, **kwargs):
# 将大写的配置键转换为小写并从config中获取实际值 # 将大写的配置键转换为小写并从config中获取实际值
try: try:
@@ -36,7 +47,8 @@ class LLM_request:
# 获取数据库实例 # 获取数据库实例
self._init_database() self._init_database()
def _init_database(self): @staticmethod
def _init_database():
"""初始化数据库集合""" """初始化数据库集合"""
try: try:
# 创建llm_usage集合的索引 # 创建llm_usage集合的索引
@@ -44,8 +56,8 @@ class LLM_request:
db.llm_usage.create_index([("model_name", 1)]) db.llm_usage.create_index([("model_name", 1)])
db.llm_usage.create_index([("user_id", 1)]) db.llm_usage.create_index([("user_id", 1)])
db.llm_usage.create_index([("request_type", 1)]) db.llm_usage.create_index([("request_type", 1)])
except Exception: except Exception as e:
logger.error("创建数据库索引失败") logger.error(f"创建数据库索引失败: {str(e)}")
def _record_usage( def _record_usage(
self, self,
@@ -85,8 +97,8 @@ class LLM_request:
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, " f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
f"总计: {total_tokens}" f"总计: {total_tokens}"
) )
except Exception: except Exception as e:
logger.error("记录token使用情况失败") logger.error(f"记录token使用情况失败: {str(e)}")
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float: def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
"""计算API调用成本 """计算API调用成本
@@ -152,10 +164,8 @@ class LLM_request:
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}" api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
# 判断是否为流式 # 判断是否为流式
stream_mode = self.params.get("stream", False) stream_mode = self.params.get("stream", False)
if self.params.get("stream", False) is True: logger_msg = "进入流式输出模式," if stream_mode else ""
logger.debug(f"进入流式输出模式,发送请求到URL: {api_url}") logger.debug(f"{logger_msg}发送请求到URL: {api_url}")
else:
logger.debug(f"发送请求到URL: {api_url}")
logger.info(f"使用模型: {self.model_name}") logger.info(f"使用模型: {self.model_name}")
# 构建请求体 # 构建请求体
@@ -255,6 +265,8 @@ class LLM_request:
if stream_mode: if stream_mode:
flag_delta_content_finished = False flag_delta_content_finished = False
accumulated_content = "" accumulated_content = ""
usage = None # 初始化usage变量避免未定义错误
async for line_bytes in response.content: async for line_bytes in response.content:
line = line_bytes.decode("utf-8").strip() line = line_bytes.decode("utf-8").strip()
if not line: if not line:
@@ -266,7 +278,9 @@ class LLM_request:
try: try:
chunk = json.loads(data_str) chunk = json.loads(data_str)
if flag_delta_content_finished: if flag_delta_content_finished:
usage = chunk.get("usage", None) # 获取tokn用量 chunk_usage = chunk.get("usage",None)
if chunk_usage:
usage = chunk_usage # 获取token用量
else: else:
delta = chunk["choices"][0]["delta"] delta = chunk["choices"][0]["delta"]
delta_content = delta.get("content") delta_content = delta.get("content")
@@ -276,14 +290,15 @@ class LLM_request:
# 检测流式输出文本是否结束 # 检测流式输出文本是否结束
finish_reason = chunk["choices"][0].get("finish_reason") finish_reason = chunk["choices"][0].get("finish_reason")
if finish_reason == "stop": if finish_reason == "stop":
usage = chunk.get("usage", None) chunk_usage = chunk.get("usage",None)
if usage: if chunk_usage:
usage = chunk_usage
break break
# 部分平台在文本输出结束前不会返回token用量此时需要再获取一次chunk # 部分平台在文本输出结束前不会返回token用量此时需要再获取一次chunk
flag_delta_content_finished = True flag_delta_content_finished = True
except Exception: except Exception as e:
logger.exception("解析流式输出错误") logger.exception(f"解析流式输出错误: {str(e)}")
content = accumulated_content content = accumulated_content
reasoning_content = "" reasoning_content = ""
think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL) think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
@@ -315,33 +330,40 @@ class LLM_request:
wait_time = policy["base_wait"] * (2**retry) wait_time = policy["base_wait"] * (2**retry)
logger.error(f"HTTP响应错误等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}") logger.error(f"HTTP响应错误等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}")
try: try:
if hasattr(e, "history") and e.history and hasattr(e.history[0], "text"): if hasattr(e, "response") and e.response and hasattr(e.response, "text"):
error_text = await e.history[0].text() error_text = await e.response.text()
error_json = json.loads(error_text) try:
if isinstance(error_json, list) and len(error_json) > 0: error_json = json.loads(error_text)
for error_item in error_json: if isinstance(error_json, list) and len(error_json) > 0:
if "error" in error_item and isinstance(error_item["error"], dict): for error_item in error_json:
error_obj = error_item["error"] if "error" in error_item and isinstance(error_item["error"], dict):
logger.error( error_obj = error_item["error"]
f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}" logger.error(
) f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}"
elif isinstance(error_json, dict) and "error" in error_json: )
error_obj = error_json.get("error", {}) elif isinstance(error_json, dict) and "error" in error_json:
logger.error( error_obj = error_json.get("error", {})
f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}" logger.error(
) f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}"
else: )
logger.error(f"服务器错误响应: {error_json}") else:
except Exception as parse_err: logger.error(f"服务器错误响应: {error_json}")
except (json.JSONDecodeError, TypeError) as json_err:
logger.warning(f"响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}")
except (AttributeError, TypeError, ValueError) as parse_err:
logger.warning(f"无法解析响应错误内容: {str(parse_err)}") logger.warning(f"无法解析响应错误内容: {str(parse_err)}")
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
else: else:
logger.critical(f"HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}") logger.critical(f"HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}")
if image_base64: # 安全地检查和记录请求详情
payload["messages"][0]["content"][1]["image_url"]["url"] = ( if image_base64 and payload and isinstance(payload, dict) and "messages" in payload and len(payload["messages"]) > 0:
f"data:image/{image_format.lower()};base64,{image_base64[:10]}...{image_base64[-10:]}" if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
) content = payload["messages"][0]["content"]
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
payload["messages"][0]["content"][1]["image_url"]["url"] = (
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64[:10]}...{image_base64[-10:]}"
)
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}") logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
raise RuntimeError(f"API请求失败: 状态码 {e.status}, {e.message}") raise RuntimeError(f"API请求失败: 状态码 {e.status}, {e.message}")
except Exception as e: except Exception as e:
@@ -351,10 +373,14 @@ class LLM_request:
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
else: else:
logger.critical(f"请求失败: {str(e)}") logger.critical(f"请求失败: {str(e)}")
if image_base64: # 安全地检查和记录请求详情
payload["messages"][0]["content"][1]["image_url"]["url"] = ( if image_base64 and payload and isinstance(payload, dict) and "messages" in payload and len(payload["messages"]) > 0:
f"data:image/{image_format.lower()};base64,{image_base64[:10]}...{image_base64[-10:]}" if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
) content = payload["messages"][0]["content"]
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
payload["messages"][0]["content"][1]["image_url"]["url"] = (
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64[:10]}...{image_base64[-10:]}"
)
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}") logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
raise RuntimeError(f"API请求失败: {str(e)}") raise RuntimeError(f"API请求失败: {str(e)}")
@@ -364,23 +390,14 @@ class LLM_request:
async def _transform_parameters(self, params: dict) -> dict: async def _transform_parameters(self, params: dict) -> dict:
""" """
根据模型名称转换参数: 根据模型名称转换参数:
- 对于需要转换的OpenAI CoT系列模型例如 "o3-mini"),删除 'temprature' 参数, - 对于需要转换的OpenAI CoT系列模型例如 "o3-mini"),删除 'temperature' 参数,
并将 'max_tokens' 重命名为 'max_completion_tokens' 并将 'max_tokens' 重命名为 'max_completion_tokens'
""" """
# 复制一份参数,避免直接修改原始数据 # 复制一份参数,避免直接修改原始数据
new_params = dict(params) new_params = dict(params)
# 定义需要转换的模型列表
models_needing_transformation = [ if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
"o3-mini", # 删除 'temperature' 参数(如果存在)
"o1-mini",
"o1-preview",
"o1-2024-12-17",
"o1-preview-2024-09-12",
"o3-mini-2025-01-31",
"o1-mini-2024-09-12",
]
if self.model_name.lower() in models_needing_transformation:
# 删除 'temprature' 参数(如果存在)
new_params.pop("temperature", None) new_params.pop("temperature", None)
# 如果存在 'max_tokens',则重命名为 'max_completion_tokens' # 如果存在 'max_tokens',则重命名为 'max_completion_tokens'
if "max_tokens" in new_params: if "max_tokens" in new_params:
@@ -417,19 +434,7 @@ class LLM_request:
**params_copy, **params_copy,
} }
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
if ( if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
self.model_name.lower()
in [
"o3-mini",
"o1-mini",
"o1-preview",
"o1-2024-12-17",
"o1-preview-2024-09-12",
"o3-mini-2025-01-31",
"o1-mini-2024-09-12",
]
and "max_tokens" in payload
):
payload["max_completion_tokens"] = payload.pop("max_tokens") payload["max_completion_tokens"] = payload.pop("max_tokens")
return payload return payload
@@ -466,7 +471,8 @@ class LLM_request:
return "没有返回结果", "" return "没有返回结果", ""
def _extract_reasoning(self, content: str) -> tuple[str, str]: @staticmethod
def _extract_reasoning(content: str) -> Tuple[str, str]:
"""CoT思维链提取""" """CoT思维链提取"""
match = re.search(r"(?:<think>)?(.*?)</think>", content, re.DOTALL) match = re.search(r"(?:<think>)?(.*?)</think>", content, re.DOTALL)
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip() content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
@@ -506,6 +512,7 @@ class LLM_request:
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
"max_tokens": global_config.max_response_length, "max_tokens": global_config.max_response_length,
**self.params, **self.params,
**kwargs,
} }
content, reasoning_content = await self._execute_request( content, reasoning_content = await self._execute_request(