style: 格式化代码
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
- **LLMRequest (主接口)**:
|
||||
作为模块的统一入口(Facade),为上层业务逻辑提供了简洁的接口来发起文本、图像、语音等不同类型的LLM请求。
|
||||
"""
|
||||
|
||||
import re
|
||||
import asyncio
|
||||
import time
|
||||
@@ -26,14 +27,13 @@ import string
|
||||
|
||||
from enum import Enum
|
||||
from rich.traceback import install
|
||||
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine, Generator
|
||||
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
|
||||
from .payload_content.message import MessageBuilder, Message
|
||||
from .payload_content.resp_format import RespFormat
|
||||
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType
|
||||
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder
|
||||
from .model_client.base_client import BaseClient, APIResponse, client_registry, UsageRecord
|
||||
from .utils import compress_messages, llm_usage_recorder
|
||||
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
|
||||
@@ -46,6 +46,7 @@ logger = get_logger("model_utils")
|
||||
# Standalone Utility Functions
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def _normalize_image_format(image_format: str) -> str:
|
||||
"""
|
||||
标准化图片格式名称,确保与各种API的兼容性
|
||||
@@ -57,17 +58,26 @@ def _normalize_image_format(image_format: str) -> str:
|
||||
str: 标准化后的图片格式
|
||||
"""
|
||||
format_mapping = {
|
||||
"jpg": "jpeg", "JPG": "jpeg", "JPEG": "jpeg", "jpeg": "jpeg",
|
||||
"png": "png", "PNG": "png",
|
||||
"webp": "webp", "WEBP": "webp",
|
||||
"gif": "gif", "GIF": "gif",
|
||||
"heic": "heic", "HEIC": "heic",
|
||||
"heif": "heif", "HEIF": "heif",
|
||||
"jpg": "jpeg",
|
||||
"JPG": "jpeg",
|
||||
"JPEG": "jpeg",
|
||||
"jpeg": "jpeg",
|
||||
"png": "png",
|
||||
"PNG": "png",
|
||||
"webp": "webp",
|
||||
"WEBP": "webp",
|
||||
"gif": "gif",
|
||||
"GIF": "gif",
|
||||
"heic": "heic",
|
||||
"HEIC": "heic",
|
||||
"heif": "heif",
|
||||
"HEIF": "heif",
|
||||
}
|
||||
normalized = format_mapping.get(image_format, image_format.lower())
|
||||
logger.debug(f"图片格式标准化: {image_format} -> {normalized}")
|
||||
return normalized
|
||||
|
||||
|
||||
async def execute_concurrently(
|
||||
coro_callable: Callable[..., Coroutine[Any, Any, Any]],
|
||||
concurrency_count: int,
|
||||
@@ -103,25 +113,29 @@ async def execute_concurrently(
|
||||
for i, res in enumerate(results):
|
||||
if isinstance(res, Exception):
|
||||
logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}")
|
||||
|
||||
|
||||
first_exception = next((res for res in results if isinstance(res, Exception)), None)
|
||||
if first_exception:
|
||||
raise first_exception
|
||||
raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息")
|
||||
|
||||
|
||||
class RequestType(Enum):
|
||||
"""请求类型枚举"""
|
||||
|
||||
RESPONSE = "response"
|
||||
EMBEDDING = "embedding"
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Helper Classes for LLMRequest Refactoring
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class _ModelSelector:
|
||||
"""负责模型选择、负载均衡和动态故障切换的策略。"""
|
||||
|
||||
|
||||
CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数
|
||||
DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量
|
||||
|
||||
@@ -168,16 +182,18 @@ class _ModelSelector:
|
||||
# - `usage_penalty * 1000`: 短期使用惩罚项。每次被选中后会增加,完成后会减少。高权重确保在多个模型都健康的情况下,请求会均匀分布(轮询)。
|
||||
least_used_model_name = min(
|
||||
candidate_models_usage,
|
||||
key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000,
|
||||
key=lambda k: candidate_models_usage[k][0]
|
||||
+ candidate_models_usage[k][1] * 300
|
||||
+ candidate_models_usage[k][2] * 1000,
|
||||
)
|
||||
|
||||
|
||||
model_info = model_config.get_model_info(least_used_model_name)
|
||||
api_provider = model_config.get_provider(model_info.api_provider)
|
||||
# 特殊处理:对于 embedding 任务,强制创建新的 aiohttp.ClientSession。
|
||||
# 这是为了避免在某些高并发场景下,共享的ClientSession可能引发的事件循环相关问题。
|
||||
force_new_client = request_type == "embedding"
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||
|
||||
|
||||
logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}")
|
||||
# 增加所选模型的请求使用惩罚值,以实现动态负载均衡。
|
||||
self.update_usage_penalty(model_info.name, increase=True)
|
||||
@@ -214,26 +230,32 @@ class _ModelSelector:
|
||||
if isinstance(e, (NetworkConnectionError, ReqAbortException)):
|
||||
# 网络连接错误或请求被中断,通常是基础设施问题,应重罚
|
||||
penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER
|
||||
logger.warning(f"模型 '{model_name}' 发生严重错误 ({type(e).__name__}),增加高额惩罚值: {penalty_increment}")
|
||||
logger.warning(
|
||||
f"模型 '{model_name}' 发生严重错误 ({type(e).__name__}),增加高额惩罚值: {penalty_increment}"
|
||||
)
|
||||
elif isinstance(e, RespNotOkException):
|
||||
# 对于HTTP响应错误,重点关注服务器端错误
|
||||
if e.status_code >= 500:
|
||||
# 5xx 错误表明服务器端出现问题,应重罚
|
||||
penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER
|
||||
logger.warning(f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加高额惩罚值: {penalty_increment}")
|
||||
logger.warning(
|
||||
f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加高额惩罚值: {penalty_increment}"
|
||||
)
|
||||
else:
|
||||
# 4xx 客户端错误通常不代表模型本身不可用,给予基础惩罚
|
||||
logger.warning(f"模型 '{model_name}' 发生客户端响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}")
|
||||
logger.warning(
|
||||
f"模型 '{model_name}' 发生客户端响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}"
|
||||
)
|
||||
else:
|
||||
# 其他未知异常,给予基础惩罚
|
||||
logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}")
|
||||
|
||||
|
||||
self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty)
|
||||
|
||||
|
||||
class _PromptProcessor:
|
||||
"""封装所有与提示词和响应内容的预处理和后处理逻辑。"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化提示处理器。
|
||||
@@ -276,18 +298,18 @@ class _PromptProcessor:
|
||||
"""
|
||||
# 步骤1: 根据API提供商的配置应用内容混淆
|
||||
processed_prompt = self._apply_content_obfuscation(prompt, api_provider)
|
||||
|
||||
|
||||
# 步骤2: 检查模型是否需要注入反截断指令
|
||||
if getattr(model_info, "use_anti_truncation", False):
|
||||
processed_prompt += self.anti_truncation_instruction
|
||||
logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。")
|
||||
|
||||
|
||||
return processed_prompt
|
||||
|
||||
def process_response(self, content: str, use_anti_truncation: bool) -> Tuple[str, str, bool]:
|
||||
"""
|
||||
处理响应内容,提取思维链并检查截断。
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[str, str, bool]: (处理后的内容, 思维链内容, 是否被截断)
|
||||
"""
|
||||
@@ -317,14 +339,14 @@ class _PromptProcessor:
|
||||
# 检查当前API提供商是否启用了内容混淆功能
|
||||
if not getattr(api_provider, "enable_content_obfuscation", False):
|
||||
return text
|
||||
|
||||
|
||||
# 获取混淆强度,默认为1
|
||||
intensity = getattr(api_provider, "obfuscation_intensity", 1)
|
||||
logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}")
|
||||
|
||||
|
||||
# 将抗审查指令和原始文本拼接
|
||||
processed_text = self.noise_instruction + "\n\n" + text
|
||||
|
||||
|
||||
# 在拼接后的文本中注入随机噪音
|
||||
return self._inject_random_noise(processed_text, intensity)
|
||||
|
||||
@@ -346,12 +368,12 @@ class _PromptProcessor:
|
||||
# 定义不同强度级别的噪音参数:概率和长度范围
|
||||
params = {
|
||||
1: {"probability": 15, "length": (3, 6)}, # 低强度
|
||||
2: {"probability": 25, "length": (5, 10)}, # 中强度
|
||||
3: {"probability": 35, "length": (8, 15)}, # 高强度
|
||||
2: {"probability": 25, "length": (5, 10)}, # 中强度
|
||||
3: {"probability": 35, "length": (8, 15)}, # 高强度
|
||||
}
|
||||
# 根据传入的强度选择配置,如果强度无效则使用默认值
|
||||
config = params.get(intensity, params[1])
|
||||
|
||||
|
||||
words = text.split()
|
||||
result = []
|
||||
# 遍历每个单词
|
||||
@@ -366,7 +388,7 @@ class _PromptProcessor:
|
||||
# 生成噪音字符串
|
||||
noise = "".join(random.choice(chars) for _ in range(noise_length))
|
||||
result.append(noise)
|
||||
|
||||
|
||||
# 将处理后的单词列表重新组合成字符串
|
||||
return " ".join(result)
|
||||
|
||||
@@ -396,7 +418,7 @@ class _PromptProcessor:
|
||||
else:
|
||||
reasoning = ""
|
||||
clean_content = content.strip()
|
||||
|
||||
|
||||
return clean_content, reasoning
|
||||
|
||||
|
||||
@@ -441,7 +463,7 @@ class _RequestExecutor:
|
||||
"""
|
||||
retry_remain = api_provider.max_retry
|
||||
compressed_messages: Optional[List[Message]] = None
|
||||
|
||||
|
||||
while retry_remain > 0:
|
||||
try:
|
||||
# 优先使用压缩后的消息列表
|
||||
@@ -451,11 +473,11 @@ class _RequestExecutor:
|
||||
# 根据请求类型调用不同的客户端方法
|
||||
if request_type == RequestType.RESPONSE:
|
||||
assert current_messages is not None, "message_list cannot be None for response requests"
|
||||
|
||||
|
||||
# 修复: 防止 'message_list' 在 kwargs 中重复传递
|
||||
request_params = kwargs.copy()
|
||||
request_params.pop("message_list", None)
|
||||
|
||||
|
||||
return await client.get_response(
|
||||
model_info=model_info, message_list=current_messages, **request_params
|
||||
)
|
||||
@@ -463,15 +485,19 @@ class _RequestExecutor:
|
||||
return await client.get_embedding(model_info=model_info, **kwargs)
|
||||
elif request_type == RequestType.AUDIO:
|
||||
return await client.get_audio_transcriptions(model_info=model_info, **kwargs)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"请求失败: {str(e)}")
|
||||
# 记录失败并更新模型的惩罚值
|
||||
self.model_selector.update_failure_penalty(model_info.name, e)
|
||||
|
||||
|
||||
# 处理异常,决定是否重试以及等待多久
|
||||
wait_interval, new_compressed_messages = self._handle_exception(
|
||||
e, model_info, api_provider, retry_remain, (kwargs.get("message_list"), compressed_messages is not None)
|
||||
e,
|
||||
model_info,
|
||||
api_provider,
|
||||
retry_remain,
|
||||
(kwargs.get("message_list"), compressed_messages is not None),
|
||||
)
|
||||
if new_compressed_messages:
|
||||
compressed_messages = new_compressed_messages # 更新为压缩后的消息
|
||||
@@ -482,7 +508,7 @@ class _RequestExecutor:
|
||||
await asyncio.sleep(wait_interval) # 等待指定时间后重试
|
||||
finally:
|
||||
retry_remain -= 1
|
||||
|
||||
|
||||
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次")
|
||||
raise RuntimeError("请求失败,已达到最大重试次数")
|
||||
|
||||
@@ -491,7 +517,7 @@ class _RequestExecutor:
|
||||
) -> Tuple[int, Optional[List[Message]]]:
|
||||
"""
|
||||
默认异常处理函数,决定是否重试。
|
||||
|
||||
|
||||
Returns:
|
||||
(等待间隔(-1表示不再重试), 新的消息列表(适用于压缩消息))
|
||||
"""
|
||||
@@ -534,7 +560,9 @@ class _RequestExecutor:
|
||||
model_name = model_info.name
|
||||
# 处理客户端错误 (400-404),这些错误通常是请求本身有问题,不应重试
|
||||
if e.status_code in [400, 401, 402, 403, 404]:
|
||||
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 客户端错误 {e.status_code} - {e.message},不再重试。")
|
||||
logger.warning(
|
||||
f"任务-'{self.task_name}' 模型-'{model_name}': 客户端错误 {e.status_code} - {e.message},不再重试。"
|
||||
)
|
||||
return -1, None
|
||||
# 处理请求体过大的情况
|
||||
elif e.status_code == 413:
|
||||
@@ -570,9 +598,11 @@ class _RequestExecutor:
|
||||
"""
|
||||
# 只有在剩余重试次数大于1时才进行下一次重试(因为当前这次失败已经消耗掉一次)
|
||||
if remain_try > 1:
|
||||
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},将于{interval}秒后重试 ({remain_try - 1}次剩余)。")
|
||||
logger.warning(
|
||||
f"任务-'{self.task_name}' 模型-'{model_name}': {reason},将于{interval}秒后重试 ({remain_try - 1}次剩余)。"
|
||||
)
|
||||
return interval, None
|
||||
|
||||
|
||||
# 如果已无剩余重试次数,则记录错误并返回-1表示放弃
|
||||
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},已达最大重试次数,放弃。")
|
||||
return -1, None
|
||||
@@ -585,7 +615,14 @@ class _RequestStrategy:
|
||||
即使在单个模型或API端点失败的情况下也能正常工作。
|
||||
"""
|
||||
|
||||
def __init__(self, model_selector: _ModelSelector, prompt_processor: _PromptProcessor, executor: _RequestExecutor, model_list: List[str], task_name: str):
|
||||
def __init__(
|
||||
self,
|
||||
model_selector: _ModelSelector,
|
||||
prompt_processor: _PromptProcessor,
|
||||
executor: _RequestExecutor,
|
||||
model_list: List[str],
|
||||
task_name: str,
|
||||
):
|
||||
"""
|
||||
初始化请求策略。
|
||||
|
||||
@@ -616,11 +653,13 @@ class _RequestStrategy:
|
||||
last_exception: Optional[Exception] = None
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
selection_result = self.model_selector.select_best_available_model(failed_models_in_this_request, str(request_type.value))
|
||||
selection_result = self.model_selector.select_best_available_model(
|
||||
failed_models_in_this_request, str(request_type.value)
|
||||
)
|
||||
if selection_result is None:
|
||||
logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。")
|
||||
break
|
||||
|
||||
|
||||
model_info, api_provider, client = selection_result
|
||||
logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_info.name}'...")
|
||||
|
||||
@@ -637,32 +676,36 @@ class _RequestStrategy:
|
||||
|
||||
# 合并模型特定的额外参数
|
||||
if model_info.extra_params:
|
||||
request_kwargs["extra_params"] = {**model_info.extra_params, **request_kwargs.get("extra_params", {})}
|
||||
request_kwargs["extra_params"] = {
|
||||
**model_info.extra_params,
|
||||
**request_kwargs.get("extra_params", {}),
|
||||
}
|
||||
|
||||
response = await self._try_model_request(
|
||||
model_info, api_provider, client, request_type, **request_kwargs
|
||||
)
|
||||
|
||||
response = await self._try_model_request(model_info, api_provider, client, request_type, **request_kwargs)
|
||||
|
||||
# 成功,立即返回
|
||||
logger.debug(f"模型 '{model_info.name}' 成功生成了回复。")
|
||||
self.model_selector.update_usage_penalty(model_info.name, increase=False)
|
||||
return response, model_info
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。")
|
||||
failed_models_in_this_request.add(model_info.name)
|
||||
last_exception = e
|
||||
# 使用惩罚值已在 select 时增加,失败后不减少,以降低其后续被选中的概率
|
||||
|
||||
|
||||
logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。")
|
||||
if raise_when_empty:
|
||||
if last_exception:
|
||||
raise RuntimeError("所有模型均未能生成响应。") from last_exception
|
||||
raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。")
|
||||
|
||||
|
||||
# 如果不抛出异常,返回一个备用响应
|
||||
fallback_model_info = model_config.get_model_info(self.model_list[0])
|
||||
return APIResponse(content="所有模型都请求失败"), fallback_model_info
|
||||
|
||||
|
||||
async def _try_model_request(
|
||||
self, model_info: ModelInfo, api_provider: APIProvider, client: BaseClient, request_type: RequestType, **kwargs
|
||||
) -> APIResponse:
|
||||
@@ -684,46 +727,49 @@ class _RequestStrategy:
|
||||
RuntimeError: 如果在达到最大重试次数后仍然收到空回复或截断的响应。
|
||||
"""
|
||||
max_empty_retry = api_provider.max_retry
|
||||
|
||||
|
||||
for i in range(max_empty_retry + 1):
|
||||
response = await self.executor.execute_request(
|
||||
api_provider, client, request_type, model_info, **kwargs
|
||||
)
|
||||
response = await self.executor.execute_request(api_provider, client, request_type, model_info, **kwargs)
|
||||
|
||||
if request_type != RequestType.RESPONSE:
|
||||
return response # 对于非响应类型,直接返回
|
||||
return response # 对于非响应类型,直接返回
|
||||
|
||||
# --- 响应内容处理和空回复/截断检查 ---
|
||||
content = response.content or ""
|
||||
use_anti_truncation = getattr(model_info, "use_anti_truncation", False)
|
||||
processed_content, reasoning, is_truncated = self.prompt_processor.process_response(content, use_anti_truncation)
|
||||
|
||||
processed_content, reasoning, is_truncated = self.prompt_processor.process_response(
|
||||
content, use_anti_truncation
|
||||
)
|
||||
|
||||
# 更新响应对象
|
||||
response.content = processed_content
|
||||
response.reasoning_content = response.reasoning_content or reasoning
|
||||
|
||||
is_empty_reply = not response.tool_calls and not (response.content and response.content.strip())
|
||||
|
||||
|
||||
if not is_empty_reply and not is_truncated:
|
||||
return response # 成功获取有效响应
|
||||
return response # 成功获取有效响应
|
||||
|
||||
if i < max_empty_retry:
|
||||
reason = "空回复" if is_empty_reply else "截断"
|
||||
logger.warning(f"模型 '{model_info.name}' 检测到{reason},正在进行内部重试 ({i + 1}/{max_empty_retry})...")
|
||||
logger.warning(
|
||||
f"模型 '{model_info.name}' 检测到{reason},正在进行内部重试 ({i + 1}/{max_empty_retry})..."
|
||||
)
|
||||
if api_provider.retry_interval > 0:
|
||||
await asyncio.sleep(api_provider.retry_interval)
|
||||
else:
|
||||
reason = "空回复" if is_empty_reply else "截断"
|
||||
logger.error(f"模型 '{model_info.name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。")
|
||||
raise RuntimeError(f"模型 '{model_info.name}' 已达到空回复/截断的最大内部重试次数。")
|
||||
|
||||
raise RuntimeError("内部重试逻辑错误") # 理论上不应到达这里
|
||||
|
||||
raise RuntimeError("内部重试逻辑错误") # 理论上不应到达这里
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main Facade Class
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class LLMRequest:
|
||||
"""
|
||||
LLM请求协调器。
|
||||
@@ -745,7 +791,7 @@ class LLMRequest:
|
||||
model: (0, 0, 0) for model in self.model_for_task.model_list
|
||||
}
|
||||
"""模型使用量记录,(total_tokens, penalty, usage_penalty)"""
|
||||
|
||||
|
||||
# 初始化辅助类
|
||||
self._model_selector = _ModelSelector(self.model_for_task.model_list, self.model_usage)
|
||||
self._prompt_processor = _PromptProcessor()
|
||||
@@ -769,36 +815,44 @@ class LLMRequest:
|
||||
prompt (str): 提示词
|
||||
image_base64 (str): 图像的Base64编码字符串
|
||||
image_format (str): 图像格式(如 'png', 'jpeg' 等)
|
||||
|
||||
|
||||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# 图像请求目前不使用复杂的故障转移策略,直接选择模型并执行
|
||||
selection_result = self._model_selector.select_best_available_model(set(), "response")
|
||||
if not selection_result:
|
||||
raise RuntimeError("无法为图像响应选择可用模型。")
|
||||
model_info, api_provider, client = selection_result
|
||||
|
||||
|
||||
normalized_format = _normalize_image_format(image_format)
|
||||
message = MessageBuilder().add_text_content(prompt).add_image_content(
|
||||
image_base64=image_base64,
|
||||
image_format=normalized_format,
|
||||
support_formats=client.get_support_image_formats(),
|
||||
).build()
|
||||
message = (
|
||||
MessageBuilder()
|
||||
.add_text_content(prompt)
|
||||
.add_image_content(
|
||||
image_base64=image_base64,
|
||||
image_format=normalized_format,
|
||||
support_formats=client.get_support_image_formats(),
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = await self._executor.execute_request(
|
||||
api_provider, client, RequestType.RESPONSE, model_info,
|
||||
api_provider,
|
||||
client,
|
||||
RequestType.RESPONSE,
|
||||
model_info,
|
||||
message_list=[message],
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
|
||||
await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
|
||||
content, reasoning, _ = self._prompt_processor.process_response(response.content or "", False)
|
||||
reasoning = response.reasoning_content or reasoning
|
||||
|
||||
|
||||
return content, (reasoning, model_info.name, response.tool_calls)
|
||||
|
||||
async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]:
|
||||
@@ -812,9 +866,7 @@ class LLMRequest:
|
||||
Returns:
|
||||
Optional[str]: 语音转换后的文本内容,如果所有模型都失败则返回None。
|
||||
"""
|
||||
response, _ = await self._strategy.execute_with_failover(
|
||||
RequestType.AUDIO, audio_base64=voice_base64
|
||||
)
|
||||
response, _ = await self._strategy.execute_with_failover(RequestType.AUDIO, audio_base64=voice_base64)
|
||||
return response.content or None
|
||||
|
||||
async def generate_response_async(
|
||||
@@ -834,7 +886,7 @@ class LLMRequest:
|
||||
max_tokens (int, optional): 最大token数
|
||||
tools: 工具配置
|
||||
raise_when_empty (bool): 是否在空回复时抛出异常
|
||||
|
||||
|
||||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
@@ -842,12 +894,16 @@ class LLMRequest:
|
||||
|
||||
if concurrency_count <= 1:
|
||||
return await self._execute_single_text_request(prompt, temperature, max_tokens, tools, raise_when_empty)
|
||||
|
||||
|
||||
try:
|
||||
return await execute_concurrently(
|
||||
self._execute_single_text_request,
|
||||
concurrency_count,
|
||||
prompt, temperature, max_tokens, tools, raise_when_empty=False
|
||||
prompt,
|
||||
temperature,
|
||||
max_tokens,
|
||||
tools,
|
||||
raise_when_empty=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"所有 {concurrency_count} 个并发请求都失败了: {e}")
|
||||
@@ -885,7 +941,7 @@ class LLMRequest:
|
||||
response, model_info = await self._strategy.execute_with_failover(
|
||||
RequestType.RESPONSE,
|
||||
raise_when_empty=raise_when_empty,
|
||||
prompt=prompt, # 传递原始prompt,由strategy处理
|
||||
prompt=prompt, # 传递原始prompt,由strategy处理
|
||||
tool_options=tool_options,
|
||||
temperature=self.model_for_task.temperature if temperature is None else temperature,
|
||||
max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
|
||||
@@ -906,21 +962,20 @@ class LLMRequest:
|
||||
|
||||
Args:
|
||||
embedding_input (str): 获取嵌入的目标
|
||||
|
||||
|
||||
Returns:
|
||||
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
|
||||
"""
|
||||
start_time = time.time()
|
||||
response, model_info = await self._strategy.execute_with_failover(
|
||||
RequestType.EMBEDDING,
|
||||
embedding_input=embedding_input
|
||||
RequestType.EMBEDDING, embedding_input=embedding_input
|
||||
)
|
||||
|
||||
|
||||
await self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings")
|
||||
|
||||
|
||||
if not response.embedding:
|
||||
raise RuntimeError("获取embedding失败")
|
||||
|
||||
|
||||
return response.embedding, model_info.name
|
||||
|
||||
async def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str):
|
||||
@@ -940,16 +995,18 @@ class LLMRequest:
|
||||
# 步骤1: 更新内存中的token计数,用于负载均衡
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||
self.model_usage[model_info.name] = (total_tokens + usage.total_tokens, penalty, usage_penalty)
|
||||
|
||||
|
||||
# 步骤2: 创建一个后台任务,将用量数据异步写入数据库
|
||||
asyncio.create_task(llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
user_id="system", # 此处可根据业务需求修改
|
||||
time_cost=time_cost,
|
||||
request_type=self.task_name,
|
||||
endpoint=endpoint,
|
||||
))
|
||||
asyncio.create_task(
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
user_id="system", # 此处可根据业务需求修改
|
||||
time_cost=time_cost,
|
||||
request_type=self.task_name,
|
||||
endpoint=endpoint,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
|
||||
@@ -970,14 +1027,14 @@ class LLMRequest:
|
||||
# 如果没有提供工具,直接返回 None
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
|
||||
tool_options: List[ToolOption] = []
|
||||
# 遍历每个工具定义
|
||||
for tool in tools:
|
||||
try:
|
||||
# 使用建造者模式创建 ToolOption
|
||||
builder = ToolOptionBuilder().set_name(tool["name"]).set_description(tool.get("description", ""))
|
||||
|
||||
|
||||
# 遍历工具的参数
|
||||
for param in tool.get("parameters", []):
|
||||
# 严格验证参数格式是否为包含5个元素的元组
|
||||
@@ -994,6 +1051,6 @@ class LLMRequest:
|
||||
except (KeyError, IndexError, TypeError, AssertionError) as e:
|
||||
# 如果构建过程中出现任何错误,记录日志并跳过该工具
|
||||
logger.error(f"构建工具 '{tool.get('name', 'N/A')}' 失败: {e}")
|
||||
|
||||
|
||||
# 如果列表非空则返回列表,否则返回 None
|
||||
return tool_options or None
|
||||
|
||||
Reference in New Issue
Block a user