re-style: 格式化代码
This commit is contained in:
committed by
Windpicker-owo
parent
00ba07e0e1
commit
a79253c714
@@ -1,21 +1,24 @@
|
||||
import asyncio
|
||||
import orjson
|
||||
import io
|
||||
from typing import Callable, Any, Coroutine, Optional
|
||||
import aiohttp
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
|
||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||
from src.common.logger import get_logger
|
||||
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo
|
||||
|
||||
from ..exceptions import (
|
||||
RespParseException,
|
||||
NetworkConnectionError,
|
||||
RespNotOkException,
|
||||
ReqAbortException,
|
||||
RespNotOkException,
|
||||
RespParseException,
|
||||
)
|
||||
from ..payload_content.message import Message, RoleType
|
||||
from ..payload_content.resp_format import RespFormat, RespFormatType
|
||||
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
||||
from ..payload_content.tool_option import ToolCall, ToolOption, ToolParam
|
||||
from .base_client import APIResponse, BaseClient, UsageRecord, client_registry
|
||||
|
||||
logger = get_logger("AioHTTP-Gemini客户端")
|
||||
|
||||
@@ -210,7 +213,7 @@ class AiohttpGeminiStreamParser:
|
||||
chunk_data = orjson.loads(chunk_text)
|
||||
|
||||
# 解析候选项
|
||||
if "candidates" in chunk_data and chunk_data["candidates"]:
|
||||
if chunk_data.get("candidates"):
|
||||
candidate = chunk_data["candidates"][0]
|
||||
|
||||
# 解析内容
|
||||
@@ -266,7 +269,7 @@ class AiohttpGeminiStreamParser:
|
||||
async def _default_stream_response_handler(
|
||||
response: aiohttp.ClientResponse,
|
||||
interrupt_flag: asyncio.Event | None,
|
||||
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
||||
) -> tuple[APIResponse, tuple[int, int, int] | None]:
|
||||
"""默认流式响应处理器"""
|
||||
parser = AiohttpGeminiStreamParser()
|
||||
|
||||
@@ -290,13 +293,13 @@ async def _default_stream_response_handler(
|
||||
|
||||
def _default_normal_response_parser(
|
||||
response_data: dict,
|
||||
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
||||
) -> tuple[APIResponse, tuple[int, int, int] | None]:
|
||||
"""默认普通响应解析器"""
|
||||
api_response = APIResponse()
|
||||
|
||||
try:
|
||||
# 解析候选项
|
||||
if "candidates" in response_data and response_data["candidates"]:
|
||||
if response_data.get("candidates"):
|
||||
candidate = response_data["candidates"][0]
|
||||
|
||||
# 解析文本内容
|
||||
@@ -418,13 +421,12 @@ class AiohttpGeminiClient(BaseClient):
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 0.7,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[
|
||||
[aiohttp.ClientResponse, asyncio.Event | None],
|
||||
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
|
||||
]
|
||||
] = None,
|
||||
async_response_parser: Optional[Callable[[dict], tuple[APIResponse, Optional[tuple[int, int, int]]]]] = None,
|
||||
stream_response_handler: Callable[
|
||||
[aiohttp.ClientResponse, asyncio.Event | None],
|
||||
Coroutine[Any, Any, tuple[APIResponse, tuple[int, int, int] | None]],
|
||||
]
|
||||
| None = None,
|
||||
async_response_parser: Callable[[dict], tuple[APIResponse, tuple[int, int, int] | None]] | None = None,
|
||||
interrupt_flag: asyncio.Event | None = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Any, Optional
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo
|
||||
|
||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||
from ..payload_content.message import Message
|
||||
from ..payload_content.resp_format import RespFormat
|
||||
from ..payload_content.tool_option import ToolOption, ToolCall
|
||||
from ..payload_content.tool_option import ToolCall, ToolOption
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -75,9 +77,8 @@ class BaseClient(ABC):
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 0.7,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
|
||||
] = None,
|
||||
stream_response_handler: Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
|
||||
| None = None,
|
||||
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
|
||||
interrupt_flag: asyncio.Event | None = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import asyncio
|
||||
import io
|
||||
import orjson
|
||||
import re
|
||||
import base64
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, Any, Coroutine, Optional
|
||||
from json_repair import repair_json
|
||||
import io
|
||||
import re
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
from openai import (
|
||||
AsyncOpenAI,
|
||||
NOT_GIVEN,
|
||||
APIConnectionError,
|
||||
APIStatusError,
|
||||
NOT_GIVEN,
|
||||
AsyncOpenAI,
|
||||
AsyncStream,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
@@ -22,18 +22,19 @@ from openai.types.chat import (
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
|
||||
from src.config.api_ada_configs import ModelInfo, APIProvider
|
||||
from src.common.logger import get_logger
|
||||
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo
|
||||
|
||||
from ..exceptions import (
|
||||
RespParseException,
|
||||
NetworkConnectionError,
|
||||
RespNotOkException,
|
||||
ReqAbortException,
|
||||
RespNotOkException,
|
||||
RespParseException,
|
||||
)
|
||||
from ..payload_content.message import Message, RoleType
|
||||
from ..payload_content.resp_format import RespFormat
|
||||
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
||||
from ..payload_content.tool_option import ToolCall, ToolOption, ToolParam
|
||||
from .base_client import APIResponse, BaseClient, UsageRecord, client_registry
|
||||
|
||||
logger = get_logger("OpenAI客户端")
|
||||
|
||||
@@ -241,7 +242,7 @@ def _build_stream_api_resp(
|
||||
async def _default_stream_response_handler(
|
||||
resp_stream: AsyncStream[ChatCompletionChunk],
|
||||
interrupt_flag: asyncio.Event | None,
|
||||
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
||||
) -> tuple[APIResponse, tuple[int, int, int] | None]:
|
||||
"""
|
||||
流式响应处理函数 - 处理OpenAI API的流式响应
|
||||
:param resp_stream: 流式响应对象
|
||||
@@ -323,7 +324,7 @@ pattern = re.compile(
|
||||
|
||||
def _default_normal_response_parser(
|
||||
resp: ChatCompletion,
|
||||
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
||||
) -> tuple[APIResponse, tuple[int, int, int] | None]:
|
||||
"""
|
||||
解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象
|
||||
:param resp: 响应对象
|
||||
@@ -402,15 +403,13 @@ class OpenaiClient(BaseClient):
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 0.7,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[
|
||||
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
|
||||
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
|
||||
]
|
||||
] = None,
|
||||
async_response_parser: Optional[
|
||||
Callable[[ChatCompletion], tuple[APIResponse, Optional[tuple[int, int, int]]]]
|
||||
] = None,
|
||||
stream_response_handler: Callable[
|
||||
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
|
||||
Coroutine[Any, Any, tuple[APIResponse, tuple[int, int, int] | None]],
|
||||
]
|
||||
| None = None,
|
||||
async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int] | None]]
|
||||
| None = None,
|
||||
interrupt_flag: asyncio.Event | None = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
@@ -527,17 +526,17 @@ class OpenaiClient(BaseClient):
|
||||
)
|
||||
except APIConnectionError as e:
|
||||
# 添加详细的错误信息以便调试
|
||||
logger.error(f"OpenAI API连接错误(嵌入模型): {str(e)}")
|
||||
logger.error(f"OpenAI API连接错误(嵌入模型): {e!s}")
|
||||
logger.error(f"错误类型: {type(e)}")
|
||||
if hasattr(e, "__cause__") and e.__cause__:
|
||||
logger.error(f"底层错误: {str(e.__cause__)}")
|
||||
logger.error(f"底层错误: {e.__cause__!s}")
|
||||
raise NetworkConnectionError() from e
|
||||
except APIStatusError as e:
|
||||
# 重封装APIError为RespNotOkException
|
||||
raise RespNotOkException(e.status_code) from e
|
||||
except Exception as e:
|
||||
# 添加通用异常处理和日志记录
|
||||
logger.error(f"获取嵌入时发生未知错误: {str(e)}")
|
||||
logger.error(f"获取嵌入时发生未知错误: {e!s}")
|
||||
logger.error(f"错误类型: {type(e)}")
|
||||
raise
|
||||
|
||||
|
||||
Reference in New Issue
Block a user