fix(cache-manager): 修复并行获取缓存统计信息和内存使用,避免死锁和重复计数

fix(scheduler): 改进调度任务的多阶段取消机制,彻底避免死锁风险
This commit is contained in:
Windpicker-owo
2025-11-08 22:17:12 +08:00
parent e716dee371
commit ce558514c3
2 changed files with 320 additions and 92 deletions

View File

@@ -391,76 +391,213 @@ class MultiLevelCache:
logger.info("所有缓存已清空")
async def get_stats(self) -> dict[str, Any]:
"""获取所有缓存层的统计信息(修正版,避免重复计数"""
l1_stats = await self.l1_cache.get_stats()
l2_stats = await self.l2_cache.get_stats()
"""获取所有缓存层的统计信息(修复版:避免锁嵌套,使用超时"""
# 🔧 修复:并行获取统计信息,避免锁嵌套
l1_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l1_cache, "L1"))
l2_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l2_cache, "L2"))
# 🔧 修复计算实际独占的内存避免L1和L2共享数据的重复计数
l1_keys = set(self.l1_cache._cache.keys())
l2_keys = set(self.l2_cache._cache.keys())
# 使用超时避免死锁
try:
l1_stats, l2_stats = await asyncio.gather(
asyncio.wait_for(l1_stats_task, timeout=1.0),
asyncio.wait_for(l2_stats_task, timeout=1.0),
return_exceptions=True
)
except asyncio.TimeoutError:
logger.warning("缓存统计获取超时,使用基本统计")
l1_stats = await self.l1_cache.get_stats()
l2_stats = await self.l2_cache.get_stats()
# 处理异常情况
if isinstance(l1_stats, Exception):
logger.error(f"L1统计获取失败: {l1_stats}")
l1_stats = CacheStats()
if isinstance(l2_stats, Exception):
logger.error(f"L2统计获取失败: {l2_stats}")
l2_stats = CacheStats()
# 🔧 修复:并行获取键集合,避免锁嵌套
l1_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l1_cache))
l2_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l2_cache))
try:
l1_keys, l2_keys = await asyncio.gather(
asyncio.wait_for(l1_keys_task, timeout=1.0),
asyncio.wait_for(l2_keys_task, timeout=1.0),
return_exceptions=True
)
except asyncio.TimeoutError:
logger.warning("缓存键获取超时,使用默认值")
l1_keys, l2_keys = set(), set()
# 处理异常情况
if isinstance(l1_keys, Exception):
logger.warning(f"L1键获取失败: {l1_keys}")
l1_keys = set()
if isinstance(l2_keys, Exception):
logger.warning(f"L2键获取失败: {l2_keys}")
l2_keys = set()
# 计算共享键和独占键
shared_keys = l1_keys & l2_keys
l1_only_keys = l1_keys - l2_keys
l2_only_keys = l2_keys - l1_keys
# 计算实际总内存(避免重复计数)
# L1独占内存
l1_only_size = sum(
self.l1_cache._cache[k].size
for k in l1_only_keys
if k in self.l1_cache._cache
)
# L2独占内存
l2_only_size = sum(
self.l2_cache._cache[k].size
for k in l2_only_keys
if k in self.l2_cache._cache
)
# 共享内存只计算一次使用L1的数据
shared_size = sum(
self.l1_cache._cache[k].size
for k in shared_keys
if k in self.l1_cache._cache
)
# 🔧 修复:并行计算内存使用,避免锁嵌套
l1_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l1_cache, l1_keys))
l2_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l2_cache, l2_keys))
actual_total_size = l1_only_size + l2_only_size + shared_size
try:
l1_size, l2_size = await asyncio.gather(
asyncio.wait_for(l1_size_task, timeout=1.0),
asyncio.wait_for(l2_size_task, timeout=1.0),
return_exceptions=True
)
except asyncio.TimeoutError:
logger.warning("内存计算超时,使用统计值")
l1_size, l2_size = l1_stats.total_size, l2_stats.total_size
# 处理异常情况
if isinstance(l1_size, Exception):
logger.warning(f"L1内存计算失败: {l1_size}")
l1_size = l1_stats.total_size
if isinstance(l2_size, Exception):
logger.warning(f"L2内存计算失败: {l2_size}")
l2_size = l2_stats.total_size
# 计算实际总内存(避免重复计数)
actual_total_size = l1_size + l2_size - min(l1_stats.total_size, l2_stats.total_size)
return {
"l1": l1_stats,
"l2": l2_stats,
"total_memory_mb": actual_total_size / (1024 * 1024),
"l1_only_mb": l1_only_size / (1024 * 1024),
"l2_only_mb": l2_only_size / (1024 * 1024),
"shared_mb": shared_size / (1024 * 1024),
"l1_only_mb": l1_size / (1024 * 1024),
"l2_only_mb": l2_size / (1024 * 1024),
"shared_mb": min(l1_stats.total_size, l2_stats.total_size) / (1024 * 1024),
"shared_keys_count": len(shared_keys),
"dedup_savings_mb": (l1_stats.total_size + l2_stats.total_size - actual_total_size) / (1024 * 1024),
"max_memory_mb": self.max_memory_bytes / (1024 * 1024),
"memory_usage_percent": (actual_total_size / self.max_memory_bytes * 100) if self.max_memory_bytes > 0 else 0,
}
async def check_memory_limit(self) -> None:
"""检查并强制清理超出内存限制的缓存"""
stats = await self.get_stats()
total_size = stats["l1"].total_size + stats["l2"].total_size
async def _get_cache_stats_safe(self, cache, cache_name: str) -> CacheStats:
"""安全获取缓存统计信息(带超时)"""
try:
return await asyncio.wait_for(cache.get_stats(), timeout=0.5)
except asyncio.TimeoutError:
logger.warning(f"{cache_name}统计获取超时")
return CacheStats()
except Exception as e:
logger.error(f"{cache_name}统计获取异常: {e}")
return CacheStats()
if total_size > self.max_memory_bytes:
memory_mb = total_size / (1024 * 1024)
max_mb = self.max_memory_bytes / (1024 * 1024)
logger.warning(
f"缓存内存超限: {memory_mb:.2f}MB / {max_mb:.2f}MB "
f"({stats['memory_usage_percent']:.1f}%)开始强制清理L2缓存"
async def _get_cache_keys_safe(self, cache) -> set[str]:
"""安全获取缓存键集合(带超时)"""
try:
# 快速获取键集合,使用超时避免死锁
return await asyncio.wait_for(
self._extract_keys_with_lock(cache),
timeout=0.5
)
# 优先清理L2缓存温数据
await self.l2_cache.clear()
except asyncio.TimeoutError:
logger.warning(f"缓存键获取超时: {cache.name}")
return set()
except Exception as e:
logger.error(f"缓存键获取异常: {e}")
return set()
# 如果清理L2后仍超限清理L1
stats_after_l2 = await self.get_stats()
total_after_l2 = stats_after_l2["l1"].total_size + stats_after_l2["l2"].total_size
if total_after_l2 > self.max_memory_bytes:
logger.warning("清理L2后仍超限继续清理L1缓存")
await self.l1_cache.clear()
async def _extract_keys_with_lock(self, cache) -> set[str]:
"""在锁保护下提取键集合"""
async with cache._lock:
return set(cache._cache.keys())
logger.info("缓存强制清理完成")
async def _calculate_memory_usage_safe(self, cache, keys: set[str]) -> int:
"""安全计算内存使用(带超时)"""
if not keys:
return 0
try:
return await asyncio.wait_for(
self._calc_memory_with_lock(cache, keys),
timeout=0.5
)
except asyncio.TimeoutError:
logger.warning(f"内存计算超时: {cache.name}")
return 0
except Exception as e:
logger.error(f"内存计算异常: {e}")
return 0
async def _calc_memory_with_lock(self, cache, keys: set[str]) -> int:
"""在锁保护下计算内存使用"""
total_size = 0
async with cache._lock:
for key in keys:
entry = cache._cache.get(key)
if entry:
total_size += entry.size
return total_size
async def check_memory_limit(self) -> None:
"""检查并强制清理超出内存限制的缓存(修复版:避免嵌套锁)"""
try:
# 🔧 修复:使用超时获取统计,避免死锁
stats = await asyncio.wait_for(self.get_stats(), timeout=2.0)
total_size = stats["total_memory_mb"] * (1024 * 1024) # 转换回字节
if total_size > self.max_memory_bytes:
memory_mb = total_size / (1024 * 1024)
max_mb = self.max_memory_bytes / (1024 * 1024)
logger.warning(
f"缓存内存超限: {memory_mb:.2f}MB / {max_mb:.2f}MB "
f"({stats['memory_usage_percent']:.1f}%),开始分阶段清理"
)
# 🔧 修复:分阶段清理,每阶段都有超时保护
cleanup_success = False
# 阶段1: 清理过期条目
try:
await asyncio.wait_for(self._clean_expired_entries(), timeout=3.0)
# 重新检查内存使用
stats_after_clean = await asyncio.wait_for(self.get_stats(), timeout=1.0)
total_after_clean = stats_after_clean["total_memory_mb"] * (1024 * 1024)
if total_after_clean <= self.max_memory_bytes:
logger.info("清理过期条目后内存使用正常")
cleanup_success = True
except asyncio.TimeoutError:
logger.warning("清理过期条目超时,跳到强制清理")
# 阶段2: 如果过期清理不够清理L2缓存
if not cleanup_success:
try:
logger.info("开始清理L2缓存")
await asyncio.wait_for(self.l2_cache.clear(), timeout=2.0)
logger.info("L2缓存清理完成")
# 检查L1缓存是否还需要清理
stats_after_l2 = await asyncio.wait_for(self.get_stats(), timeout=1.0)
total_after_l2 = stats_after_l2["total_memory_mb"] * (1024 * 1024)
if total_after_l2 > self.max_memory_bytes:
logger.warning("清理L2后仍超限继续清理L1缓存")
await asyncio.wait_for(self.l1_cache.clear(), timeout=2.0)
logger.info("L1缓存清理完成")
except asyncio.TimeoutError:
logger.error("强制清理超时,内存可能仍有问题")
except Exception as e:
logger.error(f"强制清理失败: {e}")
logger.info("缓存内存限制检查完成")
except asyncio.TimeoutError:
logger.warning("内存限制检查超时,跳过本次检查")
except Exception as e:
logger.error(f"内存限制检查失败: {e}", exc_info=True)
async def start_cleanup_task(self, interval: float = 60) -> None:
"""启动定期清理任务
@@ -522,44 +659,94 @@ class MultiLevelCache:
logger.info("缓存清理任务已停止")
async def _clean_expired_entries(self) -> None:
"""清理过期的缓存条目"""
"""清理过期的缓存条目(修复版:并行清理,避免锁嵌套)"""
try:
current_time = time.time()
# 清理 L1 过期条目
async with self.l1_cache._lock:
expired_keys = [
key for key, entry in self.l1_cache._cache.items()
if current_time - entry.created_at > self.l1_cache.ttl
]
# 🔧 修复:并行清理 L1 和 L2使用超时避免死锁
async def clean_l1_expired():
"""清理L1过期条目"""
try:
# 使用超时避免长时间持锁
await asyncio.wait_for(
self._clean_cache_layer_expired(self.l1_cache, current_time, "L1"),
timeout=2.0
)
except asyncio.TimeoutError:
logger.warning("L1缓存清理超时跳过本次清理")
except Exception as e:
logger.error(f"L1缓存清理异常: {e}")
for key in expired_keys:
entry = self.l1_cache._cache.pop(key, None)
if entry:
self.l1_cache._stats.evictions += 1
self.l1_cache._stats.item_count -= 1
self.l1_cache._stats.total_size -= entry.size
async def clean_l2_expired():
"""清理L2过期条目"""
try:
# 使用超时避免长时间持锁
await asyncio.wait_for(
self._clean_cache_layer_expired(self.l2_cache, current_time, "L2"),
timeout=2.0
)
except asyncio.TimeoutError:
logger.warning("L2缓存清理超时跳过本次清理")
except Exception as e:
logger.error(f"L2缓存清理异常: {e}")
# 清理 L2 过期条目
async with self.l2_cache._lock:
expired_keys = [
key for key, entry in self.l2_cache._cache.items()
if current_time - entry.created_at > self.l2_cache.ttl
]
# 🔧 关键修复:并行执行清理,避免串行等待
l1_task = asyncio.create_task(clean_l1_expired())
l2_task = asyncio.create_task(clean_l2_expired())
for key in expired_keys:
entry = self.l2_cache._cache.pop(key, None)
if entry:
self.l2_cache._stats.evictions += 1
self.l2_cache._stats.item_count -= 1
self.l2_cache._stats.total_size -= entry.size
# 等待两个清理任务完成使用return_exceptions避免一个失败影响另一个
results = await asyncio.gather(l1_task, l2_task, return_exceptions=True)
if expired_keys:
logger.debug(f"清理了 {len(expired_keys)} 个过期缓存条目")
# 检查清理结果
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(f"缓存清理任务 {'L1' if i == 0 else 'L2'} 失败: {result}")
else:
logger.debug(f"缓存清理任务 {'L1' if i == 0 else 'L2'} 完成")
except Exception as e:
logger.error(f"清理过期条目失败: {e}", exc_info=True)
async def _clean_cache_layer_expired(self, cache_layer, current_time: float, layer_name: str) -> int:
"""清理单个缓存层的过期条目(避免锁嵌套)"""
expired_keys = []
cleaned_count = 0
try:
# 快速扫描过期键(短暂持锁)
async with cache_layer._lock:
expired_keys = [
key for key, entry in cache_layer._cache.items()
if current_time - entry.created_at > cache_layer.ttl
]
# 分批删除过期键,避免长时间持锁
batch_size = 50 # 每批处理50个键
for i in range(0, len(expired_keys), batch_size):
batch = expired_keys[i:i + batch_size]
async with cache_layer._lock:
for key in batch:
entry = cache_layer._cache.pop(key, None)
if entry:
cache_layer._stats.evictions += 1
cache_layer._stats.item_count -= 1
cache_layer._stats.total_size -= entry.size
cleaned_count += 1
# 在批次之间短暂让出控制权,避免长时间阻塞
if i + batch_size < len(expired_keys):
await asyncio.sleep(0.001) # 1ms
if cleaned_count > 0:
logger.debug(f"{layer_name}缓存清理完成: {cleaned_count} 个过期条目")
except Exception as e:
logger.error(f"{layer_name}缓存层清理失败: {e}")
raise
return cleaned_count
# 全局缓存实例
_global_cache: MultiLevelCache | None = None