调整对应的调用

This commit is contained in:
UnCLAS-Prommer
2025-07-30 17:07:55 +08:00
parent 3c40ceda4c
commit 6c0edd0ad7
40 changed files with 580 additions and 1236 deletions

View File

@@ -1,12 +0,0 @@
import importlib
from typing import Dict
from src.config.config import model_config
from src.common.logger import get_logger
from .model_client import ModelRequestHandler, BaseClient
logger = get_logger("模型管理器")
class ModelManager:

View File

@@ -1,92 +0,0 @@
import importlib
from typing import Dict
from src.config.config import model_config
from src.config.api_ada_configs import ModuleConfig, ModelUsageArgConfig
from src.common.logger import get_logger
from .model_client import ModelRequestHandler, BaseClient
logger = get_logger("模型管理器")
class ModelManager:
# TODO: 添加读写锁,防止异步刷新配置时发生数据竞争
def __init__(
self,
config: ModuleConfig,
):
self.config: ModuleConfig = config
"""配置信息"""
self.api_client_map: Dict[str, BaseClient] = {}
"""API客户端映射表"""
self._request_handler_cache: Dict[str, ModelRequestHandler] = {}
"""ModelRequestHandler缓存避免重复创建"""
for provider_name, api_provider in self.config.api_providers.items():
# 初始化API客户端
try:
# 根据配置动态加载实现
client_module = importlib.import_module(
f".model_client.{api_provider.client_type}_client", __package__
)
client_class = getattr(
client_module, f"{api_provider.client_type.capitalize()}Client"
)
if not issubclass(client_class, BaseClient):
raise TypeError(
f"'{client_class.__name__}' is not a subclass of 'BaseClient'"
)
self.api_client_map[api_provider.name] = client_class(
api_provider
) # 实例化放入api_client_map
except ImportError as e:
logger.error(f"Failed to import client module: {e}")
raise ImportError(
f"Failed to import client module for '{provider_name}': {e}"
) from e
def __getitem__(self, task_name: str) -> ModelRequestHandler:
"""
获取任务所需的模型客户端(封装)
使用缓存机制避免重复创建ModelRequestHandler
:param task_name: 任务名称
:return: 模型客户端
"""
if task_name not in self.config.task_model_arg_map:
raise KeyError(f"'{task_name}' not registered in ModelManager")
# 检查缓存中是否已存在
if task_name in self._request_handler_cache:
logger.debug(f"🚀 [性能优化] 从缓存获取ModelRequestHandler: {task_name}")
return self._request_handler_cache[task_name]
# 创建新的ModelRequestHandler并缓存
logger.debug(f"🔧 [性能优化] 创建并缓存ModelRequestHandler: {task_name}")
handler = ModelRequestHandler(
task_name=task_name,
config=self.config,
api_client_map=self.api_client_map,
)
self._request_handler_cache[task_name] = handler
return handler
def __setitem__(self, task_name: str, value: ModelUsageArgConfig):
"""
注册任务的模型使用配置
:param task_name: 任务名称
:param value: 模型使用配置
"""
self.config.task_model_arg_map[task_name] = value
def __contains__(self, task_name: str):
"""
判断任务是否已注册
:param task_name: 任务名称
:return: 是否在模型列表中
"""
return task_name in self.config.task_model_arg_map

View File

@@ -1,169 +0,0 @@
from datetime import datetime
from enum import Enum
from typing import Tuple
from src.common.logger import get_logger
from src.config.api_ada_configs import ModelInfo
from src.common.database.database_model import LLMUsage
logger = get_logger("模型使用统计")
class ReqType(Enum):
"""
请求类型
"""
CHAT = "chat" # 对话请求
EMBEDDING = "embedding" # 嵌入请求
class UsageCallStatus(Enum):
"""
任务调用状态
"""
PROCESSING = "processing" # 处理中
SUCCESS = "success" # 成功
FAILURE = "failure" # 失败
CANCELED = "canceled" # 取消
class ModelUsageStatistic:
"""
模型使用统计类 - 使用SQLite+Peewee
"""
def __init__(self):
"""
初始化统计类
由于使用Peewee ORM不需要传入数据库实例
"""
# 确保表已经创建
try:
from src.common.database.database import db
db.create_tables([LLMUsage], safe=True)
except Exception as e:
logger.error(f"创建LLMUsage表失败: {e}")
@staticmethod
def _calculate_cost(prompt_tokens: int, completion_tokens: int, model_info: ModelInfo) -> float:
"""计算API调用成本
使用模型的pri_in和pri_out价格计算输入和输出的成本
Args:
prompt_tokens: 输入token数量
completion_tokens: 输出token数量
model_info: 模型信息
Returns:
float: 总成本(元)
"""
# 使用模型的pri_in和pri_out计算成本
input_cost = (prompt_tokens / 1000000) * model_info.price_in
output_cost = (completion_tokens / 1000000) * model_info.price_out
return round(input_cost + output_cost, 6)
def create_usage(
self,
model_name: str,
task_name: str = "N/A",
request_type: ReqType = ReqType.CHAT,
user_id: str = "system",
endpoint: str = "/chat/completions",
) -> int | None:
"""
创建模型使用情况记录
Args:
model_name: 模型名
task_name: 任务名称
request_type: 请求类型默认为Chat
user_id: 用户ID默认为system
endpoint: API端点
Returns:
int | None: 返回记录ID失败返回None
"""
try:
usage_record = LLMUsage.create(
model_name=model_name,
user_id=user_id,
request_type=request_type.value,
endpoint=endpoint,
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost=0.0,
status=UsageCallStatus.PROCESSING.value,
timestamp=datetime.now(),
)
# logger.trace(
# f"创建了一条模型使用情况记录 - 模型: {model_name}, "
# f"子任务: {task_name}, 类型: {request_type.value}, "
# f"用户: {user_id}, 记录ID: {usage_record.id}"
# )
return usage_record.id
except Exception as e:
logger.error(f"创建模型使用情况记录失败: {str(e)}")
return None
def update_usage(
self,
record_id: int | None,
model_info: ModelInfo,
usage_data: Tuple[int, int, int] | None = None,
stat: UsageCallStatus = UsageCallStatus.SUCCESS,
ext_msg: str | None = None,
):
"""
更新模型使用情况
Args:
record_id: 记录ID
model_info: 模型信息
usage_data: 使用情况数据(输入token数量, 输出token数量, 总token数量)
stat: 任务调用状态
ext_msg: 额外信息
"""
if not record_id:
logger.error("更新模型使用情况失败: record_id不能为空")
return
if usage_data and len(usage_data) != 3:
logger.error("更新模型使用情况失败: usage_data的长度不正确应该为3个元素")
return
# 提取使用情况数据
prompt_tokens = usage_data[0] if usage_data else 0
completion_tokens = usage_data[1] if usage_data else 0
total_tokens = usage_data[2] if usage_data else 0
try:
# 使用Peewee更新记录
update_query = LLMUsage.update(
status=stat.value,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost=self._calculate_cost(prompt_tokens, completion_tokens, model_info) if usage_data else 0.0,
).where(LLMUsage.id == record_id) # type: ignore
updated_count = update_query.execute()
if updated_count == 0:
logger.warning(f"记录ID {record_id} 不存在,无法更新")
return
logger.debug(
f"Token使用情况 - 模型: {model_info.name}, "
f"记录ID: {record_id}, "
f"任务状态: {stat.value}, 额外信息: {ext_msg or 'N/A'}, "
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
f"总计: {total_tokens}"
)
except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}")

View File

@@ -2,16 +2,19 @@ import base64
import io
from PIL import Image
from datetime import datetime
from src.common.logger import get_logger
from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage
from src.config.api_ada_configs import ModelInfo
from .payload_content.message import Message, MessageBuilder
from .model_client.base_client import UsageRecord
logger = get_logger("消息压缩工具")
def compress_messages(
messages: list[Message], img_target_size: int = 1 * 1024 * 1024
) -> list[Message]:
def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 * 1024) -> list[Message]:
"""
压缩消息列表中的图片
:param messages: 消息列表
@@ -28,14 +31,10 @@ def compress_messages(
try:
image = Image.open(image_data)
if image.format and (
image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]
):
if image.format and (image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]):
# 静态图像转换为JPEG格式
reformated_image_data = io.BytesIO()
image.save(
reformated_image_data, format="JPEG", quality=95, optimize=True
)
image.save(reformated_image_data, format="JPEG", quality=95, optimize=True)
image_data = reformated_image_data.getvalue()
return image_data
@@ -43,9 +42,7 @@ def compress_messages(
logger.error(f"图片转换格式失败: {str(e)}")
return image_data
def rescale_image(
image_data: bytes, scale: float
) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]:
def rescale_image(image_data: bytes, scale: float) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]:
"""
缩放图片
:param image_data: 图片数据
@@ -86,9 +83,7 @@ def compress_messages(
else:
# 静态图片,直接缩放保存
resized_image = image.resize(new_size, Image.Resampling.LANCZOS)
resized_image.save(
output_buffer, format="JPEG", quality=95, optimize=True
)
resized_image.save(output_buffer, format="JPEG", quality=95, optimize=True)
return output_buffer.getvalue(), original_size, new_size
@@ -99,9 +94,7 @@ def compress_messages(
logger.error(traceback.format_exc())
return image_data, None, None
def compress_base64_image(
base64_data: str, target_size: int = 1 * 1024 * 1024
) -> str:
def compress_base64_image(base64_data: str, target_size: int = 1 * 1024 * 1024) -> str:
original_b64_data_size = len(base64_data) # 计算原始数据大小
image_data = base64.b64decode(base64_data)
@@ -111,9 +104,7 @@ def compress_messages(
base64_data = base64.b64encode(image_data).decode("utf-8")
if len(base64_data) <= target_size:
# 如果转换后小于目标大小,直接返回
logger.info(
f"成功将图片转为JPEG格式编码后大小: {len(base64_data) / 1024:.1f}KB"
)
logger.info(f"成功将图片转为JPEG格式编码后大小: {len(base64_data) / 1024:.1f}KB")
return base64_data
# 如果转换后仍然大于目标大小,进行尺寸压缩
@@ -139,9 +130,7 @@ def compress_messages(
# 图片,进行压缩
message_builder.add_image_content(
content_item[0],
compress_base64_image(
content_item[1], target_size=img_target_size
),
compress_base64_image(content_item[1], target_size=img_target_size),
)
else:
message_builder.add_text_content(content_item)
@@ -150,3 +139,48 @@ def compress_messages(
compressed_messages.append(message)
return compressed_messages
class LLMUsageRecorder:
"""
LLM使用情况记录器
"""
def __init__(self):
try:
# 使用 Peewee 创建表safe=True 表示如果表已存在则不会抛出错误
db.create_tables([LLMUsage], safe=True)
# logger.debug("LLMUsage 表已初始化/确保存在。")
except Exception as e:
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
def record_usage_to_database(
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str
):
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
total_cost = round(input_cost + output_cost, 6)
try:
# 使用 Peewee 模型创建记录
LLMUsage.create(
model_name=model_info.model_identifier,
user_id=user_id,
request_type=request_type,
endpoint=endpoint,
prompt_tokens=model_usage.prompt_tokens or 0,
completion_tokens=model_usage.completion_tokens or 0,
total_tokens=model_usage.total_tokens or 0,
cost=total_cost or 0.0,
status="success",
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
)
logger.debug(
f"Token使用情况 - 模型: {model_usage.model_name}, "
f"用户: {user_id}, 类型: {request_type}, "
f"提示词: {model_usage.prompt_tokens}, 完成: {model_usage.completion_tokens}, "
f"总计: {model_usage.total_tokens}"
)
except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}")
llm_usage_recorder = LLMUsageRecorder()

View File

@@ -1,34 +1,20 @@
import re
import copy
import asyncio
from datetime import datetime
from typing import Tuple, Union, List, Dict, Optional, Callable, Any
from src.common.logger import get_logger
import base64
from PIL import Image
from enum import Enum
import io
from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
from src.config.config import global_config, model_config
from src.config.api_ada_configs import APIProvider, ModelInfo
from rich.traceback import install
from enum import Enum
from rich.traceback import install
from typing import Tuple, List, Dict, Optional, Callable, Any
from src.common.logger import get_logger
from src.config.config import model_config
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
from .payload_content.message import MessageBuilder, Message
from .payload_content.resp_format import RespFormat
from .payload_content.tool_option import ToolOption, ToolCall
from .model_client.base_client import BaseClient, APIResponse, UsageRecord, client_registry
from .utils import compress_messages
from .exceptions import (
NetworkConnectionError,
ReqAbortException,
RespNotOkException,
RespParseException,
PayLoadTooLargeError,
RequestAbortException,
PermissionDeniedException,
)
from .model_client.base_client import BaseClient, APIResponse, client_registry
from .utils import compress_messages, llm_usage_recorder
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
install(extra_lines=3)
@@ -57,45 +43,15 @@ class RequestType(Enum):
class LLMRequest:
"""LLM请求类"""
# 定义需要转换的模型列表,作为类变量避免重复
MODELS_NEEDING_TRANSFORMATION = [
"o1",
"o1-2024-12-17",
"o1-mini",
"o1-mini-2024-09-12",
"o1-preview",
"o1-preview-2024-09-12",
"o1-pro",
"o1-pro-2025-03-19",
"o3",
"o3-2025-04-16",
"o3-mini",
"o3-mini-2025-01-31",
"o4-mini",
"o4-mini-2025-04-16",
]
def __init__(self, task_name: str, request_type: str = "") -> None:
self.task_name = task_name
self.model_for_task = model_config.model_task_config.get_task(task_name)
def __init__(self, model_set: TaskConfig, request_type: str = "") -> None:
self.task_name = request_type
self.model_for_task = model_set
self.request_type = request_type
self.model_usage: Dict[str, Tuple[int, int]] = {model: (0, 0) for model in self.model_for_task.model_list}
"""模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty),惩罚值是为了能在某个模型请求不给力的时候进行调整"""
self.pri_in = 0
self.pri_out = 0
self._init_database()
@staticmethod
def _init_database():
"""初始化数据库集合"""
try:
# 使用 Peewee 创建表safe=True 表示如果表已存在则不会抛出错误
db.create_tables([LLMUsage], safe=True)
# logger.debug("LLMUsage 表已初始化/确保存在。")
except Exception as e:
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
async def generate_response_for_image(
self,
@@ -104,7 +60,7 @@ class LLMRequest:
image_format: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[str, str, Optional[List[Dict[str, Any]]]]:
) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]:
"""
为图像生成响应
Args:
@@ -112,7 +68,7 @@ class LLMRequest:
image_base64 (str): 图像的Base64编码字符串
image_format (str): 图像格式(如 'png', 'jpeg' 等)
Returns:
(Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
# 请求体构建
message_builder = MessageBuilder()
@@ -141,25 +97,25 @@ class LLMRequest:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
if usage := response.usage:
self.pri_in = model_info.price_in
self.pri_out = model_info.price_out
self._record_usage(
model_name=model_info.name,
prompt_tokens=usage.prompt_tokens or 0,
completion_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens or 0,
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system",
request_type=self.request_type,
endpoint="/chat/completions",
)
return content, reasoning_content, self._convert_tool_calls(tool_calls) if tool_calls else None
return content, (
reasoning_content,
model_info.name,
self._convert_tool_calls(tool_calls) if tool_calls else None,
)
async def generate_response_for_voice(self):
pass
async def generate_response_async(
self, prompt: str, temperature: Optional[float] = None, max_tokens: Optional[int] = None
) -> Tuple[str, str, Optional[List[Dict[str, Any]]]]:
) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]:
"""
异步生成响应
Args:
@@ -167,7 +123,7 @@ class LLMRequest:
temperature (float, optional): 温度参数
max_tokens (int, optional): 最大token数
Returns:
Tuple[str, str, Optional[List[Dict[str, Any]]]]: 响应内容、推理内容工具调用列表
(Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
# 请求体构建
message_builder = MessageBuilder()
@@ -195,13 +151,9 @@ class LLMRequest:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
if usage := response.usage:
self.pri_in = model_info.price_in
self.pri_out = model_info.price_out
self._record_usage(
model_name=model_info.name,
prompt_tokens=usage.prompt_tokens or 0,
completion_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens or 0,
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system",
request_type=self.request_type,
endpoint="/chat/completions",
@@ -209,10 +161,19 @@ class LLMRequest:
if not content:
raise RuntimeError("获取LLM生成内容失败")
return content, reasoning_content, self._convert_tool_calls(tool_calls) if tool_calls else None
return content, (
reasoning_content,
model_info.name,
self._convert_tool_calls(tool_calls) if tool_calls else None,
)
async def get_embedding(self, embedding_input: str) -> List[float]:
"""获取嵌入向量"""
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
"""获取嵌入向量
Args:
embedding_input (str): 获取嵌入的目标
Returns:
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
"""
# 无需构建消息体,直接使用输入文本
model_info, api_provider, client = self._select_model()
@@ -227,14 +188,10 @@ class LLMRequest:
embedding = response.embedding
if response.usage:
self.pri_in = model_info.price_in
self.pri_out = model_info.price_out
self._record_usage(
model_name=model_info.name,
prompt_tokens=response.usage.prompt_tokens or 0,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens or 0,
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system",
request_type=self.request_type,
endpoint="/embeddings",
@@ -243,7 +200,7 @@ class LLMRequest:
if not embedding:
raise RuntimeError("获取embedding失败")
return embedding
return embedding, model_info.name
def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
"""
@@ -305,12 +262,13 @@ class LLMRequest:
# 处理异常
total_tokens, penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty + 1)
wait_interval, compressed_messages = self._default_exception_handler(
e,
self.task_name,
model_name=model_info.name,
remain_try=retry_remain,
messages=(message_list, compressed_messages is not None),
messages=(message_list, compressed_messages is not None) if message_list else None,
)
if wait_interval == -1:
@@ -321,9 +279,7 @@ class LLMRequest:
finally:
# 放在finally防止死循环
retry_remain -= 1
logger.error(
f"任务 '{self.task_name}' 模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry}"
)
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry}")
raise RuntimeError("请求失败,已达到最大重试次数")
def _default_exception_handler(
@@ -481,65 +437,3 @@ class LLMRequest:
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
reasoning = match[1].strip() if match else ""
return content, reasoning
def _record_usage(
self,
model_name: str,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
user_id: str = "system",
request_type: str | None = None,
endpoint: str = "/chat/completions",
):
"""记录模型使用情况到数据库
Args:
prompt_tokens: 输入token数
completion_tokens: 输出token数
total_tokens: 总token数
user_id: 用户ID默认为system
request_type: 请求类型
endpoint: API端点
"""
# 如果 request_type 为 None则使用实例变量中的值
if request_type is None:
request_type = self.request_type
try:
# 使用 Peewee 模型创建记录
LLMUsage.create(
model_name=model_name,
user_id=user_id,
request_type=request_type,
endpoint=endpoint,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost=self._calculate_cost(prompt_tokens, completion_tokens),
status="success",
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
)
logger.debug(
f"Token使用情况 - 模型: {model_name}, "
f"用户: {user_id}, 类型: {request_type}, "
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
f"总计: {total_tokens}"
)
except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}")
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
"""计算API调用成本
使用模型的pri_in和pri_out价格计算输入和输出的成本
Args:
prompt_tokens: 输入token数量
completion_tokens: 输出token数量
Returns:
float: 总成本(元)
"""
# 使用模型的pri_in和pri_out计算成本
input_cost = (prompt_tokens / 1000000) * self.pri_in
output_cost = (completion_tokens / 1000000) * self.pri_out
return round(input_cost + output_cost, 6)