diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py
index d4d57a93d..c5782a923 100644
--- a/src/plugins/models/utils_model.py
+++ b/src/plugins/models/utils_model.py
@@ -18,6 +18,17 @@ config = driver.config
class LLM_request:
+ # 定义需要转换的模型列表,作为类变量避免重复
+ MODELS_NEEDING_TRANSFORMATION = [
+ "o3-mini",
+ "o1-mini",
+ "o1-preview",
+ "o1-2024-12-17",
+ "o1-preview-2024-09-12",
+ "o3-mini-2025-01-31",
+ "o1-mini-2024-09-12",
+ ]
+
def __init__(self, model, **kwargs):
# 将大写的配置键转换为小写并从config中获取实际值
try:
@@ -36,7 +47,8 @@ class LLM_request:
# 获取数据库实例
self._init_database()
- def _init_database(self):
+ @staticmethod
+ def _init_database():
"""初始化数据库集合"""
try:
# 创建llm_usage集合的索引
@@ -44,8 +56,8 @@ class LLM_request:
db.llm_usage.create_index([("model_name", 1)])
db.llm_usage.create_index([("user_id", 1)])
db.llm_usage.create_index([("request_type", 1)])
- except Exception:
- logger.error("创建数据库索引失败")
+ except Exception as e:
+ logger.error(f"创建数据库索引失败: {str(e)}")
def _record_usage(
self,
@@ -85,8 +97,8 @@ class LLM_request:
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
f"总计: {total_tokens}"
)
- except Exception:
- logger.error("记录token使用情况失败")
+ except Exception as e:
+ logger.error(f"记录token使用情况失败: {str(e)}")
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
"""计算API调用成本
@@ -152,10 +164,8 @@ class LLM_request:
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
# 判断是否为流式
stream_mode = self.params.get("stream", False)
- if self.params.get("stream", False) is True:
- logger.debug(f"进入流式输出模式,发送请求到URL: {api_url}")
- else:
- logger.debug(f"发送请求到URL: {api_url}")
+ logger_msg = "进入流式输出模式," if stream_mode else ""
+ logger.debug(f"{logger_msg}发送请求到URL: {api_url}")
logger.info(f"使用模型: {self.model_name}")
# 构建请求体
@@ -255,6 +265,8 @@ class LLM_request:
if stream_mode:
flag_delta_content_finished = False
accumulated_content = ""
+ usage = None # 初始化usage变量,避免未定义错误
+
async for line_bytes in response.content:
line = line_bytes.decode("utf-8").strip()
if not line:
@@ -266,7 +278,9 @@ class LLM_request:
try:
chunk = json.loads(data_str)
if flag_delta_content_finished:
- usage = chunk.get("usage", None) # 获取tokn用量
+ chunk_usage = chunk.get("usage",None)
+ if chunk_usage:
+ usage = chunk_usage # 获取token用量
else:
delta = chunk["choices"][0]["delta"]
delta_content = delta.get("content")
@@ -276,14 +290,15 @@ class LLM_request:
# 检测流式输出文本是否结束
finish_reason = chunk["choices"][0].get("finish_reason")
if finish_reason == "stop":
- usage = chunk.get("usage", None)
- if usage:
+ chunk_usage = chunk.get("usage",None)
+ if chunk_usage:
+ usage = chunk_usage
break
# 部分平台在文本输出结束前不会返回token用量,此时需要再获取一次chunk
flag_delta_content_finished = True
- except Exception:
- logger.exception("解析流式输出错误")
+ except Exception as e:
+ logger.exception(f"解析流式输出错误: {str(e)}")
content = accumulated_content
reasoning_content = ""
think_match = re.search(r"(.*?)", content, re.DOTALL)
@@ -315,33 +330,40 @@ class LLM_request:
wait_time = policy["base_wait"] * (2**retry)
logger.error(f"HTTP响应错误,等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}")
try:
- if hasattr(e, "history") and e.history and hasattr(e.history[0], "text"):
- error_text = await e.history[0].text()
- error_json = json.loads(error_text)
- if isinstance(error_json, list) and len(error_json) > 0:
- for error_item in error_json:
- if "error" in error_item and isinstance(error_item["error"], dict):
- error_obj = error_item["error"]
- logger.error(
- f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}"
- )
- elif isinstance(error_json, dict) and "error" in error_json:
- error_obj = error_json.get("error", {})
- logger.error(
- f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}"
- )
- else:
- logger.error(f"服务器错误响应: {error_json}")
- except Exception as parse_err:
+ if hasattr(e, "response") and e.response and hasattr(e.response, "text"):
+ error_text = await e.response.text()
+ try:
+ error_json = json.loads(error_text)
+ if isinstance(error_json, list) and len(error_json) > 0:
+ for error_item in error_json:
+ if "error" in error_item and isinstance(error_item["error"], dict):
+ error_obj = error_item["error"]
+ logger.error(
+ f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}"
+ )
+ elif isinstance(error_json, dict) and "error" in error_json:
+ error_obj = error_json.get("error", {})
+ logger.error(
+ f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}"
+ )
+ else:
+ logger.error(f"服务器错误响应: {error_json}")
+ except (json.JSONDecodeError, TypeError) as json_err:
+ logger.warning(f"响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}")
+ except (AttributeError, TypeError, ValueError) as parse_err:
logger.warning(f"无法解析响应错误内容: {str(parse_err)}")
await asyncio.sleep(wait_time)
else:
logger.critical(f"HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}")
- if image_base64:
- payload["messages"][0]["content"][1]["image_url"]["url"] = (
- f"data:image/{image_format.lower()};base64,{image_base64[:10]}...{image_base64[-10:]}"
- )
+ # 安全地检查和记录请求详情
+ if image_base64 and payload and isinstance(payload, dict) and "messages" in payload and len(payload["messages"]) > 0:
+ if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
+ content = payload["messages"][0]["content"]
+ if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
+ payload["messages"][0]["content"][1]["image_url"]["url"] = (
+ f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64[:10]}...{image_base64[-10:]}"
+ )
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
raise RuntimeError(f"API请求失败: 状态码 {e.status}, {e.message}")
except Exception as e:
@@ -351,10 +373,14 @@ class LLM_request:
await asyncio.sleep(wait_time)
else:
logger.critical(f"请求失败: {str(e)}")
- if image_base64:
- payload["messages"][0]["content"][1]["image_url"]["url"] = (
- f"data:image/{image_format.lower()};base64,{image_base64[:10]}...{image_base64[-10:]}"
- )
+ # 安全地检查和记录请求详情
+ if image_base64 and payload and isinstance(payload, dict) and "messages" in payload and len(payload["messages"]) > 0:
+ if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
+ content = payload["messages"][0]["content"]
+ if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
+ payload["messages"][0]["content"][1]["image_url"]["url"] = (
+ f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64[:10]}...{image_base64[-10:]}"
+ )
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
raise RuntimeError(f"API请求失败: {str(e)}")
@@ -364,23 +390,14 @@ class LLM_request:
async def _transform_parameters(self, params: dict) -> dict:
"""
根据模型名称转换参数:
- - 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temprature' 参数,
+ - 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temperature' 参数,
并将 'max_tokens' 重命名为 'max_completion_tokens'
"""
# 复制一份参数,避免直接修改原始数据
new_params = dict(params)
- # 定义需要转换的模型列表
- models_needing_transformation = [
- "o3-mini",
- "o1-mini",
- "o1-preview",
- "o1-2024-12-17",
- "o1-preview-2024-09-12",
- "o3-mini-2025-01-31",
- "o1-mini-2024-09-12",
- ]
- if self.model_name.lower() in models_needing_transformation:
- # 删除 'temprature' 参数(如果存在)
+
+ if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
+ # 删除 'temperature' 参数(如果存在)
new_params.pop("temperature", None)
# 如果存在 'max_tokens',则重命名为 'max_completion_tokens'
if "max_tokens" in new_params:
@@ -417,19 +434,7 @@ class LLM_request:
**params_copy,
}
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
- if (
- self.model_name.lower()
- in [
- "o3-mini",
- "o1-mini",
- "o1-preview",
- "o1-2024-12-17",
- "o1-preview-2024-09-12",
- "o3-mini-2025-01-31",
- "o1-mini-2024-09-12",
- ]
- and "max_tokens" in payload
- ):
+ if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens")
return payload
@@ -466,7 +471,8 @@ class LLM_request:
return "没有返回结果", ""
- def _extract_reasoning(self, content: str) -> tuple[str, str]:
+ @staticmethod
+ def _extract_reasoning(content: str) -> Tuple[str, str]:
"""CoT思维链提取"""
match = re.search(r"(?:)?(.*?)", content, re.DOTALL)
content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip()
@@ -506,6 +512,7 @@ class LLM_request:
"messages": [{"role": "user", "content": prompt}],
"max_tokens": global_config.max_response_length,
**self.params,
+ **kwargs,
}
content, reasoning_content = await self._execute_request(