fix(LLM_Client):修复了API 密钥的动态轮询机制

This commit is contained in:
minecraft1024a
2025-10-02 21:24:30 +08:00
parent d5627b0661
commit 15fdb67ef7
2 changed files with 22 additions and 17 deletions

View File

@@ -120,9 +120,9 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]:
转换工具选项格式 - 将工具选项转换为Gemini REST API所需的格式
"""
def _convert_tool_param(param: ToolParam) -> dict:
def _convert_tool_param(param: ToolParam) -> dict[str, Any]:
"""转换工具参数"""
result = {
result: dict[str, Any] = {
"type": param.param_type.value,
"description": param.description,
}
@@ -130,9 +130,9 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]:
result["enum"] = param.enum_values
return result
def _convert_tool_option_item(tool_option: ToolOption) -> dict:
def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]:
"""转换单个工具选项"""
function_declaration = {
function_declaration: dict[str, Any] = {
"name": tool_option.name,
"description": tool_option.description,
}
@@ -341,7 +341,6 @@ class AiohttpGeminiClient(BaseClient):
super().__init__(api_provider)
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
self.session: aiohttp.ClientSession | None = None
self.api_key = api_provider.api_key
# 如果提供了自定义base_url使用它
if api_provider.base_url:
@@ -388,11 +387,11 @@ class AiohttpGeminiClient(BaseClient):
self, method: str, endpoint: str, data: dict | None = None, stream: bool = False
) -> aiohttp.ClientResponse:
"""发起HTTP请求每次都用 with aiohttp.ClientSession() as session"""
url = f"{self.base_url}/{endpoint}?key={self.api_key}"
timeout = aiohttp.ClientTimeout(total=300)
api_key = self.api_provider.get_api_key()
url = f"{self.base_url}/{endpoint}?key={api_key}"
try:
async with aiohttp.ClientSession(
timeout=timeout,
timeout=aiohttp.ClientTimeout(total=300),
headers={"Content-Type": "application/json", "User-Agent": "MMC-AioHTTP-Gemini-Client/1.0"},
) as session:
if method.upper() == "POST":
@@ -500,7 +499,7 @@ class AiohttpGeminiClient(BaseClient):
# 直接重抛项目定义的异常
raise
except Exception as e:
logger.debug(e)
logger.debug(str(e))
# 其他异常转换为网络连接错误
raise NetworkConnectionError() from e

View File

@@ -376,11 +376,14 @@ def _default_normal_response_parser(
class OpenaiClient(BaseClient):
def __init__(self, api_provider: APIProvider):
super().__init__(api_provider)
self.client: AsyncOpenAI = AsyncOpenAI(
base_url=api_provider.base_url,
api_key=api_provider.api_key,
def _create_client(self) -> AsyncOpenAI:
"""动态创建OpenAI客户端"""
return AsyncOpenAI(
base_url=self.api_provider.base_url,
api_key=self.api_provider.get_api_key(),
max_retries=0,
timeout=api_provider.timeout,
timeout=self.api_provider.timeout,
)
async def get_response(
@@ -429,10 +432,11 @@ class OpenaiClient(BaseClient):
# 将tool_options转换为OpenAI API所需的格式
tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore
client = self._create_client()
try:
if model_info.force_stream_mode:
req_task = asyncio.create_task(
self.client.chat.completions.create(
client.chat.completions.create(
model=model_info.model_identifier,
messages=messages,
tools=tools,
@@ -455,7 +459,7 @@ class OpenaiClient(BaseClient):
# 发送请求并获取响应
# start_time = time.time()
req_task = asyncio.create_task(
self.client.chat.completions.create(
client.chat.completions.create(
model=model_info.model_identifier,
messages=messages,
tools=tools,
@@ -506,8 +510,9 @@ class OpenaiClient(BaseClient):
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
"""
client = self._create_client()
try:
raw_response = await self.client.embeddings.create(
raw_response = await client.embeddings.create(
model=model_info.model_identifier,
input=embedding_input,
extra_body=extra_params,
@@ -564,8 +569,9 @@ class OpenaiClient(BaseClient):
:extra_params: 附加的请求参数
:return: 音频转录响应
"""
client = self._create_client()
try:
raw_response = await self.client.audio.transcriptions.create(
raw_response = await client.audio.transcriptions.create(
model=model_info.model_identifier,
file=("audio.wav", io.BytesIO(base64.b64decode(audio_base64))),
extra_body=extra_params,