依旧修pyright喵~
This commit is contained in:
@@ -9,9 +9,10 @@
|
||||
import operator
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.engine import CursorResult, Result
|
||||
|
||||
from src.common.database.core.models import Base
|
||||
from src.common.database.core.session import get_db_session
|
||||
@@ -25,23 +26,23 @@ from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("database.crud")
|
||||
|
||||
T = TypeVar("T", bound=Base)
|
||||
T = TypeVar("T", bound=Any)
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def _get_model_column_names(model: type[Base]) -> tuple[str, ...]:
|
||||
def _get_model_column_names(model: type[Any]) -> tuple[str, ...]:
|
||||
"""获取模型的列名称列表"""
|
||||
return tuple(column.name for column in model.__table__.columns)
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def _get_model_field_set(model: type[Base]) -> frozenset[str]:
|
||||
def _get_model_field_set(model: type[Any]) -> frozenset[str]:
|
||||
"""获取模型的有效字段集合"""
|
||||
return frozenset(_get_model_column_names(model))
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, ...]]:
|
||||
def _get_model_value_fetcher(model: type[Any]) -> Callable[[Any], tuple[Any, ...]]:
|
||||
"""为模型准备attrgetter,用于批量获取属性值"""
|
||||
column_names = _get_model_column_names(model)
|
||||
|
||||
@@ -51,21 +52,21 @@ def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, .
|
||||
if len(column_names) == 1:
|
||||
attr_name = column_names[0]
|
||||
|
||||
def _single(instance: Base) -> tuple[Any, ...]:
|
||||
def _single(instance: Any) -> tuple[Any, ...]:
|
||||
return (getattr(instance, attr_name),)
|
||||
|
||||
return _single
|
||||
|
||||
getter = operator.attrgetter(*column_names)
|
||||
|
||||
def _multi(instance: Base) -> tuple[Any, ...]:
|
||||
def _multi(instance: Any) -> tuple[Any, ...]:
|
||||
values = getter(instance)
|
||||
return values if isinstance(values, tuple) else (values,)
|
||||
|
||||
return _multi
|
||||
|
||||
|
||||
def _model_to_dict(instance: Base) -> dict[str, Any]:
|
||||
def _model_to_dict(instance: Any) -> dict[str, Any]:
|
||||
"""将 SQLAlchemy 模型实例转换为字典
|
||||
|
||||
Args:
|
||||
@@ -113,7 +114,7 @@ def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T:
|
||||
return instance
|
||||
|
||||
|
||||
class CRUDBase:
|
||||
class CRUDBase(Generic[T]):
|
||||
"""基础CRUD操作类
|
||||
|
||||
提供通用的增删改查操作,自动集成缓存和批处理
|
||||
@@ -249,7 +250,7 @@ class CRUDBase:
|
||||
if cached_dicts is not None:
|
||||
logger.debug(f"缓存命中: {cache_key}")
|
||||
# 从字典列表恢复对象列表
|
||||
return [_dict_to_model(self.model, d) for d in cached_dicts]
|
||||
return [_dict_to_model(self.model, d) for d in cached_dicts] # type: ignore
|
||||
|
||||
# 从数据库查询
|
||||
async with get_db_session() as session:
|
||||
@@ -278,7 +279,7 @@ class CRUDBase:
|
||||
await cache.set(cache_key, instances_dicts)
|
||||
|
||||
# 从字典列表重建对象列表返回(detached状态,所有字段已加载)
|
||||
return [_dict_to_model(self.model, d) for d in instances_dicts]
|
||||
return [_dict_to_model(self.model, d) for d in instances_dicts] # type: ignore
|
||||
|
||||
async def create(
|
||||
self,
|
||||
@@ -420,7 +421,7 @@ class CRUDBase:
|
||||
async with get_db_session() as session:
|
||||
stmt = delete(self.model).where(self.model.id == id)
|
||||
result = await session.execute(stmt)
|
||||
success = result.rowcount > 0
|
||||
success = result.rowcount > 0 # type: ignore
|
||||
# 注意:commit在get_db_session的context manager退出时自动执行
|
||||
|
||||
# 清除缓存
|
||||
@@ -455,7 +456,7 @@ class CRUDBase:
|
||||
stmt = stmt.where(getattr(self.model, key) == value)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar()
|
||||
return int(result.scalar() or 0)
|
||||
|
||||
async def exists(
|
||||
self,
|
||||
@@ -549,7 +550,7 @@ class CRUDBase:
|
||||
.values(**obj_in)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
count += result.rowcount
|
||||
count += result.rowcount # type: ignore
|
||||
|
||||
# 清除缓存
|
||||
cache_key = f"{self.model_name}:id:{id}"
|
||||
|
||||
Reference in New Issue
Block a user