fix: 修复utils_model.py潜在问题
- 将重复的模型列表提取为类变量 - 修复流式处理中变量未初始化的错误 - 改进错误响应处理的安全性 - 修复类型标注 - 优化重复的条件判断 - 将_init_database改为静态方法 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -18,6 +18,17 @@ config = driver.config
|
||||
|
||||
|
||||
class LLM_request:
|
||||
# 定义需要转换的模型列表,作为类变量避免重复
|
||||
MODELS_NEEDING_TRANSFORMATION = [
|
||||
"o3-mini",
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
"o1-2024-12-17",
|
||||
"o1-preview-2024-09-12",
|
||||
"o3-mini-2025-01-31",
|
||||
"o1-mini-2024-09-12",
|
||||
]
|
||||
|
||||
def __init__(self, model, **kwargs):
|
||||
# 将大写的配置键转换为小写并从config中获取实际值
|
||||
try:
|
||||
@@ -36,7 +47,8 @@ class LLM_request:
|
||||
# 获取数据库实例
|
||||
self._init_database()
|
||||
|
||||
def _init_database(self):
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
"""初始化数据库集合"""
|
||||
try:
|
||||
# 创建llm_usage集合的索引
|
||||
@@ -44,8 +56,8 @@ class LLM_request:
|
||||
db.llm_usage.create_index([("model_name", 1)])
|
||||
db.llm_usage.create_index([("user_id", 1)])
|
||||
db.llm_usage.create_index([("request_type", 1)])
|
||||
except Exception:
|
||||
logger.error("创建数据库索引失败")
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库索引失败: {str(e)}")
|
||||
|
||||
def _record_usage(
|
||||
self,
|
||||
@@ -85,8 +97,8 @@ class LLM_request:
|
||||
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
|
||||
f"总计: {total_tokens}"
|
||||
)
|
||||
except Exception:
|
||||
logger.error("记录token使用情况失败")
|
||||
except Exception as e:
|
||||
logger.error(f"记录token使用情况失败: {str(e)}")
|
||||
|
||||
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
|
||||
"""计算API调用成本
|
||||
@@ -152,10 +164,8 @@ class LLM_request:
|
||||
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
||||
# 判断是否为流式
|
||||
stream_mode = self.params.get("stream", False)
|
||||
if self.params.get("stream", False) is True:
|
||||
logger.debug(f"进入流式输出模式,发送请求到URL: {api_url}")
|
||||
else:
|
||||
logger.debug(f"发送请求到URL: {api_url}")
|
||||
logger_msg = "进入流式输出模式," if stream_mode else ""
|
||||
logger.debug(f"{logger_msg}发送请求到URL: {api_url}")
|
||||
logger.info(f"使用模型: {self.model_name}")
|
||||
|
||||
# 构建请求体
|
||||
@@ -255,6 +265,8 @@ class LLM_request:
|
||||
if stream_mode:
|
||||
flag_delta_content_finished = False
|
||||
accumulated_content = ""
|
||||
usage = None # 初始化usage变量,避免未定义错误
|
||||
|
||||
async for line_bytes in response.content:
|
||||
line = line_bytes.decode("utf-8").strip()
|
||||
if not line:
|
||||
@@ -266,7 +278,9 @@ class LLM_request:
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
if flag_delta_content_finished:
|
||||
usage = chunk.get("usage", None) # 获取tokn用量
|
||||
chunk_usage = chunk.get("usage",None)
|
||||
if chunk_usage:
|
||||
usage = chunk_usage # 获取token用量
|
||||
else:
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
delta_content = delta.get("content")
|
||||
@@ -276,14 +290,15 @@ class LLM_request:
|
||||
# 检测流式输出文本是否结束
|
||||
finish_reason = chunk["choices"][0].get("finish_reason")
|
||||
if finish_reason == "stop":
|
||||
usage = chunk.get("usage", None)
|
||||
if usage:
|
||||
chunk_usage = chunk.get("usage",None)
|
||||
if chunk_usage:
|
||||
usage = chunk_usage
|
||||
break
|
||||
# 部分平台在文本输出结束前不会返回token用量,此时需要再获取一次chunk
|
||||
flag_delta_content_finished = True
|
||||
|
||||
except Exception:
|
||||
logger.exception("解析流式输出错误")
|
||||
except Exception as e:
|
||||
logger.exception(f"解析流式输出错误: {str(e)}")
|
||||
content = accumulated_content
|
||||
reasoning_content = ""
|
||||
think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
|
||||
@@ -315,8 +330,9 @@ class LLM_request:
|
||||
wait_time = policy["base_wait"] * (2**retry)
|
||||
logger.error(f"HTTP响应错误,等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}")
|
||||
try:
|
||||
if hasattr(e, "history") and e.history and hasattr(e.history[0], "text"):
|
||||
error_text = await e.history[0].text()
|
||||
if hasattr(e, "response") and e.response and hasattr(e.response, "text"):
|
||||
error_text = await e.response.text()
|
||||
try:
|
||||
error_json = json.loads(error_text)
|
||||
if isinstance(error_json, list) and len(error_json) > 0:
|
||||
for error_item in error_json:
|
||||
@@ -332,15 +348,21 @@ class LLM_request:
|
||||
)
|
||||
else:
|
||||
logger.error(f"服务器错误响应: {error_json}")
|
||||
except Exception as parse_err:
|
||||
except (json.JSONDecodeError, TypeError) as json_err:
|
||||
logger.warning(f"响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}")
|
||||
except (AttributeError, TypeError, ValueError) as parse_err:
|
||||
logger.warning(f"无法解析响应错误内容: {str(parse_err)}")
|
||||
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.critical(f"HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}")
|
||||
if image_base64:
|
||||
# 安全地检查和记录请求详情
|
||||
if image_base64 and payload and isinstance(payload, dict) and "messages" in payload and len(payload["messages"]) > 0:
|
||||
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
|
||||
content = payload["messages"][0]["content"]
|
||||
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
|
||||
payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
||||
f"data:image/{image_format.lower()};base64,{image_base64[:10]}...{image_base64[-10:]}"
|
||||
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64[:10]}...{image_base64[-10:]}"
|
||||
)
|
||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
|
||||
raise RuntimeError(f"API请求失败: 状态码 {e.status}, {e.message}")
|
||||
@@ -351,9 +373,13 @@ class LLM_request:
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.critical(f"请求失败: {str(e)}")
|
||||
if image_base64:
|
||||
# 安全地检查和记录请求详情
|
||||
if image_base64 and payload and isinstance(payload, dict) and "messages" in payload and len(payload["messages"]) > 0:
|
||||
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
|
||||
content = payload["messages"][0]["content"]
|
||||
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
|
||||
payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
||||
f"data:image/{image_format.lower()};base64,{image_base64[:10]}...{image_base64[-10:]}"
|
||||
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64[:10]}...{image_base64[-10:]}"
|
||||
)
|
||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
|
||||
raise RuntimeError(f"API请求失败: {str(e)}")
|
||||
@@ -364,23 +390,14 @@ class LLM_request:
|
||||
async def _transform_parameters(self, params: dict) -> dict:
|
||||
"""
|
||||
根据模型名称转换参数:
|
||||
- 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temprature' 参数,
|
||||
- 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temperature' 参数,
|
||||
并将 'max_tokens' 重命名为 'max_completion_tokens'
|
||||
"""
|
||||
# 复制一份参数,避免直接修改原始数据
|
||||
new_params = dict(params)
|
||||
# 定义需要转换的模型列表
|
||||
models_needing_transformation = [
|
||||
"o3-mini",
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
"o1-2024-12-17",
|
||||
"o1-preview-2024-09-12",
|
||||
"o3-mini-2025-01-31",
|
||||
"o1-mini-2024-09-12",
|
||||
]
|
||||
if self.model_name.lower() in models_needing_transformation:
|
||||
# 删除 'temprature' 参数(如果存在)
|
||||
|
||||
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
|
||||
# 删除 'temperature' 参数(如果存在)
|
||||
new_params.pop("temperature", None)
|
||||
# 如果存在 'max_tokens',则重命名为 'max_completion_tokens'
|
||||
if "max_tokens" in new_params:
|
||||
@@ -417,19 +434,7 @@ class LLM_request:
|
||||
**params_copy,
|
||||
}
|
||||
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
||||
if (
|
||||
self.model_name.lower()
|
||||
in [
|
||||
"o3-mini",
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
"o1-2024-12-17",
|
||||
"o1-preview-2024-09-12",
|
||||
"o3-mini-2025-01-31",
|
||||
"o1-mini-2024-09-12",
|
||||
]
|
||||
and "max_tokens" in payload
|
||||
):
|
||||
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
|
||||
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
||||
return payload
|
||||
|
||||
@@ -466,7 +471,8 @@ class LLM_request:
|
||||
|
||||
return "没有返回结果", ""
|
||||
|
||||
def _extract_reasoning(self, content: str) -> tuple[str, str]:
|
||||
@staticmethod
|
||||
def _extract_reasoning(content: str) -> Tuple[str, str]:
|
||||
"""CoT思维链提取"""
|
||||
match = re.search(r"(?:<think>)?(.*?)</think>", content, re.DOTALL)
|
||||
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
|
||||
@@ -506,6 +512,7 @@ class LLM_request:
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": global_config.max_response_length,
|
||||
**self.params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
content, reasoning_content = await self._execute_request(
|
||||
|
||||
Reference in New Issue
Block a user