Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
# 常见Error Code Mapping (以OpenAI API为例)
|
||||
error_code_mapping = {
|
||||
400: "参数不正确",
|
||||
|
||||
@@ -1,21 +1,24 @@
|
||||
import asyncio
|
||||
import orjson
|
||||
import io
|
||||
from typing import Callable, Any, Coroutine, Optional
|
||||
import aiohttp
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
|
||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||
from src.common.logger import get_logger
|
||||
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo
|
||||
|
||||
from ..exceptions import (
|
||||
RespParseException,
|
||||
NetworkConnectionError,
|
||||
RespNotOkException,
|
||||
ReqAbortException,
|
||||
RespNotOkException,
|
||||
RespParseException,
|
||||
)
|
||||
from ..payload_content.message import Message, RoleType
|
||||
from ..payload_content.resp_format import RespFormat, RespFormatType
|
||||
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
||||
from ..payload_content.tool_option import ToolCall, ToolOption, ToolParam
|
||||
from .base_client import APIResponse, BaseClient, UsageRecord, client_registry
|
||||
|
||||
logger = get_logger("AioHTTP-Gemini客户端")
|
||||
|
||||
@@ -210,7 +213,7 @@ class AiohttpGeminiStreamParser:
|
||||
chunk_data = orjson.loads(chunk_text)
|
||||
|
||||
# 解析候选项
|
||||
if "candidates" in chunk_data and chunk_data["candidates"]:
|
||||
if chunk_data.get("candidates"):
|
||||
candidate = chunk_data["candidates"][0]
|
||||
|
||||
# 解析内容
|
||||
@@ -266,7 +269,7 @@ class AiohttpGeminiStreamParser:
|
||||
async def _default_stream_response_handler(
|
||||
response: aiohttp.ClientResponse,
|
||||
interrupt_flag: asyncio.Event | None,
|
||||
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
||||
) -> tuple[APIResponse, tuple[int, int, int] | None]:
|
||||
"""默认流式响应处理器"""
|
||||
parser = AiohttpGeminiStreamParser()
|
||||
|
||||
@@ -290,13 +293,13 @@ async def _default_stream_response_handler(
|
||||
|
||||
def _default_normal_response_parser(
|
||||
response_data: dict,
|
||||
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
||||
) -> tuple[APIResponse, tuple[int, int, int] | None]:
|
||||
"""默认普通响应解析器"""
|
||||
api_response = APIResponse()
|
||||
|
||||
try:
|
||||
# 解析候选项
|
||||
if "candidates" in response_data and response_data["candidates"]:
|
||||
if response_data.get("candidates"):
|
||||
candidate = response_data["candidates"][0]
|
||||
|
||||
# 解析文本内容
|
||||
@@ -418,13 +421,12 @@ class AiohttpGeminiClient(BaseClient):
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 0.7,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[
|
||||
[aiohttp.ClientResponse, asyncio.Event | None],
|
||||
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
|
||||
]
|
||||
] = None,
|
||||
async_response_parser: Optional[Callable[[dict], tuple[APIResponse, Optional[tuple[int, int, int]]]]] = None,
|
||||
stream_response_handler: Callable[
|
||||
[aiohttp.ClientResponse, asyncio.Event | None],
|
||||
Coroutine[Any, Any, tuple[APIResponse, tuple[int, int, int] | None]],
|
||||
]
|
||||
| None = None,
|
||||
async_response_parser: Callable[[dict], tuple[APIResponse, tuple[int, int, int] | None]] | None = None,
|
||||
interrupt_flag: asyncio.Event | None = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
@@ -452,9 +454,7 @@ class AiohttpGeminiClient(BaseClient):
|
||||
# 构建请求体
|
||||
request_data = {
|
||||
"contents": contents,
|
||||
"generationConfig": _build_generation_config(
|
||||
max_tokens, temperature, tb, response_format, extra_params
|
||||
),
|
||||
"generationConfig": _build_generation_config(max_tokens, temperature, tb, response_format, extra_params),
|
||||
}
|
||||
|
||||
# 添加系统指令
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Any, Optional
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo
|
||||
|
||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||
from ..payload_content.message import Message
|
||||
from ..payload_content.resp_format import RespFormat
|
||||
from ..payload_content.tool_option import ToolOption, ToolCall
|
||||
from ..payload_content.tool_option import ToolCall, ToolOption
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -75,9 +77,8 @@ class BaseClient(ABC):
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 0.7,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
|
||||
] = None,
|
||||
stream_response_handler: Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
|
||||
| None = None,
|
||||
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
|
||||
interrupt_flag: asyncio.Event | None = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import asyncio
|
||||
import io
|
||||
import orjson
|
||||
import re
|
||||
import base64
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, Any, Coroutine, Optional
|
||||
from json_repair import repair_json
|
||||
import io
|
||||
import re
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
from openai import (
|
||||
AsyncOpenAI,
|
||||
NOT_GIVEN,
|
||||
APIConnectionError,
|
||||
APIStatusError,
|
||||
NOT_GIVEN,
|
||||
AsyncOpenAI,
|
||||
AsyncStream,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
@@ -22,18 +22,19 @@ from openai.types.chat import (
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
|
||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||
from src.common.logger import get_logger
|
||||
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo
|
||||
|
||||
from ..exceptions import (
|
||||
RespParseException,
|
||||
NetworkConnectionError,
|
||||
RespNotOkException,
|
||||
ReqAbortException,
|
||||
RespNotOkException,
|
||||
RespParseException,
|
||||
)
|
||||
from ..payload_content.message import Message, RoleType
|
||||
from ..payload_content.resp_format import RespFormat
|
||||
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
||||
from ..payload_content.tool_option import ToolCall, ToolOption, ToolParam
|
||||
from .base_client import APIResponse, BaseClient, UsageRecord, client_registry
|
||||
|
||||
logger = get_logger("OpenAI客户端")
|
||||
|
||||
@@ -241,7 +242,7 @@ def _build_stream_api_resp(
|
||||
async def _default_stream_response_handler(
|
||||
resp_stream: AsyncStream[ChatCompletionChunk],
|
||||
interrupt_flag: asyncio.Event | None,
|
||||
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
||||
) -> tuple[APIResponse, tuple[int, int, int] | None]:
|
||||
"""
|
||||
流式响应处理函数 - 处理OpenAI API的流式响应
|
||||
:param resp_stream: 流式响应对象
|
||||
@@ -315,7 +316,7 @@ pattern = re.compile(
|
||||
|
||||
def _default_normal_response_parser(
|
||||
resp: ChatCompletion,
|
||||
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
||||
) -> tuple[APIResponse, tuple[int, int, int] | None]:
|
||||
"""
|
||||
解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象
|
||||
:param resp: 响应对象
|
||||
@@ -394,15 +395,13 @@ class OpenaiClient(BaseClient):
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 0.7,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[
|
||||
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
|
||||
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
|
||||
]
|
||||
] = None,
|
||||
async_response_parser: Optional[
|
||||
Callable[[ChatCompletion], tuple[APIResponse, Optional[tuple[int, int, int]]]]
|
||||
] = None,
|
||||
stream_response_handler: Callable[
|
||||
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
|
||||
Coroutine[Any, Any, tuple[APIResponse, tuple[int, int, int] | None]],
|
||||
]
|
||||
| None = None,
|
||||
async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int] | None]]
|
||||
| None = None,
|
||||
interrupt_flag: asyncio.Event | None = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
@@ -519,17 +518,17 @@ class OpenaiClient(BaseClient):
|
||||
)
|
||||
except APIConnectionError as e:
|
||||
# 添加详细的错误信息以便调试
|
||||
logger.error(f"OpenAI API连接错误(嵌入模型): {str(e)}")
|
||||
logger.error(f"OpenAI API连接错误(嵌入模型): {e!s}")
|
||||
logger.error(f"错误类型: {type(e)}")
|
||||
if hasattr(e, "__cause__") and e.__cause__:
|
||||
logger.error(f"底层错误: {str(e.__cause__)}")
|
||||
logger.error(f"底层错误: {e.__cause__!s}")
|
||||
raise NetworkConnectionError() from e
|
||||
except APIStatusError as e:
|
||||
# 重封装APIError为RespNotOkException
|
||||
raise RespNotOkException(e.status_code) from e
|
||||
except Exception as e:
|
||||
# 添加通用异常处理和日志记录
|
||||
logger.error(f"获取嵌入时发生未知错误: {str(e)}")
|
||||
logger.error(f"获取嵌入时发生未知错误: {e!s}")
|
||||
logger.error(f"错误类型: {type(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
# 设计这系列类的目的是为未来可能的扩展做准备
|
||||
|
||||
|
||||
@@ -58,7 +57,7 @@ class MessageBuilder:
|
||||
self,
|
||||
image_format: str,
|
||||
image_base64: str,
|
||||
support_formats=None, # 默认支持格式
|
||||
support_formats=None, # 默认支持格式
|
||||
) -> "MessageBuilder":
|
||||
"""
|
||||
添加图片内容
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict, Required
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class RespFormatType(Enum):
|
||||
@@ -20,7 +20,7 @@ class JsonSchema(TypedDict, total=False):
|
||||
of 64.
|
||||
"""
|
||||
|
||||
description: Optional[str]
|
||||
description: str | None
|
||||
"""
|
||||
A description of what the response format is for, used by the model to determine
|
||||
how to respond in the format.
|
||||
@@ -32,7 +32,7 @@ class JsonSchema(TypedDict, total=False):
|
||||
to build JSON schemas [here](https://json-schema.org/).
|
||||
"""
|
||||
|
||||
strict: Optional[bool]
|
||||
strict: bool | None
|
||||
"""
|
||||
Whether to enable strict schema adherence when generating the output. If set to
|
||||
true, the model will always follow the exact schema defined in the `schema`
|
||||
@@ -100,7 +100,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
# 如果当前Schema是列表,则遍历每个元素
|
||||
for i in range(len(sub_schema)):
|
||||
if isinstance(sub_schema[i], dict):
|
||||
sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs)
|
||||
sub_schema[i] = link_definitions_recursive(f"{path}/{i!s}", sub_schema[i], defs)
|
||||
else:
|
||||
# 否则为字典
|
||||
if "$defs" in sub_schema:
|
||||
@@ -140,8 +140,7 @@ def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
schema[idx] = _remove_title(item)
|
||||
elif isinstance(schema, dict):
|
||||
# 是字典,移除title字段,并对所有dict/list子元素递归调用
|
||||
if "$defs" in schema:
|
||||
del schema["$defs"]
|
||||
schema.pop("$defs", None)
|
||||
for key, value in schema.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
schema[key] = _remove_title(value)
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import base64
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
from datetime import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from PIL import Image
|
||||
|
||||
from src.common.database.sqlalchemy_models import LLMUsage, get_db_session
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import ModelInfo
|
||||
from .payload_content.message import Message, MessageBuilder
|
||||
|
||||
from .model_client.base_client import UsageRecord
|
||||
from .payload_content.message import Message, MessageBuilder
|
||||
|
||||
logger = get_logger("消息压缩工具")
|
||||
|
||||
@@ -38,7 +39,7 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
|
||||
|
||||
return image_data
|
||||
except Exception as e:
|
||||
logger.error(f"图片转换格式失败: {str(e)}")
|
||||
logger.error(f"图片转换格式失败: {e!s}")
|
||||
return image_data
|
||||
|
||||
def rescale_image(image_data: bytes, scale: float) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]:
|
||||
@@ -87,7 +88,7 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
|
||||
return output_buffer.getvalue(), original_size, new_size
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"图片缩放失败: {str(e)}")
|
||||
logger.error(f"图片缩放失败: {e!s}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -188,7 +189,7 @@ class LLMUsageRecorder:
|
||||
f"总计: {model_usage.total_tokens}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"记录token使用情况失败: {str(e)}")
|
||||
logger.error(f"记录token使用情况失败: {e!s}")
|
||||
|
||||
|
||||
llm_usage_recorder = LLMUsageRecorder()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@desc: 该模块封装了与大语言模型(LLM)交互的所有核心逻辑。
|
||||
它被设计为一个高度容错和可扩展的系统,包含以下主要组件:
|
||||
@@ -18,25 +17,27 @@
|
||||
- **LLMRequest (主接口)**:
|
||||
作为模块的统一入口(Facade),为上层业务逻辑提供了简洁的接口来发起文本、图像、语音等不同类型的LLM请求。
|
||||
"""
|
||||
import re
|
||||
import asyncio
|
||||
import time
|
||||
import random
|
||||
import string
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
from collections.abc import Callable, Coroutine
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from rich.traceback import install
|
||||
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine, Generator
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
|
||||
from .payload_content.message import MessageBuilder, Message
|
||||
from .payload_content.resp_format import RespFormat
|
||||
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType
|
||||
from .model_client.base_client import BaseClient, APIResponse, client_registry, UsageRecord
|
||||
from .utils import compress_messages, llm_usage_recorder
|
||||
from src.config.config import model_config
|
||||
|
||||
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
|
||||
from .model_client.base_client import APIResponse, BaseClient, UsageRecord, client_registry
|
||||
from .payload_content.message import Message, MessageBuilder
|
||||
from .payload_content.tool_option import ToolCall, ToolOption, ToolOptionBuilder
|
||||
from .utils import compress_messages, llm_usage_recorder
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -46,6 +47,7 @@ logger = get_logger("model_utils")
|
||||
# Standalone Utility Functions
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def _normalize_image_format(image_format: str) -> str:
|
||||
"""
|
||||
标准化图片格式名称,确保与各种API的兼容性
|
||||
@@ -57,17 +59,26 @@ def _normalize_image_format(image_format: str) -> str:
|
||||
str: 标准化后的图片格式
|
||||
"""
|
||||
format_mapping = {
|
||||
"jpg": "jpeg", "JPG": "jpeg", "JPEG": "jpeg", "jpeg": "jpeg",
|
||||
"png": "png", "PNG": "png",
|
||||
"webp": "webp", "WEBP": "webp",
|
||||
"gif": "gif", "GIF": "gif",
|
||||
"heic": "heic", "HEIC": "heic",
|
||||
"heif": "heif", "HEIF": "heif",
|
||||
"jpg": "jpeg",
|
||||
"JPG": "jpeg",
|
||||
"JPEG": "jpeg",
|
||||
"jpeg": "jpeg",
|
||||
"png": "png",
|
||||
"PNG": "png",
|
||||
"webp": "webp",
|
||||
"WEBP": "webp",
|
||||
"gif": "gif",
|
||||
"GIF": "gif",
|
||||
"heic": "heic",
|
||||
"HEIC": "heic",
|
||||
"heif": "heif",
|
||||
"HEIF": "heif",
|
||||
}
|
||||
normalized = format_mapping.get(image_format, image_format.lower())
|
||||
logger.debug(f"图片格式标准化: {image_format} -> {normalized}")
|
||||
return normalized
|
||||
|
||||
|
||||
async def execute_concurrently(
|
||||
coro_callable: Callable[..., Coroutine[Any, Any, Any]],
|
||||
concurrency_count: int,
|
||||
@@ -103,29 +114,33 @@ async def execute_concurrently(
|
||||
for i, res in enumerate(results):
|
||||
if isinstance(res, Exception):
|
||||
logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}")
|
||||
|
||||
|
||||
first_exception = next((res for res in results if isinstance(res, Exception)), None)
|
||||
if first_exception:
|
||||
raise first_exception
|
||||
raise RuntimeError(f"所有 {concurrency_count} 个并发请求都失败了,但没有具体的异常信息")
|
||||
|
||||
|
||||
class RequestType(Enum):
|
||||
"""请求类型枚举"""
|
||||
|
||||
RESPONSE = "response"
|
||||
EMBEDDING = "embedding"
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Helper Classes for LLMRequest Refactoring
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class _ModelSelector:
|
||||
"""负责模型选择、负载均衡和动态故障切换的策略。"""
|
||||
|
||||
|
||||
CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数
|
||||
DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量
|
||||
|
||||
def __init__(self, model_list: List[str], model_usage: Dict[str, Tuple[int, int, int]]):
|
||||
def __init__(self, model_list: list[str], model_usage: dict[str, tuple[int, int, int]]):
|
||||
"""
|
||||
初始化模型选择器。
|
||||
|
||||
@@ -139,7 +154,7 @@ class _ModelSelector:
|
||||
|
||||
def select_best_available_model(
|
||||
self, failed_models_in_this_request: set, request_type: str
|
||||
) -> Optional[Tuple[ModelInfo, APIProvider, BaseClient]]:
|
||||
) -> tuple[ModelInfo, APIProvider, BaseClient] | None:
|
||||
"""
|
||||
从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。
|
||||
|
||||
@@ -168,16 +183,18 @@ class _ModelSelector:
|
||||
# - `usage_penalty * 1000`: 短期使用惩罚项。每次被选中后会增加,完成后会减少。高权重确保在多个模型都健康的情况下,请求会均匀分布(轮询)。
|
||||
least_used_model_name = min(
|
||||
candidate_models_usage,
|
||||
key=lambda k: candidate_models_usage[k][0] + candidate_models_usage[k][1] * 300 + candidate_models_usage[k][2] * 1000,
|
||||
key=lambda k: candidate_models_usage[k][0]
|
||||
+ candidate_models_usage[k][1] * 300
|
||||
+ candidate_models_usage[k][2] * 1000,
|
||||
)
|
||||
|
||||
|
||||
model_info = model_config.get_model_info(least_used_model_name)
|
||||
api_provider = model_config.get_provider(model_info.api_provider)
|
||||
# 特殊处理:对于 embedding 任务,强制创建新的 aiohttp.ClientSession。
|
||||
# 这是为了避免在某些高并发场景下,共享的ClientSession可能引发的事件循环相关问题。
|
||||
force_new_client = request_type == "embedding"
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||
|
||||
|
||||
logger.debug(f"为当前请求选择了最佳可用模型: {model_info.name}")
|
||||
# 增加所选模型的请求使用惩罚值,以实现动态负载均衡。
|
||||
self.update_usage_penalty(model_info.name, increase=True)
|
||||
@@ -214,26 +231,32 @@ class _ModelSelector:
|
||||
if isinstance(e, (NetworkConnectionError, ReqAbortException)):
|
||||
# 网络连接错误或请求被中断,通常是基础设施问题,应重罚
|
||||
penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER
|
||||
logger.warning(f"模型 '{model_name}' 发生严重错误 ({type(e).__name__}),增加高额惩罚值: {penalty_increment}")
|
||||
logger.warning(
|
||||
f"模型 '{model_name}' 发生严重错误 ({type(e).__name__}),增加高额惩罚值: {penalty_increment}"
|
||||
)
|
||||
elif isinstance(e, RespNotOkException):
|
||||
# 对于HTTP响应错误,重点关注服务器端错误
|
||||
if e.status_code >= 500:
|
||||
# 5xx 错误表明服务器端出现问题,应重罚
|
||||
penalty_increment = self.CRITICAL_PENALTY_MULTIPLIER
|
||||
logger.warning(f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加高额惩罚值: {penalty_increment}")
|
||||
logger.warning(
|
||||
f"模型 '{model_name}' 发生服务器错误 (状态码: {e.status_code}),增加高额惩罚值: {penalty_increment}"
|
||||
)
|
||||
else:
|
||||
# 4xx 客户端错误通常不代表模型本身不可用,给予基础惩罚
|
||||
logger.warning(f"模型 '{model_name}' 发生客户端响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}")
|
||||
logger.warning(
|
||||
f"模型 '{model_name}' 发生客户端响应错误 (状态码: {e.status_code}),增加基础惩罚值: {penalty_increment}"
|
||||
)
|
||||
else:
|
||||
# 其他未知异常,给予基础惩罚
|
||||
logger.warning(f"模型 '{model_name}' 发生未知异常: {type(e).__name__},增加基础惩罚值: {penalty_increment}")
|
||||
|
||||
|
||||
self.model_usage[model_name] = (total_tokens, penalty + penalty_increment, usage_penalty)
|
||||
|
||||
|
||||
class _PromptProcessor:
|
||||
"""封装所有与提示词和响应内容的预处理和后处理逻辑。"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化提示处理器。
|
||||
@@ -276,18 +299,18 @@ class _PromptProcessor:
|
||||
"""
|
||||
# 步骤1: 根据API提供商的配置应用内容混淆
|
||||
processed_prompt = self._apply_content_obfuscation(prompt, api_provider)
|
||||
|
||||
|
||||
# 步骤2: 检查模型是否需要注入反截断指令
|
||||
if getattr(model_info, "use_anti_truncation", False):
|
||||
processed_prompt += self.anti_truncation_instruction
|
||||
logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。")
|
||||
|
||||
|
||||
return processed_prompt
|
||||
|
||||
def process_response(self, content: str, use_anti_truncation: bool) -> Tuple[str, str, bool]:
|
||||
def process_response(self, content: str, use_anti_truncation: bool) -> tuple[str, str, bool]:
|
||||
"""
|
||||
处理响应内容,提取思维链并检查截断。
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[str, str, bool]: (处理后的内容, 思维链内容, 是否被截断)
|
||||
"""
|
||||
@@ -317,14 +340,14 @@ class _PromptProcessor:
|
||||
# 检查当前API提供商是否启用了内容混淆功能
|
||||
if not getattr(api_provider, "enable_content_obfuscation", False):
|
||||
return text
|
||||
|
||||
|
||||
# 获取混淆强度,默认为1
|
||||
intensity = getattr(api_provider, "obfuscation_intensity", 1)
|
||||
logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}")
|
||||
|
||||
|
||||
# 将抗审查指令和原始文本拼接
|
||||
processed_text = self.noise_instruction + "\n\n" + text
|
||||
|
||||
|
||||
# 在拼接后的文本中注入随机噪音
|
||||
return self._inject_random_noise(processed_text, intensity)
|
||||
|
||||
@@ -346,12 +369,12 @@ class _PromptProcessor:
|
||||
# 定义不同强度级别的噪音参数:概率和长度范围
|
||||
params = {
|
||||
1: {"probability": 15, "length": (3, 6)}, # 低强度
|
||||
2: {"probability": 25, "length": (5, 10)}, # 中强度
|
||||
3: {"probability": 35, "length": (8, 15)}, # 高强度
|
||||
2: {"probability": 25, "length": (5, 10)}, # 中强度
|
||||
3: {"probability": 35, "length": (8, 15)}, # 高强度
|
||||
}
|
||||
# 根据传入的强度选择配置,如果强度无效则使用默认值
|
||||
config = params.get(intensity, params[1])
|
||||
|
||||
|
||||
words = text.split()
|
||||
result = []
|
||||
# 遍历每个单词
|
||||
@@ -366,12 +389,12 @@ class _PromptProcessor:
|
||||
# 生成噪音字符串
|
||||
noise = "".join(random.choice(chars) for _ in range(noise_length))
|
||||
result.append(noise)
|
||||
|
||||
|
||||
# 将处理后的单词列表重新组合成字符串
|
||||
return " ".join(result)
|
||||
|
||||
@staticmethod
|
||||
def _extract_reasoning(content: str) -> Tuple[str, str]:
|
||||
def _extract_reasoning(content: str) -> tuple[str, str]:
|
||||
"""
|
||||
从模型返回的完整内容中提取被<think>...</think>标签包裹的思考过程,
|
||||
并返回清理后的内容和思考过程。
|
||||
@@ -396,7 +419,7 @@ class _PromptProcessor:
|
||||
else:
|
||||
reasoning = ""
|
||||
clean_content = content.strip()
|
||||
|
||||
|
||||
return clean_content, reasoning
|
||||
|
||||
|
||||
@@ -440,8 +463,8 @@ class _RequestExecutor:
|
||||
RuntimeError: 如果达到最大重试次数。
|
||||
"""
|
||||
retry_remain = api_provider.max_retry
|
||||
compressed_messages: Optional[List[Message]] = None
|
||||
|
||||
compressed_messages: list[Message] | None = None
|
||||
|
||||
while retry_remain > 0:
|
||||
try:
|
||||
# 优先使用压缩后的消息列表
|
||||
@@ -451,11 +474,11 @@ class _RequestExecutor:
|
||||
# 根据请求类型调用不同的客户端方法
|
||||
if request_type == RequestType.RESPONSE:
|
||||
assert current_messages is not None, "message_list cannot be None for response requests"
|
||||
|
||||
|
||||
# 修复: 防止 'message_list' 在 kwargs 中重复传递
|
||||
request_params = kwargs.copy()
|
||||
request_params.pop("message_list", None)
|
||||
|
||||
|
||||
return await client.get_response(
|
||||
model_info=model_info, message_list=current_messages, **request_params
|
||||
)
|
||||
@@ -463,15 +486,19 @@ class _RequestExecutor:
|
||||
return await client.get_embedding(model_info=model_info, **kwargs)
|
||||
elif request_type == RequestType.AUDIO:
|
||||
return await client.get_audio_transcriptions(model_info=model_info, **kwargs)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"请求失败: {str(e)}")
|
||||
logger.debug(f"请求失败: {e!s}")
|
||||
# 记录失败并更新模型的惩罚值
|
||||
self.model_selector.update_failure_penalty(model_info.name, e)
|
||||
|
||||
|
||||
# 处理异常,决定是否重试以及等待多久
|
||||
wait_interval, new_compressed_messages = self._handle_exception(
|
||||
e, model_info, api_provider, retry_remain, (kwargs.get("message_list"), compressed_messages is not None)
|
||||
e,
|
||||
model_info,
|
||||
api_provider,
|
||||
retry_remain,
|
||||
(kwargs.get("message_list"), compressed_messages is not None),
|
||||
)
|
||||
if new_compressed_messages:
|
||||
compressed_messages = new_compressed_messages # 更新为压缩后的消息
|
||||
@@ -482,16 +509,16 @@ class _RequestExecutor:
|
||||
await asyncio.sleep(wait_interval) # 等待指定时间后重试
|
||||
finally:
|
||||
retry_remain -= 1
|
||||
|
||||
|
||||
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次")
|
||||
raise RuntimeError("请求失败,已达到最大重试次数")
|
||||
|
||||
def _handle_exception(
|
||||
self, e: Exception, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
|
||||
) -> Tuple[int, Optional[List[Message]]]:
|
||||
) -> tuple[int, list[Message] | None]:
|
||||
"""
|
||||
默认异常处理函数,决定是否重试。
|
||||
|
||||
|
||||
Returns:
|
||||
(等待间隔(-1表示不再重试), 新的消息列表(适用于压缩消息))
|
||||
"""
|
||||
@@ -506,12 +533,12 @@ class _RequestExecutor:
|
||||
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误 - {e.message}")
|
||||
return -1, None
|
||||
else:
|
||||
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {str(e)}")
|
||||
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {e!s}")
|
||||
return -1, None
|
||||
|
||||
def _handle_resp_not_ok(
|
||||
self, e: RespNotOkException, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
|
||||
) -> Tuple[int, Optional[List[Message]]]:
|
||||
) -> tuple[int, list[Message] | None]:
|
||||
"""
|
||||
处理非200的HTTP响应异常。
|
||||
|
||||
@@ -534,7 +561,9 @@ class _RequestExecutor:
|
||||
model_name = model_info.name
|
||||
# 处理客户端错误 (400-404),这些错误通常是请求本身有问题,不应重试
|
||||
if e.status_code in [400, 401, 402, 403, 404]:
|
||||
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 客户端错误 {e.status_code} - {e.message},不再重试。")
|
||||
logger.warning(
|
||||
f"任务-'{self.task_name}' 模型-'{model_name}': 客户端错误 {e.status_code} - {e.message},不再重试。"
|
||||
)
|
||||
return -1, None
|
||||
# 处理请求体过大的情况
|
||||
elif e.status_code == 413:
|
||||
@@ -555,7 +584,7 @@ class _RequestExecutor:
|
||||
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知响应错误 {e.status_code} - {e.message}")
|
||||
return -1, None
|
||||
|
||||
def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> Tuple[int, None]:
|
||||
def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> tuple[int, None]:
|
||||
"""
|
||||
辅助函数,根据剩余次数决定是否进行下一次重试。
|
||||
|
||||
@@ -570,9 +599,11 @@ class _RequestExecutor:
|
||||
"""
|
||||
# 只有在剩余重试次数大于1时才进行下一次重试(因为当前这次失败已经消耗掉一次)
|
||||
if remain_try > 1:
|
||||
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},将于{interval}秒后重试 ({remain_try - 1}次剩余)。")
|
||||
logger.warning(
|
||||
f"任务-'{self.task_name}' 模型-'{model_name}': {reason},将于{interval}秒后重试 ({remain_try - 1}次剩余)。"
|
||||
)
|
||||
return interval, None
|
||||
|
||||
|
||||
# 如果已无剩余重试次数,则记录错误并返回-1表示放弃
|
||||
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': {reason},已达最大重试次数,放弃。")
|
||||
return -1, None
|
||||
@@ -585,7 +616,14 @@ class _RequestStrategy:
|
||||
即使在单个模型或API端点失败的情况下也能正常工作。
|
||||
"""
|
||||
|
||||
def __init__(self, model_selector: _ModelSelector, prompt_processor: _PromptProcessor, executor: _RequestExecutor, model_list: List[str], task_name: str):
|
||||
def __init__(
|
||||
self,
|
||||
model_selector: _ModelSelector,
|
||||
prompt_processor: _PromptProcessor,
|
||||
executor: _RequestExecutor,
|
||||
model_list: list[str],
|
||||
task_name: str,
|
||||
):
|
||||
"""
|
||||
初始化请求策略。
|
||||
|
||||
@@ -607,20 +645,22 @@ class _RequestStrategy:
|
||||
request_type: RequestType,
|
||||
raise_when_empty: bool = True,
|
||||
**kwargs,
|
||||
) -> Tuple[APIResponse, ModelInfo]:
|
||||
) -> tuple[APIResponse, ModelInfo]:
|
||||
"""
|
||||
执行请求,动态选择最佳可用模型,并在模型失败时进行故障转移。
|
||||
"""
|
||||
failed_models_in_this_request = set()
|
||||
max_attempts = len(self.model_list)
|
||||
last_exception: Optional[Exception] = None
|
||||
last_exception: Exception | None = None
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
selection_result = self.model_selector.select_best_available_model(failed_models_in_this_request, str(request_type.value))
|
||||
selection_result = self.model_selector.select_best_available_model(
|
||||
failed_models_in_this_request, str(request_type.value)
|
||||
)
|
||||
if selection_result is None:
|
||||
logger.error(f"尝试 {attempt + 1}/{max_attempts}: 没有可用的模型了。")
|
||||
break
|
||||
|
||||
|
||||
model_info, api_provider, client = selection_result
|
||||
logger.debug(f"尝试 {attempt + 1}/{max_attempts}: 正在使用模型 '{model_info.name}'...")
|
||||
|
||||
@@ -637,32 +677,36 @@ class _RequestStrategy:
|
||||
|
||||
# 合并模型特定的额外参数
|
||||
if model_info.extra_params:
|
||||
request_kwargs["extra_params"] = {**model_info.extra_params, **request_kwargs.get("extra_params", {})}
|
||||
request_kwargs["extra_params"] = {
|
||||
**model_info.extra_params,
|
||||
**request_kwargs.get("extra_params", {}),
|
||||
}
|
||||
|
||||
response = await self._try_model_request(
|
||||
model_info, api_provider, client, request_type, **request_kwargs
|
||||
)
|
||||
|
||||
response = await self._try_model_request(model_info, api_provider, client, request_type, **request_kwargs)
|
||||
|
||||
# 成功,立即返回
|
||||
logger.debug(f"模型 '{model_info.name}' 成功生成了回复。")
|
||||
self.model_selector.update_usage_penalty(model_info.name, increase=False)
|
||||
return response, model_info
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型 '{model_info.name}' 失败,异常: {e}。将其添加到当前请求的失败模型列表中。")
|
||||
failed_models_in_this_request.add(model_info.name)
|
||||
last_exception = e
|
||||
# 使用惩罚值已在 select 时增加,失败后不减少,以降低其后续被选中的概率
|
||||
|
||||
|
||||
logger.error(f"当前请求已尝试 {max_attempts} 个模型,所有模型均已失败。")
|
||||
if raise_when_empty:
|
||||
if last_exception:
|
||||
raise RuntimeError("所有模型均未能生成响应。") from last_exception
|
||||
raise RuntimeError("所有模型均未能生成响应,且无具体异常信息。")
|
||||
|
||||
|
||||
# 如果不抛出异常,返回一个备用响应
|
||||
fallback_model_info = model_config.get_model_info(self.model_list[0])
|
||||
return APIResponse(content="所有模型都请求失败"), fallback_model_info
|
||||
|
||||
|
||||
async def _try_model_request(
|
||||
self, model_info: ModelInfo, api_provider: APIProvider, client: BaseClient, request_type: RequestType, **kwargs
|
||||
) -> APIResponse:
|
||||
@@ -684,46 +728,49 @@ class _RequestStrategy:
|
||||
RuntimeError: 如果在达到最大重试次数后仍然收到空回复或截断的响应。
|
||||
"""
|
||||
max_empty_retry = api_provider.max_retry
|
||||
|
||||
|
||||
for i in range(max_empty_retry + 1):
|
||||
response = await self.executor.execute_request(
|
||||
api_provider, client, request_type, model_info, **kwargs
|
||||
)
|
||||
response = await self.executor.execute_request(api_provider, client, request_type, model_info, **kwargs)
|
||||
|
||||
if request_type != RequestType.RESPONSE:
|
||||
return response # 对于非响应类型,直接返回
|
||||
return response # 对于非响应类型,直接返回
|
||||
|
||||
# --- 响应内容处理和空回复/截断检查 ---
|
||||
content = response.content or ""
|
||||
use_anti_truncation = getattr(model_info, "use_anti_truncation", False)
|
||||
processed_content, reasoning, is_truncated = self.prompt_processor.process_response(content, use_anti_truncation)
|
||||
|
||||
processed_content, reasoning, is_truncated = self.prompt_processor.process_response(
|
||||
content, use_anti_truncation
|
||||
)
|
||||
|
||||
# 更新响应对象
|
||||
response.content = processed_content
|
||||
response.reasoning_content = response.reasoning_content or reasoning
|
||||
|
||||
is_empty_reply = not response.tool_calls and not (response.content and response.content.strip())
|
||||
|
||||
|
||||
if not is_empty_reply and not is_truncated:
|
||||
return response # 成功获取有效响应
|
||||
return response # 成功获取有效响应
|
||||
|
||||
if i < max_empty_retry:
|
||||
reason = "空回复" if is_empty_reply else "截断"
|
||||
logger.warning(f"模型 '{model_info.name}' 检测到{reason},正在进行内部重试 ({i + 1}/{max_empty_retry})...")
|
||||
logger.warning(
|
||||
f"模型 '{model_info.name}' 检测到{reason},正在进行内部重试 ({i + 1}/{max_empty_retry})..."
|
||||
)
|
||||
if api_provider.retry_interval > 0:
|
||||
await asyncio.sleep(api_provider.retry_interval)
|
||||
else:
|
||||
reason = "空回复" if is_empty_reply else "截断"
|
||||
logger.error(f"模型 '{model_info.name}' 经过 {max_empty_retry} 次内部重试后仍然生成{reason}的回复。")
|
||||
raise RuntimeError(f"模型 '{model_info.name}' 已达到空回复/截断的最大内部重试次数。")
|
||||
|
||||
raise RuntimeError("内部重试逻辑错误") # 理论上不应到达这里
|
||||
|
||||
raise RuntimeError("内部重试逻辑错误") # 理论上不应到达这里
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main Facade Class
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class LLMRequest:
|
||||
"""
|
||||
LLM请求协调器。
|
||||
@@ -741,11 +788,9 @@ class LLMRequest:
|
||||
"""
|
||||
self.task_name = request_type
|
||||
self.model_for_task = model_set
|
||||
self.model_usage: Dict[str, Tuple[int, int, int]] = {
|
||||
model: (0, 0, 0) for model in self.model_for_task.model_list
|
||||
}
|
||||
self.model_usage: dict[str, tuple[int, int, int]] = dict.fromkeys(self.model_for_task.model_list, (0, 0, 0))
|
||||
"""模型使用量记录,(total_tokens, penalty, usage_penalty)"""
|
||||
|
||||
|
||||
# 初始化辅助类
|
||||
self._model_selector = _ModelSelector(self.model_for_task.model_list, self.model_usage)
|
||||
self._prompt_processor = _PromptProcessor()
|
||||
@@ -759,9 +804,9 @@ class LLMRequest:
|
||||
prompt: str,
|
||||
image_base64: str,
|
||||
image_format: str,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> tuple[str, tuple[str, str, list[ToolCall] | None]]:
|
||||
"""
|
||||
为图像生成响应。
|
||||
|
||||
@@ -769,39 +814,47 @@ class LLMRequest:
|
||||
prompt (str): 提示词
|
||||
image_base64 (str): 图像的Base64编码字符串
|
||||
image_format (str): 图像格式(如 'png', 'jpeg' 等)
|
||||
|
||||
|
||||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# 图像请求目前不使用复杂的故障转移策略,直接选择模型并执行
|
||||
selection_result = self._model_selector.select_best_available_model(set(), "response")
|
||||
if not selection_result:
|
||||
raise RuntimeError("无法为图像响应选择可用模型。")
|
||||
model_info, api_provider, client = selection_result
|
||||
|
||||
|
||||
normalized_format = _normalize_image_format(image_format)
|
||||
message = MessageBuilder().add_text_content(prompt).add_image_content(
|
||||
image_base64=image_base64,
|
||||
image_format=normalized_format,
|
||||
support_formats=client.get_support_image_formats(),
|
||||
).build()
|
||||
message = (
|
||||
MessageBuilder()
|
||||
.add_text_content(prompt)
|
||||
.add_image_content(
|
||||
image_base64=image_base64,
|
||||
image_format=normalized_format,
|
||||
support_formats=client.get_support_image_formats(),
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = await self._executor.execute_request(
|
||||
api_provider, client, RequestType.RESPONSE, model_info,
|
||||
api_provider,
|
||||
client,
|
||||
RequestType.RESPONSE,
|
||||
model_info,
|
||||
message_list=[message],
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
|
||||
await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
|
||||
content, reasoning, _ = self._prompt_processor.process_response(response.content or "", False)
|
||||
reasoning = response.reasoning_content or reasoning
|
||||
|
||||
|
||||
return content, (reasoning, model_info.name, response.tool_calls)
|
||||
|
||||
async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]:
|
||||
async def generate_response_for_voice(self, voice_base64: str) -> str | None:
|
||||
"""
|
||||
为语音生成响应(语音转文字)。
|
||||
使用故障转移策略来确保即使主模型失败也能获得结果。
|
||||
@@ -812,19 +865,17 @@ class LLMRequest:
|
||||
Returns:
|
||||
Optional[str]: 语音转换后的文本内容,如果所有模型都失败则返回None。
|
||||
"""
|
||||
response, _ = await self._strategy.execute_with_failover(
|
||||
RequestType.AUDIO, audio_base64=voice_base64
|
||||
)
|
||||
response, _ = await self._strategy.execute_with_failover(RequestType.AUDIO, audio_base64=voice_base64)
|
||||
return response.content or None
|
||||
|
||||
async def generate_response_async(
|
||||
self,
|
||||
prompt: str,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
raise_when_empty: bool = True,
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||
) -> tuple[str, tuple[str, str, list[ToolCall] | None]]:
|
||||
"""
|
||||
异步生成响应,支持并发请求。
|
||||
|
||||
@@ -834,7 +885,7 @@ class LLMRequest:
|
||||
max_tokens (int, optional): 最大token数
|
||||
tools: 工具配置
|
||||
raise_when_empty (bool): 是否在空回复时抛出异常
|
||||
|
||||
|
||||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
@@ -842,12 +893,16 @@ class LLMRequest:
|
||||
|
||||
if concurrency_count <= 1:
|
||||
return await self._execute_single_text_request(prompt, temperature, max_tokens, tools, raise_when_empty)
|
||||
|
||||
|
||||
try:
|
||||
return await execute_concurrently(
|
||||
self._execute_single_text_request,
|
||||
concurrency_count,
|
||||
prompt, temperature, max_tokens, tools, raise_when_empty=False
|
||||
prompt,
|
||||
temperature,
|
||||
max_tokens,
|
||||
tools,
|
||||
raise_when_empty=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"所有 {concurrency_count} 个并发请求都失败了: {e}")
|
||||
@@ -858,11 +913,11 @@ class LLMRequest:
|
||||
async def _execute_single_text_request(
|
||||
self,
|
||||
prompt: str,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
raise_when_empty: bool = True,
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||
) -> tuple[str, tuple[str, str, list[ToolCall] | None]]:
|
||||
"""
|
||||
执行单次文本生成请求的内部方法。
|
||||
这是 `generate_response_async` 的核心实现,处理单个请求的完整生命周期,
|
||||
@@ -885,7 +940,7 @@ class LLMRequest:
|
||||
response, model_info = await self._strategy.execute_with_failover(
|
||||
RequestType.RESPONSE,
|
||||
raise_when_empty=raise_when_empty,
|
||||
prompt=prompt, # 传递原始prompt,由strategy处理
|
||||
prompt=prompt, # 传递原始prompt,由strategy处理
|
||||
tool_options=tool_options,
|
||||
temperature=self.model_for_task.temperature if temperature is None else temperature,
|
||||
max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
|
||||
@@ -900,30 +955,29 @@ class LLMRequest:
|
||||
|
||||
return response.content or "", (response.reasoning_content or "", model_info.name, response.tool_calls)
|
||||
|
||||
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
||||
async def get_embedding(self, embedding_input: str) -> tuple[list[float], str]:
|
||||
"""
|
||||
获取嵌入向量。
|
||||
|
||||
Args:
|
||||
embedding_input (str): 获取嵌入的目标
|
||||
|
||||
|
||||
Returns:
|
||||
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
|
||||
"""
|
||||
start_time = time.time()
|
||||
response, model_info = await self._strategy.execute_with_failover(
|
||||
RequestType.EMBEDDING,
|
||||
embedding_input=embedding_input
|
||||
RequestType.EMBEDDING, embedding_input=embedding_input
|
||||
)
|
||||
|
||||
|
||||
await self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings")
|
||||
|
||||
|
||||
if not response.embedding:
|
||||
raise RuntimeError("获取embedding失败")
|
||||
|
||||
|
||||
return response.embedding, model_info.name
|
||||
|
||||
async def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str):
|
||||
async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord | None, time_cost: float, endpoint: str):
|
||||
"""
|
||||
记录模型使用情况。
|
||||
|
||||
@@ -940,19 +994,21 @@ class LLMRequest:
|
||||
# 步骤1: 更新内存中的token计数,用于负载均衡
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||
self.model_usage[model_info.name] = (total_tokens + usage.total_tokens, penalty, usage_penalty)
|
||||
|
||||
|
||||
# 步骤2: 创建一个后台任务,将用量数据异步写入数据库
|
||||
asyncio.create_task(llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
user_id="system", # 此处可根据业务需求修改
|
||||
time_cost=time_cost,
|
||||
request_type=self.task_name,
|
||||
endpoint=endpoint,
|
||||
))
|
||||
asyncio.create_task(
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
user_id="system", # 此处可根据业务需求修改
|
||||
time_cost=time_cost,
|
||||
request_type=self.task_name,
|
||||
endpoint=endpoint,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
|
||||
def _build_tool_options(tools: list[dict[str, Any]] | None) -> list[ToolOption] | None:
|
||||
"""
|
||||
根据输入的字典列表构建并验证 `ToolOption` 对象列表。
|
||||
|
||||
@@ -970,14 +1026,14 @@ class LLMRequest:
|
||||
# 如果没有提供工具,直接返回 None
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
tool_options: List[ToolOption] = []
|
||||
|
||||
tool_options: list[ToolOption] = []
|
||||
# 遍历每个工具定义
|
||||
for tool in tools:
|
||||
try:
|
||||
# 使用建造者模式创建 ToolOption
|
||||
builder = ToolOptionBuilder().set_name(tool["name"]).set_description(tool.get("description", ""))
|
||||
|
||||
|
||||
# 遍历工具的参数
|
||||
for param in tool.get("parameters", []):
|
||||
# 严格验证参数格式是否为包含5个元素的元组
|
||||
@@ -994,6 +1050,6 @@ class LLMRequest:
|
||||
except (KeyError, IndexError, TypeError, AssertionError) as e:
|
||||
# 如果构建过程中出现任何错误,记录日志并跳过该工具
|
||||
logger.error(f"构建工具 '{tool.get('name', 'N/A')}' 失败: {e}")
|
||||
|
||||
|
||||
# 如果列表非空则返回列表,否则返回 None
|
||||
return tool_options or None
|
||||
|
||||
Reference in New Issue
Block a user