diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index d4d57a93d..c5782a923 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -18,6 +18,17 @@ config = driver.config class LLM_request: + # 定义需要转换的模型列表,作为类变量避免重复 + MODELS_NEEDING_TRANSFORMATION = [ + "o3-mini", + "o1-mini", + "o1-preview", + "o1-2024-12-17", + "o1-preview-2024-09-12", + "o3-mini-2025-01-31", + "o1-mini-2024-09-12", + ] + def __init__(self, model, **kwargs): # 将大写的配置键转换为小写并从config中获取实际值 try: @@ -36,7 +47,8 @@ class LLM_request: # 获取数据库实例 self._init_database() - def _init_database(self): + @staticmethod + def _init_database(): """初始化数据库集合""" try: # 创建llm_usage集合的索引 @@ -44,8 +56,8 @@ class LLM_request: db.llm_usage.create_index([("model_name", 1)]) db.llm_usage.create_index([("user_id", 1)]) db.llm_usage.create_index([("request_type", 1)]) - except Exception: - logger.error("创建数据库索引失败") + except Exception as e: + logger.error(f"创建数据库索引失败: {str(e)}") def _record_usage( self, @@ -85,8 +97,8 @@ class LLM_request: f"提示词: {prompt_tokens}, 完成: {completion_tokens}, " f"总计: {total_tokens}" ) - except Exception: - logger.error("记录token使用情况失败") + except Exception as e: + logger.error(f"记录token使用情况失败: {str(e)}") def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float: """计算API调用成本 @@ -152,10 +164,8 @@ class LLM_request: api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}" # 判断是否为流式 stream_mode = self.params.get("stream", False) - if self.params.get("stream", False) is True: - logger.debug(f"进入流式输出模式,发送请求到URL: {api_url}") - else: - logger.debug(f"发送请求到URL: {api_url}") + logger_msg = "进入流式输出模式," if stream_mode else "" + logger.debug(f"{logger_msg}发送请求到URL: {api_url}") logger.info(f"使用模型: {self.model_name}") # 构建请求体 @@ -255,6 +265,8 @@ class LLM_request: if stream_mode: flag_delta_content_finished = False accumulated_content = "" + usage = None # 初始化usage变量,避免未定义错误 + async for line_bytes in response.content: line = line_bytes.decode("utf-8").strip() if not line: @@ -266,7 +278,9 @@ class LLM_request: try: chunk = json.loads(data_str) if flag_delta_content_finished: - usage = chunk.get("usage", None) # 获取tokn用量 + chunk_usage = chunk.get("usage",None) + if chunk_usage: + usage = chunk_usage # 获取token用量 else: delta = chunk["choices"][0]["delta"] delta_content = delta.get("content") @@ -276,14 +290,15 @@ class LLM_request: # 检测流式输出文本是否结束 finish_reason = chunk["choices"][0].get("finish_reason") if finish_reason == "stop": - usage = chunk.get("usage", None) - if usage: + chunk_usage = chunk.get("usage",None) + if chunk_usage: + usage = chunk_usage break # 部分平台在文本输出结束前不会返回token用量,此时需要再获取一次chunk flag_delta_content_finished = True - except Exception: - logger.exception("解析流式输出错误") + except Exception as e: + logger.exception(f"解析流式输出错误: {str(e)}") content = accumulated_content reasoning_content = "" think_match = re.search(r"(.*?)", content, re.DOTALL) @@ -315,33 +330,40 @@ class LLM_request: wait_time = policy["base_wait"] * (2**retry) logger.error(f"HTTP响应错误,等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}") try: - if hasattr(e, "history") and e.history and hasattr(e.history[0], "text"): - error_text = await e.history[0].text() - error_json = json.loads(error_text) - if isinstance(error_json, list) and len(error_json) > 0: - for error_item in error_json: - if "error" in error_item and isinstance(error_item["error"], dict): - error_obj = error_item["error"] - logger.error( - f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}" - ) - elif isinstance(error_json, dict) and "error" in error_json: - error_obj = error_json.get("error", {}) - logger.error( - f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}" - ) - else: - logger.error(f"服务器错误响应: {error_json}") - except Exception as parse_err: + if hasattr(e, "response") and e.response and hasattr(e.response, "text"): + error_text = await e.response.text() + try: + error_json = json.loads(error_text) + if isinstance(error_json, list) and len(error_json) > 0: + for error_item in error_json: + if "error" in error_item and isinstance(error_item["error"], dict): + error_obj = error_item["error"] + logger.error( + f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}" + ) + elif isinstance(error_json, dict) and "error" in error_json: + error_obj = error_json.get("error", {}) + logger.error( + f"服务器错误详情: 代码={error_obj.get('code')}, 状态={error_obj.get('status')}, 消息={error_obj.get('message')}" + ) + else: + logger.error(f"服务器错误响应: {error_json}") + except (json.JSONDecodeError, TypeError) as json_err: + logger.warning(f"响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}") + except (AttributeError, TypeError, ValueError) as parse_err: logger.warning(f"无法解析响应错误内容: {str(parse_err)}") await asyncio.sleep(wait_time) else: logger.critical(f"HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}") - if image_base64: - payload["messages"][0]["content"][1]["image_url"]["url"] = ( - f"data:image/{image_format.lower()};base64,{image_base64[:10]}...{image_base64[-10:]}" - ) + # 安全地检查和记录请求详情 + if image_base64 and payload and isinstance(payload, dict) and "messages" in payload and len(payload["messages"]) > 0: + if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]: + content = payload["messages"][0]["content"] + if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]: + payload["messages"][0]["content"][1]["image_url"]["url"] = ( + f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64[:10]}...{image_base64[-10:]}" + ) logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}") raise RuntimeError(f"API请求失败: 状态码 {e.status}, {e.message}") except Exception as e: @@ -351,10 +373,14 @@ class LLM_request: await asyncio.sleep(wait_time) else: logger.critical(f"请求失败: {str(e)}") - if image_base64: - payload["messages"][0]["content"][1]["image_url"]["url"] = ( - f"data:image/{image_format.lower()};base64,{image_base64[:10]}...{image_base64[-10:]}" - ) + # 安全地检查和记录请求详情 + if image_base64 and payload and isinstance(payload, dict) and "messages" in payload and len(payload["messages"]) > 0: + if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]: + content = payload["messages"][0]["content"] + if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]: + payload["messages"][0]["content"][1]["image_url"]["url"] = ( + f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,{image_base64[:10]}...{image_base64[-10:]}" + ) logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}") raise RuntimeError(f"API请求失败: {str(e)}") @@ -364,23 +390,14 @@ class LLM_request: async def _transform_parameters(self, params: dict) -> dict: """ 根据模型名称转换参数: - - 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temprature' 参数, + - 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temperature' 参数, 并将 'max_tokens' 重命名为 'max_completion_tokens' """ # 复制一份参数,避免直接修改原始数据 new_params = dict(params) - # 定义需要转换的模型列表 - models_needing_transformation = [ - "o3-mini", - "o1-mini", - "o1-preview", - "o1-2024-12-17", - "o1-preview-2024-09-12", - "o3-mini-2025-01-31", - "o1-mini-2024-09-12", - ] - if self.model_name.lower() in models_needing_transformation: - # 删除 'temprature' 参数(如果存在) + + if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION: + # 删除 'temperature' 参数(如果存在) new_params.pop("temperature", None) # 如果存在 'max_tokens',则重命名为 'max_completion_tokens' if "max_tokens" in new_params: @@ -417,19 +434,7 @@ class LLM_request: **params_copy, } # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 - if ( - self.model_name.lower() - in [ - "o3-mini", - "o1-mini", - "o1-preview", - "o1-2024-12-17", - "o1-preview-2024-09-12", - "o3-mini-2025-01-31", - "o1-mini-2024-09-12", - ] - and "max_tokens" in payload - ): + if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: payload["max_completion_tokens"] = payload.pop("max_tokens") return payload @@ -466,7 +471,8 @@ class LLM_request: return "没有返回结果", "" - def _extract_reasoning(self, content: str) -> tuple[str, str]: + @staticmethod + def _extract_reasoning(content: str) -> Tuple[str, str]: """CoT思维链提取""" match = re.search(r"(?:)?(.*?)", content, re.DOTALL) content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip() @@ -506,6 +512,7 @@ class LLM_request: "messages": [{"role": "user", "content": prompt}], "max_tokens": global_config.max_response_length, **self.params, + **kwargs, } content, reasoning_content = await self._execute_request(