fix(LLM_Client):修复了API 密钥的动态轮询机制
This commit is contained in:
@@ -120,9 +120,9 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]:
|
|||||||
转换工具选项格式 - 将工具选项转换为Gemini REST API所需的格式
|
转换工具选项格式 - 将工具选项转换为Gemini REST API所需的格式
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _convert_tool_param(param: ToolParam) -> dict:
|
def _convert_tool_param(param: ToolParam) -> dict[str, Any]:
|
||||||
"""转换工具参数"""
|
"""转换工具参数"""
|
||||||
result = {
|
result: dict[str, Any] = {
|
||||||
"type": param.param_type.value,
|
"type": param.param_type.value,
|
||||||
"description": param.description,
|
"description": param.description,
|
||||||
}
|
}
|
||||||
@@ -130,9 +130,9 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]:
|
|||||||
result["enum"] = param.enum_values
|
result["enum"] = param.enum_values
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _convert_tool_option_item(tool_option: ToolOption) -> dict:
|
def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]:
|
||||||
"""转换单个工具选项"""
|
"""转换单个工具选项"""
|
||||||
function_declaration = {
|
function_declaration: dict[str, Any] = {
|
||||||
"name": tool_option.name,
|
"name": tool_option.name,
|
||||||
"description": tool_option.description,
|
"description": tool_option.description,
|
||||||
}
|
}
|
||||||
@@ -341,7 +341,6 @@ class AiohttpGeminiClient(BaseClient):
|
|||||||
super().__init__(api_provider)
|
super().__init__(api_provider)
|
||||||
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
|
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||||
self.session: aiohttp.ClientSession | None = None
|
self.session: aiohttp.ClientSession | None = None
|
||||||
self.api_key = api_provider.api_key
|
|
||||||
|
|
||||||
# 如果提供了自定义base_url,使用它
|
# 如果提供了自定义base_url,使用它
|
||||||
if api_provider.base_url:
|
if api_provider.base_url:
|
||||||
@@ -388,11 +387,11 @@ class AiohttpGeminiClient(BaseClient):
|
|||||||
self, method: str, endpoint: str, data: dict | None = None, stream: bool = False
|
self, method: str, endpoint: str, data: dict | None = None, stream: bool = False
|
||||||
) -> aiohttp.ClientResponse:
|
) -> aiohttp.ClientResponse:
|
||||||
"""发起HTTP请求(每次都用 with aiohttp.ClientSession() as session)"""
|
"""发起HTTP请求(每次都用 with aiohttp.ClientSession() as session)"""
|
||||||
url = f"{self.base_url}/{endpoint}?key={self.api_key}"
|
api_key = self.api_provider.get_api_key()
|
||||||
timeout = aiohttp.ClientTimeout(total=300)
|
url = f"{self.base_url}/{endpoint}?key={api_key}"
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
timeout=timeout,
|
timeout=aiohttp.ClientTimeout(total=300),
|
||||||
headers={"Content-Type": "application/json", "User-Agent": "MMC-AioHTTP-Gemini-Client/1.0"},
|
headers={"Content-Type": "application/json", "User-Agent": "MMC-AioHTTP-Gemini-Client/1.0"},
|
||||||
) as session:
|
) as session:
|
||||||
if method.upper() == "POST":
|
if method.upper() == "POST":
|
||||||
@@ -500,7 +499,7 @@ class AiohttpGeminiClient(BaseClient):
|
|||||||
# 直接重抛项目定义的异常
|
# 直接重抛项目定义的异常
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(e)
|
logger.debug(str(e))
|
||||||
# 其他异常转换为网络连接错误
|
# 其他异常转换为网络连接错误
|
||||||
raise NetworkConnectionError() from e
|
raise NetworkConnectionError() from e
|
||||||
|
|
||||||
|
|||||||
@@ -376,11 +376,14 @@ def _default_normal_response_parser(
|
|||||||
class OpenaiClient(BaseClient):
|
class OpenaiClient(BaseClient):
|
||||||
def __init__(self, api_provider: APIProvider):
|
def __init__(self, api_provider: APIProvider):
|
||||||
super().__init__(api_provider)
|
super().__init__(api_provider)
|
||||||
self.client: AsyncOpenAI = AsyncOpenAI(
|
|
||||||
base_url=api_provider.base_url,
|
def _create_client(self) -> AsyncOpenAI:
|
||||||
api_key=api_provider.api_key,
|
"""动态创建OpenAI客户端"""
|
||||||
|
return AsyncOpenAI(
|
||||||
|
base_url=self.api_provider.base_url,
|
||||||
|
api_key=self.api_provider.get_api_key(),
|
||||||
max_retries=0,
|
max_retries=0,
|
||||||
timeout=api_provider.timeout,
|
timeout=self.api_provider.timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_response(
|
async def get_response(
|
||||||
@@ -429,10 +432,11 @@ class OpenaiClient(BaseClient):
|
|||||||
# 将tool_options转换为OpenAI API所需的格式
|
# 将tool_options转换为OpenAI API所需的格式
|
||||||
tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore
|
tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore
|
||||||
|
|
||||||
|
client = self._create_client()
|
||||||
try:
|
try:
|
||||||
if model_info.force_stream_mode:
|
if model_info.force_stream_mode:
|
||||||
req_task = asyncio.create_task(
|
req_task = asyncio.create_task(
|
||||||
self.client.chat.completions.create(
|
client.chat.completions.create(
|
||||||
model=model_info.model_identifier,
|
model=model_info.model_identifier,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
@@ -455,7 +459,7 @@ class OpenaiClient(BaseClient):
|
|||||||
# 发送请求并获取响应
|
# 发送请求并获取响应
|
||||||
# start_time = time.time()
|
# start_time = time.time()
|
||||||
req_task = asyncio.create_task(
|
req_task = asyncio.create_task(
|
||||||
self.client.chat.completions.create(
|
client.chat.completions.create(
|
||||||
model=model_info.model_identifier,
|
model=model_info.model_identifier,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
@@ -506,8 +510,9 @@ class OpenaiClient(BaseClient):
|
|||||||
:param embedding_input: 嵌入输入文本
|
:param embedding_input: 嵌入输入文本
|
||||||
:return: 嵌入响应
|
:return: 嵌入响应
|
||||||
"""
|
"""
|
||||||
|
client = self._create_client()
|
||||||
try:
|
try:
|
||||||
raw_response = await self.client.embeddings.create(
|
raw_response = await client.embeddings.create(
|
||||||
model=model_info.model_identifier,
|
model=model_info.model_identifier,
|
||||||
input=embedding_input,
|
input=embedding_input,
|
||||||
extra_body=extra_params,
|
extra_body=extra_params,
|
||||||
@@ -564,8 +569,9 @@ class OpenaiClient(BaseClient):
|
|||||||
:extra_params: 附加的请求参数
|
:extra_params: 附加的请求参数
|
||||||
:return: 音频转录响应
|
:return: 音频转录响应
|
||||||
"""
|
"""
|
||||||
|
client = self._create_client()
|
||||||
try:
|
try:
|
||||||
raw_response = await self.client.audio.transcriptions.create(
|
raw_response = await client.audio.transcriptions.create(
|
||||||
model=model_info.model_identifier,
|
model=model_info.model_identifier,
|
||||||
file=("audio.wav", io.BytesIO(base64.b64decode(audio_base64))),
|
file=("audio.wav", io.BytesIO(base64.b64decode(audio_base64))),
|
||||||
extra_body=extra_params,
|
extra_body=extra_params,
|
||||||
|
|||||||
Reference in New Issue
Block a user