fix(LLM_Client):修复了API 密钥的动态轮询机制

This commit is contained in:
minecraft1024a
2025-10-02 21:24:30 +08:00
parent d5627b0661
commit 15fdb67ef7
2 changed files with 22 additions and 17 deletions

View File

@@ -120,9 +120,9 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]:
转换工具选项格式 - 将工具选项转换为Gemini REST API所需的格式 转换工具选项格式 - 将工具选项转换为Gemini REST API所需的格式
""" """
def _convert_tool_param(param: ToolParam) -> dict: def _convert_tool_param(param: ToolParam) -> dict[str, Any]:
"""转换工具参数""" """转换工具参数"""
result = { result: dict[str, Any] = {
"type": param.param_type.value, "type": param.param_type.value,
"description": param.description, "description": param.description,
} }
@@ -130,9 +130,9 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]:
result["enum"] = param.enum_values result["enum"] = param.enum_values
return result return result
def _convert_tool_option_item(tool_option: ToolOption) -> dict: def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]:
"""转换单个工具选项""" """转换单个工具选项"""
function_declaration = { function_declaration: dict[str, Any] = {
"name": tool_option.name, "name": tool_option.name,
"description": tool_option.description, "description": tool_option.description,
} }
@@ -341,7 +341,6 @@ class AiohttpGeminiClient(BaseClient):
super().__init__(api_provider) super().__init__(api_provider)
self.base_url = "https://generativelanguage.googleapis.com/v1beta" self.base_url = "https://generativelanguage.googleapis.com/v1beta"
self.session: aiohttp.ClientSession | None = None self.session: aiohttp.ClientSession | None = None
self.api_key = api_provider.api_key
# 如果提供了自定义base_url使用它 # 如果提供了自定义base_url使用它
if api_provider.base_url: if api_provider.base_url:
@@ -388,11 +387,11 @@ class AiohttpGeminiClient(BaseClient):
self, method: str, endpoint: str, data: dict | None = None, stream: bool = False self, method: str, endpoint: str, data: dict | None = None, stream: bool = False
) -> aiohttp.ClientResponse: ) -> aiohttp.ClientResponse:
"""发起HTTP请求每次都用 with aiohttp.ClientSession() as session""" """发起HTTP请求每次都用 with aiohttp.ClientSession() as session"""
url = f"{self.base_url}/{endpoint}?key={self.api_key}" api_key = self.api_provider.get_api_key()
timeout = aiohttp.ClientTimeout(total=300) url = f"{self.base_url}/{endpoint}?key={api_key}"
try: try:
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
timeout=timeout, timeout=aiohttp.ClientTimeout(total=300),
headers={"Content-Type": "application/json", "User-Agent": "MMC-AioHTTP-Gemini-Client/1.0"}, headers={"Content-Type": "application/json", "User-Agent": "MMC-AioHTTP-Gemini-Client/1.0"},
) as session: ) as session:
if method.upper() == "POST": if method.upper() == "POST":
@@ -500,7 +499,7 @@ class AiohttpGeminiClient(BaseClient):
# 直接重抛项目定义的异常 # 直接重抛项目定义的异常
raise raise
except Exception as e: except Exception as e:
logger.debug(e) logger.debug(str(e))
# 其他异常转换为网络连接错误 # 其他异常转换为网络连接错误
raise NetworkConnectionError() from e raise NetworkConnectionError() from e

View File

@@ -376,11 +376,14 @@ def _default_normal_response_parser(
class OpenaiClient(BaseClient): class OpenaiClient(BaseClient):
def __init__(self, api_provider: APIProvider): def __init__(self, api_provider: APIProvider):
super().__init__(api_provider) super().__init__(api_provider)
self.client: AsyncOpenAI = AsyncOpenAI(
base_url=api_provider.base_url, def _create_client(self) -> AsyncOpenAI:
api_key=api_provider.api_key, """动态创建OpenAI客户端"""
return AsyncOpenAI(
base_url=self.api_provider.base_url,
api_key=self.api_provider.get_api_key(),
max_retries=0, max_retries=0,
timeout=api_provider.timeout, timeout=self.api_provider.timeout,
) )
async def get_response( async def get_response(
@@ -429,10 +432,11 @@ class OpenaiClient(BaseClient):
# 将tool_options转换为OpenAI API所需的格式 # 将tool_options转换为OpenAI API所需的格式
tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore
client = self._create_client()
try: try:
if model_info.force_stream_mode: if model_info.force_stream_mode:
req_task = asyncio.create_task( req_task = asyncio.create_task(
self.client.chat.completions.create( client.chat.completions.create(
model=model_info.model_identifier, model=model_info.model_identifier,
messages=messages, messages=messages,
tools=tools, tools=tools,
@@ -455,7 +459,7 @@ class OpenaiClient(BaseClient):
# 发送请求并获取响应 # 发送请求并获取响应
# start_time = time.time() # start_time = time.time()
req_task = asyncio.create_task( req_task = asyncio.create_task(
self.client.chat.completions.create( client.chat.completions.create(
model=model_info.model_identifier, model=model_info.model_identifier,
messages=messages, messages=messages,
tools=tools, tools=tools,
@@ -506,8 +510,9 @@ class OpenaiClient(BaseClient):
:param embedding_input: 嵌入输入文本 :param embedding_input: 嵌入输入文本
:return: 嵌入响应 :return: 嵌入响应
""" """
client = self._create_client()
try: try:
raw_response = await self.client.embeddings.create( raw_response = await client.embeddings.create(
model=model_info.model_identifier, model=model_info.model_identifier,
input=embedding_input, input=embedding_input,
extra_body=extra_params, extra_body=extra_params,
@@ -564,8 +569,9 @@ class OpenaiClient(BaseClient):
:extra_params: 附加的请求参数 :extra_params: 附加的请求参数
:return: 音频转录响应 :return: 音频转录响应
""" """
client = self._create_client()
try: try:
raw_response = await self.client.audio.transcriptions.create( raw_response = await client.audio.transcriptions.create(
model=model_info.model_identifier, model=model_info.model_identifier,
file=("audio.wav", io.BytesIO(base64.b64decode(audio_base64))), file=("audio.wav", io.BytesIO(base64.b64decode(audio_base64))),
extra_body=extra_params, extra_body=extra_params,