re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
committed by Windpicker-owo
parent 00ba07e0e1
commit a79253c714
263 changed files with 3781 additions and 3189 deletions

View File

@@ -1,21 +1,24 @@
import asyncio
import orjson
import io
from typing import Callable, Any, Coroutine, Optional
import aiohttp
from collections.abc import Callable, Coroutine
from typing import Any
import aiohttp
import orjson
from src.config.api_ada_configs import ModelInfo, APIProvider
from src.common.logger import get_logger
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
from src.config.api_ada_configs import APIProvider, ModelInfo
from ..exceptions import (
RespParseException,
NetworkConnectionError,
RespNotOkException,
ReqAbortException,
RespNotOkException,
RespParseException,
)
from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat, RespFormatType
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
from ..payload_content.tool_option import ToolCall, ToolOption, ToolParam
from .base_client import APIResponse, BaseClient, UsageRecord, client_registry
logger = get_logger("AioHTTP-Gemini客户端")
@@ -210,7 +213,7 @@ class AiohttpGeminiStreamParser:
chunk_data = orjson.loads(chunk_text)
# 解析候选项
if "candidates" in chunk_data and chunk_data["candidates"]:
if chunk_data.get("candidates"):
candidate = chunk_data["candidates"][0]
# 解析内容
@@ -266,7 +269,7 @@ class AiohttpGeminiStreamParser:
async def _default_stream_response_handler(
response: aiohttp.ClientResponse,
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
) -> tuple[APIResponse, tuple[int, int, int] | None]:
"""默认流式响应处理器"""
parser = AiohttpGeminiStreamParser()
@@ -290,13 +293,13 @@ async def _default_stream_response_handler(
def _default_normal_response_parser(
response_data: dict,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
) -> tuple[APIResponse, tuple[int, int, int] | None]:
"""默认普通响应解析器"""
api_response = APIResponse()
try:
# 解析候选项
if "candidates" in response_data and response_data["candidates"]:
if response_data.get("candidates"):
candidate = response_data["candidates"][0]
# 解析文本内容
@@ -418,13 +421,12 @@ class AiohttpGeminiClient(BaseClient):
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[
[aiohttp.ClientResponse, asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
]
] = None,
async_response_parser: Optional[Callable[[dict], tuple[APIResponse, Optional[tuple[int, int, int]]]]] = None,
stream_response_handler: Callable[
[aiohttp.ClientResponse, asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, tuple[int, int, int] | None]],
]
| None = None,
async_response_parser: Callable[[dict], tuple[APIResponse, tuple[int, int, int] | None]] | None = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:

View File

@@ -1,12 +1,14 @@
import asyncio
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import Callable, Any, Optional
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from src.config.api_ada_configs import APIProvider, ModelInfo
from src.config.api_ada_configs import ModelInfo, APIProvider
from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolCall
from ..payload_content.tool_option import ToolCall, ToolOption
@dataclass
@@ -75,9 +77,8 @@ class BaseClient(ABC):
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
] = None,
stream_response_handler: Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
| None = None,
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,

View File

@@ -1,17 +1,17 @@
import asyncio
import io
import orjson
import re
import base64
from collections.abc import Iterable
from typing import Callable, Any, Coroutine, Optional
from json_repair import repair_json
import io
import re
from collections.abc import Callable, Coroutine, Iterable
from typing import Any
import orjson
from json_repair import repair_json
from openai import (
AsyncOpenAI,
NOT_GIVEN,
APIConnectionError,
APIStatusError,
NOT_GIVEN,
AsyncOpenAI,
AsyncStream,
)
from openai.types.chat import (
@@ -22,18 +22,19 @@ from openai.types.chat import (
)
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from src.config.api_ada_configs import ModelInfo, APIProvider
from src.common.logger import get_logger
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
from src.config.api_ada_configs import APIProvider, ModelInfo
from ..exceptions import (
RespParseException,
NetworkConnectionError,
RespNotOkException,
ReqAbortException,
RespNotOkException,
RespParseException,
)
from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
from ..payload_content.tool_option import ToolCall, ToolOption, ToolParam
from .base_client import APIResponse, BaseClient, UsageRecord, client_registry
logger = get_logger("OpenAI客户端")
@@ -241,7 +242,7 @@ def _build_stream_api_resp(
async def _default_stream_response_handler(
resp_stream: AsyncStream[ChatCompletionChunk],
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
) -> tuple[APIResponse, tuple[int, int, int] | None]:
"""
流式响应处理函数 - 处理OpenAI API的流式响应
:param resp_stream: 流式响应对象
@@ -323,7 +324,7 @@ pattern = re.compile(
def _default_normal_response_parser(
resp: ChatCompletion,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
) -> tuple[APIResponse, tuple[int, int, int] | None]:
"""
解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象
:param resp: 响应对象
@@ -402,15 +403,13 @@ class OpenaiClient(BaseClient):
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
]
] = None,
async_response_parser: Optional[
Callable[[ChatCompletion], tuple[APIResponse, Optional[tuple[int, int, int]]]]
] = None,
stream_response_handler: Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, tuple[int, int, int] | None]],
]
| None = None,
async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int] | None]]
| None = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
@@ -527,17 +526,17 @@ class OpenaiClient(BaseClient):
)
except APIConnectionError as e:
# 添加详细的错误信息以便调试
logger.error(f"OpenAI API连接错误嵌入模型: {str(e)}")
logger.error(f"OpenAI API连接错误嵌入模型: {e!s}")
logger.error(f"错误类型: {type(e)}")
if hasattr(e, "__cause__") and e.__cause__:
logger.error(f"底层错误: {str(e.__cause__)}")
logger.error(f"底层错误: {e.__cause__!s}")
raise NetworkConnectionError() from e
except APIStatusError as e:
# 重封装APIError为RespNotOkException
raise RespNotOkException(e.status_code) from e
except Exception as e:
# 添加通用异常处理和日志记录
logger.error(f"获取嵌入时发生未知错误: {str(e)}")
logger.error(f"获取嵌入时发生未知错误: {e!s}")
logger.error(f"错误类型: {type(e)}")
raise