初步重构llmrequest

This commit is contained in:
墨梓柒
2025-07-25 13:21:48 +08:00
parent 999ea4a7ce
commit 909e47bcee
21 changed files with 612 additions and 1237 deletions

View File

@@ -0,0 +1,365 @@
import asyncio
from typing import Callable, Any
from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk, ChatCompletion
from .base_client import BaseClient, APIResponse
from src.config.api_ada_configs import (
ModelInfo,
ModelUsageArgConfigItem,
RequestConfig,
ModuleConfig,
)
from ..exceptions import (
NetworkConnectionError,
ReqAbortException,
RespNotOkException,
RespParseException,
)
from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption
from ..utils import compress_messages
from src.common.logger import get_logger
logger = get_logger("模型客户端")
def _check_retry(
remain_try: int,
retry_interval: int,
can_retry_msg: str,
cannot_retry_msg: str,
can_retry_callable: Callable | None = None,
**kwargs,
) -> tuple[int, Any | None]:
"""
辅助函数:检查是否可以重试
:param remain_try: 剩余尝试次数
:param retry_interval: 重试间隔
:param can_retry_msg: 可以重试时的提示信息
:param cannot_retry_msg: 不可以重试时的提示信息
:return: (等待间隔如果为0则不等待为-1则不再请求该模型, 新的消息列表(适用于压缩消息))
"""
if remain_try > 0:
# 还有重试机会
logger.warning(f"{can_retry_msg}")
if can_retry_callable is not None:
return retry_interval, can_retry_callable(**kwargs)
else:
return retry_interval, None
else:
# 达到最大重试次数
logger.warning(f"{cannot_retry_msg}")
return -1, None # 不再重试请求该模型
def _handle_resp_not_ok(
e: RespNotOkException,
task_name: str,
model_name: str,
remain_try: int,
retry_interval: int = 10,
messages: tuple[list[Message], bool] | None = None,
):
"""
处理响应错误异常
:param e: 异常对象
:param task_name: 任务名称
:param model_name: 模型名称
:param remain_try: 剩余尝试次数
:param retry_interval: 重试间隔
:param messages: (消息列表, 是否已压缩过)
:return: (等待间隔如果为0则不等待为-1则不再请求该模型, 新的消息列表(适用于压缩消息))
"""
# 响应错误
if e.status_code in [400, 401, 402, 403, 404]:
# 客户端错误
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"请求失败,错误代码-{e.status_code},错误信息-{e.message}"
)
return -1, None # 不再重试请求该模型
elif e.status_code == 413:
if messages and not messages[1]:
# 消息列表不为空且未压缩,尝试压缩消息
return _check_retry(
remain_try,
0,
can_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"请求体过大,尝试压缩消息后重试"
),
cannot_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"请求体过大,压缩消息后仍然过大,放弃请求"
),
can_retry_callable=compress_messages,
messages=messages[0],
)
# 没有消息可压缩
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"请求体过大,无法压缩消息,放弃请求。"
)
return -1, None
elif e.status_code == 429:
# 请求过于频繁
return _check_retry(
remain_try,
retry_interval,
can_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"请求过于频繁,将于{retry_interval}秒后重试"
),
cannot_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"请求过于频繁,超过最大重试次数,放弃请求"
),
)
elif e.status_code >= 500:
# 服务器错误
return _check_retry(
remain_try,
retry_interval,
can_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"服务器错误,将于{retry_interval}秒后重试"
),
cannot_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
"服务器错误,超过最大重试次数,请稍后再试"
),
)
else:
# 未知错误
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"未知错误,错误代码-{e.status_code},错误信息-{e.message}"
)
return -1, None
def default_exception_handler(
e: Exception,
task_name: str,
model_name: str,
remain_try: int,
retry_interval: int = 10,
messages: tuple[list[Message], bool] | None = None,
) -> tuple[int, list[Message] | None]:
"""
默认异常处理函数
:param e: 异常对象
:param task_name: 任务名称
:param model_name: 模型名称
:param remain_try: 剩余尝试次数
:param retry_interval: 重试间隔
:param messages: (消息列表, 是否已压缩过)
:return (等待间隔如果为0则不等待为-1则不再请求该模型, 新的消息列表(适用于压缩消息))
"""
if isinstance(e, NetworkConnectionError): # 网络连接错误
return _check_retry(
remain_try,
retry_interval,
can_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"连接异常,将于{retry_interval}秒后重试"
),
cannot_retry_msg=(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"连接异常超过最大重试次数请检查网络连接状态或URL是否正确"
),
)
elif isinstance(e, ReqAbortException):
logger.warning(
f"任务-'{task_name}' 模型-'{model_name}'\n请求被中断,详细信息-{str(e.message)}"
)
return -1, None # 不再重试请求该模型
elif isinstance(e, RespNotOkException):
return _handle_resp_not_ok(
e,
task_name,
model_name,
remain_try,
retry_interval,
messages,
)
elif isinstance(e, RespParseException):
# 响应解析错误
logger.error(
f"任务-'{task_name}' 模型-'{model_name}'\n"
f"响应解析错误,错误信息-{e.message}\n"
)
logger.debug(f"附加内容:\n{str(e.ext_info)}")
return -1, None # 不再重试请求该模型
else:
logger.error(
f"任务-'{task_name}' 模型-'{model_name}'\n未知异常,错误信息-{str(e)}"
)
return -1, None # 不再重试请求该模型
class ModelRequestHandler:
"""
模型请求处理器
"""
def __init__(
self,
task_name: str,
config: ModuleConfig,
api_client_map: dict[str, BaseClient],
):
self.task_name: str = task_name
"""任务名称"""
self.client_map: dict[str, BaseClient] = {}
"""API客户端列表"""
self.configs: list[tuple[ModelInfo, ModelUsageArgConfigItem]] = []
"""模型参数配置"""
self.req_conf: RequestConfig = config.req_conf
"""请求配置"""
# 获取模型与使用配置
for model_usage in config.task_model_arg_map[task_name].usage:
if model_usage.name not in config.models:
logger.error(f"Model '{model_usage.name}' not found in ModelManager")
raise KeyError(f"Model '{model_usage.name}' not found in ModelManager")
model_info = config.models[model_usage.name]
if model_info.api_provider not in self.client_map:
# 缓存API客户端
self.client_map[model_info.api_provider] = api_client_map[
model_info.api_provider
]
self.configs.append((model_info, model_usage)) # 添加模型与使用配置
async def get_response(
self,
messages: list[Message],
tool_options: list[ToolOption] | None = None,
response_format: RespFormat | None = None, # 暂不启用
stream_response_handler: Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None], APIResponse
]
| None = None,
async_response_parser: Callable[[ChatCompletion], APIResponse] | None = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""
获取对话响应
:param messages: 消息列表
:param tool_options: 工具选项列表
:param response_format: 响应格式
:param stream_response_handler: 流式响应处理函数(可选)
:param async_response_parser: 响应解析函数(可选)
:param interrupt_flag: 中断信号量可选默认为None
:return: APIResponse
"""
# 遍历可用模型,若获取响应失败,则使用下一个模型继续请求
for config_item in self.configs:
client = self.client_map[config_item[0].api_provider]
model_info: ModelInfo = config_item[0]
model_usage_config: ModelUsageArgConfigItem = config_item[1]
remain_try = (
model_usage_config.max_retry or self.req_conf.max_retry
) + 1 # 初始化:剩余尝试次数 = 最大重试次数 + 1
compressed_messages = None
retry_interval = self.req_conf.retry_interval
while remain_try > 0:
try:
return await client.get_response(
model_info,
message_list=(compressed_messages or messages),
tool_options=tool_options,
max_tokens=model_usage_config.max_tokens
or self.req_conf.default_max_tokens,
temperature=model_usage_config.temperature
or self.req_conf.default_temperature,
response_format=response_format,
stream_response_handler=stream_response_handler,
async_response_parser=async_response_parser,
interrupt_flag=interrupt_flag,
)
except Exception as e:
logger.debug(e)
remain_try -= 1 # 剩余尝试次数减1
# 处理异常
handle_res = default_exception_handler(
e,
self.task_name,
model_info.name,
remain_try,
retry_interval=self.req_conf.retry_interval,
messages=(messages, compressed_messages is not None),
)
if handle_res[0] == -1:
# 等待间隔为-1表示不再请求该模型
remain_try = 0
elif handle_res[0] != 0:
# 等待间隔不为0表示需要等待
await asyncio.sleep(handle_res[0])
retry_interval *= 2
if handle_res[1] is not None:
# 压缩消息
compressed_messages = handle_res[1]
logger.error(f"任务-'{self.task_name}' 请求执行失败,所有模型均不可用")
raise RuntimeError("请求失败,所有模型均不可用") # 所有请求尝试均失败
async def get_embedding(
self,
embedding_input: str,
) -> APIResponse:
"""
获取嵌入向量
:param embedding_input: 嵌入输入
:return: APIResponse
"""
for config in self.configs:
client = self.client_map[config[0].api_provider]
model_info: ModelInfo = config[0]
model_usage_config: ModelUsageArgConfigItem = config[1]
remain_try = (
model_usage_config.max_retry or self.req_conf.max_retry
) + 1 # 初始化:剩余尝试次数 = 最大重试次数 + 1
while remain_try:
try:
return await client.get_embedding(
model_info=model_info,
embedding_input=embedding_input,
)
except Exception as e:
logger.debug(e)
remain_try -= 1 # 剩余尝试次数减1
# 处理异常
handle_res = default_exception_handler(
e,
self.task_name,
model_info.name,
remain_try,
retry_interval=self.req_conf.retry_interval,
)
if handle_res[0] == -1:
# 等待间隔为-1表示不再请求该模型
remain_try = 0
elif handle_res[0] != 0:
# 等待间隔不为0表示需要等待
await asyncio.sleep(handle_res[0])
logger.error(f"任务-'{self.task_name}' 请求执行失败,所有模型均不可用")
raise RuntimeError("请求失败,所有模型均不可用") # 所有请求尝试均失败

View File

@@ -0,0 +1,116 @@
import asyncio
from dataclasses import dataclass
from typing import Callable, Any
from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk, ChatCompletion
from src.config.api_ada_configs import ModelInfo, APIProvider
from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolCall
@dataclass
class UsageRecord:
"""
使用记录类
"""
model_name: str
"""模型名称"""
provider_name: str
"""提供商名称"""
prompt_tokens: int
"""提示token数"""
completion_tokens: int
"""完成token数"""
total_tokens: int
"""总token数"""
@dataclass
class APIResponse:
"""
API响应类
"""
content: str | None = None
"""响应内容"""
reasoning_content: str | None = None
"""推理内容"""
tool_calls: list[ToolCall] | None = None
"""工具调用 [(工具名称, 工具参数), ...]"""
embedding: list[float] | None = None
"""嵌入向量"""
usage: UsageRecord | None = None
"""使用情况 (prompt_tokens, completion_tokens, total_tokens)"""
raw_data: Any = None
"""响应原始数据"""
class BaseClient:
"""
基础客户端
"""
api_provider: APIProvider
def __init__(self, api_provider: APIProvider):
self.api_provider = api_provider
async def get_response(
self,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
tuple[APIResponse, tuple[int, int, int]],
]
| None = None,
async_response_parser: Callable[
[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]
]
| None = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""
获取对话响应
:param model_info: 模型信息
:param message_list: 对话体
:param tool_options: 工具选项可选默认为None
:param max_tokens: 最大token数可选默认为1024
:param temperature: 温度可选默认为0.7
:param response_format: 响应格式(可选,默认为 NotGiven
:param stream_response_handler: 流式响应处理函数(可选)
:param async_response_parser: 响应解析函数(可选)
:param interrupt_flag: 中断信号量可选默认为None
:return: (响应文本, 推理文本, 工具调用, 其他数据)
"""
raise RuntimeError("This method should be overridden in subclasses")
async def get_embedding(
self,
model_info: ModelInfo,
embedding_input: str,
) -> APIResponse:
"""
获取文本嵌入
:param model_info: 模型信息
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
"""
raise RuntimeError("This method should be overridden in subclasses")

View File

@@ -0,0 +1,481 @@
import asyncio
import io
from collections.abc import Iterable
from typing import Callable, Iterator, TypeVar, AsyncIterator
from google import genai
from google.genai import types
from google.genai.types import FunctionDeclaration, GenerateContentResponse
from google.genai.errors import (
ClientError,
ServerError,
UnknownFunctionCallArgumentError,
UnsupportedFunctionError,
FunctionInvocationError,
)
from .base_client import APIResponse, UsageRecord
from src.config.api_ada_configs import ModelInfo, APIProvider
from . import BaseClient
from ..exceptions import (
RespParseException,
NetworkConnectionError,
RespNotOkException,
ReqAbortException,
)
from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat, RespFormatType
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
T = TypeVar("T")
def _convert_messages(
messages: list[Message],
) -> tuple[list[types.Content], list[str] | None]:
"""
转换消息格式 - 将消息转换为Gemini API所需的格式
:param messages: 消息列表
:return: 转换后的消息列表(和可能存在的system消息)
"""
def _convert_message_item(message: Message) -> types.Content:
"""
转换单个消息格式除了system和tool类型的消息
:param message: 消息对象
:return: 转换后的消息字典
"""
# 将openai格式的角色重命名为gemini格式的角色
if message.role == RoleType.Assistant:
role = "model"
elif message.role == RoleType.User:
role = "user"
# 添加Content
content: types.Part | list
if isinstance(message.content, str):
content = types.Part.from_text(message.content)
elif isinstance(message.content, list):
content = []
for item in message.content:
if isinstance(item, tuple):
content.append(
types.Part.from_bytes(
data=item[1], mime_type=f"image/{item[0].lower()}"
)
)
elif isinstance(item, str):
content.append(types.Part.from_text(item))
else:
raise RuntimeError("无法触及的代码请使用MessageBuilder类构建消息对象")
return types.Content(role=role, content=content)
temp_list: list[types.Content] = []
system_instructions: list[str] = []
for message in messages:
if message.role == RoleType.System:
if isinstance(message.content, str):
system_instructions.append(message.content)
else:
raise RuntimeError("你tm怎么往system里面塞图片base64")
elif message.role == RoleType.Tool:
if not message.tool_call_id:
raise ValueError("无法触及的代码请使用MessageBuilder类构建消息对象")
else:
temp_list.append(_convert_message_item(message))
if system_instructions:
# 如果有system消息就把它加上去
ret: tuple = (temp_list, system_instructions)
else:
# 如果没有system消息就直接返回
ret: tuple = (temp_list, None)
return ret
def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclaration]:
"""
转换工具选项格式 - 将工具选项转换为Gemini API所需的格式
:param tool_options: 工具选项列表
:return: 转换后的工具对象列表
"""
def _convert_tool_param(tool_option_param: ToolParam) -> dict:
"""
转换单个工具参数格式
:param tool_option_param: 工具参数对象
:return: 转换后的工具参数字典
"""
return {
"type": tool_option_param.param_type.value,
"description": tool_option_param.description,
}
def _convert_tool_option_item(tool_option: ToolOption) -> FunctionDeclaration:
"""
转换单个工具项格式
:param tool_option: 工具选项对象
:return: 转换后的Gemini工具选项对象
"""
ret = {
"name": tool_option.name,
"description": tool_option.description,
}
if tool_option.params:
ret["parameters"] = {
"type": "object",
"properties": {
param.name: _convert_tool_param(param)
for param in tool_option.params
},
"required": [
param.name for param in tool_option.params if param.required
],
}
ret1 = types.FunctionDeclaration(**ret)
return ret1
return [_convert_tool_option_item(tool_option) for tool_option in tool_options]
def _process_delta(
delta: GenerateContentResponse,
fc_delta_buffer: io.StringIO,
tool_calls_buffer: list[tuple[str, str, dict]],
):
if not hasattr(delta, "candidates") or len(delta.candidates) == 0:
raise RespParseException(delta, "响应解析失败缺失candidates字段")
if delta.text:
fc_delta_buffer.write(delta.text)
if delta.function_calls: # 为什么不用hasattr呢是因为这个属性一定有即使是个空的
for call in delta.function_calls:
try:
if not isinstance(
call.args, dict
): # gemini返回的function call参数就是dict格式的了
raise RespParseException(
delta, "响应解析失败,工具调用参数无法解析为字典类型"
)
tool_calls_buffer.append(
(
call.id,
call.name,
call.args,
)
)
except Exception as e:
raise RespParseException(delta, "响应解析失败,无法解析工具调用参数") from e
def _build_stream_api_resp(
_fc_delta_buffer: io.StringIO,
_tool_calls_buffer: list[tuple[str, str, dict]],
) -> APIResponse:
resp = APIResponse()
if _fc_delta_buffer.tell() > 0:
# 如果正式内容缓冲区不为空则将其写入APIResponse对象
resp.content = _fc_delta_buffer.getvalue()
_fc_delta_buffer.close()
if len(_tool_calls_buffer) > 0:
# 如果工具调用缓冲区不为空则将其解析为ToolCall对象列表
resp.tool_calls = []
for call_id, function_name, arguments_buffer in _tool_calls_buffer:
if arguments_buffer is not None:
arguments = arguments_buffer
if not isinstance(arguments, dict):
raise RespParseException(
None,
"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n"
f"{arguments_buffer}",
)
else:
arguments = None
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
return resp
async def _to_async_iterable(iterable: Iterable[T]) -> AsyncIterator[T]:
"""
将迭代器转换为异步迭代器
:param iterable: 迭代器对象
:return: 异步迭代器对象
"""
for item in iterable:
await asyncio.sleep(0)
yield item
async def _default_stream_response_handler(
resp_stream: Iterator[GenerateContentResponse],
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, tuple[int, int, int]]:
"""
流式响应处理函数 - 处理Gemini API的流式响应
:param resp_stream: 流式响应对象,是一个神秘的iterator我完全不知道这个玩意能不能跑不过遍历一遍之后它就空了如果跑不了一点的话可以考虑改成别的东西
:return: APIResponse对象
"""
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
_tool_calls_buffer: list[
tuple[str, str, dict]
] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
def _insure_buffer_closed():
if _fc_delta_buffer and not _fc_delta_buffer.closed:
_fc_delta_buffer.close()
async for chunk in _to_async_iterable(resp_stream):
# 检查是否有中断量
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量被设置则抛出ReqAbortException
raise ReqAbortException("请求被外部信号中断")
_process_delta(
chunk,
_fc_delta_buffer,
_tool_calls_buffer,
)
if chunk.usage_metadata:
# 如果有使用情况则将其存储在APIResponse对象中
_usage_record = (
chunk.usage_metadata.prompt_token_count,
chunk.usage_metadata.candidates_token_count
+ chunk.usage_metadata.thoughts_token_count,
chunk.usage_metadata.total_token_count,
)
try:
return _build_stream_api_resp(
_fc_delta_buffer,
_tool_calls_buffer,
), _usage_record
except Exception:
# 确保缓冲区被关闭
_insure_buffer_closed()
raise
def _default_normal_response_parser(
resp: GenerateContentResponse,
) -> tuple[APIResponse, tuple[int, int, int]]:
"""
解析对话补全响应 - 将Gemini API响应解析为APIResponse对象
:param resp: 响应对象
:return: APIResponse对象
"""
api_response = APIResponse()
if not hasattr(resp, "candidates") or len(resp.candidates) == 0:
raise RespParseException(resp, "响应解析失败缺失candidates字段")
if resp.text:
api_response.content = resp.text
if resp.function_calls:
api_response.tool_calls = []
for call in resp.function_calls:
try:
if not isinstance(call.args, dict):
raise RespParseException(
resp, "响应解析失败,工具调用参数无法解析为字典类型"
)
api_response.tool_calls.append(ToolCall(call.id, call.name, call.args))
except Exception as e:
raise RespParseException(
resp, "响应解析失败,无法解析工具调用参数"
) from e
if resp.usage_metadata:
_usage_record = (
resp.usage_metadata.prompt_token_count,
resp.usage_metadata.candidates_token_count
+ resp.usage_metadata.thoughts_token_count,
resp.usage_metadata.total_token_count,
)
else:
_usage_record = None
api_response.raw_data = resp
return api_response, _usage_record
class GeminiClient(BaseClient):
client: genai.Client
def __init__(self, api_provider: APIProvider):
super().__init__(api_provider)
self.client = genai.Client(
api_key=api_provider.api_key,
) # 这里和openai不一样gemini会自己决定自己是否需要retry
async def get_response(
self,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
thinking_budget: int = 0,
response_format: RespFormat | None = None,
stream_response_handler: Callable[
[Iterator[GenerateContentResponse], asyncio.Event | None], APIResponse
]
| None = None,
async_response_parser: Callable[[GenerateContentResponse], APIResponse]
| None = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""
获取对话响应
:param model_info: 模型信息
:param message_list: 对话体
:param tool_options: 工具选项可选默认为None
:param max_tokens: 最大token数可选默认为1024
:param temperature: 温度可选默认为0.7
:param thinking_budget: 思考预算可选默认为0
:param response_format: 响应格式默认为text/plain,如果是输入的JSON Schema则必须遵守OpenAPI3.0格式,理论上和openai是一样的暂不支持其它相应格式输入
:param stream_response_handler: 流式响应处理函数可选默认为default_stream_response_handler
:param async_response_parser: 响应解析函数可选默认为default_response_parser
:param interrupt_flag: 中断信号量可选默认为None
:return: (响应文本, 推理文本, 工具调用, 其他数据)
"""
if stream_response_handler is None:
stream_response_handler = _default_stream_response_handler
if async_response_parser is None:
async_response_parser = _default_normal_response_parser
# 将messages构造为Gemini API所需的格式
messages = _convert_messages(message_list)
# 将tool_options转换为Gemini API所需的格式
tools = _convert_tool_options(tool_options) if tool_options else None
# 将response_format转换为Gemini API所需的格式
generation_config_dict = {
"max_output_tokens": max_tokens,
"temperature": temperature,
"response_modalities": ["TEXT"], # 暂时只支持文本输出
}
if "2.5" in model_info.model_identifier.lower():
# 我偷个懒在这里识别一下2.5然后开摆反正现在只有2.5支持思维链,然后我测试之后发现它不返回思考内容,反正我也怕他有朝一日返回了,我决定干掉任何有关的思维内容
generation_config_dict["thinking_config"] = types.ThinkingConfig(
thinking_budget=thinking_budget, include_thoughts=False
)
if tools:
generation_config_dict["tools"] = types.Tool(tools)
if messages[1]:
# 如果有system消息则将其添加到配置中
generation_config_dict["system_instructions"] = messages[1]
if response_format and response_format.format_type == RespFormatType.TEXT:
generation_config_dict["response_mime_type"] = "text/plain"
elif response_format and response_format.format_type in (RespFormatType.JSON_OBJ, RespFormatType.JSON_SCHEMA):
generation_config_dict["response_mime_type"] = "application/json"
generation_config_dict["response_schema"] = response_format.to_dict()
generation_config = types.GenerateContentConfig(**generation_config_dict)
try:
if model_info.force_stream_mode:
req_task = asyncio.create_task(
self.client.aio.models.generate_content_stream(
model=model_info.model_identifier,
contents=messages[0],
config=generation_config,
)
)
while not req_task.done():
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量存在且被设置,则取消任务并抛出异常
req_task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态
resp, usage_record = await stream_response_handler(
req_task.result(), interrupt_flag
)
else:
req_task = asyncio.create_task(
self.client.aio.models.generate_content(
model=model_info.model_identifier,
contents=messages[0],
config=generation_config,
)
)
while not req_task.done():
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量存在且被设置,则取消任务并抛出异常
req_task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态
resp, usage_record = async_response_parser(req_task.result())
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.status_code, e.message)
except (
UnknownFunctionCallArgumentError,
UnsupportedFunctionError,
FunctionInvocationError,
) as e:
raise ValueError("工具类型错误:请检查工具选项和参数:" + str(e))
except Exception as e:
raise NetworkConnectionError() from e
if usage_record:
resp.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=usage_record[0],
completion_tokens=usage_record[1],
total_tokens=usage_record[2],
)
async def get_embedding(
self,
model_info: ModelInfo,
embedding_input: str,
) -> APIResponse:
"""
获取文本嵌入
:param model_info: 模型信息
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
"""
try:
raw_response: types.EmbedContentResponse = (
await self.client.aio.models.embed_content(
model=model_info.model_identifier,
contents=embedding_input,
config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
)
)
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.status_code)
except Exception as e:
raise NetworkConnectionError() from e
response = APIResponse()
# 解析嵌入响应和使用情况
if hasattr(raw_response, "embeddings"):
response.embedding = raw_response.embeddings[0].values
else:
raise RespParseException(raw_response, "响应解析失败缺失embeddings字段")
response.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=len(embedding_input),
completion_tokens=0,
total_tokens=len(embedding_input),
)
return response

View File

@@ -0,0 +1,548 @@
import asyncio
import io
import json
import re
from collections.abc import Iterable
from typing import Callable, Any
from openai import (
AsyncOpenAI,
APIConnectionError,
APIStatusError,
NOT_GIVEN,
AsyncStream,
)
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessageParam,
ChatCompletionToolParam,
)
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from .base_client import APIResponse, UsageRecord
from src.config.api_ada_configs import ModelInfo, APIProvider
from . import BaseClient
from ..exceptions import (
RespParseException,
NetworkConnectionError,
RespNotOkException,
ReqAbortException,
)
from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]:
"""
转换消息格式 - 将消息转换为OpenAI API所需的格式
:param messages: 消息列表
:return: 转换后的消息列表
"""
def _convert_message_item(message: Message) -> ChatCompletionMessageParam:
"""
转换单个消息格式
:param message: 消息对象
:return: 转换后的消息字典
"""
# 添加Content
content: str | list[dict[str, Any]]
if isinstance(message.content, str):
content = message.content
elif isinstance(message.content, list):
content = []
for item in message.content:
if isinstance(item, tuple):
content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/{item[0].lower()};base64,{item[1]}"
},
}
)
elif isinstance(item, str):
content.append({"type": "text", "text": item})
else:
raise RuntimeError("无法触及的代码请使用MessageBuilder类构建消息对象")
ret = {
"role": message.role.value,
"content": content,
}
# 添加工具调用ID
if message.role == RoleType.Tool:
if not message.tool_call_id:
raise ValueError("无法触及的代码请使用MessageBuilder类构建消息对象")
ret["tool_call_id"] = message.tool_call_id
return ret
return [_convert_message_item(message) for message in messages]
def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]]:
"""
转换工具选项格式 - 将工具选项转换为OpenAI API所需的格式
:param tool_options: 工具选项列表
:return: 转换后的工具选项列表
"""
def _convert_tool_param(tool_option_param: ToolParam) -> dict[str, str]:
"""
转换单个工具参数格式
:param tool_option_param: 工具参数对象
:return: 转换后的工具参数字典
"""
return {
"type": tool_option_param.param_type.value,
"description": tool_option_param.description,
}
def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]:
"""
转换单个工具项格式
:param tool_option: 工具选项对象
:return: 转换后的工具选项字典
"""
ret: dict[str, Any] = {
"name": tool_option.name,
"description": tool_option.description,
}
if tool_option.params:
ret["parameters"] = {
"type": "object",
"properties": {
param.name: _convert_tool_param(param)
for param in tool_option.params
},
"required": [
param.name for param in tool_option.params if param.required
],
}
return ret
return [
{
"type": "function",
"function": _convert_tool_option_item(tool_option),
}
for tool_option in tool_options
]
def _process_delta(
delta: ChoiceDelta,
has_rc_attr_flag: bool,
in_rc_flag: bool,
rc_delta_buffer: io.StringIO,
fc_delta_buffer: io.StringIO,
tool_calls_buffer: list[tuple[str, str, io.StringIO]],
) -> bool:
# 接收content
if has_rc_attr_flag:
# 有独立的推理内容块则无需考虑content内容的判读
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
# 如果有推理内容,则将其写入推理内容缓冲区
assert isinstance(delta.reasoning_content, str)
rc_delta_buffer.write(delta.reasoning_content)
elif delta.content:
# 如果有正式内容,则将其写入正式内容缓冲区
fc_delta_buffer.write(delta.content)
elif hasattr(delta, "content") and delta.content is not None:
# 没有独立的推理内容块,但有正式内容
if in_rc_flag:
# 当前在推理内容块中
if delta.content == "</think>":
# 如果当前内容是</think>,则将其视为推理内容的结束标记,退出推理内容块
in_rc_flag = False
else:
# 其他情况视为推理内容,加入推理内容缓冲区
rc_delta_buffer.write(delta.content)
elif delta.content == "<think>" and not fc_delta_buffer.getvalue():
# 如果当前内容是<think>,且正式内容缓冲区为空,说明<think>为输出的首个token
# 则将其视为推理内容的开始标记,进入推理内容块
in_rc_flag = True
else:
# 其他情况视为正式内容,加入正式内容缓冲区
fc_delta_buffer.write(delta.content)
# 接收tool_calls
if hasattr(delta, "tool_calls") and delta.tool_calls:
tool_call_delta = delta.tool_calls[0]
if tool_call_delta.index >= len(tool_calls_buffer):
# 调用索引号大于等于缓冲区长度,说明是新的工具调用
tool_calls_buffer.append(
(
tool_call_delta.id,
tool_call_delta.function.name,
io.StringIO(),
)
)
if tool_call_delta.function.arguments:
# 如果有工具调用参数,则添加到对应的工具调用的参数串缓冲区中
tool_calls_buffer[tool_call_delta.index][2].write(
tool_call_delta.function.arguments
)
return in_rc_flag
def _build_stream_api_resp(
_fc_delta_buffer: io.StringIO,
_rc_delta_buffer: io.StringIO,
_tool_calls_buffer: list[tuple[str, str, io.StringIO]],
) -> APIResponse:
resp = APIResponse()
if _rc_delta_buffer.tell() > 0:
# 如果推理内容缓冲区不为空则将其写入APIResponse对象
resp.reasoning_content = _rc_delta_buffer.getvalue()
_rc_delta_buffer.close()
if _fc_delta_buffer.tell() > 0:
# 如果正式内容缓冲区不为空则将其写入APIResponse对象
resp.content = _fc_delta_buffer.getvalue()
_fc_delta_buffer.close()
if _tool_calls_buffer:
# 如果工具调用缓冲区不为空则将其解析为ToolCall对象列表
resp.tool_calls = []
for call_id, function_name, arguments_buffer in _tool_calls_buffer:
if arguments_buffer.tell() > 0:
# 如果参数串缓冲区不为空则解析为JSON对象
raw_arg_data = arguments_buffer.getvalue()
arguments_buffer.close()
try:
arguments = json.loads(raw_arg_data)
if not isinstance(arguments, dict):
raise RespParseException(
None,
"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n"
f"{raw_arg_data}",
)
except json.JSONDecodeError as e:
raise RespParseException(
None,
"响应解析失败,无法解析工具调用参数。工具调用参数原始响应:"
f"{raw_arg_data}",
) from e
else:
arguments_buffer.close()
arguments = None
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
return resp
async def _default_stream_response_handler(
resp_stream: AsyncStream[ChatCompletionChunk],
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, tuple[int, int, int]]:
"""
流式响应处理函数 - 处理OpenAI API的流式响应
:param resp_stream: 流式响应对象
:return: APIResponse对象
"""
_has_rc_attr_flag = False # 标记是否有独立的推理内容块
_in_rc_flag = False # 标记是否在推理内容块中
_rc_delta_buffer = io.StringIO() # 推理内容缓冲区,用于存储接收到的推理内容
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
_tool_calls_buffer: list[
tuple[str, str, io.StringIO]
] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
def _insure_buffer_closed():
# 确保缓冲区被关闭
if _rc_delta_buffer and not _rc_delta_buffer.closed:
_rc_delta_buffer.close()
if _fc_delta_buffer and not _fc_delta_buffer.closed:
_fc_delta_buffer.close()
for _, _, buffer in _tool_calls_buffer:
if buffer and not buffer.closed:
buffer.close()
async for event in resp_stream:
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量被设置则抛出ReqAbortException
_insure_buffer_closed()
raise ReqAbortException("请求被外部信号中断")
delta = event.choices[0].delta # 获取当前块的delta内容
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
# 标记:有独立的推理内容块
_has_rc_attr_flag = True
_in_rc_flag = _process_delta(
delta,
_has_rc_attr_flag,
_in_rc_flag,
_rc_delta_buffer,
_fc_delta_buffer,
_tool_calls_buffer,
)
if event.usage:
# 如果有使用情况则将其存储在APIResponse对象中
_usage_record = (
event.usage.prompt_tokens,
event.usage.completion_tokens,
event.usage.total_tokens,
)
try:
return _build_stream_api_resp(
_fc_delta_buffer,
_rc_delta_buffer,
_tool_calls_buffer,
), _usage_record
except Exception:
# 确保缓冲区被关闭
_insure_buffer_closed()
raise
pattern = re.compile(
r"<think>(?P<think>.*?)</think>(?P<content>.*)|<think>(?P<think_unclosed>.*)|(?P<content_only>.+)",
re.DOTALL,
)
"""用于解析推理内容的正则表达式"""
def _default_normal_response_parser(
resp: ChatCompletion,
) -> tuple[APIResponse, tuple[int, int, int]]:
"""
解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象
:param resp: 响应对象
:return: APIResponse对象
"""
api_response = APIResponse()
if not hasattr(resp, "choices") or len(resp.choices) == 0:
raise RespParseException(resp, "响应解析失败缺失choices字段")
message_part = resp.choices[0].message
if hasattr(message_part, "reasoning_content") and message_part.reasoning_content:
# 有有效的推理字段
api_response.content = message_part.content
api_response.reasoning_content = message_part.reasoning_content
elif message_part.content:
# 提取推理和内容
match = pattern.match(message_part.content)
if not match:
raise RespParseException(resp, "响应解析失败,无法捕获推理内容和输出内容")
if match.group("think") is not None:
result = match.group("think").strip(), match.group("content").strip()
elif match.group("think_unclosed") is not None:
result = match.group("think_unclosed").strip(), None
else:
result = None, match.group("content_only").strip()
api_response.reasoning_content, api_response.content = result
# 提取工具调用
if message_part.tool_calls:
api_response.tool_calls = []
for call in message_part.tool_calls:
try:
arguments = json.loads(call.function.arguments)
if not isinstance(arguments, dict):
raise RespParseException(
resp, "响应解析失败,工具调用参数无法解析为字典类型"
)
api_response.tool_calls.append(
ToolCall(call.id, call.function.name, arguments)
)
except json.JSONDecodeError as e:
raise RespParseException(
resp, "响应解析失败,无法解析工具调用参数"
) from e
# 提取Usage信息
if resp.usage:
_usage_record = (
resp.usage.prompt_tokens,
resp.usage.completion_tokens,
resp.usage.total_tokens,
)
else:
_usage_record = None
# 将原始响应存储在原始数据中
api_response.raw_data = resp
return api_response, _usage_record
class OpenaiClient(BaseClient):
def __init__(self, api_provider: APIProvider):
super().__init__(api_provider)
self.client: AsyncOpenAI = AsyncOpenAI(
base_url=api_provider.base_url,
api_key=api_provider.api_key,
max_retries=0,
)
async def get_response(
self,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
tuple[APIResponse, tuple[int, int, int]],
]
| None = None,
async_response_parser: Callable[
[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]
]
| None = None,
interrupt_flag: asyncio.Event | None = None,
) -> APIResponse:
"""
获取对话响应
:param model_info: 模型信息
:param message_list: 对话体
:param tool_options: 工具选项可选默认为None
:param max_tokens: 最大token数可选默认为1024
:param temperature: 温度可选默认为0.7
:param response_format: 响应格式(可选,默认为 NotGiven
:param stream_response_handler: 流式响应处理函数可选默认为default_stream_response_handler
:param async_response_parser: 响应解析函数可选默认为default_response_parser
:param interrupt_flag: 中断信号量可选默认为None
:return: (响应文本, 推理文本, 工具调用, 其他数据)
"""
if stream_response_handler is None:
stream_response_handler = _default_stream_response_handler
if async_response_parser is None:
async_response_parser = _default_normal_response_parser
# 将messages构造为OpenAI API所需的格式
messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list)
# 将tool_options转换为OpenAI API所需的格式
tools: Iterable[ChatCompletionToolParam] = (
_convert_tool_options(tool_options) if tool_options else NOT_GIVEN
)
try:
if model_info.force_stream_mode:
req_task = asyncio.create_task(
self.client.chat.completions.create(
model=model_info.model_identifier,
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
stream=True,
response_format=response_format.to_dict()
if response_format
else NOT_GIVEN,
)
)
while not req_task.done():
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量存在且被设置,则取消任务并抛出异常
req_task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态
resp, usage_record = await stream_response_handler(
req_task.result(), interrupt_flag
)
else:
# 发送请求并获取响应
req_task = asyncio.create_task(
self.client.chat.completions.create(
model=model_info.model_identifier,
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
stream=False,
response_format=response_format.to_dict()
if response_format
else NOT_GIVEN,
)
)
while not req_task.done():
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量存在且被设置,则取消任务并抛出异常
req_task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态
resp, usage_record = async_response_parser(req_task.result())
except APIConnectionError as e:
# 重封装APIConnectionError为NetworkConnectionError
raise NetworkConnectionError() from e
except APIStatusError as e:
# 重封装APIError为RespNotOkException
raise RespNotOkException(e.status_code, e.message) from e
if usage_record:
resp.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=usage_record[0],
completion_tokens=usage_record[1],
total_tokens=usage_record[2],
)
return resp
async def get_embedding(
self,
model_info: ModelInfo,
embedding_input: str,
) -> APIResponse:
"""
获取文本嵌入
:param model_info: 模型信息
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
"""
try:
raw_response = await self.client.embeddings.create(
model=model_info.model_identifier,
input=embedding_input,
)
except APIConnectionError as e:
raise NetworkConnectionError() from e
except APIStatusError as e:
# 重封装APIError为RespNotOkException
raise RespNotOkException(e.status_code) from e
response = APIResponse()
# 解析嵌入响应
if len(raw_response.data) > 0:
response.embedding = raw_response.data[0].embedding
else:
raise RespParseException(
raw_response,
"响应解析失败,缺失嵌入数据。",
)
# 解析使用情况
if hasattr(raw_response, "usage"):
response.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=raw_response.usage.prompt_tokens,
completion_tokens=raw_response.usage.completion_tokens,
total_tokens=raw_response.usage.total_tokens,
)
return response