调整对应的调用
This commit is contained in:
@@ -1,34 +1,20 @@
|
||||
import re
|
||||
import copy
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Tuple, Union, List, Dict, Optional, Callable, Any
|
||||
from src.common.logger import get_logger
|
||||
import base64
|
||||
from PIL import Image
|
||||
from enum import Enum
|
||||
import io
|
||||
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
||||
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
|
||||
from src.config.config import global_config, model_config
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo
|
||||
from rich.traceback import install
|
||||
|
||||
from enum import Enum
|
||||
from rich.traceback import install
|
||||
from typing import Tuple, List, Dict, Optional, Callable, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
|
||||
from .payload_content.message import MessageBuilder, Message
|
||||
from .payload_content.resp_format import RespFormat
|
||||
from .payload_content.tool_option import ToolOption, ToolCall
|
||||
from .model_client.base_client import BaseClient, APIResponse, UsageRecord, client_registry
|
||||
from .utils import compress_messages
|
||||
|
||||
from .exceptions import (
|
||||
NetworkConnectionError,
|
||||
ReqAbortException,
|
||||
RespNotOkException,
|
||||
RespParseException,
|
||||
PayLoadTooLargeError,
|
||||
RequestAbortException,
|
||||
PermissionDeniedException,
|
||||
)
|
||||
from .model_client.base_client import BaseClient, APIResponse, client_registry
|
||||
from .utils import compress_messages, llm_usage_recorder
|
||||
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -57,45 +43,15 @@ class RequestType(Enum):
|
||||
class LLMRequest:
|
||||
"""LLM请求类"""
|
||||
|
||||
# 定义需要转换的模型列表,作为类变量避免重复
|
||||
MODELS_NEEDING_TRANSFORMATION = [
|
||||
"o1",
|
||||
"o1-2024-12-17",
|
||||
"o1-mini",
|
||||
"o1-mini-2024-09-12",
|
||||
"o1-preview",
|
||||
"o1-preview-2024-09-12",
|
||||
"o1-pro",
|
||||
"o1-pro-2025-03-19",
|
||||
"o3",
|
||||
"o3-2025-04-16",
|
||||
"o3-mini",
|
||||
"o3-mini-2025-01-31",
|
||||
"o4-mini",
|
||||
"o4-mini-2025-04-16",
|
||||
]
|
||||
|
||||
def __init__(self, task_name: str, request_type: str = "") -> None:
|
||||
self.task_name = task_name
|
||||
self.model_for_task = model_config.model_task_config.get_task(task_name)
|
||||
def __init__(self, model_set: TaskConfig, request_type: str = "") -> None:
|
||||
self.task_name = request_type
|
||||
self.model_for_task = model_set
|
||||
self.request_type = request_type
|
||||
self.model_usage: Dict[str, Tuple[int, int]] = {model: (0, 0) for model in self.model_for_task.model_list}
|
||||
"""模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty),惩罚值是为了能在某个模型请求不给力的时候进行调整"""
|
||||
|
||||
self.pri_in = 0
|
||||
self.pri_out = 0
|
||||
|
||||
self._init_database()
|
||||
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
"""初始化数据库集合"""
|
||||
try:
|
||||
# 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误
|
||||
db.create_tables([LLMUsage], safe=True)
|
||||
# logger.debug("LLMUsage 表已初始化/确保存在。")
|
||||
except Exception as e:
|
||||
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
|
||||
|
||||
async def generate_response_for_image(
|
||||
self,
|
||||
@@ -104,7 +60,7 @@ class LLMRequest:
|
||||
image_format: str,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[str, str, Optional[List[Dict[str, Any]]]]:
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]:
|
||||
"""
|
||||
为图像生成响应
|
||||
Args:
|
||||
@@ -112,7 +68,7 @@ class LLMRequest:
|
||||
image_base64 (str): 图像的Base64编码字符串
|
||||
image_format (str): 图像格式(如 'png', 'jpeg' 等)
|
||||
Returns:
|
||||
|
||||
(Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
# 请求体构建
|
||||
message_builder = MessageBuilder()
|
||||
@@ -141,25 +97,25 @@ class LLMRequest:
|
||||
content, extracted_reasoning = self._extract_reasoning(content)
|
||||
reasoning_content = extracted_reasoning
|
||||
if usage := response.usage:
|
||||
self.pri_in = model_info.price_in
|
||||
self.pri_out = model_info.price_out
|
||||
self._record_usage(
|
||||
model_name=model_info.name,
|
||||
prompt_tokens=usage.prompt_tokens or 0,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
total_tokens=usage.total_tokens or 0,
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
user_id="system",
|
||||
request_type=self.request_type,
|
||||
endpoint="/chat/completions",
|
||||
)
|
||||
return content, reasoning_content, self._convert_tool_calls(tool_calls) if tool_calls else None
|
||||
return content, (
|
||||
reasoning_content,
|
||||
model_info.name,
|
||||
self._convert_tool_calls(tool_calls) if tool_calls else None,
|
||||
)
|
||||
|
||||
async def generate_response_for_voice(self):
|
||||
pass
|
||||
|
||||
async def generate_response_async(
|
||||
self, prompt: str, temperature: Optional[float] = None, max_tokens: Optional[int] = None
|
||||
) -> Tuple[str, str, Optional[List[Dict[str, Any]]]]:
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]:
|
||||
"""
|
||||
异步生成响应
|
||||
Args:
|
||||
@@ -167,7 +123,7 @@ class LLMRequest:
|
||||
temperature (float, optional): 温度参数
|
||||
max_tokens (int, optional): 最大token数
|
||||
Returns:
|
||||
Tuple[str, str, Optional[List[Dict[str, Any]]]]: 响应内容、推理内容和工具调用列表
|
||||
(Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
# 请求体构建
|
||||
message_builder = MessageBuilder()
|
||||
@@ -195,13 +151,9 @@ class LLMRequest:
|
||||
content, extracted_reasoning = self._extract_reasoning(content)
|
||||
reasoning_content = extracted_reasoning
|
||||
if usage := response.usage:
|
||||
self.pri_in = model_info.price_in
|
||||
self.pri_out = model_info.price_out
|
||||
self._record_usage(
|
||||
model_name=model_info.name,
|
||||
prompt_tokens=usage.prompt_tokens or 0,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
total_tokens=usage.total_tokens or 0,
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
user_id="system",
|
||||
request_type=self.request_type,
|
||||
endpoint="/chat/completions",
|
||||
@@ -209,10 +161,19 @@ class LLMRequest:
|
||||
if not content:
|
||||
raise RuntimeError("获取LLM生成内容失败")
|
||||
|
||||
return content, reasoning_content, self._convert_tool_calls(tool_calls) if tool_calls else None
|
||||
return content, (
|
||||
reasoning_content,
|
||||
model_info.name,
|
||||
self._convert_tool_calls(tool_calls) if tool_calls else None,
|
||||
)
|
||||
|
||||
async def get_embedding(self, embedding_input: str) -> List[float]:
|
||||
"""获取嵌入向量"""
|
||||
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
||||
"""获取嵌入向量
|
||||
Args:
|
||||
embedding_input (str): 获取嵌入的目标
|
||||
Returns:
|
||||
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
|
||||
"""
|
||||
# 无需构建消息体,直接使用输入文本
|
||||
model_info, api_provider, client = self._select_model()
|
||||
|
||||
@@ -227,14 +188,10 @@ class LLMRequest:
|
||||
|
||||
embedding = response.embedding
|
||||
|
||||
if response.usage:
|
||||
self.pri_in = model_info.price_in
|
||||
self.pri_out = model_info.price_out
|
||||
self._record_usage(
|
||||
model_name=model_info.name,
|
||||
prompt_tokens=response.usage.prompt_tokens or 0,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens or 0,
|
||||
if usage := response.usage:
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
user_id="system",
|
||||
request_type=self.request_type,
|
||||
endpoint="/embeddings",
|
||||
@@ -243,7 +200,7 @@ class LLMRequest:
|
||||
if not embedding:
|
||||
raise RuntimeError("获取embedding失败")
|
||||
|
||||
return embedding
|
||||
return embedding, model_info.name
|
||||
|
||||
def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
|
||||
"""
|
||||
@@ -305,12 +262,13 @@ class LLMRequest:
|
||||
# 处理异常
|
||||
total_tokens, penalty = self.model_usage[model_info.name]
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty + 1)
|
||||
|
||||
wait_interval, compressed_messages = self._default_exception_handler(
|
||||
e,
|
||||
self.task_name,
|
||||
model_name=model_info.name,
|
||||
remain_try=retry_remain,
|
||||
messages=(message_list, compressed_messages is not None),
|
||||
messages=(message_list, compressed_messages is not None) if message_list else None,
|
||||
)
|
||||
|
||||
if wait_interval == -1:
|
||||
@@ -321,9 +279,7 @@ class LLMRequest:
|
||||
finally:
|
||||
# 放在finally防止死循环
|
||||
retry_remain -= 1
|
||||
logger.error(
|
||||
f"任务 '{self.task_name}' 模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次"
|
||||
)
|
||||
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次")
|
||||
raise RuntimeError("请求失败,已达到最大重试次数")
|
||||
|
||||
def _default_exception_handler(
|
||||
@@ -481,65 +437,3 @@ class LLMRequest:
|
||||
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
|
||||
reasoning = match[1].strip() if match else ""
|
||||
return content, reasoning
|
||||
|
||||
def _record_usage(
|
||||
self,
|
||||
model_name: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
user_id: str = "system",
|
||||
request_type: str | None = None,
|
||||
endpoint: str = "/chat/completions",
|
||||
):
|
||||
"""记录模型使用情况到数据库
|
||||
Args:
|
||||
prompt_tokens: 输入token数
|
||||
completion_tokens: 输出token数
|
||||
total_tokens: 总token数
|
||||
user_id: 用户ID,默认为system
|
||||
request_type: 请求类型
|
||||
endpoint: API端点
|
||||
"""
|
||||
# 如果 request_type 为 None,则使用实例变量中的值
|
||||
if request_type is None:
|
||||
request_type = self.request_type
|
||||
|
||||
try:
|
||||
# 使用 Peewee 模型创建记录
|
||||
LLMUsage.create(
|
||||
model_name=model_name,
|
||||
user_id=user_id,
|
||||
request_type=request_type,
|
||||
endpoint=endpoint,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost=self._calculate_cost(prompt_tokens, completion_tokens),
|
||||
status="success",
|
||||
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
|
||||
)
|
||||
logger.debug(
|
||||
f"Token使用情况 - 模型: {model_name}, "
|
||||
f"用户: {user_id}, 类型: {request_type}, "
|
||||
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
|
||||
f"总计: {total_tokens}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"记录token使用情况失败: {str(e)}")
|
||||
|
||||
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
|
||||
"""计算API调用成本
|
||||
使用模型的pri_in和pri_out价格计算输入和输出的成本
|
||||
|
||||
Args:
|
||||
prompt_tokens: 输入token数量
|
||||
completion_tokens: 输出token数量
|
||||
|
||||
Returns:
|
||||
float: 总成本(元)
|
||||
"""
|
||||
# 使用模型的pri_in和pri_out计算成本
|
||||
input_cost = (prompt_tokens / 1000000) * self.pri_in
|
||||
output_cost = (completion_tokens / 1000000) * self.pri_out
|
||||
return round(input_cost + output_cost, 6)
|
||||
|
||||
Reference in New Issue
Block a user