依旧修pyright喵~

This commit is contained in:
ikun-11451
2025-11-29 21:26:42 +08:00
parent 28719c1c89
commit 72e7492953
25 changed files with 170 additions and 104 deletions

View File

@@ -9,9 +9,10 @@
import operator
from collections.abc import Callable
from functools import lru_cache
from typing import Any, TypeVar
from typing import Any, Generic, TypeVar
from sqlalchemy import delete, func, select, update
from sqlalchemy.engine import CursorResult, Result
from src.common.database.core.models import Base
from src.common.database.core.session import get_db_session
@@ -25,23 +26,23 @@ from src.common.logger import get_logger
logger = get_logger("database.crud")
T = TypeVar("T", bound=Base)
T = TypeVar("T", bound=Any)
@lru_cache(maxsize=256)
def _get_model_column_names(model: type[Base]) -> tuple[str, ...]:
def _get_model_column_names(model: type[Any]) -> tuple[str, ...]:
"""获取模型的列名称列表"""
return tuple(column.name for column in model.__table__.columns)
@lru_cache(maxsize=256)
def _get_model_field_set(model: type[Base]) -> frozenset[str]:
def _get_model_field_set(model: type[Any]) -> frozenset[str]:
"""获取模型的有效字段集合"""
return frozenset(_get_model_column_names(model))
@lru_cache(maxsize=256)
def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, ...]]:
def _get_model_value_fetcher(model: type[Any]) -> Callable[[Any], tuple[Any, ...]]:
"""为模型准备attrgetter用于批量获取属性值"""
column_names = _get_model_column_names(model)
@@ -51,21 +52,21 @@ def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, .
if len(column_names) == 1:
attr_name = column_names[0]
def _single(instance: Base) -> tuple[Any, ...]:
def _single(instance: Any) -> tuple[Any, ...]:
return (getattr(instance, attr_name),)
return _single
getter = operator.attrgetter(*column_names)
def _multi(instance: Base) -> tuple[Any, ...]:
def _multi(instance: Any) -> tuple[Any, ...]:
values = getter(instance)
return values if isinstance(values, tuple) else (values,)
return _multi
def _model_to_dict(instance: Base) -> dict[str, Any]:
def _model_to_dict(instance: Any) -> dict[str, Any]:
"""将 SQLAlchemy 模型实例转换为字典
Args:
@@ -113,7 +114,7 @@ def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T:
return instance
class CRUDBase:
class CRUDBase(Generic[T]):
"""基础CRUD操作类
提供通用的增删改查操作,自动集成缓存和批处理
@@ -249,7 +250,7 @@ class CRUDBase:
if cached_dicts is not None:
logger.debug(f"缓存命中: {cache_key}")
# 从字典列表恢复对象列表
return [_dict_to_model(self.model, d) for d in cached_dicts]
return [_dict_to_model(self.model, d) for d in cached_dicts] # type: ignore
# 从数据库查询
async with get_db_session() as session:
@@ -278,7 +279,7 @@ class CRUDBase:
await cache.set(cache_key, instances_dicts)
# 从字典列表重建对象列表返回detached状态所有字段已加载
return [_dict_to_model(self.model, d) for d in instances_dicts]
return [_dict_to_model(self.model, d) for d in instances_dicts] # type: ignore
async def create(
self,
@@ -420,7 +421,7 @@ class CRUDBase:
async with get_db_session() as session:
stmt = delete(self.model).where(self.model.id == id)
result = await session.execute(stmt)
success = result.rowcount > 0
success = result.rowcount > 0 # type: ignore
# 注意commit在get_db_session的context manager退出时自动执行
# 清除缓存
@@ -455,7 +456,7 @@ class CRUDBase:
stmt = stmt.where(getattr(self.model, key) == value)
result = await session.execute(stmt)
return result.scalar()
return int(result.scalar() or 0)
async def exists(
self,
@@ -549,7 +550,7 @@ class CRUDBase:
.values(**obj_in)
)
result = await session.execute(stmt)
count += result.rowcount
count += result.rowcount # type: ignore
# 清除缓存
cache_key = f"{self.model_name}:id:{id}"