re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
parent ecb02cae31
commit 7923eafef3
263 changed files with 3103 additions and 3123 deletions

View File

@@ -1,6 +1,5 @@
from typing import Any
# 常见Error Code Mapping (以OpenAI API为例)
error_code_mapping = {
400: "参数不正确",

View File

@@ -1,21 +1,24 @@
import asyncio
import orjson
import io
from typing import Callable, Any, Coroutine, Optional
import aiohttp
from collections.abc import Callable, Coroutine
from typing import Any
import aiohttp
import orjson
from src.config.api_ada_configs import ModelInfo, APIProvider
from src.common.logger import get_logger
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
from src.config.api_ada_configs import APIProvider, ModelInfo
from ..exceptions import (
RespParseException,
NetworkConnectionError,
RespNotOkException,
ReqAbortException,
RespNotOkException,
RespParseException,
)
from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat, RespFormatType
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
from ..payload_content.tool_option import ToolCall, ToolOption, ToolParam
from .base_client import APIResponse, BaseClient, UsageRecord, client_registry
logger = get_logger("AioHTTP-Gemini客户端")
@@ -210,7 +213,7 @@ class AiohttpGeminiStreamParser:
chunk_data = orjson.loads(chunk_text)
# 解析候选项
if "candidates" in chunk_data and chunk_data["candidates"]:
if chunk_data.get("candidates"):
candidate = chunk_data["candidates"][0]
# 解析内容
@@ -266,7 +269,7 @@ class AiohttpGeminiStreamParser:
async def _default_stream_response_handler(
response: aiohttp.ClientResponse,
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
) -> tuple[APIResponse, tuple[int, int, int] | None]:
"""默认流式响应处理器"""
parser = AiohttpGeminiStreamParser()
@@ -290,13 +293,13 @@ async def _default_stream_response_handler(
def _default_normal_response_parser(
response_data: dict,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
) -> tuple[APIResponse, tuple[int, int, int] | None]:
"""默认普通响应解析器"""
api_response = APIResponse()
try:
# 解析候选项
if "candidates" in response_data and response_data["candidates"]:
if response_data.get("candidates"):
candidate = response_data["candidates"][0]
# 解析文本内容
@@ -419,13 +422,12 @@ class AiohttpGeminiClient(BaseClient):
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[
[aiohttp.ClientResponse, asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
]
] = None,
async_response_parser: Optional[Callable[[dict], tuple[APIResponse, Optional[tuple[int, int, int]]]]] = None,
stream_response_handler: Callable[
[aiohttp.ClientResponse, asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, tuple[int, int, int] | None]],
]
| None = None,
async_response_parser: Callable[[dict], tuple[APIResponse, tuple[int, int, int] | None]] | None = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:

View File

@@ -1,12 +1,14 @@
import asyncio
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import Callable, Any, Optional
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from src.config.api_ada_configs import APIProvider, ModelInfo
from src.config.api_ada_configs import ModelInfo, APIProvider
from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolCall
from ..payload_content.tool_option import ToolCall, ToolOption
@dataclass
@@ -75,9 +77,8 @@ class BaseClient(ABC):
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
] = None,
stream_response_handler: Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
| None = None,
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,

View File

@@ -1,17 +1,17 @@
import asyncio
import io
import orjson
import re
import base64
from collections.abc import Iterable
from typing import Callable, Any, Coroutine, Optional
from json_repair import repair_json
import io
import re
from collections.abc import Callable, Coroutine, Iterable
from typing import Any
import orjson
from json_repair import repair_json
from openai import (
AsyncOpenAI,
NOT_GIVEN,
APIConnectionError,
APIStatusError,
NOT_GIVEN,
AsyncOpenAI,
AsyncStream,
)
from openai.types.chat import (
@@ -22,18 +22,19 @@ from openai.types.chat import (
)
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from src.config.api_ada_configs import ModelInfo, APIProvider
from src.common.logger import get_logger
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
from src.config.api_ada_configs import APIProvider, ModelInfo
from ..exceptions import (
RespParseException,
NetworkConnectionError,
RespNotOkException,
ReqAbortException,
RespNotOkException,
RespParseException,
)
from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
from ..payload_content.tool_option import ToolCall, ToolOption, ToolParam
from .base_client import APIResponse, BaseClient, UsageRecord, client_registry
logger = get_logger("OpenAI客户端")
@@ -241,7 +242,7 @@ def _build_stream_api_resp(
async def _default_stream_response_handler(
resp_stream: AsyncStream[ChatCompletionChunk],
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
) -> tuple[APIResponse, tuple[int, int, int] | None]:
"""
流式响应处理函数 - 处理OpenAI API的流式响应
:param resp_stream: 流式响应对象
@@ -315,7 +316,7 @@ pattern = re.compile(
def _default_normal_response_parser(
resp: ChatCompletion,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
) -> tuple[APIResponse, tuple[int, int, int] | None]:
"""
解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象
:param resp: 响应对象
@@ -391,15 +392,13 @@ class OpenaiClient(BaseClient):
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
]
] = None,
async_response_parser: Optional[
Callable[[ChatCompletion], tuple[APIResponse, Optional[tuple[int, int, int]]]]
] = None,
stream_response_handler: Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, tuple[int, int, int] | None]],
]
| None = None,
async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int] | None]]
| None = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
@@ -514,17 +513,17 @@ class OpenaiClient(BaseClient):
)
except APIConnectionError as e:
# 添加详细的错误信息以便调试
logger.error(f"OpenAI API连接错误嵌入模型: {str(e)}")
logger.error(f"OpenAI API连接错误嵌入模型: {e!s}")
logger.error(f"错误类型: {type(e)}")
if hasattr(e, "__cause__") and e.__cause__:
logger.error(f"底层错误: {str(e.__cause__)}")
logger.error(f"底层错误: {e.__cause__!s}")
raise NetworkConnectionError() from e
except APIStatusError as e:
# 重封装APIError为RespNotOkException
raise RespNotOkException(e.status_code) from e
except Exception as e:
# 添加通用异常处理和日志记录
logger.error(f"获取嵌入时发生未知错误: {str(e)}")
logger.error(f"获取嵌入时发生未知错误: {e!s}")
logger.error(f"错误类型: {type(e)}")
raise

View File

@@ -1,6 +1,5 @@
from enum import Enum
# 设计这系列类的目的是为未来可能的扩展做准备

View File

@@ -1,8 +1,8 @@
from enum import Enum
from typing import Optional, Any
from typing import Any
from pydantic import BaseModel
from typing_extensions import TypedDict, Required
from typing_extensions import Required, TypedDict
class RespFormatType(Enum):
@@ -20,7 +20,7 @@ class JsonSchema(TypedDict, total=False):
of 64.
"""
description: Optional[str]
description: str | None
"""
A description of what the response format is for, used by the model to determine
how to respond in the format.
@@ -32,7 +32,7 @@ class JsonSchema(TypedDict, total=False):
to build JSON schemas [here](https://json-schema.org/).
"""
strict: Optional[bool]
strict: bool | None
"""
Whether to enable strict schema adherence when generating the output. If set to
true, the model will always follow the exact schema defined in the `schema`
@@ -100,7 +100,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
# 如果当前Schema是列表则遍历每个元素
for i in range(len(sub_schema)):
if isinstance(sub_schema[i], dict):
sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs)
sub_schema[i] = link_definitions_recursive(f"{path}/{i!s}", sub_schema[i], defs)
else:
# 否则为字典
if "$defs" in sub_schema:
@@ -140,8 +140,7 @@ def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]:
schema[idx] = _remove_title(item)
elif isinstance(schema, dict):
# 是字典移除title字段并对所有dict/list子元素递归调用
if "$defs" in schema:
del schema["$defs"]
schema.pop("$defs", None)
for key, value in schema.items():
if isinstance(value, (dict, list)):
schema[key] = _remove_title(value)

View File

@@ -1,14 +1,15 @@
import base64
import io
from PIL import Image
from datetime import datetime
from src.common.logger import get_logger
from PIL import Image
from src.common.database.sqlalchemy_models import LLMUsage, get_db_session
from src.common.logger import get_logger
from src.config.api_ada_configs import ModelInfo
from .payload_content.message import Message, MessageBuilder
from .model_client.base_client import UsageRecord
from .payload_content.message import Message, MessageBuilder
logger = get_logger("消息压缩工具")
@@ -38,7 +39,7 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
return image_data
except Exception as e:
logger.error(f"图片转换格式失败: {str(e)}")
logger.error(f"图片转换格式失败: {e!s}")
return image_data
def rescale_image(image_data: bytes, scale: float) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]:
@@ -87,7 +88,7 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
return output_buffer.getvalue(), original_size, new_size
except Exception as e:
logger.error(f"图片缩放失败: {str(e)}")
logger.error(f"图片缩放失败: {e!s}")
import traceback
logger.error(traceback.format_exc())
@@ -188,7 +189,7 @@ class LLMUsageRecorder:
f"总计: {model_usage.total_tokens}"
)
except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}")
logger.error(f"记录token使用情况失败: {e!s}")
llm_usage_recorder = LLMUsageRecorder()

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
@desc: 该模块封装了与大语言模型LLM交互的所有核心逻辑。
它被设计为一个高度容错和可扩展的系统,包含以下主要组件:
@@ -19,24 +18,26 @@
作为模块的统一入口Facade为上层业务逻辑提供了简洁的接口来发起文本、图像、语音等不同类型的LLM请求。
"""
import re
import asyncio
import time
import random
import re
import string
import time
from collections.abc import Callable, Coroutine
from enum import Enum
from typing import Any
from rich.traceback import install
from typing import Tuple, List, Dict, Optional, Callable, Any, Coroutine
from src.common.logger import get_logger
from src.config.config import model_config
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
from .payload_content.message import MessageBuilder, Message
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder
from .model_client.base_client import BaseClient, APIResponse, client_registry, UsageRecord
from .utils import compress_messages, llm_usage_recorder
from src.config.config import model_config
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
from .model_client.base_client import APIResponse, BaseClient, UsageRecord, client_registry
from .payload_content.message import Message, MessageBuilder
from .payload_content.tool_option import ToolCall, ToolOption, ToolOptionBuilder
from .utils import compress_messages, llm_usage_recorder
install(extra_lines=3)
@@ -139,7 +140,7 @@ class _ModelSelector:
CRITICAL_PENALTY_MULTIPLIER = 5 # 严重错误惩罚乘数
DEFAULT_PENALTY_INCREMENT = 1 # 默认惩罚增量
def __init__(self, model_list: List[str], model_usage: Dict[str, Tuple[int, int, int]]):
def __init__(self, model_list: list[str], model_usage: dict[str, tuple[int, int, int]]):
"""
初始化模型选择器。
@@ -153,7 +154,7 @@ class _ModelSelector:
def select_best_available_model(
self, failed_models_in_this_request: set, request_type: str
) -> Optional[Tuple[ModelInfo, APIProvider, BaseClient]]:
) -> tuple[ModelInfo, APIProvider, BaseClient] | None:
"""
从可用模型中选择负载均衡评分最低的模型,并排除当前请求中已失败的模型。
@@ -306,7 +307,7 @@ class _PromptProcessor:
return processed_prompt
def process_response(self, content: str, use_anti_truncation: bool) -> Tuple[str, str, bool]:
def process_response(self, content: str, use_anti_truncation: bool) -> tuple[str, str, bool]:
"""
处理响应内容,提取思维链并检查截断。
@@ -393,7 +394,7 @@ class _PromptProcessor:
return " ".join(result)
@staticmethod
def _extract_reasoning(content: str) -> Tuple[str, str]:
def _extract_reasoning(content: str) -> tuple[str, str]:
"""
从模型返回的完整内容中提取被<think>...</think>标签包裹的思考过程,
并返回清理后的内容和思考过程。
@@ -462,7 +463,7 @@ class _RequestExecutor:
RuntimeError: 如果达到最大重试次数。
"""
retry_remain = api_provider.max_retry
compressed_messages: Optional[List[Message]] = None
compressed_messages: list[Message] | None = None
while retry_remain > 0:
try:
@@ -487,7 +488,7 @@ class _RequestExecutor:
return await client.get_audio_transcriptions(model_info=model_info, **kwargs)
except Exception as e:
logger.debug(f"请求失败: {str(e)}")
logger.debug(f"请求失败: {e!s}")
# 记录失败并更新模型的惩罚值
self.model_selector.update_failure_penalty(model_info.name, e)
@@ -514,7 +515,7 @@ class _RequestExecutor:
def _handle_exception(
self, e: Exception, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
) -> Tuple[int, Optional[List[Message]]]:
) -> tuple[int, list[Message] | None]:
"""
默认异常处理函数,决定是否重试。
@@ -532,12 +533,12 @@ class _RequestExecutor:
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 响应解析错误 - {e.message}")
return -1, None
else:
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {str(e)}")
logger.error(f"任务-'{self.task_name}' 模型-'{model_name}': 未知异常 - {e!s}")
return -1, None
def _handle_resp_not_ok(
self, e: RespNotOkException, model_info: ModelInfo, api_provider: APIProvider, remain_try: int, messages_info
) -> Tuple[int, Optional[List[Message]]]:
) -> tuple[int, list[Message] | None]:
"""
处理非200的HTTP响应异常。
@@ -583,7 +584,7 @@ class _RequestExecutor:
logger.warning(f"任务-'{self.task_name}' 模型-'{model_name}': 未知响应错误 {e.status_code} - {e.message}")
return -1, None
def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> Tuple[int, None]:
def _check_retry(self, remain_try: int, interval: int, reason: str, model_name: str) -> tuple[int, None]:
"""
辅助函数,根据剩余次数决定是否进行下一次重试。
@@ -620,7 +621,7 @@ class _RequestStrategy:
model_selector: _ModelSelector,
prompt_processor: _PromptProcessor,
executor: _RequestExecutor,
model_list: List[str],
model_list: list[str],
task_name: str,
):
"""
@@ -644,13 +645,13 @@ class _RequestStrategy:
request_type: RequestType,
raise_when_empty: bool = True,
**kwargs,
) -> Tuple[APIResponse, ModelInfo]:
) -> tuple[APIResponse, ModelInfo]:
"""
执行请求,动态选择最佳可用模型,并在模型失败时进行故障转移。
"""
failed_models_in_this_request = set()
max_attempts = len(self.model_list)
last_exception: Optional[Exception] = None
last_exception: Exception | None = None
for attempt in range(max_attempts):
selection_result = self.model_selector.select_best_available_model(
@@ -787,9 +788,7 @@ class LLMRequest:
"""
self.task_name = request_type
self.model_for_task = model_set
self.model_usage: Dict[str, Tuple[int, int, int]] = {
model: (0, 0, 0) for model in self.model_for_task.model_list
}
self.model_usage: dict[str, tuple[int, int, int]] = dict.fromkeys(self.model_for_task.model_list, (0, 0, 0))
"""模型使用量记录,(total_tokens, penalty, usage_penalty)"""
# 初始化辅助类
@@ -805,9 +804,9 @@ class LLMRequest:
prompt: str,
image_base64: str,
image_format: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
temperature: float | None = None,
max_tokens: int | None = None,
) -> tuple[str, tuple[str, str, list[ToolCall] | None]]:
"""
为图像生成响应。
@@ -855,7 +854,7 @@ class LLMRequest:
return content, (reasoning, model_info.name, response.tool_calls)
async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]:
async def generate_response_for_voice(self, voice_base64: str) -> str | None:
"""
为语音生成响应(语音转文字)。
使用故障转移策略来确保即使主模型失败也能获得结果。
@@ -872,11 +871,11 @@ class LLMRequest:
async def generate_response_async(
self,
prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
tools: Optional[List[Dict[str, Any]]] = None,
temperature: float | None = None,
max_tokens: int | None = None,
tools: list[dict[str, Any]] | None = None,
raise_when_empty: bool = True,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
) -> tuple[str, tuple[str, str, list[ToolCall] | None]]:
"""
异步生成响应,支持并发请求。
@@ -914,11 +913,11 @@ class LLMRequest:
async def _execute_single_text_request(
self,
prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
tools: Optional[List[Dict[str, Any]]] = None,
temperature: float | None = None,
max_tokens: int | None = None,
tools: list[dict[str, Any]] | None = None,
raise_when_empty: bool = True,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
) -> tuple[str, tuple[str, str, list[ToolCall] | None]]:
"""
执行单次文本生成请求的内部方法。
这是 `generate_response_async` 的核心实现,处理单个请求的完整生命周期,
@@ -956,7 +955,7 @@ class LLMRequest:
return response.content or "", (response.reasoning_content or "", model_info.name, response.tool_calls)
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
async def get_embedding(self, embedding_input: str) -> tuple[list[float], str]:
"""
获取嵌入向量。
@@ -978,7 +977,7 @@ class LLMRequest:
return response.embedding, model_info.name
async def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str):
async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord | None, time_cost: float, endpoint: str):
"""
记录模型使用情况。
@@ -1009,7 +1008,7 @@ class LLMRequest:
)
@staticmethod
def _build_tool_options(tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
def _build_tool_options(tools: list[dict[str, Any]] | None) -> list[ToolOption] | None:
"""
根据输入的字典列表构建并验证 `ToolOption` 对象列表。
@@ -1028,7 +1027,7 @@ class LLMRequest:
if not tools:
return None
tool_options: List[ToolOption] = []
tool_options: list[ToolOption] = []
# 遍历每个工具定义
for tool in tools:
try: