调整对应的调用
This commit is contained in:
@@ -1,12 +0,0 @@
|
||||
import importlib
|
||||
from typing import Dict
|
||||
|
||||
from src.config.config import model_config
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .model_client import ModelRequestHandler, BaseClient
|
||||
|
||||
logger = get_logger("模型管理器")
|
||||
|
||||
class ModelManager:
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
import importlib
|
||||
from typing import Dict
|
||||
|
||||
from src.config.config import model_config
|
||||
from src.config.api_ada_configs import ModuleConfig, ModelUsageArgConfig
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .model_client import ModelRequestHandler, BaseClient
|
||||
|
||||
logger = get_logger("模型管理器")
|
||||
|
||||
class ModelManager:
|
||||
# TODO: 添加读写锁,防止异步刷新配置时发生数据竞争
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModuleConfig,
|
||||
):
|
||||
self.config: ModuleConfig = config
|
||||
"""配置信息"""
|
||||
|
||||
self.api_client_map: Dict[str, BaseClient] = {}
|
||||
"""API客户端映射表"""
|
||||
|
||||
self._request_handler_cache: Dict[str, ModelRequestHandler] = {}
|
||||
"""ModelRequestHandler缓存,避免重复创建"""
|
||||
|
||||
for provider_name, api_provider in self.config.api_providers.items():
|
||||
# 初始化API客户端
|
||||
try:
|
||||
# 根据配置动态加载实现
|
||||
client_module = importlib.import_module(
|
||||
f".model_client.{api_provider.client_type}_client", __package__
|
||||
)
|
||||
client_class = getattr(
|
||||
client_module, f"{api_provider.client_type.capitalize()}Client"
|
||||
)
|
||||
if not issubclass(client_class, BaseClient):
|
||||
raise TypeError(
|
||||
f"'{client_class.__name__}' is not a subclass of 'BaseClient'"
|
||||
)
|
||||
self.api_client_map[api_provider.name] = client_class(
|
||||
api_provider
|
||||
) # 实例化,放入api_client_map
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import client module: {e}")
|
||||
raise ImportError(
|
||||
f"Failed to import client module for '{provider_name}': {e}"
|
||||
) from e
|
||||
|
||||
def __getitem__(self, task_name: str) -> ModelRequestHandler:
|
||||
"""
|
||||
获取任务所需的模型客户端(封装)
|
||||
使用缓存机制避免重复创建ModelRequestHandler
|
||||
:param task_name: 任务名称
|
||||
:return: 模型客户端
|
||||
"""
|
||||
if task_name not in self.config.task_model_arg_map:
|
||||
raise KeyError(f"'{task_name}' not registered in ModelManager")
|
||||
|
||||
# 检查缓存中是否已存在
|
||||
if task_name in self._request_handler_cache:
|
||||
logger.debug(f"🚀 [性能优化] 从缓存获取ModelRequestHandler: {task_name}")
|
||||
return self._request_handler_cache[task_name]
|
||||
|
||||
# 创建新的ModelRequestHandler并缓存
|
||||
logger.debug(f"🔧 [性能优化] 创建并缓存ModelRequestHandler: {task_name}")
|
||||
handler = ModelRequestHandler(
|
||||
task_name=task_name,
|
||||
config=self.config,
|
||||
api_client_map=self.api_client_map,
|
||||
)
|
||||
self._request_handler_cache[task_name] = handler
|
||||
return handler
|
||||
|
||||
def __setitem__(self, task_name: str, value: ModelUsageArgConfig):
|
||||
"""
|
||||
注册任务的模型使用配置
|
||||
:param task_name: 任务名称
|
||||
:param value: 模型使用配置
|
||||
"""
|
||||
self.config.task_model_arg_map[task_name] = value
|
||||
|
||||
def __contains__(self, task_name: str):
|
||||
"""
|
||||
判断任务是否已注册
|
||||
:param task_name: 任务名称
|
||||
:return: 是否在模型列表中
|
||||
"""
|
||||
return task_name in self.config.task_model_arg_map
|
||||
|
||||
|
||||
@@ -1,169 +0,0 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import ModelInfo
|
||||
from src.common.database.database_model import LLMUsage
|
||||
|
||||
logger = get_logger("模型使用统计")
|
||||
|
||||
|
||||
class ReqType(Enum):
|
||||
"""
|
||||
请求类型
|
||||
"""
|
||||
|
||||
CHAT = "chat" # 对话请求
|
||||
EMBEDDING = "embedding" # 嵌入请求
|
||||
|
||||
|
||||
class UsageCallStatus(Enum):
|
||||
"""
|
||||
任务调用状态
|
||||
"""
|
||||
|
||||
PROCESSING = "processing" # 处理中
|
||||
SUCCESS = "success" # 成功
|
||||
FAILURE = "failure" # 失败
|
||||
CANCELED = "canceled" # 取消
|
||||
|
||||
|
||||
class ModelUsageStatistic:
|
||||
"""
|
||||
模型使用统计类 - 使用SQLite+Peewee
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化统计类
|
||||
由于使用Peewee ORM,不需要传入数据库实例
|
||||
"""
|
||||
# 确保表已经创建
|
||||
try:
|
||||
from src.common.database.database import db
|
||||
|
||||
db.create_tables([LLMUsage], safe=True)
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLMUsage表失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _calculate_cost(prompt_tokens: int, completion_tokens: int, model_info: ModelInfo) -> float:
|
||||
"""计算API调用成本
|
||||
使用模型的pri_in和pri_out价格计算输入和输出的成本
|
||||
|
||||
Args:
|
||||
prompt_tokens: 输入token数量
|
||||
completion_tokens: 输出token数量
|
||||
model_info: 模型信息
|
||||
|
||||
Returns:
|
||||
float: 总成本(元)
|
||||
"""
|
||||
# 使用模型的pri_in和pri_out计算成本
|
||||
input_cost = (prompt_tokens / 1000000) * model_info.price_in
|
||||
output_cost = (completion_tokens / 1000000) * model_info.price_out
|
||||
return round(input_cost + output_cost, 6)
|
||||
|
||||
def create_usage(
|
||||
self,
|
||||
model_name: str,
|
||||
task_name: str = "N/A",
|
||||
request_type: ReqType = ReqType.CHAT,
|
||||
user_id: str = "system",
|
||||
endpoint: str = "/chat/completions",
|
||||
) -> int | None:
|
||||
"""
|
||||
创建模型使用情况记录
|
||||
|
||||
Args:
|
||||
model_name: 模型名
|
||||
task_name: 任务名称
|
||||
request_type: 请求类型,默认为Chat
|
||||
user_id: 用户ID,默认为system
|
||||
endpoint: API端点
|
||||
|
||||
Returns:
|
||||
int | None: 返回记录ID,失败返回None
|
||||
"""
|
||||
try:
|
||||
usage_record = LLMUsage.create(
|
||||
model_name=model_name,
|
||||
user_id=user_id,
|
||||
request_type=request_type.value,
|
||||
endpoint=endpoint,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost=0.0,
|
||||
status=UsageCallStatus.PROCESSING.value,
|
||||
timestamp=datetime.now(),
|
||||
)
|
||||
|
||||
# logger.trace(
|
||||
# f"创建了一条模型使用情况记录 - 模型: {model_name}, "
|
||||
# f"子任务: {task_name}, 类型: {request_type.value}, "
|
||||
# f"用户: {user_id}, 记录ID: {usage_record.id}"
|
||||
# )
|
||||
|
||||
return usage_record.id
|
||||
except Exception as e:
|
||||
logger.error(f"创建模型使用情况记录失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def update_usage(
|
||||
self,
|
||||
record_id: int | None,
|
||||
model_info: ModelInfo,
|
||||
usage_data: Tuple[int, int, int] | None = None,
|
||||
stat: UsageCallStatus = UsageCallStatus.SUCCESS,
|
||||
ext_msg: str | None = None,
|
||||
):
|
||||
"""
|
||||
更新模型使用情况
|
||||
|
||||
Args:
|
||||
record_id: 记录ID
|
||||
model_info: 模型信息
|
||||
usage_data: 使用情况数据(输入token数量, 输出token数量, 总token数量)
|
||||
stat: 任务调用状态
|
||||
ext_msg: 额外信息
|
||||
"""
|
||||
if not record_id:
|
||||
logger.error("更新模型使用情况失败: record_id不能为空")
|
||||
return
|
||||
|
||||
if usage_data and len(usage_data) != 3:
|
||||
logger.error("更新模型使用情况失败: usage_data的长度不正确,应该为3个元素")
|
||||
return
|
||||
|
||||
# 提取使用情况数据
|
||||
prompt_tokens = usage_data[0] if usage_data else 0
|
||||
completion_tokens = usage_data[1] if usage_data else 0
|
||||
total_tokens = usage_data[2] if usage_data else 0
|
||||
|
||||
try:
|
||||
# 使用Peewee更新记录
|
||||
update_query = LLMUsage.update(
|
||||
status=stat.value,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost=self._calculate_cost(prompt_tokens, completion_tokens, model_info) if usage_data else 0.0,
|
||||
).where(LLMUsage.id == record_id) # type: ignore
|
||||
|
||||
updated_count = update_query.execute()
|
||||
|
||||
if updated_count == 0:
|
||||
logger.warning(f"记录ID {record_id} 不存在,无法更新")
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Token使用情况 - 模型: {model_info.name}, "
|
||||
f"记录ID: {record_id}, "
|
||||
f"任务状态: {stat.value}, 额外信息: {ext_msg or 'N/A'}, "
|
||||
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
|
||||
f"总计: {total_tokens}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"记录token使用情况失败: {str(e)}")
|
||||
@@ -2,16 +2,19 @@ import base64
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
from datetime import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
||||
from src.common.database.database_model import LLMUsage
|
||||
from src.config.api_ada_configs import ModelInfo
|
||||
from .payload_content.message import Message, MessageBuilder
|
||||
from .model_client.base_client import UsageRecord
|
||||
|
||||
logger = get_logger("消息压缩工具")
|
||||
|
||||
|
||||
def compress_messages(
|
||||
messages: list[Message], img_target_size: int = 1 * 1024 * 1024
|
||||
) -> list[Message]:
|
||||
def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 * 1024) -> list[Message]:
|
||||
"""
|
||||
压缩消息列表中的图片
|
||||
:param messages: 消息列表
|
||||
@@ -28,14 +31,10 @@ def compress_messages(
|
||||
try:
|
||||
image = Image.open(image_data)
|
||||
|
||||
if image.format and (
|
||||
image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]
|
||||
):
|
||||
if image.format and (image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]):
|
||||
# 静态图像,转换为JPEG格式
|
||||
reformated_image_data = io.BytesIO()
|
||||
image.save(
|
||||
reformated_image_data, format="JPEG", quality=95, optimize=True
|
||||
)
|
||||
image.save(reformated_image_data, format="JPEG", quality=95, optimize=True)
|
||||
image_data = reformated_image_data.getvalue()
|
||||
|
||||
return image_data
|
||||
@@ -43,9 +42,7 @@ def compress_messages(
|
||||
logger.error(f"图片转换格式失败: {str(e)}")
|
||||
return image_data
|
||||
|
||||
def rescale_image(
|
||||
image_data: bytes, scale: float
|
||||
) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]:
|
||||
def rescale_image(image_data: bytes, scale: float) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]:
|
||||
"""
|
||||
缩放图片
|
||||
:param image_data: 图片数据
|
||||
@@ -86,9 +83,7 @@ def compress_messages(
|
||||
else:
|
||||
# 静态图片,直接缩放保存
|
||||
resized_image = image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
resized_image.save(
|
||||
output_buffer, format="JPEG", quality=95, optimize=True
|
||||
)
|
||||
resized_image.save(output_buffer, format="JPEG", quality=95, optimize=True)
|
||||
|
||||
return output_buffer.getvalue(), original_size, new_size
|
||||
|
||||
@@ -99,9 +94,7 @@ def compress_messages(
|
||||
logger.error(traceback.format_exc())
|
||||
return image_data, None, None
|
||||
|
||||
def compress_base64_image(
|
||||
base64_data: str, target_size: int = 1 * 1024 * 1024
|
||||
) -> str:
|
||||
def compress_base64_image(base64_data: str, target_size: int = 1 * 1024 * 1024) -> str:
|
||||
original_b64_data_size = len(base64_data) # 计算原始数据大小
|
||||
|
||||
image_data = base64.b64decode(base64_data)
|
||||
@@ -111,9 +104,7 @@ def compress_messages(
|
||||
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||
if len(base64_data) <= target_size:
|
||||
# 如果转换后小于目标大小,直接返回
|
||||
logger.info(
|
||||
f"成功将图片转为JPEG格式,编码后大小: {len(base64_data) / 1024:.1f}KB"
|
||||
)
|
||||
logger.info(f"成功将图片转为JPEG格式,编码后大小: {len(base64_data) / 1024:.1f}KB")
|
||||
return base64_data
|
||||
|
||||
# 如果转换后仍然大于目标大小,进行尺寸压缩
|
||||
@@ -139,9 +130,7 @@ def compress_messages(
|
||||
# 图片,进行压缩
|
||||
message_builder.add_image_content(
|
||||
content_item[0],
|
||||
compress_base64_image(
|
||||
content_item[1], target_size=img_target_size
|
||||
),
|
||||
compress_base64_image(content_item[1], target_size=img_target_size),
|
||||
)
|
||||
else:
|
||||
message_builder.add_text_content(content_item)
|
||||
@@ -150,3 +139,48 @@ def compress_messages(
|
||||
compressed_messages.append(message)
|
||||
|
||||
return compressed_messages
|
||||
|
||||
|
||||
class LLMUsageRecorder:
|
||||
"""
|
||||
LLM使用情况记录器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
# 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误
|
||||
db.create_tables([LLMUsage], safe=True)
|
||||
# logger.debug("LLMUsage 表已初始化/确保存在。")
|
||||
except Exception as e:
|
||||
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
|
||||
|
||||
def record_usage_to_database(
|
||||
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str
|
||||
):
|
||||
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
|
||||
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
|
||||
total_cost = round(input_cost + output_cost, 6)
|
||||
try:
|
||||
# 使用 Peewee 模型创建记录
|
||||
LLMUsage.create(
|
||||
model_name=model_info.model_identifier,
|
||||
user_id=user_id,
|
||||
request_type=request_type,
|
||||
endpoint=endpoint,
|
||||
prompt_tokens=model_usage.prompt_tokens or 0,
|
||||
completion_tokens=model_usage.completion_tokens or 0,
|
||||
total_tokens=model_usage.total_tokens or 0,
|
||||
cost=total_cost or 0.0,
|
||||
status="success",
|
||||
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
|
||||
)
|
||||
logger.debug(
|
||||
f"Token使用情况 - 模型: {model_usage.model_name}, "
|
||||
f"用户: {user_id}, 类型: {request_type}, "
|
||||
f"提示词: {model_usage.prompt_tokens}, 完成: {model_usage.completion_tokens}, "
|
||||
f"总计: {model_usage.total_tokens}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"记录token使用情况失败: {str(e)}")
|
||||
|
||||
llm_usage_recorder = LLMUsageRecorder()
|
||||
@@ -1,34 +1,20 @@
|
||||
import re
|
||||
import copy
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Tuple, Union, List, Dict, Optional, Callable, Any
|
||||
from src.common.logger import get_logger
|
||||
import base64
|
||||
from PIL import Image
|
||||
from enum import Enum
|
||||
import io
|
||||
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
||||
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
|
||||
from src.config.config import global_config, model_config
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo
|
||||
from rich.traceback import install
|
||||
|
||||
from enum import Enum
|
||||
from rich.traceback import install
|
||||
from typing import Tuple, List, Dict, Optional, Callable, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
|
||||
from .payload_content.message import MessageBuilder, Message
|
||||
from .payload_content.resp_format import RespFormat
|
||||
from .payload_content.tool_option import ToolOption, ToolCall
|
||||
from .model_client.base_client import BaseClient, APIResponse, UsageRecord, client_registry
|
||||
from .utils import compress_messages
|
||||
|
||||
from .exceptions import (
|
||||
NetworkConnectionError,
|
||||
ReqAbortException,
|
||||
RespNotOkException,
|
||||
RespParseException,
|
||||
PayLoadTooLargeError,
|
||||
RequestAbortException,
|
||||
PermissionDeniedException,
|
||||
)
|
||||
from .model_client.base_client import BaseClient, APIResponse, client_registry
|
||||
from .utils import compress_messages, llm_usage_recorder
|
||||
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -57,45 +43,15 @@ class RequestType(Enum):
|
||||
class LLMRequest:
|
||||
"""LLM请求类"""
|
||||
|
||||
# 定义需要转换的模型列表,作为类变量避免重复
|
||||
MODELS_NEEDING_TRANSFORMATION = [
|
||||
"o1",
|
||||
"o1-2024-12-17",
|
||||
"o1-mini",
|
||||
"o1-mini-2024-09-12",
|
||||
"o1-preview",
|
||||
"o1-preview-2024-09-12",
|
||||
"o1-pro",
|
||||
"o1-pro-2025-03-19",
|
||||
"o3",
|
||||
"o3-2025-04-16",
|
||||
"o3-mini",
|
||||
"o3-mini-2025-01-31",
|
||||
"o4-mini",
|
||||
"o4-mini-2025-04-16",
|
||||
]
|
||||
|
||||
def __init__(self, task_name: str, request_type: str = "") -> None:
|
||||
self.task_name = task_name
|
||||
self.model_for_task = model_config.model_task_config.get_task(task_name)
|
||||
def __init__(self, model_set: TaskConfig, request_type: str = "") -> None:
|
||||
self.task_name = request_type
|
||||
self.model_for_task = model_set
|
||||
self.request_type = request_type
|
||||
self.model_usage: Dict[str, Tuple[int, int]] = {model: (0, 0) for model in self.model_for_task.model_list}
|
||||
"""模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty),惩罚值是为了能在某个模型请求不给力的时候进行调整"""
|
||||
|
||||
self.pri_in = 0
|
||||
self.pri_out = 0
|
||||
|
||||
self._init_database()
|
||||
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
"""初始化数据库集合"""
|
||||
try:
|
||||
# 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误
|
||||
db.create_tables([LLMUsage], safe=True)
|
||||
# logger.debug("LLMUsage 表已初始化/确保存在。")
|
||||
except Exception as e:
|
||||
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
|
||||
|
||||
async def generate_response_for_image(
|
||||
self,
|
||||
@@ -104,7 +60,7 @@ class LLMRequest:
|
||||
image_format: str,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[str, str, Optional[List[Dict[str, Any]]]]:
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]:
|
||||
"""
|
||||
为图像生成响应
|
||||
Args:
|
||||
@@ -112,7 +68,7 @@ class LLMRequest:
|
||||
image_base64 (str): 图像的Base64编码字符串
|
||||
image_format (str): 图像格式(如 'png', 'jpeg' 等)
|
||||
Returns:
|
||||
|
||||
(Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
# 请求体构建
|
||||
message_builder = MessageBuilder()
|
||||
@@ -141,25 +97,25 @@ class LLMRequest:
|
||||
content, extracted_reasoning = self._extract_reasoning(content)
|
||||
reasoning_content = extracted_reasoning
|
||||
if usage := response.usage:
|
||||
self.pri_in = model_info.price_in
|
||||
self.pri_out = model_info.price_out
|
||||
self._record_usage(
|
||||
model_name=model_info.name,
|
||||
prompt_tokens=usage.prompt_tokens or 0,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
total_tokens=usage.total_tokens or 0,
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
user_id="system",
|
||||
request_type=self.request_type,
|
||||
endpoint="/chat/completions",
|
||||
)
|
||||
return content, reasoning_content, self._convert_tool_calls(tool_calls) if tool_calls else None
|
||||
return content, (
|
||||
reasoning_content,
|
||||
model_info.name,
|
||||
self._convert_tool_calls(tool_calls) if tool_calls else None,
|
||||
)
|
||||
|
||||
async def generate_response_for_voice(self):
|
||||
pass
|
||||
|
||||
async def generate_response_async(
|
||||
self, prompt: str, temperature: Optional[float] = None, max_tokens: Optional[int] = None
|
||||
) -> Tuple[str, str, Optional[List[Dict[str, Any]]]]:
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]:
|
||||
"""
|
||||
异步生成响应
|
||||
Args:
|
||||
@@ -167,7 +123,7 @@ class LLMRequest:
|
||||
temperature (float, optional): 温度参数
|
||||
max_tokens (int, optional): 最大token数
|
||||
Returns:
|
||||
Tuple[str, str, Optional[List[Dict[str, Any]]]]: 响应内容、推理内容和工具调用列表
|
||||
(Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
# 请求体构建
|
||||
message_builder = MessageBuilder()
|
||||
@@ -195,13 +151,9 @@ class LLMRequest:
|
||||
content, extracted_reasoning = self._extract_reasoning(content)
|
||||
reasoning_content = extracted_reasoning
|
||||
if usage := response.usage:
|
||||
self.pri_in = model_info.price_in
|
||||
self.pri_out = model_info.price_out
|
||||
self._record_usage(
|
||||
model_name=model_info.name,
|
||||
prompt_tokens=usage.prompt_tokens or 0,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
total_tokens=usage.total_tokens or 0,
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
user_id="system",
|
||||
request_type=self.request_type,
|
||||
endpoint="/chat/completions",
|
||||
@@ -209,10 +161,19 @@ class LLMRequest:
|
||||
if not content:
|
||||
raise RuntimeError("获取LLM生成内容失败")
|
||||
|
||||
return content, reasoning_content, self._convert_tool_calls(tool_calls) if tool_calls else None
|
||||
return content, (
|
||||
reasoning_content,
|
||||
model_info.name,
|
||||
self._convert_tool_calls(tool_calls) if tool_calls else None,
|
||||
)
|
||||
|
||||
async def get_embedding(self, embedding_input: str) -> List[float]:
|
||||
"""获取嵌入向量"""
|
||||
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
||||
"""获取嵌入向量
|
||||
Args:
|
||||
embedding_input (str): 获取嵌入的目标
|
||||
Returns:
|
||||
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
|
||||
"""
|
||||
# 无需构建消息体,直接使用输入文本
|
||||
model_info, api_provider, client = self._select_model()
|
||||
|
||||
@@ -227,14 +188,10 @@ class LLMRequest:
|
||||
|
||||
embedding = response.embedding
|
||||
|
||||
if response.usage:
|
||||
self.pri_in = model_info.price_in
|
||||
self.pri_out = model_info.price_out
|
||||
self._record_usage(
|
||||
model_name=model_info.name,
|
||||
prompt_tokens=response.usage.prompt_tokens or 0,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens or 0,
|
||||
if usage := response.usage:
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
user_id="system",
|
||||
request_type=self.request_type,
|
||||
endpoint="/embeddings",
|
||||
@@ -243,7 +200,7 @@ class LLMRequest:
|
||||
if not embedding:
|
||||
raise RuntimeError("获取embedding失败")
|
||||
|
||||
return embedding
|
||||
return embedding, model_info.name
|
||||
|
||||
def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
|
||||
"""
|
||||
@@ -305,12 +262,13 @@ class LLMRequest:
|
||||
# 处理异常
|
||||
total_tokens, penalty = self.model_usage[model_info.name]
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty + 1)
|
||||
|
||||
wait_interval, compressed_messages = self._default_exception_handler(
|
||||
e,
|
||||
self.task_name,
|
||||
model_name=model_info.name,
|
||||
remain_try=retry_remain,
|
||||
messages=(message_list, compressed_messages is not None),
|
||||
messages=(message_list, compressed_messages is not None) if message_list else None,
|
||||
)
|
||||
|
||||
if wait_interval == -1:
|
||||
@@ -321,9 +279,7 @@ class LLMRequest:
|
||||
finally:
|
||||
# 放在finally防止死循环
|
||||
retry_remain -= 1
|
||||
logger.error(
|
||||
f"任务 '{self.task_name}' 模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次"
|
||||
)
|
||||
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次")
|
||||
raise RuntimeError("请求失败,已达到最大重试次数")
|
||||
|
||||
def _default_exception_handler(
|
||||
@@ -481,65 +437,3 @@ class LLMRequest:
|
||||
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
|
||||
reasoning = match[1].strip() if match else ""
|
||||
return content, reasoning
|
||||
|
||||
def _record_usage(
|
||||
self,
|
||||
model_name: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
user_id: str = "system",
|
||||
request_type: str | None = None,
|
||||
endpoint: str = "/chat/completions",
|
||||
):
|
||||
"""记录模型使用情况到数据库
|
||||
Args:
|
||||
prompt_tokens: 输入token数
|
||||
completion_tokens: 输出token数
|
||||
total_tokens: 总token数
|
||||
user_id: 用户ID,默认为system
|
||||
request_type: 请求类型
|
||||
endpoint: API端点
|
||||
"""
|
||||
# 如果 request_type 为 None,则使用实例变量中的值
|
||||
if request_type is None:
|
||||
request_type = self.request_type
|
||||
|
||||
try:
|
||||
# 使用 Peewee 模型创建记录
|
||||
LLMUsage.create(
|
||||
model_name=model_name,
|
||||
user_id=user_id,
|
||||
request_type=request_type,
|
||||
endpoint=endpoint,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost=self._calculate_cost(prompt_tokens, completion_tokens),
|
||||
status="success",
|
||||
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
|
||||
)
|
||||
logger.debug(
|
||||
f"Token使用情况 - 模型: {model_name}, "
|
||||
f"用户: {user_id}, 类型: {request_type}, "
|
||||
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
|
||||
f"总计: {total_tokens}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"记录token使用情况失败: {str(e)}")
|
||||
|
||||
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
|
||||
"""计算API调用成本
|
||||
使用模型的pri_in和pri_out价格计算输入和输出的成本
|
||||
|
||||
Args:
|
||||
prompt_tokens: 输入token数量
|
||||
completion_tokens: 输出token数量
|
||||
|
||||
Returns:
|
||||
float: 总成本(元)
|
||||
"""
|
||||
# 使用模型的pri_in和pri_out计算成本
|
||||
input_cost = (prompt_tokens / 1000000) * self.pri_in
|
||||
output_cost = (completion_tokens / 1000000) * self.pri_out
|
||||
return round(input_cost + output_cost, 6)
|
||||
|
||||
Reference in New Issue
Block a user