Files
Mofox-Core/src/plugins/models/utils_model.py
2025-03-20 17:08:53 +08:00

675 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import re
from datetime import datetime
from typing import Tuple, Union
import aiohttp
from src.common.logger import get_module_logger
from nonebot import get_driver
import base64
from PIL import Image
import io
from ...common.database import db
from ..chat.config import global_config
driver = get_driver()
config = driver.config
logger = get_module_logger("model_utils")
class LLM_request:
# 定义需要转换的模型列表,作为类变量避免重复
MODELS_NEEDING_TRANSFORMATION = [
"o3-mini",
"o1-mini",
"o1-preview",
"o1-2024-12-17",
"o1-preview-2024-09-12",
"o3-mini-2025-01-31",
"o1-mini-2024-09-12",
]
def __init__(self, model, **kwargs):
# 将大写的配置键转换为小写并从config中获取实际值
try:
self.api_key = getattr(config, model["key"])
self.base_url = getattr(config, model["base_url"])
except AttributeError as e:
logger.error(f"原始 model dict 信息:{model}")
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
self.model_name = model["name"]
self.params = kwargs
self.pri_in = model.get("pri_in", 0)
self.pri_out = model.get("pri_out", 0)
# 获取数据库实例
self._init_database()
# 从 kwargs 中提取 request_type如果没有提供则默认为 "default"
self.request_type = kwargs.pop("request_type", "default")
@staticmethod
def _init_database():
"""初始化数据库集合"""
try:
# 创建llm_usage集合的索引
db.llm_usage.create_index([("timestamp", 1)])
db.llm_usage.create_index([("model_name", 1)])
db.llm_usage.create_index([("user_id", 1)])
db.llm_usage.create_index([("request_type", 1)])
except Exception as e:
logger.error(f"创建数据库索引失败: {str(e)}")
def _record_usage(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
user_id: str = "system",
request_type: str = None,
endpoint: str = "/chat/completions",
):
"""记录模型使用情况到数据库
Args:
prompt_tokens: 输入token数
completion_tokens: 输出token数
total_tokens: 总token数
user_id: 用户ID默认为system
request_type: 请求类型(chat/embedding/image/topic/schedule)
endpoint: API端点
"""
# 如果 request_type 为 None则使用实例变量中的值
if request_type is None:
request_type = self.request_type
try:
usage_data = {
"model_name": self.model_name,
"user_id": user_id,
"request_type": request_type,
"endpoint": endpoint,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"cost": self._calculate_cost(prompt_tokens, completion_tokens),
"status": "success",
"timestamp": datetime.now(),
}
db.llm_usage.insert_one(usage_data)
logger.debug(
f"Token使用情况 - 模型: {self.model_name}, "
f"用户: {user_id}, 类型: {request_type}, "
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
f"总计: {total_tokens}"
)
except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}")
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
"""计算API调用成本
使用模型的pri_in和pri_out价格计算输入和输出的成本
Args:
prompt_tokens: 输入token数量
completion_tokens: 输出token数量
Returns:
float: 总成本(元)
"""
# 使用模型的pri_in和pri_out计算成本
input_cost = (prompt_tokens / 1000000) * self.pri_in
output_cost = (completion_tokens / 1000000) * self.pri_out
return round(input_cost + output_cost, 6)
async def _execute_request(
self,
endpoint: str,
prompt: str = None,
image_base64: str = None,
image_format: str = None,
payload: dict = None,
retry_policy: dict = None,
response_handler: callable = None,
user_id: str = "system",
request_type: str = None,
):
"""统一请求执行入口
Args:
endpoint: API端点路径 (如 "chat/completions")
prompt: prompt文本
image_base64: 图片的base64编码
image_format: 图片格式
payload: 请求体数据
retry_policy: 自定义重试策略
response_handler: 自定义响应处理器
user_id: 用户ID
request_type: 请求类型
"""
if request_type is None:
request_type = self.request_type
# 合并重试策略
default_retry = {
"max_retries": 3,
"base_wait": 15,
"retry_codes": [429, 413, 500, 503],
"abort_codes": [400, 401, 402, 403],
}
policy = {**default_retry, **(retry_policy or {})}
# 常见Error Code Mapping
error_code_mapping = {
400: "参数不正确",
401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env.prod中的配置是否正确哦~",
402: "账号余额不足",
403: "需要实名,或余额不足",
404: "Not Found",
429: "请求过于频繁,请稍后再试",
500: "服务器内部故障",
503: "服务器负载过高",
}
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
# 判断是否为流式
stream_mode = self.params.get("stream", False)
# logger_msg = "进入流式输出模式," if stream_mode else ""
# logger.debug(f"{logger_msg}发送请求到URL: {api_url}")
# logger.info(f"使用模型: {self.model_name}")
# 构建请求体
if image_base64:
payload = await self._build_payload(prompt, image_base64, image_format)
elif payload is None:
payload = await self._build_payload(prompt)
for retry in range(policy["max_retries"]):
try:
# 使用上下文管理器处理会话
headers = await self._build_headers()
# 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
if stream_mode:
headers["Accept"] = "text/event-stream"
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=payload) as response:
# 处理需要重试的状态码
if response.status in policy["retry_codes"]:
wait_time = policy["base_wait"] * (2**retry)
logger.warning(f"错误码: {response.status}, 等待 {wait_time}秒后重试")
if response.status == 413:
logger.warning("请求体过大,尝试压缩...")
image_base64 = compress_base64_image_by_scale(image_base64)
payload = await self._build_payload(prompt, image_base64, image_format)
elif response.status in [500, 503]:
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
raise RuntimeError("服务器负载过高模型恢复失败QAQ")
else:
logger.warning(f"请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
elif response.status in policy["abort_codes"]:
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
# 尝试获取并记录服务器返回的详细错误信息
try:
error_json = await response.json()
if error_json and isinstance(error_json, list) and len(error_json) > 0:
for error_item in error_json:
if "error" in error_item and isinstance(error_item["error"], dict):
error_obj = error_item["error"]
error_code = error_obj.get("code")
error_message = error_obj.get("message")
error_status = error_obj.get("status")
logger.error(
f"服务器错误详情: 代码={error_code}, 状态={error_status}, "
f"消息={error_message}"
)
elif isinstance(error_json, dict) and "error" in error_json:
# 处理单个错误对象的情况
error_obj = error_json.get("error", {})
error_code = error_obj.get("code")
error_message = error_obj.get("message")
error_status = error_obj.get("status")
logger.error(
f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}"
)
else:
# 记录原始错误响应内容
logger.error(f"服务器错误响应: {error_json}")
except Exception as e:
logger.warning(f"无法解析服务器错误响应: {str(e)}")
if response.status == 403:
# 只针对硅基流动的V3和R1进行降级处理
if (
self.model_name.startswith("Pro/deepseek-ai")
and self.base_url == "https://api.siliconflow.cn/v1/"
):
old_model_name = self.model_name
self.model_name = self.model_name[4:] # 移除"Pro/"前缀
logger.warning(f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}")
# 对全局配置进行更新
if global_config.llm_normal.get("name") == old_model_name:
global_config.llm_normal["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
if global_config.llm_reasoning.get("name") == old_model_name:
global_config.llm_reasoning["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
# 更新payload中的模型名
if payload and "model" in payload:
payload["model"] = self.model_name
# 重新尝试请求
retry -= 1 # 不计入重试次数
continue
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
response.raise_for_status()
# 将流式输出转化为非流式输出
if stream_mode:
flag_delta_content_finished = False
accumulated_content = ""
usage = None # 初始化usage变量避免未定义错误
async for line_bytes in response.content:
line = line_bytes.decode("utf-8").strip()
if not line:
continue
if line.startswith("data:"):
data_str = line[5:].strip()
if data_str == "[DONE]":
break
try:
chunk = json.loads(data_str)
if flag_delta_content_finished:
chunk_usage = chunk.get("usage", None)
if chunk_usage:
usage = chunk_usage # 获取token用量
else:
delta = chunk["choices"][0]["delta"]
delta_content = delta.get("content")
if delta_content is None:
delta_content = ""
accumulated_content += delta_content
# 检测流式输出文本是否结束
finish_reason = chunk["choices"][0].get("finish_reason")
if finish_reason == "stop":
chunk_usage = chunk.get("usage", None)
if chunk_usage:
usage = chunk_usage
break
# 部分平台在文本输出结束前不会返回token用量此时需要再获取一次chunk
flag_delta_content_finished = True
except Exception as e:
logger.exception(f"解析流式输出错误: {str(e)}")
content = accumulated_content
reasoning_content = ""
think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
if think_match:
reasoning_content = think_match.group(1).strip()
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
# 构造一个伪result以便调用自定义响应处理器或默认处理器
result = {
"choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}],
"usage": usage,
}
return (
response_handler(result)
if response_handler
else self._default_response_handler(result, user_id, request_type, endpoint)
)
else:
result = await response.json()
# 使用自定义处理器或默认处理
return (
response_handler(result)
if response_handler
else self._default_response_handler(result, user_id, request_type, endpoint)
)
except aiohttp.ClientResponseError as e:
# 处理aiohttp抛出的响应错误
if retry < policy["max_retries"] - 1:
wait_time = policy["base_wait"] * (2**retry)
logger.error(f"HTTP响应错误等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}")
try:
if hasattr(e, "response") and e.response and hasattr(e.response, "text"):
error_text = await e.response.text()
try:
error_json = json.loads(error_text)
if isinstance(error_json, list) and len(error_json) > 0:
for error_item in error_json:
if "error" in error_item and isinstance(error_item["error"], dict):
error_obj = error_item["error"]
logger.error(
f"服务器错误详情: 代码={error_obj.get('code')}, "
f"状态={error_obj.get('status')}, "
f"消息={error_obj.get('message')}"
)
elif isinstance(error_json, dict) and "error" in error_json:
error_obj = error_json.get("error", {})
logger.error(
f"服务器错误详情: 代码={error_obj.get('code')}, "
f"状态={error_obj.get('status')}, "
f"消息={error_obj.get('message')}"
)
else:
logger.error(f"服务器错误响应: {error_json}")
except (json.JSONDecodeError, TypeError) as json_err:
logger.warning(f"响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}")
except (AttributeError, TypeError, ValueError) as parse_err:
logger.warning(f"无法解析响应错误内容: {str(parse_err)}")
await asyncio.sleep(wait_time)
else:
logger.critical(f"HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}")
# 安全地检查和记录请求详情
if (
image_base64
and payload
and isinstance(payload, dict)
and "messages" in payload
and len(payload["messages"]) > 0
):
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
content = payload["messages"][0]["content"]
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
payload["messages"][0]["content"][1]["image_url"]["url"] = (
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}"
)
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
raise RuntimeError(f"API请求失败: 状态码 {e.status}, {e.message}") from e
except Exception as e:
if retry < policy["max_retries"] - 1:
wait_time = policy["base_wait"] * (2**retry)
logger.error(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
logger.critical(f"请求失败: {str(e)}")
# 安全地检查和记录请求详情
if (
image_base64
and payload
and isinstance(payload, dict)
and "messages" in payload
and len(payload["messages"]) > 0
):
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
content = payload["messages"][0]["content"]
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
payload["messages"][0]["content"][1]["image_url"]["url"] = (
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}"
)
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
raise RuntimeError(f"API请求失败: {str(e)}") from e
logger.error("达到最大重试次数,请求仍然失败")
raise RuntimeError("达到最大重试次数API请求仍然失败")
async def _transform_parameters(self, params: dict) -> dict:
"""
根据模型名称转换参数:
- 对于需要转换的OpenAI CoT系列模型例如 "o3-mini"),删除 'temperature' 参数,
并将 'max_tokens' 重命名为 'max_completion_tokens'
"""
# 复制一份参数,避免直接修改原始数据
new_params = dict(params)
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
# 删除 'temperature' 参数(如果存在)
new_params.pop("temperature", None)
# 如果存在 'max_tokens',则重命名为 'max_completion_tokens'
if "max_tokens" in new_params:
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
return new_params
async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
"""构建请求体"""
# 复制一份参数,避免直接修改 self.params
params_copy = await self._transform_parameters(self.params)
if image_base64:
payload = {
"model": self.model_name,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"},
},
],
}
],
"max_tokens": global_config.max_response_length,
**params_copy,
}
else:
payload = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": global_config.max_response_length,
**params_copy,
}
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens")
return payload
def _default_response_handler(
self, result: dict, user_id: str = "system", request_type: str = None, endpoint: str = "/chat/completions"
) -> Tuple:
"""默认响应解析"""
if "choices" in result and result["choices"]:
message = result["choices"][0]["message"]
content = message.get("content", "")
content, reasoning = self._extract_reasoning(content)
reasoning_content = message.get("model_extra", {}).get("reasoning_content", "")
if not reasoning_content:
reasoning_content = message.get("reasoning_content", "")
if not reasoning_content:
reasoning_content = reasoning
# 记录token使用情况
usage = result.get("usage", {})
if usage:
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", 0)
self._record_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
user_id=user_id,
request_type=request_type if request_type is not None else self.request_type,
endpoint=endpoint,
)
return content, reasoning_content
return "没有返回结果", ""
@staticmethod
def _extract_reasoning(content: str) -> Tuple[str, str]:
"""CoT思维链提取"""
match = re.search(r"(?:<think>)?(.*?)</think>", content, re.DOTALL)
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
if match:
reasoning = match.group(1).strip()
else:
reasoning = ""
return content, reasoning
async def _build_headers(self, no_key: bool = False) -> dict:
"""构建请求头"""
if no_key:
return {"Authorization": "Bearer **********", "Content-Type": "application/json"}
else:
return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
# 防止小朋友们截图自己的key
async def generate_response(self, prompt: str) -> Tuple[str, str, str]:
"""根据输入的提示生成模型的异步响应"""
content, reasoning_content = await self._execute_request(endpoint="/chat/completions", prompt=prompt)
return content, reasoning_content, self.model_name
async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple[str, str]:
"""根据输入的提示和图片生成模型的异步响应"""
content, reasoning_content = await self._execute_request(
endpoint="/chat/completions", prompt=prompt, image_base64=image_base64, image_format=image_format
)
return content, reasoning_content
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple[str, str]]:
"""异步方式根据输入的提示生成模型的响应"""
# 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": global_config.max_response_length,
**self.params,
**kwargs,
}
content, reasoning_content = await self._execute_request(
endpoint="/chat/completions", payload=data, prompt=prompt
)
return content, reasoning_content
async def get_embedding(self, text: str) -> Union[list, None]:
"""异步方法获取文本的embedding向量
Args:
text: 需要获取embedding的文本
Returns:
list: embedding向量如果失败则返回None
"""
if len(text) < 1:
logger.debug("该消息没有长度不再发送获取embedding向量的请求")
return None
def embedding_handler(result):
"""处理响应"""
if "data" in result and len(result["data"]) > 0:
# 提取 token 使用信息
usage = result.get("usage", {})
if usage:
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", 0)
# 记录 token 使用情况
self._record_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
user_id="system", # 可以根据需要修改 user_id
request_type="embedding", # 请求类型为 embedding
endpoint="/embeddings", # API 端点
)
return result["data"][0].get("embedding", None)
return result["data"][0].get("embedding", None)
return None
embedding = await self._execute_request(
endpoint="/embeddings",
prompt=text,
payload={"model": self.model_name, "input": text, "encoding_format": "float"},
retry_policy={"max_retries": 2, "base_wait": 6},
response_handler=embedding_handler,
)
return embedding
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
"""压缩base64格式的图片到指定大小
Args:
base64_data: base64编码的图片数据
target_size: 目标文件大小字节默认0.8MB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
# 如果已经小于目标大小,直接返回原图
if len(image_data) <= 2 * 1024 * 1024:
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 获取原始尺寸
original_width, original_height = img.size
# 计算缩放比例
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
# 计算新的尺寸
new_width = int(original_width * scale)
new_height = int(original_height * scale)
# 创建内存缓冲区
output_buffer = io.BytesIO()
# 如果是GIF处理所有帧
if getattr(img, "is_animated", False):
frames = []
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
new_frame = img.copy()
new_frame = new_frame.resize((new_width // 2, new_height // 2), Image.Resampling.LANCZOS) # 动图折上折
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format="GIF",
save_all=True,
append_images=frames[1:],
optimize=True,
duration=img.info.get("duration", 100),
loop=img.info.get("loop", 0),
)
else:
# 处理静态图片
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 保存到缓冲区,保持原始格式
if img.format == "PNG" and img.mode in ("RGBA", "LA"):
resized_img.save(output_buffer, format="PNG", optimize=True)
else:
resized_img.save(output_buffer, format="JPEG", quality=95, optimize=True)
# 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
logger.info(f"压缩前大小: {len(image_data) / 1024:.1f}KB, 压缩后大小: {len(compressed_data) / 1024:.1f}KB")
return base64.b64encode(compressed_data).decode("utf-8")
except Exception as e:
logger.error(f"压缩图片失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return base64_data