依旧修pyright喵~

This commit is contained in:
ikun-11451
2025-11-29 21:26:42 +08:00
parent 28719c1c89
commit 72e7492953
25 changed files with 170 additions and 104 deletions

View File

@@ -11,7 +11,7 @@ import functools
import hashlib
import time
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
from typing import Any, ParamSpec, TypeVar
from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.exc import TimeoutError as SQLTimeoutError
@@ -56,8 +56,9 @@ def generate_cache_key(
return ":".join(cache_key_parts)
T = TypeVar("T")
F = TypeVar("F", bound=Callable[..., Awaitable[Any]])
P = ParamSpec("P")
R = TypeVar("R")
def retry(
@@ -77,14 +78,13 @@ def retry(
exceptions: 需要重试的异常类型
Example:
@retry(max_attempts=3, delay=1.0)
async def query_data():
return await session.execute(stmt)
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> T:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
last_exception = None
current_delay = delay
@@ -107,7 +107,9 @@ def retry(
)
# 所有尝试都失败
raise last_exception
if last_exception:
raise last_exception
raise RuntimeError(f"Retry failed after {max_attempts} attempts")
return wrapper
@@ -128,9 +130,9 @@ def timeout(seconds: float):
return await session.execute(complex_stmt)
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> T:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
try:
return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds)
except asyncio.TimeoutError:
@@ -164,9 +166,9 @@ def cached(
return await query_user(user_id)
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> T:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# 延迟导入避免循环依赖
from src.common.database.optimization import get_cache
@@ -226,9 +228,9 @@ def measure_time(log_slow: float | None = None):
return await session.execute(stmt)
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> T:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start_time = time.perf_counter()
try:
@@ -268,21 +270,23 @@ def transactional(auto_commit: bool = True, auto_rollback: bool = True):
函数需要接受session参数
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> T:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# 查找session参数
session = None
if args:
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import AsyncSession
session: AsyncSession | None = None
if args:
for arg in args:
if isinstance(arg, AsyncSession):
session = arg
break
if not session and "session" in kwargs:
session = kwargs["session"]
possible_session = kwargs["session"]
if isinstance(possible_session, AsyncSession):
session = possible_session
if not session:
logger.warning(f"{func.__name__} 未找到session参数跳过事务管理")
@@ -331,7 +335,7 @@ def db_operation(
return await complex_operation()
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
# 从内到外应用装饰器
wrapped = func