Files
Mofox-Core/src/llm_models/utils_model.py

496 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import copy
import asyncio
from enum import Enum
from rich.traceback import install
from typing import Tuple, List, Dict, Optional, Callable, Any
from src.common.logger import get_logger
from src.config.config import model_config
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
from .payload_content.message import MessageBuilder, Message
from .payload_content.resp_format import RespFormat
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType
from .model_client.base_client import BaseClient, APIResponse, client_registry
from .utils import compress_messages, llm_usage_recorder
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
install(extra_lines=3)
logger = get_logger("model_utils")
# 常见Error Code Mapping
error_code_mapping = {
400: "参数不正确",
401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确",
402: "账号余额不足",
403: "需要实名,或余额不足",
404: "Not Found",
429: "请求过于频繁,请稍后再试",
500: "服务器内部故障",
503: "服务器负载过高",
}
class RequestType(Enum):
"""请求类型枚举"""
RESPONSE = "response"
EMBEDDING = "embedding"
AUDIO = "audio"
class LLMRequest:
"""LLM请求类"""
def __init__(self, model_set: TaskConfig, request_type: str = "") -> None:
self.task_name = request_type
self.model_for_task = model_set
self.request_type = request_type
self.model_usage: Dict[str, Tuple[int, int]] = {model: (0, 0) for model in self.model_for_task.model_list}
"""模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty),惩罚值是为了能在某个模型请求不给力的时候进行调整"""
self.pri_in = 0
self.pri_out = 0
async def generate_response_for_image(
self,
prompt: str,
image_base64: str,
image_format: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
"""
为图像生成响应
Args:
prompt (str): 提示词
image_base64 (str): 图像的Base64编码字符串
image_format (str): 图像格式(如 'png', 'jpeg' 等)
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
# 模型选择
model_info, api_provider, client = self._select_model()
# 请求体构建
message_builder = MessageBuilder()
message_builder.add_text_content(prompt)
message_builder.add_image_content(image_base64=image_base64, image_format=image_format, support_formats=client.get_support_image_formats())
messages = [message_builder.build()]
# 请求并处理返回值
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=messages,
temperature=temperature,
max_tokens=max_tokens,
)
content = response.content or ""
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
# 从内容中提取<think>标签的推理内容(向后兼容)
if not reasoning_content and content:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system",
request_type=self.request_type,
endpoint="/chat/completions",
)
return content, (reasoning_content, model_info.name, tool_calls)
async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]:
"""
为语音生成响应
Args:
voice_base64 (str): 语音的Base64编码字符串
Returns:
(Optional[str]): 生成的文本描述或None
"""
# 模型选择
model_info, api_provider, client = self._select_model()
# 请求并处理返回值
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.AUDIO,
model_info=model_info,
audio_base64=voice_base64,
)
return response.content or None
async def generate_response_async(
self,
prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
tools: Optional[List[Dict[str, Any]]] = None,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
"""
异步生成响应
Args:
prompt (str): 提示词
temperature (float, optional): 温度参数
max_tokens (int, optional): 最大token数
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
# 请求体构建
message_builder = MessageBuilder()
message_builder.add_text_content(prompt)
messages = [message_builder.build()]
tool_built = self._build_tool_options(tools)
# 模型选择
model_info, api_provider, client = self._select_model()
# 请求并处理返回值
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=messages,
temperature=temperature,
max_tokens=max_tokens,
tool_options=tool_built,
)
content = response.content
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
# 从内容中提取<think>标签的推理内容(向后兼容)
if not reasoning_content and content:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system",
request_type=self.request_type,
endpoint="/chat/completions",
)
if not content:
raise RuntimeError("获取LLM生成内容失败")
return content, (reasoning_content, model_info.name, tool_calls)
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
"""获取嵌入向量
Args:
embedding_input (str): 获取嵌入的目标
Returns:
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
"""
# 无需构建消息体,直接使用输入文本
model_info, api_provider, client = self._select_model()
# 请求并处理返回值
response = await self._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.EMBEDDING,
model_info=model_info,
embedding_input=embedding_input,
)
embedding = response.embedding
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system",
request_type=self.request_type,
endpoint="/embeddings",
)
if not embedding:
raise RuntimeError("获取embedding失败")
return embedding, model_info.name
def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
"""
根据总tokens和惩罚值选择的模型
"""
least_used_model_name = min(
self.model_usage, key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300
)
model_info = model_config.get_model_info(least_used_model_name)
api_provider = model_config.get_provider(model_info.api_provider)
client = client_registry.get_client_class(api_provider.client_type)(copy.deepcopy(api_provider))
logger.debug(f"选择请求模型: {model_info.name}")
return model_info, api_provider, client
async def _execute_request(
self,
api_provider: APIProvider,
client: BaseClient,
request_type: RequestType,
model_info: ModelInfo,
message_list: List[Message] | None = None,
tool_options: list[ToolOption] | None = None,
response_format: RespFormat | None = None,
stream_response_handler: Optional[Callable] = None,
async_response_parser: Optional[Callable] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
embedding_input: str = "",
audio_base64: str = "",
) -> APIResponse:
"""
实际执行请求的方法
包含了重试和异常处理逻辑
"""
retry_remain = api_provider.max_retry
compressed_messages: Optional[List[Message]] = None
while retry_remain > 0:
try:
if request_type == RequestType.RESPONSE:
assert message_list is not None, "message_list cannot be None for response requests"
return await client.get_response(
model_info=model_info,
message_list=(compressed_messages or message_list),
tool_options=tool_options,
max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
temperature=self.model_for_task.temperature if temperature is None else temperature,
response_format=response_format,
stream_response_handler=stream_response_handler,
async_response_parser=async_response_parser,
extra_params=model_info.extra_params,
)
elif request_type == RequestType.EMBEDDING:
assert embedding_input, "embedding_input cannot be empty for embedding requests"
return await client.get_embedding(
model_info=model_info,
embedding_input=embedding_input,
extra_params=model_info.extra_params,
)
elif request_type == RequestType.AUDIO:
assert audio_base64 is not None, "audio_base64 cannot be None for audio requests"
return await client.get_audio_transcriptions(
model_info=model_info,
audio_base64=audio_base64,
extra_params=model_info.extra_params,
)
except Exception as e:
logger.debug(f"请求失败: {str(e)}")
# 处理异常
total_tokens, penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty + 1)
wait_interval, compressed_messages = self._default_exception_handler(
e,
self.task_name,
model_name=model_info.name,
remain_try=retry_remain,
retry_interval=api_provider.retry_interval,
messages=(message_list, compressed_messages is not None) if message_list else None,
)
if wait_interval == -1:
retry_remain = 0 # 不再重试
elif wait_interval > 0:
logger.info(f"等待 {wait_interval} 秒后重试...")
await asyncio.sleep(wait_interval)
finally:
# 放在finally防止死循环
retry_remain -= 1
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry}")
raise RuntimeError("请求失败,已达到最大重试次数")
def _default_exception_handler(
self,
e: Exception,
task_name: str,
model_name: str,
remain_try: int,
retry_interval: int = 10,
messages: Tuple[List[Message], bool] | None = None,
) -> Tuple[int, List[Message] | None]:
"""
默认异常处理函数
Args:
e (Exception): 异常对象
task_name (str): 任务名称
model_name (str): 模型名称
remain_try (int): 剩余尝试次数
retry_interval (int): 重试间隔
messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过)
Returns:
(等待间隔如果为0则不等待为-1则不再请求该模型, 新的消息列表(适用于压缩消息))
"""
if isinstance(e, NetworkConnectionError): # 网络连接错误
return self._check_retry(
remain_try,
retry_interval,
can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,将于{retry_interval}秒后重试",
cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常超过最大重试次数请检查网络连接状态或URL是否正确",
)
elif isinstance(e, ReqAbortException):
logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}")
return -1, None # 不再重试请求该模型
elif isinstance(e, RespNotOkException):
return self._handle_resp_not_ok(
e,
task_name,
model_name,
remain_try,
retry_interval,
messages,
)
elif isinstance(e, RespParseException):
# 响应解析错误
logger.error(f"任务-'{task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}")
logger.debug(f"附加内容: {str(e.ext_info)}")
return -1, None # 不再重试请求该模型
else:
logger.error(f"任务-'{task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}")
return -1, None # 不再重试请求该模型
def _check_retry(
self,
remain_try: int,
retry_interval: int,
can_retry_msg: str,
cannot_retry_msg: str,
can_retry_callable: Callable | None = None,
**kwargs,
) -> Tuple[int, List[Message] | None]:
"""辅助函数:检查是否可以重试
Args:
remain_try (int): 剩余尝试次数
retry_interval (int): 重试间隔
can_retry_msg (str): 可以重试时的提示信息
cannot_retry_msg (str): 不可以重试时的提示信息
can_retry_callable (Callable | None): 可以重试时调用的函数(如果有)
**kwargs: 其他参数
Returns:
(Tuple[int, List[Message] | None]): (等待间隔如果为0则不等待为-1则不再请求该模型, 新的消息列表(适用于压缩消息))
"""
if remain_try > 0:
# 还有重试机会
logger.warning(f"{can_retry_msg}")
if can_retry_callable is not None:
return retry_interval, can_retry_callable(**kwargs)
else:
return retry_interval, None
else:
# 达到最大重试次数
logger.warning(f"{cannot_retry_msg}")
return -1, None # 不再重试请求该模型
def _handle_resp_not_ok(
self,
e: RespNotOkException,
task_name: str,
model_name: str,
remain_try: int,
retry_interval: int = 10,
messages: tuple[list[Message], bool] | None = None,
):
"""
处理响应错误异常
Args:
e (RespNotOkException): 响应错误异常对象
task_name (str): 任务名称
model_name (str): 模型名称
remain_try (int): 剩余尝试次数
retry_interval (int): 重试间隔
messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过)
Returns:
(等待间隔如果为0则不等待为-1则不再请求该模型, 新的消息列表(适用于压缩消息))
"""
# 响应错误
if e.status_code in [400, 401, 402, 403, 404]:
# 客户端错误
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}"
)
return -1, None # 不再重试请求该模型
elif e.status_code == 413:
if messages and not messages[1]:
# 消息列表不为空且未压缩,尝试压缩消息
return self._check_retry(
remain_try,
0,
can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试",
cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,压缩消息后仍然过大,放弃请求",
can_retry_callable=compress_messages,
messages=messages[0],
)
# 没有消息可压缩
logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,无法压缩消息,放弃请求。")
return -1, None
elif e.status_code == 429:
# 请求过于频繁
return self._check_retry(
remain_try,
retry_interval,
can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试",
cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数,放弃请求",
)
elif e.status_code >= 500:
# 服务器错误
return self._check_retry(
remain_try,
retry_interval,
can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,将于{retry_interval}秒后重试",
cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,超过最大重试次数,请稍后再试",
)
else:
# 未知错误
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}': 未知错误,错误代码-{e.status_code},错误信息-{e.message}"
)
return -1, None
def _build_tool_options(self, tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
"""构建工具选项列表"""
if not tools:
return None
tool_options: List[ToolOption] = []
for tool in tools:
tool_legal = True
tool_options_builder = ToolOptionBuilder()
tool_options_builder.set_name(tool.get("name", ""))
tool_options_builder.set_description(tool.get("description", ""))
parameters: List[Tuple[str, str, str, bool]] = tool.get("parameters", [])
for param in parameters:
try:
tool_options_builder.add_param(
name=param[0],
param_type=ToolParamType(param[1]),
description=param[2],
required=param[3],
)
except ValueError as ve:
tool_legal = False
logger.error(f"{param[1]} 参数类型错误: {str(ve)}")
except Exception as e:
tool_legal = False
logger.error(f"构建工具参数失败: {str(e)}")
if tool_legal:
tool_options.append(tool_options_builder.build())
return tool_options or None
@staticmethod
def _extract_reasoning(content: str) -> Tuple[str, str]:
"""CoT思维链提取向后兼容"""
match = re.search(r"(?:<think>)?(.*?)</think>", content, re.DOTALL)
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
reasoning = match[1].strip() if match else ""
return content, reasoning