Merge branch 'dev' into feature/kfc

This commit is contained in:
拾风
2025-12-01 16:06:47 +08:00
committed by GitHub
87 changed files with 6181 additions and 2355 deletions

View File

@@ -9,9 +9,10 @@
import operator
from collections.abc import Callable
from functools import lru_cache
from typing import Any, TypeVar
from typing import Any, Generic, TypeVar
from sqlalchemy import delete, func, select, update
from sqlalchemy.engine import CursorResult, Result
from src.common.database.core.models import Base
from src.common.database.core.session import get_db_session
@@ -25,23 +26,23 @@ from src.common.logger import get_logger
logger = get_logger("database.crud")
T = TypeVar("T", bound=Base)
T = TypeVar("T", bound=Any)
@lru_cache(maxsize=256)
def _get_model_column_names(model: type[Base]) -> tuple[str, ...]:
def _get_model_column_names(model: type[Any]) -> tuple[str, ...]:
"""获取模型的列名称列表"""
return tuple(column.name for column in model.__table__.columns)
@lru_cache(maxsize=256)
def _get_model_field_set(model: type[Base]) -> frozenset[str]:
def _get_model_field_set(model: type[Any]) -> frozenset[str]:
"""获取模型的有效字段集合"""
return frozenset(_get_model_column_names(model))
@lru_cache(maxsize=256)
def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, ...]]:
def _get_model_value_fetcher(model: type[Any]) -> Callable[[Any], tuple[Any, ...]]:
"""为模型准备attrgetter用于批量获取属性值"""
column_names = _get_model_column_names(model)
@@ -51,21 +52,21 @@ def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, .
if len(column_names) == 1:
attr_name = column_names[0]
def _single(instance: Base) -> tuple[Any, ...]:
def _single(instance: Any) -> tuple[Any, ...]:
return (getattr(instance, attr_name),)
return _single
getter = operator.attrgetter(*column_names)
def _multi(instance: Base) -> tuple[Any, ...]:
def _multi(instance: Any) -> tuple[Any, ...]:
values = getter(instance)
return values if isinstance(values, tuple) else (values,)
return _multi
def _model_to_dict(instance: Base) -> dict[str, Any]:
def _model_to_dict(instance: Any) -> dict[str, Any]:
"""将 SQLAlchemy 模型实例转换为字典
Args:
@@ -113,7 +114,7 @@ def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T:
return instance
class CRUDBase:
class CRUDBase(Generic[T]):
"""基础CRUD操作类
提供通用的增删改查操作,自动集成缓存和批处理
@@ -246,7 +247,7 @@ class CRUDBase:
cached_dicts = await cache.get(cache_key)
if cached_dicts is not None:
# 从字典列表恢复对象列表
return [_dict_to_model(self.model, d) for d in cached_dicts]
return [_dict_to_model(self.model, d) for d in cached_dicts] # type: ignore
# 从数据库查询
async with get_db_session() as session:
@@ -275,7 +276,7 @@ class CRUDBase:
await cache.set(cache_key, instances_dicts)
# 从字典列表重建对象列表返回detached状态所有字段已加载
return [_dict_to_model(self.model, d) for d in instances_dicts]
return [_dict_to_model(self.model, d) for d in instances_dicts] # type: ignore
async def create(
self,
@@ -417,7 +418,7 @@ class CRUDBase:
async with get_db_session() as session:
stmt = delete(self.model).where(self.model.id == id)
result = await session.execute(stmt)
success = result.rowcount > 0
success = result.rowcount > 0 # type: ignore
# 注意commit在get_db_session的context manager退出时自动执行
# 清除缓存
@@ -452,7 +453,7 @@ class CRUDBase:
stmt = stmt.where(getattr(self.model, key) == value)
result = await session.execute(stmt)
return result.scalar()
return int(result.scalar() or 0)
async def exists(
self,
@@ -546,7 +547,7 @@ class CRUDBase:
.values(**obj_in)
)
result = await session.execute(stmt)
count += result.rowcount
count += result.rowcount # type: ignore
# 清除缓存
cache_key = f"{self.model_name}:id:{id}"