529 lines
24 KiB
Python
529 lines
24 KiB
Python
import asyncio
|
||
import json
|
||
import re
|
||
from datetime import datetime
|
||
from typing import Tuple, Union
|
||
|
||
import aiohttp
|
||
from loguru import logger
|
||
from nonebot import get_driver
|
||
import base64
|
||
from PIL import Image
|
||
import io
|
||
from ...common.database import db
|
||
from ..chat.config import global_config
|
||
|
||
driver = get_driver()
|
||
config = driver.config
|
||
|
||
|
||
class LLM_request:
|
||
def __init__(self, model, **kwargs):
|
||
# 将大写的配置键转换为小写并从config中获取实际值
|
||
try:
|
||
self.api_key = getattr(config, model["key"])
|
||
self.base_url = getattr(config, model["base_url"])
|
||
except AttributeError as e:
|
||
logger.error(f"原始 model dict 信息:{model}")
|
||
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
|
||
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
|
||
self.model_name = model["name"]
|
||
self.params = kwargs
|
||
|
||
self.pri_in = model.get("pri_in", 0)
|
||
self.pri_out = model.get("pri_out", 0)
|
||
|
||
# 获取数据库实例
|
||
self._init_database()
|
||
|
||
def _init_database(self):
|
||
"""初始化数据库集合"""
|
||
try:
|
||
# 创建llm_usage集合的索引
|
||
db.llm_usage.create_index([("timestamp", 1)])
|
||
db.llm_usage.create_index([("model_name", 1)])
|
||
db.llm_usage.create_index([("user_id", 1)])
|
||
db.llm_usage.create_index([("request_type", 1)])
|
||
except Exception:
|
||
logger.error("创建数据库索引失败")
|
||
|
||
def _record_usage(self, prompt_tokens: int, completion_tokens: int, total_tokens: int,
|
||
user_id: str = "system", request_type: str = "chat",
|
||
endpoint: str = "/chat/completions"):
|
||
"""记录模型使用情况到数据库
|
||
Args:
|
||
prompt_tokens: 输入token数
|
||
completion_tokens: 输出token数
|
||
total_tokens: 总token数
|
||
user_id: 用户ID,默认为system
|
||
request_type: 请求类型(chat/embedding/image等)
|
||
endpoint: API端点
|
||
"""
|
||
try:
|
||
usage_data = {
|
||
"model_name": self.model_name,
|
||
"user_id": user_id,
|
||
"request_type": request_type,
|
||
"endpoint": endpoint,
|
||
"prompt_tokens": prompt_tokens,
|
||
"completion_tokens": completion_tokens,
|
||
"total_tokens": total_tokens,
|
||
"cost": self._calculate_cost(prompt_tokens, completion_tokens),
|
||
"status": "success",
|
||
"timestamp": datetime.now()
|
||
}
|
||
db.llm_usage.insert_one(usage_data)
|
||
logger.info(
|
||
f"Token使用情况 - 模型: {self.model_name}, "
|
||
f"用户: {user_id}, 类型: {request_type}, "
|
||
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
|
||
f"总计: {total_tokens}"
|
||
)
|
||
except Exception:
|
||
logger.error("记录token使用情况失败")
|
||
|
||
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
|
||
"""计算API调用成本
|
||
使用模型的pri_in和pri_out价格计算输入和输出的成本
|
||
|
||
Args:
|
||
prompt_tokens: 输入token数量
|
||
completion_tokens: 输出token数量
|
||
|
||
Returns:
|
||
float: 总成本(元)
|
||
"""
|
||
# 使用模型的pri_in和pri_out计算成本
|
||
input_cost = (prompt_tokens / 1000000) * self.pri_in
|
||
output_cost = (completion_tokens / 1000000) * self.pri_out
|
||
return round(input_cost + output_cost, 6)
|
||
|
||
async def _execute_request(
|
||
self,
|
||
endpoint: str,
|
||
prompt: str = None,
|
||
image_base64: str = None,
|
||
image_format: str = None,
|
||
payload: dict = None,
|
||
retry_policy: dict = None,
|
||
response_handler: callable = None,
|
||
user_id: str = "system",
|
||
request_type: str = "chat"
|
||
):
|
||
"""统一请求执行入口
|
||
Args:
|
||
endpoint: API端点路径 (如 "chat/completions")
|
||
prompt: prompt文本
|
||
image_base64: 图片的base64编码
|
||
image_format: 图片格式
|
||
payload: 请求体数据
|
||
retry_policy: 自定义重试策略
|
||
response_handler: 自定义响应处理器
|
||
user_id: 用户ID
|
||
request_type: 请求类型
|
||
"""
|
||
# 合并重试策略
|
||
default_retry = {
|
||
"max_retries": 3, "base_wait": 15,
|
||
"retry_codes": [429, 413, 500, 503],
|
||
"abort_codes": [400, 401, 402, 403]}
|
||
policy = {**default_retry, **(retry_policy or {})}
|
||
|
||
# 常见Error Code Mapping
|
||
error_code_mapping = {
|
||
400: "参数不正确",
|
||
401: "API key 错误,认证失败",
|
||
402: "账号余额不足",
|
||
403: "需要实名,或余额不足",
|
||
404: "Not Found",
|
||
429: "请求过于频繁,请稍后再试",
|
||
500: "服务器内部故障",
|
||
503: "服务器负载过高"
|
||
}
|
||
|
||
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
||
# 判断是否为流式
|
||
stream_mode = self.params.get("stream", False)
|
||
if self.params.get("stream", False) is True:
|
||
logger.debug(f"进入流式输出模式,发送请求到URL: {api_url}")
|
||
else:
|
||
logger.debug(f"发送请求到URL: {api_url}")
|
||
logger.info(f"使用模型: {self.model_name}")
|
||
|
||
# 构建请求体
|
||
if image_base64:
|
||
payload = await self._build_payload(prompt, image_base64, image_format)
|
||
elif payload is None:
|
||
payload = await self._build_payload(prompt)
|
||
|
||
for retry in range(policy["max_retries"]):
|
||
try:
|
||
# 使用上下文管理器处理会话
|
||
headers = await self._build_headers()
|
||
# 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
|
||
if stream_mode:
|
||
headers["Accept"] = "text/event-stream"
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(api_url, headers=headers, json=payload) as response:
|
||
# 处理需要重试的状态码
|
||
if response.status in policy["retry_codes"]:
|
||
wait_time = policy["base_wait"] * (2 ** retry)
|
||
logger.warning(f"错误码: {response.status}, 等待 {wait_time}秒后重试")
|
||
if response.status == 413:
|
||
logger.warning("请求体过大,尝试压缩...")
|
||
image_base64 = compress_base64_image_by_scale(image_base64)
|
||
payload = await self._build_payload(prompt, image_base64, image_format)
|
||
elif response.status in [500, 503]:
|
||
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
|
||
raise RuntimeError("服务器负载过高,模型恢复失败QAQ")
|
||
else:
|
||
logger.warning(f"请求限制(429),等待{wait_time}秒后重试...")
|
||
|
||
await asyncio.sleep(wait_time)
|
||
continue
|
||
elif response.status in policy["abort_codes"]:
|
||
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
|
||
if response.status == 403:
|
||
#只针对硅基流动的V3和R1进行降级处理
|
||
if self.model_name.startswith(
|
||
"Pro/deepseek-ai") and self.base_url == "https://api.siliconflow.cn/v1/":
|
||
old_model_name = self.model_name
|
||
self.model_name = self.model_name[4:] # 移除"Pro/"前缀
|
||
logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}")
|
||
|
||
# 对全局配置进行更新
|
||
if global_config.llm_normal.get('name') == old_model_name:
|
||
global_config.llm_normal['name'] = self.model_name
|
||
logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
|
||
|
||
if global_config.llm_reasoning.get('name') == old_model_name:
|
||
global_config.llm_reasoning['name'] = self.model_name
|
||
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
|
||
|
||
# 更新payload中的模型名
|
||
if payload and 'model' in payload:
|
||
payload['model'] = self.model_name
|
||
|
||
# 重新尝试请求
|
||
retry -= 1 # 不计入重试次数
|
||
continue
|
||
|
||
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
|
||
|
||
response.raise_for_status()
|
||
|
||
# 将流式输出转化为非流式输出
|
||
if stream_mode:
|
||
flag_delta_content_finished = False
|
||
accumulated_content = ""
|
||
async for line_bytes in response.content:
|
||
line = line_bytes.decode("utf-8").strip()
|
||
if not line:
|
||
continue
|
||
if line.startswith("data:"):
|
||
data_str = line[5:].strip()
|
||
if data_str == "[DONE]":
|
||
break
|
||
try:
|
||
chunk = json.loads(data_str)
|
||
if flag_delta_content_finished:
|
||
usage = chunk.get("usage", None) # 获取tokn用量
|
||
else:
|
||
delta = chunk["choices"][0]["delta"]
|
||
delta_content = delta.get("content")
|
||
if delta_content is None:
|
||
delta_content = ""
|
||
accumulated_content += delta_content
|
||
# 检测流式输出文本是否结束
|
||
finish_reason = chunk["choices"][0].get("finish_reason")
|
||
if finish_reason == "stop":
|
||
usage = chunk.get("usage", None)
|
||
if usage:
|
||
break
|
||
# 部分平台在文本输出结束前不会返回token用量,此时需要再获取一次chunk
|
||
flag_delta_content_finished = True
|
||
|
||
except Exception:
|
||
logger.exception("解析流式输出错误")
|
||
content = accumulated_content
|
||
reasoning_content = ""
|
||
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
|
||
if think_match:
|
||
reasoning_content = think_match.group(1).strip()
|
||
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
|
||
# 构造一个伪result以便调用自定义响应处理器或默认处理器
|
||
result = {
|
||
"choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}], "usage": usage}
|
||
return response_handler(result) if response_handler else self._default_response_handler(
|
||
result, user_id, request_type, endpoint)
|
||
else:
|
||
result = await response.json()
|
||
# 使用自定义处理器或默认处理
|
||
return response_handler(result) if response_handler else self._default_response_handler(
|
||
result, user_id, request_type, endpoint)
|
||
|
||
except Exception as e:
|
||
if retry < policy["max_retries"] - 1:
|
||
wait_time = policy["base_wait"] * (2 ** retry)
|
||
logger.error(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||
await asyncio.sleep(wait_time)
|
||
else:
|
||
logger.critical(f"请求失败: {str(e)}")
|
||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}")
|
||
raise RuntimeError(f"API请求失败: {str(e)}")
|
||
|
||
logger.error("达到最大重试次数,请求仍然失败")
|
||
raise RuntimeError("达到最大重试次数,API请求仍然失败")
|
||
|
||
async def _transform_parameters(self, params: dict) -> dict:
|
||
"""
|
||
根据模型名称转换参数:
|
||
- 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temprature' 参数,
|
||
并将 'max_tokens' 重命名为 'max_completion_tokens'
|
||
"""
|
||
# 复制一份参数,避免直接修改原始数据
|
||
new_params = dict(params)
|
||
# 定义需要转换的模型列表
|
||
models_needing_transformation = ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12",
|
||
"o3-mini-2025-01-31", "o1-mini-2024-09-12"]
|
||
if self.model_name.lower() in models_needing_transformation:
|
||
# 删除 'temprature' 参数(如果存在)
|
||
new_params.pop("temperature", None)
|
||
# 如果存在 'max_tokens',则重命名为 'max_completion_tokens'
|
||
if "max_tokens" in new_params:
|
||
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
|
||
return new_params
|
||
|
||
async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
|
||
"""构建请求体"""
|
||
# 复制一份参数,避免直接修改 self.params
|
||
params_copy = await self._transform_parameters(self.params)
|
||
if image_base64:
|
||
payload = {
|
||
"model": self.model_name,
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": prompt},
|
||
{"type": "image_url", "image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"}}
|
||
]
|
||
}
|
||
],
|
||
"max_tokens": global_config.max_response_length,
|
||
**params_copy
|
||
}
|
||
else:
|
||
payload = {
|
||
"model": self.model_name,
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"max_tokens": global_config.max_response_length,
|
||
**params_copy
|
||
}
|
||
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
||
if self.model_name.lower() in ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12",
|
||
"o3-mini-2025-01-31", "o1-mini-2024-09-12"] and "max_tokens" in payload:
|
||
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
||
return payload
|
||
|
||
def _default_response_handler(self, result: dict, user_id: str = "system",
|
||
request_type: str = "chat", endpoint: str = "/chat/completions") -> Tuple:
|
||
"""默认响应解析"""
|
||
if "choices" in result and result["choices"]:
|
||
message = result["choices"][0]["message"]
|
||
content = message.get("content", "")
|
||
content, reasoning = self._extract_reasoning(content)
|
||
reasoning_content = message.get("model_extra", {}).get("reasoning_content", "")
|
||
if not reasoning_content:
|
||
reasoning_content = message.get("reasoning_content", "")
|
||
if not reasoning_content:
|
||
reasoning_content = reasoning
|
||
|
||
# 记录token使用情况
|
||
usage = result.get("usage", {})
|
||
if usage:
|
||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||
completion_tokens = usage.get("completion_tokens", 0)
|
||
total_tokens = usage.get("total_tokens", 0)
|
||
self._record_usage(
|
||
prompt_tokens=prompt_tokens,
|
||
completion_tokens=completion_tokens,
|
||
total_tokens=total_tokens,
|
||
user_id=user_id,
|
||
request_type=request_type,
|
||
endpoint=endpoint
|
||
)
|
||
|
||
return content, reasoning_content
|
||
|
||
return "没有返回结果", ""
|
||
|
||
def _extract_reasoning(self, content: str) -> tuple[str, str]:
|
||
"""CoT思维链提取"""
|
||
match = re.search(r'(?:<think>)?(.*?)</think>', content, re.DOTALL)
|
||
content = re.sub(r'(?:<think>)?.*?</think>', '', content, flags=re.DOTALL, count=1).strip()
|
||
if match:
|
||
reasoning = match.group(1).strip()
|
||
else:
|
||
reasoning = ""
|
||
return content, reasoning
|
||
|
||
async def _build_headers(self, no_key: bool = False) -> dict:
|
||
"""构建请求头"""
|
||
if no_key:
|
||
return {
|
||
"Authorization": "Bearer **********",
|
||
"Content-Type": "application/json"
|
||
}
|
||
else:
|
||
return {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
# 防止小朋友们截图自己的key
|
||
|
||
async def generate_response(self, prompt: str) -> Tuple[str, str]:
|
||
"""根据输入的提示生成模型的异步响应"""
|
||
|
||
content, reasoning_content = await self._execute_request(
|
||
endpoint="/chat/completions",
|
||
prompt=prompt
|
||
)
|
||
return content, reasoning_content
|
||
|
||
async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple[str, str]:
|
||
"""根据输入的提示和图片生成模型的异步响应"""
|
||
|
||
content, reasoning_content = await self._execute_request(
|
||
endpoint="/chat/completions",
|
||
prompt=prompt,
|
||
image_base64=image_base64,
|
||
image_format=image_format
|
||
)
|
||
return content, reasoning_content
|
||
|
||
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple[str, str]]:
|
||
"""异步方式根据输入的提示生成模型的响应"""
|
||
# 构建请求体
|
||
data = {
|
||
"model": self.model_name,
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"max_tokens": global_config.max_response_length,
|
||
**self.params
|
||
}
|
||
|
||
content, reasoning_content = await self._execute_request(
|
||
endpoint="/chat/completions",
|
||
payload=data,
|
||
prompt=prompt
|
||
)
|
||
return content, reasoning_content
|
||
|
||
async def get_embedding(self, text: str) -> Union[list, None]:
|
||
"""异步方法:获取文本的embedding向量
|
||
|
||
Args:
|
||
text: 需要获取embedding的文本
|
||
|
||
Returns:
|
||
list: embedding向量,如果失败则返回None
|
||
"""
|
||
|
||
def embedding_handler(result):
|
||
"""处理响应"""
|
||
if "data" in result and len(result["data"]) > 0:
|
||
return result["data"][0].get("embedding", None)
|
||
return None
|
||
|
||
embedding = await self._execute_request(
|
||
endpoint="/embeddings",
|
||
prompt=text,
|
||
payload={
|
||
"model": self.model_name,
|
||
"input": text,
|
||
"encoding_format": "float"
|
||
},
|
||
retry_policy={
|
||
"max_retries": 2,
|
||
"base_wait": 6
|
||
},
|
||
response_handler=embedding_handler
|
||
)
|
||
return embedding
|
||
|
||
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
|
||
"""压缩base64格式的图片到指定大小
|
||
Args:
|
||
base64_data: base64编码的图片数据
|
||
target_size: 目标文件大小(字节),默认0.8MB
|
||
Returns:
|
||
str: 压缩后的base64图片数据
|
||
"""
|
||
try:
|
||
# 将base64转换为字节数据
|
||
image_data = base64.b64decode(base64_data)
|
||
|
||
# 如果已经小于目标大小,直接返回原图
|
||
if len(image_data) <= 2*1024*1024:
|
||
return base64_data
|
||
|
||
# 将字节数据转换为图片对象
|
||
img = Image.open(io.BytesIO(image_data))
|
||
|
||
# 获取原始尺寸
|
||
original_width, original_height = img.size
|
||
|
||
# 计算缩放比例
|
||
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
|
||
|
||
# 计算新的尺寸
|
||
new_width = int(original_width * scale)
|
||
new_height = int(original_height * scale)
|
||
|
||
# 创建内存缓冲区
|
||
output_buffer = io.BytesIO()
|
||
|
||
# 如果是GIF,处理所有帧
|
||
if getattr(img, "is_animated", False):
|
||
frames = []
|
||
for frame_idx in range(img.n_frames):
|
||
img.seek(frame_idx)
|
||
new_frame = img.copy()
|
||
new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折
|
||
frames.append(new_frame)
|
||
|
||
# 保存到缓冲区
|
||
frames[0].save(
|
||
output_buffer,
|
||
format='GIF',
|
||
save_all=True,
|
||
append_images=frames[1:],
|
||
optimize=True,
|
||
duration=img.info.get('duration', 100),
|
||
loop=img.info.get('loop', 0)
|
||
)
|
||
else:
|
||
# 处理静态图片
|
||
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||
|
||
# 保存到缓冲区,保持原始格式
|
||
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
|
||
resized_img.save(output_buffer, format='PNG', optimize=True)
|
||
else:
|
||
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
|
||
|
||
# 获取压缩后的数据并转换为base64
|
||
compressed_data = output_buffer.getvalue()
|
||
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
|
||
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
|
||
|
||
return base64.b64encode(compressed_data).decode('utf-8')
|
||
|
||
except Exception as e:
|
||
logger.error(f"压缩图片失败: {str(e)}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
return base64_data
|
||
|