调整对应的调用
This commit is contained in:
@@ -2,16 +2,19 @@ import base64
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
from datetime import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
||||
from src.common.database.database_model import LLMUsage
|
||||
from src.config.api_ada_configs import ModelInfo
|
||||
from .payload_content.message import Message, MessageBuilder
|
||||
from .model_client.base_client import UsageRecord
|
||||
|
||||
logger = get_logger("消息压缩工具")
|
||||
|
||||
|
||||
def compress_messages(
|
||||
messages: list[Message], img_target_size: int = 1 * 1024 * 1024
|
||||
) -> list[Message]:
|
||||
def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 * 1024) -> list[Message]:
|
||||
"""
|
||||
压缩消息列表中的图片
|
||||
:param messages: 消息列表
|
||||
@@ -28,14 +31,10 @@ def compress_messages(
|
||||
try:
|
||||
image = Image.open(image_data)
|
||||
|
||||
if image.format and (
|
||||
image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]
|
||||
):
|
||||
if image.format and (image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]):
|
||||
# 静态图像,转换为JPEG格式
|
||||
reformated_image_data = io.BytesIO()
|
||||
image.save(
|
||||
reformated_image_data, format="JPEG", quality=95, optimize=True
|
||||
)
|
||||
image.save(reformated_image_data, format="JPEG", quality=95, optimize=True)
|
||||
image_data = reformated_image_data.getvalue()
|
||||
|
||||
return image_data
|
||||
@@ -43,9 +42,7 @@ def compress_messages(
|
||||
logger.error(f"图片转换格式失败: {str(e)}")
|
||||
return image_data
|
||||
|
||||
def rescale_image(
|
||||
image_data: bytes, scale: float
|
||||
) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]:
|
||||
def rescale_image(image_data: bytes, scale: float) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]:
|
||||
"""
|
||||
缩放图片
|
||||
:param image_data: 图片数据
|
||||
@@ -86,9 +83,7 @@ def compress_messages(
|
||||
else:
|
||||
# 静态图片,直接缩放保存
|
||||
resized_image = image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
resized_image.save(
|
||||
output_buffer, format="JPEG", quality=95, optimize=True
|
||||
)
|
||||
resized_image.save(output_buffer, format="JPEG", quality=95, optimize=True)
|
||||
|
||||
return output_buffer.getvalue(), original_size, new_size
|
||||
|
||||
@@ -99,9 +94,7 @@ def compress_messages(
|
||||
logger.error(traceback.format_exc())
|
||||
return image_data, None, None
|
||||
|
||||
def compress_base64_image(
|
||||
base64_data: str, target_size: int = 1 * 1024 * 1024
|
||||
) -> str:
|
||||
def compress_base64_image(base64_data: str, target_size: int = 1 * 1024 * 1024) -> str:
|
||||
original_b64_data_size = len(base64_data) # 计算原始数据大小
|
||||
|
||||
image_data = base64.b64decode(base64_data)
|
||||
@@ -111,9 +104,7 @@ def compress_messages(
|
||||
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||
if len(base64_data) <= target_size:
|
||||
# 如果转换后小于目标大小,直接返回
|
||||
logger.info(
|
||||
f"成功将图片转为JPEG格式,编码后大小: {len(base64_data) / 1024:.1f}KB"
|
||||
)
|
||||
logger.info(f"成功将图片转为JPEG格式,编码后大小: {len(base64_data) / 1024:.1f}KB")
|
||||
return base64_data
|
||||
|
||||
# 如果转换后仍然大于目标大小,进行尺寸压缩
|
||||
@@ -139,9 +130,7 @@ def compress_messages(
|
||||
# 图片,进行压缩
|
||||
message_builder.add_image_content(
|
||||
content_item[0],
|
||||
compress_base64_image(
|
||||
content_item[1], target_size=img_target_size
|
||||
),
|
||||
compress_base64_image(content_item[1], target_size=img_target_size),
|
||||
)
|
||||
else:
|
||||
message_builder.add_text_content(content_item)
|
||||
@@ -150,3 +139,48 @@ def compress_messages(
|
||||
compressed_messages.append(message)
|
||||
|
||||
return compressed_messages
|
||||
|
||||
|
||||
class LLMUsageRecorder:
|
||||
"""
|
||||
LLM使用情况记录器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
# 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误
|
||||
db.create_tables([LLMUsage], safe=True)
|
||||
# logger.debug("LLMUsage 表已初始化/确保存在。")
|
||||
except Exception as e:
|
||||
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
|
||||
|
||||
def record_usage_to_database(
|
||||
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str
|
||||
):
|
||||
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
|
||||
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
|
||||
total_cost = round(input_cost + output_cost, 6)
|
||||
try:
|
||||
# 使用 Peewee 模型创建记录
|
||||
LLMUsage.create(
|
||||
model_name=model_info.model_identifier,
|
||||
user_id=user_id,
|
||||
request_type=request_type,
|
||||
endpoint=endpoint,
|
||||
prompt_tokens=model_usage.prompt_tokens or 0,
|
||||
completion_tokens=model_usage.completion_tokens or 0,
|
||||
total_tokens=model_usage.total_tokens or 0,
|
||||
cost=total_cost or 0.0,
|
||||
status="success",
|
||||
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
|
||||
)
|
||||
logger.debug(
|
||||
f"Token使用情况 - 模型: {model_usage.model_name}, "
|
||||
f"用户: {user_id}, 类型: {request_type}, "
|
||||
f"提示词: {model_usage.prompt_tokens}, 完成: {model_usage.completion_tokens}, "
|
||||
f"总计: {model_usage.total_tokens}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"记录token使用情况失败: {str(e)}")
|
||||
|
||||
llm_usage_recorder = LLMUsageRecorder()
|
||||
Reference in New Issue
Block a user