This commit is contained in:
minecraft1024a
2025-10-02 21:58:10 +08:00
269 changed files with 5299 additions and 5319 deletions

View File

@@ -1,6 +1,5 @@
from typing import Any
# 常见Error Code Mapping (以OpenAI API为例)
error_code_mapping = {
400: "参数不正确",

View File

@@ -1,21 +1,24 @@
import asyncio
import orjson
import io
from typing import Callable, Any, Coroutine, Optional
import aiohttp
from collections.abc import Callable, Coroutine
from typing import Any
import aiohttp
import orjson
from src.config.api_ada_configs import ModelInfo, APIProvider
from src.common.logger import get_logger
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
from src.config.api_ada_configs import APIProvider, ModelInfo
from ..exceptions import (
RespParseException,
NetworkConnectionError,
RespNotOkException,
ReqAbortException,
RespNotOkException,
RespParseException,
)
from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat, RespFormatType
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
from ..payload_content.tool_option import ToolCall, ToolOption, ToolParam
from .base_client import APIResponse, BaseClient, UsageRecord, client_registry
logger = get_logger("AioHTTP-Gemini客户端")
@@ -210,7 +213,7 @@ class AiohttpGeminiStreamParser:
chunk_data = orjson.loads(chunk_text)
# 解析候选项
if "candidates" in chunk_data and chunk_data["candidates"]:
if chunk_data.get("candidates"):
candidate = chunk_data["candidates"][0]
# 解析内容
@@ -266,7 +269,7 @@ class AiohttpGeminiStreamParser:
async def _default_stream_response_handler(
response: aiohttp.ClientResponse,
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
) -> tuple[APIResponse, tuple[int, int, int] | None]:
"""默认流式响应处理器"""
parser = AiohttpGeminiStreamParser()
@@ -290,13 +293,13 @@ async def _default_stream_response_handler(
def _default_normal_response_parser(
response_data: dict,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
) -> tuple[APIResponse, tuple[int, int, int] | None]:
"""默认普通响应解析器"""
api_response = APIResponse()
try:
# 解析候选项
if "candidates" in response_data and response_data["candidates"]:
if response_data.get("candidates"):
candidate = response_data["candidates"][0]
# 解析文本内容
@@ -418,13 +421,12 @@ class AiohttpGeminiClient(BaseClient):
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[
[aiohttp.ClientResponse, asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
]
] = None,
async_response_parser: Optional[Callable[[dict], tuple[APIResponse, Optional[tuple[int, int, int]]]]] = None,
stream_response_handler: Callable[
[aiohttp.ClientResponse, asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, tuple[int, int, int] | None]],
]
| None = None,
async_response_parser: Callable[[dict], tuple[APIResponse, tuple[int, int, int] | None]] | None = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
@@ -452,9 +454,7 @@ class AiohttpGeminiClient(BaseClient):
# 构建请求体
request_data = {
"contents": contents,
"generationConfig": _build_generation_config(
max_tokens, temperature, tb, response_format, extra_params
),
"generationConfig": _build_generation_config(max_tokens, temperature, tb, response_format, extra_params),
}
# 添加系统指令

View File

@@ -1,12 +1,14 @@
import asyncio
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import Callable, Any, Optional
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from src.config.api_ada_configs import APIProvider, ModelInfo
from src.config.api_ada_configs import ModelInfo, APIProvider
from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolCall
from ..payload_content.tool_option import ToolCall, ToolOption
@dataclass
@@ -75,9 +77,8 @@ class BaseClient(ABC):
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
] = None,
stream_response_handler: Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
| None = None,
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,

View File

@@ -1,17 +1,17 @@
import asyncio
import io
import orjson
import re
import base64
from collections.abc import Iterable
from typing import Callable, Any, Coroutine, Optional
from json_repair import repair_json
import io
import re
from collections.abc import Callable, Coroutine, Iterable
from typing import Any
import orjson
from json_repair import repair_json
from openai import (
AsyncOpenAI,
NOT_GIVEN,
APIConnectionError,
APIStatusError,
NOT_GIVEN,
AsyncOpenAI,
AsyncStream,
)
from openai.types.chat import (
@@ -22,18 +22,19 @@ from openai.types.chat import (
)
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from src.config.api_ada_configs import ModelInfo, APIProvider
from src.common.logger import get_logger
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
from src.config.api_ada_configs import APIProvider, ModelInfo
from ..exceptions import (
RespParseException,
NetworkConnectionError,
RespNotOkException,
ReqAbortException,
RespNotOkException,
RespParseException,
)
from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
from ..payload_content.tool_option import ToolCall, ToolOption, ToolParam
from .base_client import APIResponse, BaseClient, UsageRecord, client_registry
logger = get_logger("OpenAI客户端")
@@ -241,7 +242,7 @@ def _build_stream_api_resp(
async def _default_stream_response_handler(
resp_stream: AsyncStream[ChatCompletionChunk],
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
) -> tuple[APIResponse, tuple[int, int, int] | None]:
"""
流式响应处理函数 - 处理OpenAI API的流式响应
:param resp_stream: 流式响应对象
@@ -315,7 +316,7 @@ pattern = re.compile(
def _default_normal_response_parser(
resp: ChatCompletion,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
) -> tuple[APIResponse, tuple[int, int, int] | None]:
"""
解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象
:param resp: 响应对象
@@ -394,15 +395,13 @@ class OpenaiClient(BaseClient):
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
]
] = None,
async_response_parser: Optional[
Callable[[ChatCompletion], tuple[APIResponse, Optional[tuple[int, int, int]]]]
] = None,
stream_response_handler: Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, tuple[int, int, int] | None]],
]
| None = None,
async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int] | None]]
| None = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
@@ -519,17 +518,17 @@ class OpenaiClient(BaseClient):
)
except APIConnectionError as e:
# 添加详细的错误信息以便调试
logger.error(f"OpenAI API连接错误嵌入模型: {str(e)}")
logger.error(f"OpenAI API连接错误嵌入模型: {e!s}")
logger.error(f"错误类型: {type(e)}")
if hasattr(e, "__cause__") and e.__cause__:
logger.error(f"底层错误: {str(e.__cause__)}")
logger.error(f"底层错误: {e.__cause__!s}")
raise NetworkConnectionError() from e
except APIStatusError as e:
# 重封装APIError为RespNotOkException
raise RespNotOkException(e.status_code) from e
except Exception as e:
# 添加通用异常处理和日志记录
logger.error(f"获取嵌入时发生未知错误: {str(e)}")
logger.error(f"获取嵌入时发生未知错误: {e!s}")
logger.error(f"错误类型: {type(e)}")
raise

View File

@@ -1,6 +1,5 @@
from enum import Enum
# 设计这系列类的目的是为未来可能的扩展做准备
@@ -58,7 +57,7 @@ class MessageBuilder:
self,
image_format: str,
image_base64: str,
support_formats=None, # 默认支持格式
support_formats=None, # 默认支持格式
) -> "MessageBuilder":
"""
添加图片内容

View File

@@ -1,8 +1,8 @@
from enum import Enum
from typing import Optional, Any
from typing import Any
from pydantic import BaseModel
from typing_extensions import TypedDict, Required
from typing_extensions import Required, TypedDict
class RespFormatType(Enum):
@@ -20,7 +20,7 @@ class JsonSchema(TypedDict, total=False):
of 64.
"""
description: Optional[str]
description: str | None
"""
A description of what the response format is for, used by the model to determine
how to respond in the format.
@@ -32,7 +32,7 @@ class JsonSchema(TypedDict, total=False):
to build JSON schemas [here](https://json-schema.org/).
"""
strict: Optional[bool]
strict: bool | None
"""
Whether to enable strict schema adherence when generating the output. If set to
true, the model will always follow the exact schema defined in the `schema`
@@ -100,7 +100,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
# 如果当前Schema是列表则遍历每个元素
for i in range(len(sub_schema)):
if isinstance(sub_schema[i], dict):
sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs)
sub_schema[i] = link_definitions_recursive(f"{path}/{i!s}", sub_schema[i], defs)
else:
# 否则为字典
if "$defs" in sub_schema:
@@ -140,8 +140,7 @@ def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]:
schema[idx] = _remove_title(item)
elif isinstance(schema, dict):
# 是字典移除title字段并对所有dict/list子元素递归调用
if "$defs" in schema:
del schema["$defs"]
schema.pop("$defs", None)
for key, value in schema.items():
if isinstance(value, (dict, list)):
schema[key] = _remove_title(value)

View File

@@ -1,14 +1,15 @@
import base64
import io
from PIL import Image
from datetime import datetime
from src.common.logger import get_logger
from PIL import Image
from src.common.database.sqlalchemy_models import LLMUsage, get_db_session
from src.common.logger import get_logger
from src.config.api_ada_configs import ModelInfo
from .payload_content.message import Message, MessageBuilder
from .model_client.base_client import UsageRecord
from .payload_content.message import Message, MessageBuilder
logger = get_logger("消息压缩工具")
@@ -38,7 +39,7 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
return image_data
except Exception as e:
logger.error(f"图片转换格式失败: {str(e)}")
logger.error(f"图片转换格式失败: {e!s}")
return image_data
def rescale_image(image_data: bytes, scale: float) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]:
@@ -87,7 +88,7 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
return output_buffer.getvalue(), original_size, new_size
except Exception as e:
logger.error(f"图片缩放失败: {str(e)}")
logger.error(f"图片缩放失败: {e!s}")
import traceback
logger.error(traceback.format_exc())
@@ -188,7 +189,7 @@ class LLMUsageRecorder:
f"总计: {model_usage.total_tokens}"
)
except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}")
logger.error(f"记录token使用情况失败: {e!s}")
llm_usage_recorder = LLMUsageRecorder()

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
@desc: 该模块封装了与大语言模型LLM交互的所有核心逻辑。
它被设计为一个高度容错和可扩展的系统,包含以下主要组件:
@@ -18,25 +17,27 @@
- **LLMRequest (主接口)**:
作为模块的统一入口Facade为上层业务逻辑提供了简洁的接口来发起文本、图像、语音等不同类型的LLM请求。
"""
import re
import asyncio
import time
import random
import string
import asyncio
import random
import re
import string
import time
from collections.abc import Callable, Coroutine
from enum import Enum
from typing import Any
from rich.traceback import install
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine, Generator
from src.common.logger import get_logger
from src.config.config import model_config
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
from .payload_content.message import MessageBuilder, Message
from .payload_content.resp_format import RespFormat
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType
from .model_client.base_client import BaseClient, APIResponse, client_registry, UsageRecord
from .utils import compress_messages, llm_usage_recorder
from src.config.config import model_config
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
from .model_client.base_client import APIResponse, BaseClient, UsageRecord, client_registry
from .payload_content.message import Message, MessageBuilder
from .payload_content.tool_option import ToolCall, ToolOption, ToolOptionBuilder
from .utils import compress_messages, llm_usage_recorder
install(extra_lines=3)
@@ -46,6 +47,7 @@ logger = get_logger("model_utils")
# Standalone Utility Functions
# ==============================================================================
def _normalize_image_format(image_format: str) -> str:
"""
标准化图片格式名称确保与各种API的兼容性
@@ -57,17 +59,26 @@ def _normalize_image_format(image_format: str) -> str:
str: 标准化后的图片格式
"""
format_mapping = {
"jpg": "jpeg", "JPG": "jpeg", "JPEG": "jpeg", "jpeg": "jpeg",
"png": "png", "PNG": "png",
"webp": "webp", "WEBP": "webp",
"gif": "gif", "GIF": "gif",
"heic": "heic", "HEIC": "heic",
"heif": "heif", "HEIF": "heif",
"jpg": "jpeg",
"JPG": "jpeg",
"JPEG": "jpeg",
"jpeg": "jpeg",
"png": "png",
"PNG": "png",
"webp": "webp",
"WEBP": "webp",
"gif": "gif",
"GIF": "gif",
"heic": "heic",
"HEIC": "heic",
"heif": "heif",
"HEIF": "heif",
}
normalized = format_mapping.get(image_format, image_format.lower())
logger.debug(f"图片格式标准化: {image_format} -> {normalized}")
return normalized
async def execute_concurrently(
coro_callable: Callable[..., Coroutine[Any, Any, Any]],
concurrency_count: int,
@@ -103,29 +114,33 @@ async def execute_concurrently(
for i, res in enumerate(results):
if isinstance(res, Exception):
logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}")
first_exception = next((res for res in results if isinstance(res, Exception)), None)
if first_exception:
raise first_exception
raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息")
class RequestType(Enum):
"""请求类型枚举"""
RESPONSE = "response"
EMBEDDING = "embedding"
AUDIO = "audio"
# ==============================================================================
# Helper Classes for LLMRequest Refactoring
# ==============================================================================
class _ModelSelector:
"""负责模型选择、负载均衡和动态故障切换的策略。"""
CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数
DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量
def __init__(self, model_list: List[str], model_usage: Dict[str, Tuple[int, int, int]]):
def __init__(self, model_list: list[str], model_usage: dict[str, tuple[int, int, int]]):
"""
初始化模型选择器。
@@ -139,7 +154,7 @@ class _ModelSelector:
def select_best_available_model(
self, failed_models_in_this_request: set, request_type: str
) -> Optional[Tuple[ModelInfo, APIProvider, BaseClient]]:
) -> tuple[ModelInfo, APIProvider, BaseClient] | None:
"""
从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。
@@ -168,16 +183,18 @@ class _ModelSelector:
# - `usage_penalty * 1000`: 短期使用惩罚项。每次被选中后会增加,完成后会减少。高权重确保在多个模型都健康的情况下,请求会均匀分布(轮询)。
least_used_model_name = min(
candidate_models_usage,
key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000,
key=lambda k: candidate_models_usage[k][0]
+ candidate_models_usage[k][1] * 300
+ candidate_models_usage[k][2] * 1000,
)
model_info = model_config.get_model_info(least_used_model_name)
api_provider = model_config.get_provider(model_info.api_provider)
# 特殊处理:对于 embedding 任务,强制创建新的 aiohttp.ClientSession。
# 这是为了避免在某些高并发场景下共享的ClientSession可能引发的事件循环相关问题。
force_new_client = request_type == "embedding"
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}")
# 增加所选模型的请求使用惩罚值,以实现动态负载均衡。
self.update_usage_penalty(model_info.name, increase=True)
@@ -214,26 +231,32 @@ class _ModelSelector:
if isinstance(e, (NetworkConnectionError, ReqAbortException)):
# 网络连接错误或请求被中断,通常是基础设施问题,应重罚
penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER
logger.warning(f"模型 '{model_name}' 发生严重错误 ({type(e).__name__}),增加高额惩罚值: {penalty_increment}")
logger.warning(
f"模型 '{model_name}' 发生严重错误 ({type(e).__name__}),增加高额惩罚值: {penalty_increment}"
)
elif isinstance(e, RespNotOkException):
# 对于HTTP响应错误重点关注服务器端错误
if e.status_code >= 500:
# 5xx 错误表明服务器端出现问题,应重罚
penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER
logger.warning(f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加高额惩罚值: {penalty_increment}")
logger.warning(
f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加高额惩罚值: {penalty_increment}"
)
else:
# 4xx 客户端错误通常不代表模型本身不可用,给予基础惩罚
logger.warning(f"模型 '{model_name}' 发生客户端响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}")
logger.warning(
f"模型 '{model_name}' 发生客户端响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}"
)
else:
# 其他未知异常,给予基础惩罚
logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}")
self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty)
class _PromptProcessor:
"""封装所有与提示词和响应内容的预处理和后处理逻辑。"""
def __init__(self):
"""
初始化提示处理器。
@@ -276,18 +299,18 @@ class _PromptProcessor:
"""
# 步骤1: 根据API提供商的配置应用内容混淆
processed_prompt = self._apply_content_obfuscation(prompt, api_provider)
# 步骤2: 检查模型是否需要注入反截断指令
if getattr(model_info, "use_anti_truncation", False):
processed_prompt += self.anti_truncation_instruction
logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。")
return processed_prompt
def process_response(self, content: str, use_anti_truncation: bool) -> Tuple[str, str, bool]:
def process_response(self, content: str, use_anti_truncation: bool) -> tuple[str, str, bool]:
"""
处理响应内容,提取思维链并检查截断。
Returns:
Tuple[str, str, bool]: (处理后的内容, 思维链内容, 是否被截断)
"""
@@ -317,14 +340,14 @@ class _PromptProcessor:
# 检查当前API提供商是否启用了内容混淆功能
if not getattr(api_provider, "enable_content_obfuscation", False):
return text
# 获取混淆强度默认为1
intensity = getattr(api_provider, "obfuscation_intensity", 1)
logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}")
# 将抗审查指令和原始文本拼接
processed_text = self.noise_instruction + "\n\n" + text
# 在拼接后的文本中注入随机噪音
return self._inject_random_noise(processed_text, intensity)
@@ -346,12 +369,12 @@ class _PromptProcessor:
# 定义不同强度级别的噪音参数:概率和长度范围
params = {
1: {"probability": 15, "length": (3, 6)}, # 低强度
2: {"probability": 25, "length": (5, 10)}, # 中强度
3: {"probability": 35, "length": (8, 15)}, # 高强度
2: {"probability": 25, "length": (5, 10)}, # 中强度
3: {"probability": 35, "length": (8, 15)}, # 高强度
}
# 根据传入的强度选择配置,如果强度无效则使用默认值
config = params.get(intensity, params[1])
words = text.split()
result = []
# 遍历每个单词
@@ -366,12 +389,12 @@ class _PromptProcessor:
# 生成噪音字符串
noise = "".join(random.choice(chars) for _ in range(noise_length))
result.append(noise)
# 将处理后的单词列表重新组合成字符串
return " ".join(result)
@staticmethod
def _extract_reasoning(content: str) -> Tuple[str, str]:
def _extract_reasoning(content: str) -> tuple[str, str]:
"""
从模型返回的完整内容中提取被<think>...</think>标签包裹的思考过程,
并返回清理后的内容和思考过程。
@@ -396,7 +419,7 @@ class _PromptProcessor:
else:
reasoning = ""
clean_content = content.strip()
return clean_content, reasoning
@@ -440,8 +463,8 @@ class _RequestExecutor:
RuntimeError: 如果达到最大重试次数。
"""
retry_remain = api_provider.max_retry
compressed_messages: Optional[List[Message]] = None
compressed_messages: list[Message] | None = None
while retry_remain > 0:
try:
# 优先使用压缩后的消息列表
@@ -451,11 +474,11 @@ class _RequestExecutor:
# 根据请求类型调用不同的客户端方法
if request_type == RequestType.RESPONSE:
assert current_messages is not None, "message_list cannot be None for response requests"
# 修复: 防止 'message_list' 在 kwargs 中重复传递
request_params = kwargs.copy()
request_params.pop("message_list", None)
return await client.get_response(
model_info=model_info, message_list=current_messages, **request_params
)
@@ -463,15 +486,19 @@ class _RequestExecutor:
return await client.get_embedding(model_info=model_info, **kwargs)
elif request_type == RequestType.AUDIO:
return await client.get_audio_transcriptions(model_info=model_info, **kwargs)
except Exception as e:
logger.debug(f"请求失败: {str(e)}")
logger.debug(f"请求失败: {e!s}")
# 记录失败并更新模型的惩罚值
self.model_selector.update_failure_penalty(model_info.name, e)
# 处理异常,决定是否重试以及等待多久
wait_interval, new_compressed_messages = self._handle_exception(
e, model_info, api_provider, retry_remain, (kwargs.get("message_list"), compressed_messages is not None)
e,
model_info,
api_provider,
retry_remain,
(kwargs.get("message_list"), compressed_messages is not None),
)
if new_compressed_messages:
compressed_messages = new_compressed_messages # 更新为压缩后的消息
@@ -482,16 +509,16 @@ class _RequestExecutor:
await asyncio.sleep(wait_interval) # 等待指定时间后重试
finally:
retry_remain -= 1
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry}")
raise RuntimeError("请求失败,已达到最大重试次数")
def _handle_exception(
self, e: Exception, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
) -> Tuple[int, Optional[List[Message]]]:
) -> tuple[int, list[Message] | None]:
"""
默认异常处理函数,决定是否重试。
Returns:
(等待间隔(-1表示不再重试, 新的消息列表(适用于压缩消息))
"""
@@ -506,12 +533,12 @@ class _RequestExecutor:
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误 - {e.message}")
return -1, None
else:
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {str(e)}")
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {e!s}")
return -1, None
def _handle_resp_not_ok(
self, e: RespNotOkException, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
) -> Tuple[int, Optional[List[Message]]]:
) -> tuple[int, list[Message] | None]:
"""
处理非200的HTTP响应异常。
@@ -534,7 +561,9 @@ class _RequestExecutor:
model_name = model_info.name
# 处理客户端错误 (400-404),这些错误通常是请求本身有问题,不应重试
if e.status_code in [400, 401, 402, 403, 404]:
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 客户端错误 {e.status_code} - {e.message},不再重试。")
logger.warning(
f"任务-'{self.task_name}' 模型-'{model_name}': 客户端错误 {e.status_code} - {e.message},不再重试。"
)
return -1, None
# 处理请求体过大的情况
elif e.status_code == 413:
@@ -555,7 +584,7 @@ class _RequestExecutor:
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知响应错误 {e.status_code} - {e.message}")
return -1, None
def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> Tuple[int, None]:
def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> tuple[int, None]:
"""
辅助函数,根据剩余次数决定是否进行下一次重试。
@@ -570,9 +599,11 @@ class _RequestExecutor:
"""
# 只有在剩余重试次数大于1时才进行下一次重试因为当前这次失败已经消耗掉一次
if remain_try > 1:
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},将于{interval}秒后重试 ({remain_try - 1}次剩余)。")
logger.warning(
f"任务-'{self.task_name}' 模型-'{model_name}': {reason},将于{interval}秒后重试 ({remain_try - 1}次剩余)。"
)
return interval, None
# 如果已无剩余重试次数,则记录错误并返回-1表示放弃
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},已达最大重试次数,放弃。")
return -1, None
@@ -585,7 +616,14 @@ class _RequestStrategy:
即使在单个模型或API端点失败的情况下也能正常工作。
"""
def __init__(self, model_selector: _ModelSelector, prompt_processor: _PromptProcessor, executor: _RequestExecutor, model_list: List[str], task_name: str):
def __init__(
self,
model_selector: _ModelSelector,
prompt_processor: _PromptProcessor,
executor: _RequestExecutor,
model_list: list[str],
task_name: str,
):
"""
初始化请求策略。
@@ -607,20 +645,22 @@ class _RequestStrategy:
request_type: RequestType,
raise_when_empty: bool = True,
**kwargs,
) -> Tuple[APIResponse, ModelInfo]:
) -> tuple[APIResponse, ModelInfo]:
"""
执行请求,动态选择最佳可用模型,并在模型失败时进行故障转移。
"""
failed_models_in_this_request = set()
max_attempts = len(self.model_list)
last_exception: Optional[Exception] = None
last_exception: Exception | None = None
for attempt in range(max_attempts):
selection_result = self.model_selector.select_best_available_model(failed_models_in_this_request, str(request_type.value))
selection_result = self.model_selector.select_best_available_model(
failed_models_in_this_request, str(request_type.value)
)
if selection_result is None:
logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。")
break
model_info, api_provider, client = selection_result
logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_info.name}'...")
@@ -637,32 +677,36 @@ class _RequestStrategy:
# 合并模型特定的额外参数
if model_info.extra_params:
request_kwargs["extra_params"] = {**model_info.extra_params, **request_kwargs.get("extra_params", {})}
request_kwargs["extra_params"] = {
**model_info.extra_params,
**request_kwargs.get("extra_params", {}),
}
response = await self._try_model_request(
model_info, api_provider, client, request_type, **request_kwargs
)
response = await self._try_model_request(model_info, api_provider, client, request_type, **request_kwargs)
# 成功,立即返回
logger.debug(f"模型 '{model_info.name}' 成功生成了回复。")
self.model_selector.update_usage_penalty(model_info.name, increase=False)
return response, model_info
except Exception as e:
logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。")
failed_models_in_this_request.add(model_info.name)
last_exception = e
# 使用惩罚值已在 select 时增加,失败后不减少,以降低其后续被选中的概率
logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。")
if raise_when_empty:
if last_exception:
raise RuntimeError("所有模型均未能生成响应。") from last_exception
raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。")
# 如果不抛出异常,返回一个备用响应
fallback_model_info = model_config.get_model_info(self.model_list[0])
return APIResponse(content="所有模型都请求失败"), fallback_model_info
async def _try_model_request(
self, model_info: ModelInfo, api_provider: APIProvider, client: BaseClient, request_type: RequestType, **kwargs
) -> APIResponse:
@@ -684,46 +728,49 @@ class _RequestStrategy:
RuntimeError: 如果在达到最大重试次数后仍然收到空回复或截断的响应。
"""
max_empty_retry = api_provider.max_retry
for i in range(max_empty_retry + 1):
response = await self.executor.execute_request(
api_provider, client, request_type, model_info, **kwargs
)
response = await self.executor.execute_request(api_provider, client, request_type, model_info, **kwargs)
if request_type != RequestType.RESPONSE:
return response # 对于非响应类型,直接返回
return response # 对于非响应类型,直接返回
# --- 响应内容处理和空回复/截断检查 ---
content = response.content or ""
use_anti_truncation = getattr(model_info, "use_anti_truncation", False)
processed_content, reasoning, is_truncated = self.prompt_processor.process_response(content, use_anti_truncation)
processed_content, reasoning, is_truncated = self.prompt_processor.process_response(
content, use_anti_truncation
)
# 更新响应对象
response.content = processed_content
response.reasoning_content = response.reasoning_content or reasoning
is_empty_reply = not response.tool_calls and not (response.content and response.content.strip())
if not is_empty_reply and not is_truncated:
return response # 成功获取有效响应
return response # 成功获取有效响应
if i < max_empty_retry:
reason = "空回复" if is_empty_reply else "截断"
logger.warning(f"模型 '{model_info.name}' 检测到{reason},正在进行内部重试 ({i + 1}/{max_empty_retry})...")
logger.warning(
f"模型 '{model_info.name}' 检测到{reason},正在进行内部重试 ({i + 1}/{max_empty_retry})..."
)
if api_provider.retry_interval > 0:
await asyncio.sleep(api_provider.retry_interval)
else:
reason = "空回复" if is_empty_reply else "截断"
logger.error(f"模型 '{model_info.name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。")
raise RuntimeError(f"模型 '{model_info.name}' 已达到空回复/截断的最大内部重试次数。")
raise RuntimeError("内部重试逻辑错误") # 理论上不应到达这里
raise RuntimeError("内部重试逻辑错误") # 理论上不应到达这里
# ==============================================================================
# Main Facade Class
# ==============================================================================
class LLMRequest:
"""
LLM请求协调器。
@@ -741,11 +788,9 @@ class LLMRequest:
"""
self.task_name = request_type
self.model_for_task = model_set
self.model_usage: Dict[str, Tuple[int, int, int]] = {
model: (0, 0, 0) for model in self.model_for_task.model_list
}
self.model_usage: dict[str, tuple[int, int, int]] = dict.fromkeys(self.model_for_task.model_list, (0, 0, 0))
"""模型使用量记录,(total_tokens, penalty, usage_penalty)"""
# 初始化辅助类
self._model_selector = _ModelSelector(self.model_for_task.model_list, self.model_usage)
self._prompt_processor = _PromptProcessor()
@@ -759,9 +804,9 @@ class LLMRequest:
prompt: str,
image_base64: str,
image_format: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
temperature: float | None = None,
max_tokens: int | None = None,
) -> tuple[str, tuple[str, str, list[ToolCall] | None]]:
"""
为图像生成响应。
@@ -769,39 +814,47 @@ class LLMRequest:
prompt (str): 提示词
image_base64 (str): 图像的Base64编码字符串
image_format (str): 图像格式(如 'png', 'jpeg' 等)
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
start_time = time.time()
# 图像请求目前不使用复杂的故障转移策略,直接选择模型并执行
selection_result = self._model_selector.select_best_available_model(set(), "response")
if not selection_result:
raise RuntimeError("无法为图像响应选择可用模型。")
model_info, api_provider, client = selection_result
normalized_format = _normalize_image_format(image_format)
message = MessageBuilder().add_text_content(prompt).add_image_content(
image_base64=image_base64,
image_format=normalized_format,
support_formats=client.get_support_image_formats(),
).build()
message = (
MessageBuilder()
.add_text_content(prompt)
.add_image_content(
image_base64=image_base64,
image_format=normalized_format,
support_formats=client.get_support_image_formats(),
)
.build()
)
response = await self._executor.execute_request(
api_provider, client, RequestType.RESPONSE, model_info,
api_provider,
client,
RequestType.RESPONSE,
model_info,
message_list=[message],
temperature=temperature,
max_tokens=max_tokens,
)
await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
content, reasoning, _ = self._prompt_processor.process_response(response.content or "", False)
reasoning = response.reasoning_content or reasoning
return content, (reasoning, model_info.name, response.tool_calls)
async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]:
async def generate_response_for_voice(self, voice_base64: str) -> str | None:
"""
为语音生成响应(语音转文字)。
使用故障转移策略来确保即使主模型失败也能获得结果。
@@ -812,19 +865,17 @@ class LLMRequest:
Returns:
Optional[str]: 语音转换后的文本内容如果所有模型都失败则返回None。
"""
response, _ = await self._strategy.execute_with_failover(
RequestType.AUDIO, audio_base64=voice_base64
)
response, _ = await self._strategy.execute_with_failover(RequestType.AUDIO, audio_base64=voice_base64)
return response.content or None
async def generate_response_async(
self,
prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
tools: Optional[List[Dict[str, Any]]] = None,
temperature: float | None = None,
max_tokens: int | None = None,
tools: list[dict[str, Any]] | None = None,
raise_when_empty: bool = True,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
) -> tuple[str, tuple[str, str, list[ToolCall] | None]]:
"""
异步生成响应,支持并发请求。
@@ -834,7 +885,7 @@ class LLMRequest:
max_tokens (int, optional): 最大token数
tools: 工具配置
raise_when_empty (bool): 是否在空回复时抛出异常
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
@@ -842,12 +893,16 @@ class LLMRequest:
if concurrency_count <= 1:
return await self._execute_single_text_request(prompt, temperature, max_tokens, tools, raise_when_empty)
try:
return await execute_concurrently(
self._execute_single_text_request,
concurrency_count,
prompt, temperature, max_tokens, tools, raise_when_empty=False
prompt,
temperature,
max_tokens,
tools,
raise_when_empty=False,
)
except Exception as e:
logger.error(f"所有 {concurrency_count} 个并发请求都失败了: {e}")
@@ -858,11 +913,11 @@ class LLMRequest:
async def _execute_single_text_request(
self,
prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
tools: Optional[List[Dict[str, Any]]] = None,
temperature: float | None = None,
max_tokens: int | None = None,
tools: list[dict[str, Any]] | None = None,
raise_when_empty: bool = True,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
) -> tuple[str, tuple[str, str, list[ToolCall] | None]]:
"""
执行单次文本生成请求的内部方法。
这是 `generate_response_async` 的核心实现,处理单个请求的完整生命周期,
@@ -885,7 +940,7 @@ class LLMRequest:
response, model_info = await self._strategy.execute_with_failover(
RequestType.RESPONSE,
raise_when_empty=raise_when_empty,
prompt=prompt, # 传递原始prompt由strategy处理
prompt=prompt, # 传递原始prompt由strategy处理
tool_options=tool_options,
temperature=self.model_for_task.temperature if temperature is None else temperature,
max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
@@ -900,30 +955,29 @@ class LLMRequest:
return response.content or "", (response.reasoning_content or "", model_info.name, response.tool_calls)
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
async def get_embedding(self, embedding_input: str) -> tuple[list[float], str]:
"""
获取嵌入向量。
Args:
embedding_input (str): 获取嵌入的目标
Returns:
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
"""
start_time = time.time()
response, model_info = await self._strategy.execute_with_failover(
RequestType.EMBEDDING,
embedding_input=embedding_input
RequestType.EMBEDDING, embedding_input=embedding_input
)
await self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings")
if not response.embedding:
raise RuntimeError("获取embedding失败")
return response.embedding, model_info.name
async def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str):
async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord | None, time_cost: float, endpoint: str):
"""
记录模型使用情况。
@@ -940,19 +994,21 @@ class LLMRequest:
# 步骤1: 更新内存中的token计数用于负载均衡
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens + usage.total_tokens, penalty, usage_penalty)
# 步骤2: 创建一个后台任务,将用量数据异步写入数据库
asyncio.create_task(llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system", # 此处可根据业务需求修改
time_cost=time_cost,
request_type=self.task_name,
endpoint=endpoint,
))
asyncio.create_task(
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system", # 此处可根据业务需求修改
time_cost=time_cost,
request_type=self.task_name,
endpoint=endpoint,
)
)
@staticmethod
def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
def _build_tool_options(tools: list[dict[str, Any]] | None) -> list[ToolOption] | None:
"""
根据输入的字典列表构建并验证 `ToolOption` 对象列表。
@@ -970,14 +1026,14 @@ class LLMRequest:
# 如果没有提供工具,直接返回 None
if not tools:
return None
tool_options: List[ToolOption] = []
tool_options: list[ToolOption] = []
# 遍历每个工具定义
for tool in tools:
try:
# 使用建造者模式创建 ToolOption
builder = ToolOptionBuilder().set_name(tool["name"]).set_description(tool.get("description", ""))
# 遍历工具的参数
for param in tool.get("parameters", []):
# 严格验证参数格式是否为包含5个元素的元组
@@ -994,6 +1050,6 @@ class LLMRequest:
except (KeyError, IndexError, TypeError, AssertionError) as e:
# 如果构建过程中出现任何错误,记录日志并跳过该工具
logger.error(f"构建工具 '{tool.get('name', 'N/A')}' 失败: {e}")
# 如果列表非空则返回列表,否则返回 None
return tool_options or None