fix: 修复utils_model.py潜在问题
- 将重复的模型列表提取为类变量 - 修复流式处理中变量未初始化的错误 - 改进错误响应处理的安全性 - 修复类型标注 - 优化重复的条件判断 - 将_init_database改为静态方法 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -18,6 +18,17 @@ config = driver.config
|
|||||||
|
|
||||||
|
|
||||||
class LLM_request:
|
class LLM_request:
|
||||||
|
# 定义需要转换的模型列表,作为类变量避免重复
|
||||||
|
MODELS_NEEDING_TRANSFORMATION = [
|
||||||
|
"o3-mini",
|
||||||
|
"o1-mini",
|
||||||
|
"o1-preview",
|
||||||
|
"o1-2024-12-17",
|
||||||
|
"o1-preview-2024-09-12",
|
||||||
|
"o3-mini-2025-01-31",
|
||||||
|
"o1-mini-2024-09-12",
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(self, model, **kwargs):
|
def __init__(self, model, **kwargs):
|
||||||
# 将大写的配置键转换为小写并从config中获取实际值
|
# 将大写的配置键转换为小写并从config中获取实际值
|
||||||
try:
|
try:
|
||||||
@@ -36,7 +47,8 @@ class LLM_request:
|
|||||||
# 获取数据库实例
|
# 获取数据库实例
|
||||||
self._init_database()
|
self._init_database()
|
||||||
|
|
||||||
def _init_database(self):
|
@staticmethod
|
||||||
|
def _init_database():
|
||||||
"""初始化数据库集合"""
|
"""初始化数据库集合"""
|
||||||
try:
|
try:
|
||||||
# 创建llm_usage集合的索引
|
# 创建llm_usage集合的索引
|
||||||
@@ -44,8 +56,8 @@ class LLM_request:
|
|||||||
db.llm_usage.create_index([("model_name", 1)])
|
db.llm_usage.create_index([("model_name", 1)])
|
||||||
db.llm_usage.create_index([("user_id", 1)])
|
db.llm_usage.create_index([("user_id", 1)])
|
||||||
db.llm_usage.create_index([("request_type", 1)])
|
db.llm_usage.create_index([("request_type", 1)])
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logger.error("创建数据库索引失败")
|
logger.error(f"创建数据库索引失败: {str(e)}")
|
||||||
|
|
||||||
def _record_usage(
|
def _record_usage(
|
||||||
self,
|
self,
|
||||||
@@ -85,8 +97,8 @@ class LLM_request:
|
|||||||
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
|
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
|
||||||
f"总计: {total_tokens}"
|
f"总计: {total_tokens}"
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logger.error("记录token使用情况失败")
|
logger.error(f"记录token使用情况失败: {str(e)}")
|
||||||
|
|
||||||
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
|
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
|
||||||
"""计算API调用成本
|
"""计算API调用成本
|
||||||
@@ -152,10 +164,8 @@ class LLM_request:
|
|||||||
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
||||||
# 判断是否为流式
|
# 判断是否为流式
|
||||||
stream_mode = self.params.get("stream", False)
|
stream_mode = self.params.get("stream", False)
|
||||||
if self.params.get("stream", False) is True:
|
logger_msg = "进入流式输出模式," if stream_mode else ""
|
||||||
logger.debug(f"进入流式输出模式,发送请求到URL: {api_url}")
|
logger.debug(f"{logger_msg}发送请求到URL: {api_url}")
|
||||||
else:
|
|
||||||
logger.debug(f"发送请求到URL: {api_url}")
|
|
||||||
logger.info(f"使用模型: {self.model_name}")
|
logger.info(f"使用模型: {self.model_name}")
|
||||||
|
|
||||||
# 构建请求体
|
# 构建请求体
|
||||||
@@ -255,6 +265,8 @@ class LLM_request:
|
|||||||
if stream_mode:
|
if stream_mode:
|
||||||
flag_delta_content_finished = False
|
flag_delta_content_finished = False
|
||||||
accumulated_content = ""
|
accumulated_content = ""
|
||||||
|
usage = None # 初始化usage变量,避免未定义错误
|
||||||
|
|
||||||
async for line_bytes in response.content:
|
async for line_bytes in response.content:
|
||||||
line = line_bytes.decode("utf-8").strip()
|
line = line_bytes.decode("utf-8").strip()
|
||||||
if not line:
|
if not line:
|
||||||
@@ -266,7 +278,9 @@ class LLM_request:
|
|||||||
try:
|
try:
|
||||||
chunk = json.loads(data_str)
|
chunk = json.loads(data_str)
|
||||||
if flag_delta_content_finished:
|
if flag_delta_content_finished:
|
||||||
usage = chunk.get("usage", None) # 获取tokn用量
|
chunk_usage = chunk.get("usage",None)
|
||||||
|
if chunk_usage:
|
||||||
|
usage = chunk_usage # 获取token用量
|
||||||
else:
|
else:
|
||||||
delta = chunk["choices"][0]["delta"]
|
delta = chunk["choices"][0]["delta"]
|
||||||
delta_content = delta.get("content")
|
delta_content = delta.get("content")
|
||||||
@@ -276,14 +290,15 @@ class LLM_request:
|
|||||||
# 检测流式输出文本是否结束
|
# 检测流式输出文本是否结束
|
||||||
finish_reason = chunk["choices"][0].get("finish_reason")
|
finish_reason = chunk["choices"][0].get("finish_reason")
|
||||||
if finish_reason == "stop":
|
if finish_reason == "stop":
|
||||||
usage = chunk.get("usage", None)
|
chunk_usage = chunk.get("usage",None)
|
||||||
if usage:
|
if chunk_usage:
|
||||||
|
usage = chunk_usage
|
||||||
break
|
break
|
||||||
# 部分平台在文本输出结束前不会返回token用量,此时需要再获取一次chunk
|
# 部分平台在文本输出结束前不会返回token用量,此时需要再获取一次chunk
|
||||||
flag_delta_content_finished = True
|
flag_delta_content_finished = True
|
||||||
|
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logger.exception("解析流式输出错误")
|
logger.exception(f"解析流式输出错误: {str(e)}")
|
||||||
content = accumulated_content
|
content = accumulated_content
|
||||||
reasoning_content = ""
|
reasoning_content = ""
|
||||||
think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
|
think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
|
||||||
@@ -315,33 +330,40 @@ class LLM_request:
|
|||||||
wait_time = policy["base_wait"] * (2**retry)
|
wait_time = policy["base_wait"] * (2**retry)
|
||||||
logger.error(f"HTTP响应错误,等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}")
|
logger.error(f"HTTP响应错误,等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}")
|
||||||
try:
|
try:
|
||||||
if hasattr(e, "history") and e.history and hasattr(e.history[0], "text"):
|
if hasattr(e, "response") and e.response and hasattr(e.response, "text"):
|
||||||
error_text = await e.history[0].text()
|
error_text = await e.response.text()
|
||||||
error_json = json.loads(error_text)
|
try:
|
||||||
if isinstance(error_json, list) and len(error_json) > 0:
|
error_json = json.loads(error_text)
|
||||||
for error_item in error_json:
|
if isinstance(error_json, list) and len(error_json) > 0:
|
||||||
if "error" in error_item and isinstance(error_item["error"], dict):
|
for error_item in error_json:
|
||||||
error_obj = error_item["error"]
|
if "error" in error_item and isinstance(error_item["error"], dict):
|
||||||
logger.error(
|
error_obj = error_item["error"]
|
||||||
f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}"
|
logger.error(
|
||||||
)
|
f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}"
|
||||||
elif isinstance(error_json, dict) and "error" in error_json:
|
)
|
||||||
error_obj = error_json.get("error", {})
|
elif isinstance(error_json, dict) and "error" in error_json:
|
||||||
logger.error(
|
error_obj = error_json.get("error", {})
|
||||||
f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}"
|
logger.error(
|
||||||
)
|
f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}"
|
||||||
else:
|
)
|
||||||
logger.error(f"服务器错误响应: {error_json}")
|
else:
|
||||||
except Exception as parse_err:
|
logger.error(f"服务器错误响应: {error_json}")
|
||||||
|
except (json.JSONDecodeError, TypeError) as json_err:
|
||||||
|
logger.warning(f"响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}")
|
||||||
|
except (AttributeError, TypeError, ValueError) as parse_err:
|
||||||
logger.warning(f"无法解析响应错误内容: {str(parse_err)}")
|
logger.warning(f"无法解析响应错误内容: {str(parse_err)}")
|
||||||
|
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
logger.critical(f"HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}")
|
logger.critical(f"HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}")
|
||||||
if image_base64:
|
# 安全地检查和记录请求详情
|
||||||
payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
if image_base64 and payload and isinstance(payload, dict) and "messages" in payload and len(payload["messages"]) > 0:
|
||||||
f"data:image/{image_format.lower()};base64,{image_base64[:10]}...{image_base64[-10:]}"
|
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
|
||||||
)
|
content = payload["messages"][0]["content"]
|
||||||
|
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
|
||||||
|
payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
||||||
|
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64[:10]}...{image_base64[-10:]}"
|
||||||
|
)
|
||||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
|
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
|
||||||
raise RuntimeError(f"API请求失败: 状态码 {e.status}, {e.message}")
|
raise RuntimeError(f"API请求失败: 状态码 {e.status}, {e.message}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -351,10 +373,14 @@ class LLM_request:
|
|||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
logger.critical(f"请求失败: {str(e)}")
|
logger.critical(f"请求失败: {str(e)}")
|
||||||
if image_base64:
|
# 安全地检查和记录请求详情
|
||||||
payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
if image_base64 and payload and isinstance(payload, dict) and "messages" in payload and len(payload["messages"]) > 0:
|
||||||
f"data:image/{image_format.lower()};base64,{image_base64[:10]}...{image_base64[-10:]}"
|
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
|
||||||
)
|
content = payload["messages"][0]["content"]
|
||||||
|
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
|
||||||
|
payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
||||||
|
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64[:10]}...{image_base64[-10:]}"
|
||||||
|
)
|
||||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
|
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
|
||||||
raise RuntimeError(f"API请求失败: {str(e)}")
|
raise RuntimeError(f"API请求失败: {str(e)}")
|
||||||
|
|
||||||
@@ -364,23 +390,14 @@ class LLM_request:
|
|||||||
async def _transform_parameters(self, params: dict) -> dict:
|
async def _transform_parameters(self, params: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
根据模型名称转换参数:
|
根据模型名称转换参数:
|
||||||
- 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temprature' 参数,
|
- 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temperature' 参数,
|
||||||
并将 'max_tokens' 重命名为 'max_completion_tokens'
|
并将 'max_tokens' 重命名为 'max_completion_tokens'
|
||||||
"""
|
"""
|
||||||
# 复制一份参数,避免直接修改原始数据
|
# 复制一份参数,避免直接修改原始数据
|
||||||
new_params = dict(params)
|
new_params = dict(params)
|
||||||
# 定义需要转换的模型列表
|
|
||||||
models_needing_transformation = [
|
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
|
||||||
"o3-mini",
|
# 删除 'temperature' 参数(如果存在)
|
||||||
"o1-mini",
|
|
||||||
"o1-preview",
|
|
||||||
"o1-2024-12-17",
|
|
||||||
"o1-preview-2024-09-12",
|
|
||||||
"o3-mini-2025-01-31",
|
|
||||||
"o1-mini-2024-09-12",
|
|
||||||
]
|
|
||||||
if self.model_name.lower() in models_needing_transformation:
|
|
||||||
# 删除 'temprature' 参数(如果存在)
|
|
||||||
new_params.pop("temperature", None)
|
new_params.pop("temperature", None)
|
||||||
# 如果存在 'max_tokens',则重命名为 'max_completion_tokens'
|
# 如果存在 'max_tokens',则重命名为 'max_completion_tokens'
|
||||||
if "max_tokens" in new_params:
|
if "max_tokens" in new_params:
|
||||||
@@ -417,19 +434,7 @@ class LLM_request:
|
|||||||
**params_copy,
|
**params_copy,
|
||||||
}
|
}
|
||||||
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
||||||
if (
|
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
|
||||||
self.model_name.lower()
|
|
||||||
in [
|
|
||||||
"o3-mini",
|
|
||||||
"o1-mini",
|
|
||||||
"o1-preview",
|
|
||||||
"o1-2024-12-17",
|
|
||||||
"o1-preview-2024-09-12",
|
|
||||||
"o3-mini-2025-01-31",
|
|
||||||
"o1-mini-2024-09-12",
|
|
||||||
]
|
|
||||||
and "max_tokens" in payload
|
|
||||||
):
|
|
||||||
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
@@ -466,7 +471,8 @@ class LLM_request:
|
|||||||
|
|
||||||
return "没有返回结果", ""
|
return "没有返回结果", ""
|
||||||
|
|
||||||
def _extract_reasoning(self, content: str) -> tuple[str, str]:
|
@staticmethod
|
||||||
|
def _extract_reasoning(content: str) -> Tuple[str, str]:
|
||||||
"""CoT思维链提取"""
|
"""CoT思维链提取"""
|
||||||
match = re.search(r"(?:<think>)?(.*?)</think>", content, re.DOTALL)
|
match = re.search(r"(?:<think>)?(.*?)</think>", content, re.DOTALL)
|
||||||
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
|
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
|
||||||
@@ -506,6 +512,7 @@ class LLM_request:
|
|||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
"max_tokens": global_config.max_response_length,
|
"max_tokens": global_config.max_response_length,
|
||||||
**self.params,
|
**self.params,
|
||||||
|
**kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
content, reasoning_content = await self._execute_request(
|
content, reasoning_content = await self._execute_request(
|
||||||
|
|||||||
Reference in New Issue
Block a user