re-style: 格式化代码
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@desc: 该模块封装了与大语言模型(LLM)交互的所有核心逻辑。
|
||||
它被设计为一个高度容错和可扩展的系统,包含以下主要组件:
|
||||
@@ -19,24 +18,26 @@
|
||||
作为模块的统一入口(Facade),为上层业务逻辑提供了简洁的接口来发起文本、图像、语音等不同类型的LLM请求。
|
||||
"""
|
||||
|
||||
import re
|
||||
import asyncio
|
||||
import time
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
|
||||
import time
|
||||
from collections.abc import Callable, Coroutine
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from rich.traceback import install
|
||||
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
|
||||
from .payload_content.message import MessageBuilder, Message
|
||||
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder
|
||||
from .model_client.base_client import BaseClient, APIResponse, client_registry, UsageRecord
|
||||
from .utils import compress_messages, llm_usage_recorder
|
||||
from src.config.config import model_config
|
||||
|
||||
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
|
||||
from .model_client.base_client import APIResponse, BaseClient, UsageRecord, client_registry
|
||||
from .payload_content.message import Message, MessageBuilder
|
||||
from .payload_content.tool_option import ToolCall, ToolOption, ToolOptionBuilder
|
||||
from .utils import compress_messages, llm_usage_recorder
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -139,7 +140,7 @@ class _ModelSelector:
|
||||
CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数
|
||||
DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量
|
||||
|
||||
def __init__(self, model_list: List[str], model_usage: Dict[str, Tuple[int, int, int]]):
|
||||
def __init__(self, model_list: list[str], model_usage: dict[str, tuple[int, int, int]]):
|
||||
"""
|
||||
初始化模型选择器。
|
||||
|
||||
@@ -153,7 +154,7 @@ class _ModelSelector:
|
||||
|
||||
def select_best_available_model(
|
||||
self, failed_models_in_this_request: set, request_type: str
|
||||
) -> Optional[Tuple[ModelInfo, APIProvider, BaseClient]]:
|
||||
) -> tuple[ModelInfo, APIProvider, BaseClient] | None:
|
||||
"""
|
||||
从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。
|
||||
|
||||
@@ -306,7 +307,7 @@ class _PromptProcessor:
|
||||
|
||||
return processed_prompt
|
||||
|
||||
def process_response(self, content: str, use_anti_truncation: bool) -> Tuple[str, str, bool]:
|
||||
def process_response(self, content: str, use_anti_truncation: bool) -> tuple[str, str, bool]:
|
||||
"""
|
||||
处理响应内容,提取思维链并检查截断。
|
||||
|
||||
@@ -393,7 +394,7 @@ class _PromptProcessor:
|
||||
return " ".join(result)
|
||||
|
||||
@staticmethod
|
||||
def _extract_reasoning(content: str) -> Tuple[str, str]:
|
||||
def _extract_reasoning(content: str) -> tuple[str, str]:
|
||||
"""
|
||||
从模型返回的完整内容中提取被<think>...</think>标签包裹的思考过程,
|
||||
并返回清理后的内容和思考过程。
|
||||
@@ -462,7 +463,7 @@ class _RequestExecutor:
|
||||
RuntimeError: 如果达到最大重试次数。
|
||||
"""
|
||||
retry_remain = api_provider.max_retry
|
||||
compressed_messages: Optional[List[Message]] = None
|
||||
compressed_messages: list[Message] | None = None
|
||||
|
||||
while retry_remain > 0:
|
||||
try:
|
||||
@@ -487,7 +488,7 @@ class _RequestExecutor:
|
||||
return await client.get_audio_transcriptions(model_info=model_info, **kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"请求失败: {str(e)}")
|
||||
logger.debug(f"请求失败: {e!s}")
|
||||
# 记录失败并更新模型的惩罚值
|
||||
self.model_selector.update_failure_penalty(model_info.name, e)
|
||||
|
||||
@@ -514,7 +515,7 @@ class _RequestExecutor:
|
||||
|
||||
def _handle_exception(
|
||||
self, e: Exception, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
|
||||
) -> Tuple[int, Optional[List[Message]]]:
|
||||
) -> tuple[int, list[Message] | None]:
|
||||
"""
|
||||
默认异常处理函数,决定是否重试。
|
||||
|
||||
@@ -532,12 +533,12 @@ class _RequestExecutor:
|
||||
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误 - {e.message}")
|
||||
return -1, None
|
||||
else:
|
||||
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {str(e)}")
|
||||
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {e!s}")
|
||||
return -1, None
|
||||
|
||||
def _handle_resp_not_ok(
|
||||
self, e: RespNotOkException, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
|
||||
) -> Tuple[int, Optional[List[Message]]]:
|
||||
) -> tuple[int, list[Message] | None]:
|
||||
"""
|
||||
处理非200的HTTP响应异常。
|
||||
|
||||
@@ -583,7 +584,7 @@ class _RequestExecutor:
|
||||
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知响应错误 {e.status_code} - {e.message}")
|
||||
return -1, None
|
||||
|
||||
def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> Tuple[int, None]:
|
||||
def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> tuple[int, None]:
|
||||
"""
|
||||
辅助函数,根据剩余次数决定是否进行下一次重试。
|
||||
|
||||
@@ -620,7 +621,7 @@ class _RequestStrategy:
|
||||
model_selector: _ModelSelector,
|
||||
prompt_processor: _PromptProcessor,
|
||||
executor: _RequestExecutor,
|
||||
model_list: List[str],
|
||||
model_list: list[str],
|
||||
task_name: str,
|
||||
):
|
||||
"""
|
||||
@@ -644,13 +645,13 @@ class _RequestStrategy:
|
||||
request_type: RequestType,
|
||||
raise_when_empty: bool = True,
|
||||
**kwargs,
|
||||
) -> Tuple[APIResponse, ModelInfo]:
|
||||
) -> tuple[APIResponse, ModelInfo]:
|
||||
"""
|
||||
执行请求,动态选择最佳可用模型,并在模型失败时进行故障转移。
|
||||
"""
|
||||
failed_models_in_this_request = set()
|
||||
max_attempts = len(self.model_list)
|
||||
last_exception: Optional[Exception] = None
|
||||
last_exception: Exception | None = None
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
selection_result = self.model_selector.select_best_available_model(
|
||||
@@ -787,9 +788,7 @@ class LLMRequest:
|
||||
"""
|
||||
self.task_name = request_type
|
||||
self.model_for_task = model_set
|
||||
self.model_usage: Dict[str, Tuple[int, int, int]] = {
|
||||
model: (0, 0, 0) for model in self.model_for_task.model_list
|
||||
}
|
||||
self.model_usage: dict[str, tuple[int, int, int]] = dict.fromkeys(self.model_for_task.model_list, (0, 0, 0))
|
||||
"""模型使用量记录,(total_tokens, penalty, usage_penalty)"""
|
||||
|
||||
# 初始化辅助类
|
||||
@@ -805,9 +804,9 @@ class LLMRequest:
|
||||
prompt: str,
|
||||
image_base64: str,
|
||||
image_format: str,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> tuple[str, tuple[str, str, list[ToolCall] | None]]:
|
||||
"""
|
||||
为图像生成响应。
|
||||
|
||||
@@ -855,7 +854,7 @@ class LLMRequest:
|
||||
|
||||
return content, (reasoning, model_info.name, response.tool_calls)
|
||||
|
||||
async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]:
|
||||
async def generate_response_for_voice(self, voice_base64: str) -> str | None:
|
||||
"""
|
||||
为语音生成响应(语音转文字)。
|
||||
使用故障转移策略来确保即使主模型失败也能获得结果。
|
||||
@@ -872,11 +871,11 @@ class LLMRequest:
|
||||
async def generate_response_async(
|
||||
self,
|
||||
prompt: str,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
raise_when_empty: bool = True,
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||
) -> tuple[str, tuple[str, str, list[ToolCall] | None]]:
|
||||
"""
|
||||
异步生成响应,支持并发请求。
|
||||
|
||||
@@ -914,11 +913,11 @@ class LLMRequest:
|
||||
async def _execute_single_text_request(
|
||||
self,
|
||||
prompt: str,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
raise_when_empty: bool = True,
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||
) -> tuple[str, tuple[str, str, list[ToolCall] | None]]:
|
||||
"""
|
||||
执行单次文本生成请求的内部方法。
|
||||
这是 `generate_response_async` 的核心实现,处理单个请求的完整生命周期,
|
||||
@@ -956,7 +955,7 @@ class LLMRequest:
|
||||
|
||||
return response.content or "", (response.reasoning_content or "", model_info.name, response.tool_calls)
|
||||
|
||||
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
||||
async def get_embedding(self, embedding_input: str) -> tuple[list[float], str]:
|
||||
"""
|
||||
获取嵌入向量。
|
||||
|
||||
@@ -978,7 +977,7 @@ class LLMRequest:
|
||||
|
||||
return response.embedding, model_info.name
|
||||
|
||||
async def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str):
|
||||
async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord | None, time_cost: float, endpoint: str):
|
||||
"""
|
||||
记录模型使用情况。
|
||||
|
||||
@@ -1009,7 +1008,7 @@ class LLMRequest:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
|
||||
def _build_tool_options(tools: list[dict[str, Any]] | None) -> list[ToolOption] | None:
|
||||
"""
|
||||
根据输入的字典列表构建并验证 `ToolOption` 对象列表。
|
||||
|
||||
@@ -1028,7 +1027,7 @@ class LLMRequest:
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
tool_options: List[ToolOption] = []
|
||||
tool_options: list[ToolOption] = []
|
||||
# 遍历每个工具定义
|
||||
for tool in tools:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user