依旧修pyright喵~

This commit is contained in:
ikun-11451
2025-11-29 21:26:42 +08:00
parent 28719c1c89
commit 72e7492953
25 changed files with 170 additions and 104 deletions

View File

@@ -13,7 +13,7 @@ from collections import defaultdict, deque
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any, TypeVar
from typing import Any
from sqlalchemy import delete, insert, select, update
@@ -23,8 +23,6 @@ from src.common.memory_utils import estimate_size_smart
logger = get_logger("batch_scheduler")
T = TypeVar("T")
class Priority(IntEnum):
"""操作优先级"""
@@ -429,7 +427,7 @@ class AdaptiveBatchScheduler:
# 执行更新但不commit
result = await session.execute(stmt)
results.append((op, result.rowcount))
results.append((op, result.rowcount)) # type: ignore
# 注意commit 由 get_db_session_direct 上下文管理器自动处理
@@ -471,7 +469,7 @@ class AdaptiveBatchScheduler:
# 执行删除但不commit
result = await session.execute(stmt)
results.append((op, result.rowcount))
results.append((op, result.rowcount)) # type: ignore
# 注意commit 由 get_db_session_direct 上下文管理器自动处理

View File

@@ -398,47 +398,48 @@ class MultiLevelCache:
l2_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l2_cache, "L2"))
# 使用超时避免死锁
try:
l1_stats, l2_stats = await asyncio.gather(
asyncio.wait_for(l1_stats_task, timeout=1.0),
asyncio.wait_for(l2_stats_task, timeout=1.0),
return_exceptions=True
)
except asyncio.TimeoutError:
logger.warning("缓存统计获取超时,使用基本统计")
l1_stats = await self.l1_cache.get_stats()
l2_stats = await self.l2_cache.get_stats()
results = await asyncio.gather(
asyncio.wait_for(l1_stats_task, timeout=1.0),
asyncio.wait_for(l2_stats_task, timeout=1.0),
return_exceptions=True
)
l1_stats = results[0]
l2_stats = results[1]
# 处理异常情况
if isinstance(l1_stats, Exception):
if isinstance(l1_stats, BaseException):
logger.error(f"L1统计获取失败: {l1_stats}")
l1_stats = CacheStats()
if isinstance(l2_stats, Exception):
if isinstance(l2_stats, BaseException):
logger.error(f"L2统计获取失败: {l2_stats}")
l2_stats = CacheStats()
assert isinstance(l1_stats, CacheStats)
assert isinstance(l2_stats, CacheStats)
# 🔧 修复:并行获取键集合,避免锁嵌套
l1_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l1_cache))
l2_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l2_cache))
try:
l1_keys, l2_keys = await asyncio.gather(
asyncio.wait_for(l1_keys_task, timeout=1.0),
asyncio.wait_for(l2_keys_task, timeout=1.0),
return_exceptions=True
)
except asyncio.TimeoutError:
logger.warning("缓存键获取超时,使用默认值")
l1_keys, l2_keys = set(), set()
results = await asyncio.gather(
asyncio.wait_for(l1_keys_task, timeout=1.0),
asyncio.wait_for(l2_keys_task, timeout=1.0),
return_exceptions=True
)
l1_keys = results[0]
l2_keys = results[1]
# 处理异常情况
if isinstance(l1_keys, Exception):
if isinstance(l1_keys, BaseException):
logger.warning(f"L1键获取失败: {l1_keys}")
l1_keys = set()
if isinstance(l2_keys, Exception):
if isinstance(l2_keys, BaseException):
logger.warning(f"L2键获取失败: {l2_keys}")
l2_keys = set()
assert isinstance(l1_keys, set)
assert isinstance(l2_keys, set)
# 计算共享键和独占键
shared_keys = l1_keys & l2_keys
l1_only_keys = l1_keys - l2_keys
@@ -448,24 +449,25 @@ class MultiLevelCache:
l1_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l1_cache, l1_keys))
l2_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l2_cache, l2_keys))
try:
l1_size, l2_size = await asyncio.gather(
asyncio.wait_for(l1_size_task, timeout=1.0),
asyncio.wait_for(l2_size_task, timeout=1.0),
return_exceptions=True
)
except asyncio.TimeoutError:
logger.warning("内存计算超时,使用统计值")
l1_size, l2_size = l1_stats.total_size, l2_stats.total_size
results = await asyncio.gather(
asyncio.wait_for(l1_size_task, timeout=1.0),
asyncio.wait_for(l2_size_task, timeout=1.0),
return_exceptions=True
)
l1_size = results[0]
l2_size = results[1]
# 处理异常情况
if isinstance(l1_size, Exception):
if isinstance(l1_size, BaseException):
logger.warning(f"L1内存计算失败: {l1_size}")
l1_size = l1_stats.total_size
if isinstance(l2_size, Exception):
if isinstance(l2_size, BaseException):
logger.warning(f"L2内存计算失败: {l2_size}")
l2_size = l2_stats.total_size
assert isinstance(l1_size, int)
assert isinstance(l2_size, int)
# 计算实际总内存(避免重复计数)
actual_total_size = l1_size + l2_size - min(l1_stats.total_size, l2_stats.total_size)
@@ -769,6 +771,7 @@ async def get_cache() -> MultiLevelCache:
try:
from src.config.config import global_config
assert global_config is not None
db_config = global_config.database
# 检查是否启用缓存