调整对应的调用

This commit is contained in:
UnCLAS-Prommer
2025-07-30 17:07:55 +08:00
parent 3c40ceda4c
commit 6c0edd0ad7
40 changed files with 580 additions and 1236 deletions

View File

@@ -1,34 +1,20 @@
import re
import copy
import asyncio
from datetime import datetime
from typing import Tuple, Union, List, Dict, Optional, Callable, Any
from src.common.logger import get_logger
import base64
from PIL import Image
from enum import Enum
import io
from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
from src.config.config import global_config, model_config
from src.config.api_ada_configs import APIProvider, ModelInfo
from rich.traceback import install
from enum import Enum
from rich.traceback import install
from typing import Tuple, List, Dict, Optional, Callable, Any
from src.common.logger import get_logger
from src.config.config import model_config
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
from .payload_content.message import MessageBuilder, Message
from .payload_content.resp_format import RespFormat
from .payload_content.tool_option import ToolOption, ToolCall
from .model_client.base_client import BaseClient, APIResponse, UsageRecord, client_registry
from .utils import compress_messages
from .exceptions import (
NetworkConnectionError,
ReqAbortException,
RespNotOkException,
RespParseException,
PayLoadTooLargeError,
RequestAbortException,
PermissionDeniedException,
)
from .model_client.base_client import BaseClient, APIResponse, client_registry
from .utils import compress_messages, llm_usage_recorder
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
install(extra_lines=3)
@@ -57,45 +43,15 @@ class RequestType(Enum):
class LLMRequest:
"""LLM请求类"""
# 定义需要转换的模型列表,作为类变量避免重复
MODELS_NEEDING_TRANSFORMATION = [
"o1",
"o1-2024-12-17",
"o1-mini",
"o1-mini-2024-09-12",
"o1-preview",
"o1-preview-2024-09-12",
"o1-pro",
"o1-pro-2025-03-19",
"o3",
"o3-2025-04-16",
"o3-mini",
"o3-mini-2025-01-31",
"o4-mini",
"o4-mini-2025-04-16",
]
def __init__(self, task_name: str, request_type: str = "") -> None:
self.task_name = task_name
self.model_for_task = model_config.model_task_config.get_task(task_name)
def __init__(self, model_set: TaskConfig, request_type: str = "") -> None:
self.task_name = request_type
self.model_for_task = model_set
self.request_type = request_type
self.model_usage: Dict[str, Tuple[int, int]] = {model: (0, 0) for model in self.model_for_task.model_list}
"""模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty),惩罚值是为了能在某个模型请求不给力的时候进行调整"""
self.pri_in = 0
self.pri_out = 0
self._init_database()
@staticmethod
def _init_database():
"""初始化数据库集合"""
try:
# 使用 Peewee 创建表safe=True 表示如果表已存在则不会抛出错误
db.create_tables([LLMUsage], safe=True)
# logger.debug("LLMUsage 表已初始化/确保存在。")
except Exception as e:
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
async def generate_response_for_image(
self,
@@ -104,7 +60,7 @@ class LLMRequest:
image_format: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[str, str, Optional[List[Dict[str, Any]]]]:
) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]:
"""
为图像生成响应
Args:
@@ -112,7 +68,7 @@ class LLMRequest:
image_base64 (str): 图像的Base64编码字符串
image_format (str): 图像格式(如 'png', 'jpeg' 等)
Returns:
(Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
# 请求体构建
message_builder = MessageBuilder()
@@ -141,25 +97,25 @@ class LLMRequest:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
if usage := response.usage:
self.pri_in = model_info.price_in
self.pri_out = model_info.price_out
self._record_usage(
model_name=model_info.name,
prompt_tokens=usage.prompt_tokens or 0,
completion_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens or 0,
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system",
request_type=self.request_type,
endpoint="/chat/completions",
)
return content, reasoning_content, self._convert_tool_calls(tool_calls) if tool_calls else None
return content, (
reasoning_content,
model_info.name,
self._convert_tool_calls(tool_calls) if tool_calls else None,
)
async def generate_response_for_voice(self):
pass
async def generate_response_async(
self, prompt: str, temperature: Optional[float] = None, max_tokens: Optional[int] = None
) -> Tuple[str, str, Optional[List[Dict[str, Any]]]]:
) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]:
"""
异步生成响应
Args:
@@ -167,7 +123,7 @@ class LLMRequest:
temperature (float, optional): 温度参数
max_tokens (int, optional): 最大token数
Returns:
Tuple[str, str, Optional[List[Dict[str, Any]]]]: 响应内容、推理内容工具调用列表
(Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
# 请求体构建
message_builder = MessageBuilder()
@@ -195,13 +151,9 @@ class LLMRequest:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
if usage := response.usage:
self.pri_in = model_info.price_in
self.pri_out = model_info.price_out
self._record_usage(
model_name=model_info.name,
prompt_tokens=usage.prompt_tokens or 0,
completion_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens or 0,
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system",
request_type=self.request_type,
endpoint="/chat/completions",
@@ -209,10 +161,19 @@ class LLMRequest:
if not content:
raise RuntimeError("获取LLM生成内容失败")
return content, reasoning_content, self._convert_tool_calls(tool_calls) if tool_calls else None
return content, (
reasoning_content,
model_info.name,
self._convert_tool_calls(tool_calls) if tool_calls else None,
)
async def get_embedding(self, embedding_input: str) -> List[float]:
"""获取嵌入向量"""
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
"""获取嵌入向量
Args:
embedding_input (str): 获取嵌入的目标
Returns:
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
"""
# 无需构建消息体,直接使用输入文本
model_info, api_provider, client = self._select_model()
@@ -227,14 +188,10 @@ class LLMRequest:
embedding = response.embedding
if response.usage:
self.pri_in = model_info.price_in
self.pri_out = model_info.price_out
self._record_usage(
model_name=model_info.name,
prompt_tokens=response.usage.prompt_tokens or 0,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens or 0,
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system",
request_type=self.request_type,
endpoint="/embeddings",
@@ -243,7 +200,7 @@ class LLMRequest:
if not embedding:
raise RuntimeError("获取embedding失败")
return embedding
return embedding, model_info.name
def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
"""
@@ -305,12 +262,13 @@ class LLMRequest:
# 处理异常
total_tokens, penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty + 1)
wait_interval, compressed_messages = self._default_exception_handler(
e,
self.task_name,
model_name=model_info.name,
remain_try=retry_remain,
messages=(message_list, compressed_messages is not None),
messages=(message_list, compressed_messages is not None) if message_list else None,
)
if wait_interval == -1:
@@ -321,9 +279,7 @@ class LLMRequest:
finally:
# 放在finally防止死循环
retry_remain -= 1
logger.error(
f"任务 '{self.task_name}' 模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry}"
)
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry}")
raise RuntimeError("请求失败,已达到最大重试次数")
def _default_exception_handler(
@@ -481,65 +437,3 @@ class LLMRequest:
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
reasoning = match[1].strip() if match else ""
return content, reasoning
def _record_usage(
self,
model_name: str,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
user_id: str = "system",
request_type: str | None = None,
endpoint: str = "/chat/completions",
):
"""记录模型使用情况到数据库
Args:
prompt_tokens: 输入token数
completion_tokens: 输出token数
total_tokens: 总token数
user_id: 用户ID默认为system
request_type: 请求类型
endpoint: API端点
"""
# 如果 request_type 为 None则使用实例变量中的值
if request_type is None:
request_type = self.request_type
try:
# 使用 Peewee 模型创建记录
LLMUsage.create(
model_name=model_name,
user_id=user_id,
request_type=request_type,
endpoint=endpoint,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost=self._calculate_cost(prompt_tokens, completion_tokens),
status="success",
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
)
logger.debug(
f"Token使用情况 - 模型: {model_name}, "
f"用户: {user_id}, 类型: {request_type}, "
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
f"总计: {total_tokens}"
)
except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}")
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
"""计算API调用成本
使用模型的pri_in和pri_out价格计算输入和输出的成本
Args:
prompt_tokens: 输入token数量
completion_tokens: 输出token数量
Returns:
float: 总成本(元)
"""
# 使用模型的pri_in和pri_out计算成本
input_cost = (prompt_tokens / 1000000) * self.pri_in
output_cost = (completion_tokens / 1000000) * self.pri_out
return round(input_cost + output_cost, 6)