依旧修pyright喵~
This commit is contained in:
@@ -13,7 +13,7 @@ from collections import defaultdict, deque
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, insert, select, update
|
||||
|
||||
@@ -23,8 +23,6 @@ from src.common.memory_utils import estimate_size_smart
|
||||
|
||||
logger = get_logger("batch_scheduler")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Priority(IntEnum):
|
||||
"""操作优先级"""
|
||||
@@ -429,7 +427,7 @@ class AdaptiveBatchScheduler:
|
||||
|
||||
# 执行更新(但不commit)
|
||||
result = await session.execute(stmt)
|
||||
results.append((op, result.rowcount))
|
||||
results.append((op, result.rowcount)) # type: ignore
|
||||
|
||||
# 注意:commit 由 get_db_session_direct 上下文管理器自动处理
|
||||
|
||||
@@ -471,7 +469,7 @@ class AdaptiveBatchScheduler:
|
||||
|
||||
# 执行删除(但不commit)
|
||||
result = await session.execute(stmt)
|
||||
results.append((op, result.rowcount))
|
||||
results.append((op, result.rowcount)) # type: ignore
|
||||
|
||||
# 注意:commit 由 get_db_session_direct 上下文管理器自动处理
|
||||
|
||||
|
||||
@@ -398,47 +398,48 @@ class MultiLevelCache:
|
||||
l2_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l2_cache, "L2"))
|
||||
|
||||
# 使用超时避免死锁
|
||||
try:
|
||||
l1_stats, l2_stats = await asyncio.gather(
|
||||
asyncio.wait_for(l1_stats_task, timeout=1.0),
|
||||
asyncio.wait_for(l2_stats_task, timeout=1.0),
|
||||
return_exceptions=True
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("缓存统计获取超时,使用基本统计")
|
||||
l1_stats = await self.l1_cache.get_stats()
|
||||
l2_stats = await self.l2_cache.get_stats()
|
||||
results = await asyncio.gather(
|
||||
asyncio.wait_for(l1_stats_task, timeout=1.0),
|
||||
asyncio.wait_for(l2_stats_task, timeout=1.0),
|
||||
return_exceptions=True
|
||||
)
|
||||
l1_stats = results[0]
|
||||
l2_stats = results[1]
|
||||
|
||||
# 处理异常情况
|
||||
if isinstance(l1_stats, Exception):
|
||||
if isinstance(l1_stats, BaseException):
|
||||
logger.error(f"L1统计获取失败: {l1_stats}")
|
||||
l1_stats = CacheStats()
|
||||
if isinstance(l2_stats, Exception):
|
||||
if isinstance(l2_stats, BaseException):
|
||||
logger.error(f"L2统计获取失败: {l2_stats}")
|
||||
l2_stats = CacheStats()
|
||||
|
||||
assert isinstance(l1_stats, CacheStats)
|
||||
assert isinstance(l2_stats, CacheStats)
|
||||
|
||||
# 🔧 修复:并行获取键集合,避免锁嵌套
|
||||
l1_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l1_cache))
|
||||
l2_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l2_cache))
|
||||
|
||||
try:
|
||||
l1_keys, l2_keys = await asyncio.gather(
|
||||
asyncio.wait_for(l1_keys_task, timeout=1.0),
|
||||
asyncio.wait_for(l2_keys_task, timeout=1.0),
|
||||
return_exceptions=True
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("缓存键获取超时,使用默认值")
|
||||
l1_keys, l2_keys = set(), set()
|
||||
results = await asyncio.gather(
|
||||
asyncio.wait_for(l1_keys_task, timeout=1.0),
|
||||
asyncio.wait_for(l2_keys_task, timeout=1.0),
|
||||
return_exceptions=True
|
||||
)
|
||||
l1_keys = results[0]
|
||||
l2_keys = results[1]
|
||||
|
||||
# 处理异常情况
|
||||
if isinstance(l1_keys, Exception):
|
||||
if isinstance(l1_keys, BaseException):
|
||||
logger.warning(f"L1键获取失败: {l1_keys}")
|
||||
l1_keys = set()
|
||||
if isinstance(l2_keys, Exception):
|
||||
if isinstance(l2_keys, BaseException):
|
||||
logger.warning(f"L2键获取失败: {l2_keys}")
|
||||
l2_keys = set()
|
||||
|
||||
assert isinstance(l1_keys, set)
|
||||
assert isinstance(l2_keys, set)
|
||||
|
||||
# 计算共享键和独占键
|
||||
shared_keys = l1_keys & l2_keys
|
||||
l1_only_keys = l1_keys - l2_keys
|
||||
@@ -448,24 +449,25 @@ class MultiLevelCache:
|
||||
l1_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l1_cache, l1_keys))
|
||||
l2_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l2_cache, l2_keys))
|
||||
|
||||
try:
|
||||
l1_size, l2_size = await asyncio.gather(
|
||||
asyncio.wait_for(l1_size_task, timeout=1.0),
|
||||
asyncio.wait_for(l2_size_task, timeout=1.0),
|
||||
return_exceptions=True
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("内存计算超时,使用统计值")
|
||||
l1_size, l2_size = l1_stats.total_size, l2_stats.total_size
|
||||
results = await asyncio.gather(
|
||||
asyncio.wait_for(l1_size_task, timeout=1.0),
|
||||
asyncio.wait_for(l2_size_task, timeout=1.0),
|
||||
return_exceptions=True
|
||||
)
|
||||
l1_size = results[0]
|
||||
l2_size = results[1]
|
||||
|
||||
# 处理异常情况
|
||||
if isinstance(l1_size, Exception):
|
||||
if isinstance(l1_size, BaseException):
|
||||
logger.warning(f"L1内存计算失败: {l1_size}")
|
||||
l1_size = l1_stats.total_size
|
||||
if isinstance(l2_size, Exception):
|
||||
if isinstance(l2_size, BaseException):
|
||||
logger.warning(f"L2内存计算失败: {l2_size}")
|
||||
l2_size = l2_stats.total_size
|
||||
|
||||
assert isinstance(l1_size, int)
|
||||
assert isinstance(l2_size, int)
|
||||
|
||||
# 计算实际总内存(避免重复计数)
|
||||
actual_total_size = l1_size + l2_size - min(l1_stats.total_size, l2_stats.total_size)
|
||||
|
||||
@@ -769,6 +771,7 @@ async def get_cache() -> MultiLevelCache:
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
assert global_config is not None
|
||||
db_config = global_config.database
|
||||
|
||||
# 检查是否启用缓存
|
||||
|
||||
Reference in New Issue
Block a user