feat: 支持多个API Key,增强错误处理和负载均衡机制

This commit is contained in:
墨梓柒
2025-07-27 13:55:18 +08:00
parent e240fb92ca
commit 16931ef7b4
6 changed files with 391 additions and 44 deletions

View File

@@ -1,5 +1,7 @@
from dataclasses import dataclass, field
from typing import List, Dict
from typing import List, Dict, Union
import threading
import time
from packaging.version import Version
@@ -9,8 +11,106 @@ NEWEST_VER = "0.1.1" # 当前支持的最新版本
class APIProvider:
name: str = "" # API提供商名称
base_url: str = "" # API基础URL
api_key: str = field(repr=False, default="") # API密钥
api_key: str = field(repr=False, default="") # API密钥(向后兼容)
api_keys: List[str] = field(repr=False, default_factory=list) # API密钥列表新格式
client_type: str = "openai" # 客户端类型如openai/google等默认为openai
# 多API Key管理相关属性
_current_key_index: int = field(default=0, init=False, repr=False) # 当前使用的key索引
_key_failure_count: Dict[int, int] = field(default_factory=dict, init=False, repr=False) # 每个key的失败次数
_key_last_failure_time: Dict[int, float] = field(default_factory=dict, init=False, repr=False) # 每个key最后失败时间
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False) # 线程锁
def __post_init__(self):
"""初始化后处理确保API keys列表正确"""
# 向后兼容如果只设置了api_key将其添加到api_keys列表
if self.api_key and not self.api_keys:
self.api_keys = [self.api_key]
# 如果api_keys不为空但api_key为空设置api_key为第一个
elif self.api_keys and not self.api_key:
self.api_key = self.api_keys[0]
# 初始化失败计数器
for i in range(len(self.api_keys)):
self._key_failure_count[i] = 0
self._key_last_failure_time[i] = 0
def get_current_api_key(self) -> str:
"""获取当前应该使用的API Key"""
with self._lock:
if not self.api_keys:
return ""
# 确保索引在有效范围内
if self._current_key_index >= len(self.api_keys):
self._current_key_index = 0
return self.api_keys[self._current_key_index]
def get_next_api_key(self) -> Union[str, None]:
"""获取下一个可用的API Key负载均衡"""
with self._lock:
if not self.api_keys:
return None
# 如果只有一个key直接返回
if len(self.api_keys) == 1:
return self.api_keys[0]
# 轮询到下一个key
self._current_key_index = (self._current_key_index + 1) % len(self.api_keys)
return self.api_keys[self._current_key_index]
def mark_key_failed(self, api_key: str) -> Union[str, None]:
"""标记某个API Key失败返回下一个可用的key"""
with self._lock:
if not self.api_keys or api_key not in self.api_keys:
return None
key_index = self.api_keys.index(api_key)
self._key_failure_count[key_index] += 1
self._key_last_failure_time[key_index] = time.time()
# 寻找下一个可用的key
current_time = time.time()
for _ in range(len(self.api_keys)):
self._current_key_index = (self._current_key_index + 1) % len(self.api_keys)
next_key_index = self._current_key_index
# 检查该key是否最近失败过5分钟内失败超过3次则暂时跳过
if (self._key_failure_count[next_key_index] <= 3 or
current_time - self._key_last_failure_time[next_key_index] > 300): # 5分钟后重试
return self.api_keys[next_key_index]
# 如果所有key都不可用返回当前key让上层处理
return api_key
def reset_key_failures(self, api_key: str = None):
"""重置失败计数(成功调用后调用)"""
with self._lock:
if api_key and api_key in self.api_keys:
key_index = self.api_keys.index(api_key)
self._key_failure_count[key_index] = 0
self._key_last_failure_time[key_index] = 0
else:
# 重置所有key的失败计数
for i in range(len(self.api_keys)):
self._key_failure_count[i] = 0
self._key_last_failure_time[i] = 0
def get_api_key_stats(self) -> Dict[str, Dict[str, Union[int, float]]]:
"""获取API Key使用统计"""
with self._lock:
stats = {}
for i, key in enumerate(self.api_keys):
# 只显示key的前8位和后4位中间用*代替
masked_key = f"{key[:8]}***{key[-4:]}" if len(key) > 12 else "***"
stats[masked_key] = {
"failure_count": self._key_failure_count.get(i, 0),
"last_failure_time": self._key_last_failure_time.get(i, 0),
"is_current": i == self._current_key_index
}
return stats
@dataclass