Files
Mofox-Core/src/common/database/api/crud.py

560 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""基础CRUD API
提供通用的数据库CRUD操作集成优化层功能
- 自动缓存:查询结果自动缓存
- 批量处理:写操作自动批处理
- 智能预加载:关联数据自动预加载
"""
import operator
from collections.abc import Callable
from functools import lru_cache
from typing import Any, TypeVar
from sqlalchemy import delete, func, select, update
from src.common.database.core.models import Base
from src.common.database.core.session import get_db_session
from src.common.database.optimization import (
BatchOperation,
Priority,
get_batch_scheduler,
get_cache,
)
from src.common.logger import get_logger
logger = get_logger("database.crud")
T = TypeVar("T", bound=Base)
@lru_cache(maxsize=256)
def _get_model_column_names(model: type[Base]) -> tuple[str, ...]:
"""获取模型的列名称列表"""
return tuple(column.name for column in model.__table__.columns)
@lru_cache(maxsize=256)
def _get_model_field_set(model: type[Base]) -> frozenset[str]:
"""获取模型的有效字段集合"""
return frozenset(_get_model_column_names(model))
@lru_cache(maxsize=256)
def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, ...]]:
"""为模型准备attrgetter用于批量获取属性值"""
column_names = _get_model_column_names(model)
if not column_names:
return lambda _: ()
if len(column_names) == 1:
attr_name = column_names[0]
def _single(instance: Base) -> tuple[Any, ...]:
return (getattr(instance, attr_name),)
return _single
getter = operator.attrgetter(*column_names)
def _multi(instance: Base) -> tuple[Any, ...]:
values = getter(instance)
return values if isinstance(values, tuple) else (values,)
return _multi
def _model_to_dict(instance: Base) -> dict[str, Any]:
"""将 SQLAlchemy 模型实例转换为字典
Args:
instance: SQLAlchemy 模型实例
Returns:
字典表示的模型实例的字段值
"""
if instance is None:
return {}
model = type(instance)
column_names = _get_model_column_names(model)
fetch_values = _get_model_value_fetcher(model)
try:
values = fetch_values(instance)
return dict(zip(column_names, values))
except Exception as exc:
logger.warning(f"无法转换模型 {model.__name__}: {exc}")
fallback = {}
for column in column_names:
try:
fallback[column] = getattr(instance, column)
except Exception:
fallback[column] = None
return fallback
def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T:
"""从字典创建 SQLAlchemy 模型实例 (detached状态)
Args:
model_class: SQLAlchemy 模型类
data: 字典数据
Returns:
模型实例 (detached, 所有字段已加载)
"""
instance = model_class()
valid_fields = _get_model_field_set(model_class)
for key, value in data.items():
if key in valid_fields:
setattr(instance, key, value)
return instance
class CRUDBase:
"""基础CRUD操作类
提供通用的增删改查操作,自动集成缓存和批处理
"""
def __init__(self, model: type[T]):
"""初始化CRUD操作
Args:
model: SQLAlchemy模型类
"""
self.model = model
self.model_name = model.__tablename__
async def get(
self,
id: int,
use_cache: bool = True,
) -> T | None:
"""根据ID获取单条记录
Args:
id: 记录ID
use_cache: 是否使用缓存
Returns:
模型实例或None
"""
cache_key = f"{self.model_name}:id:{id}"
# 尝试从缓存获取 (缓存的是字典)
if use_cache:
cache = await get_cache()
cached_dict = await cache.get(cache_key)
if cached_dict is not None:
logger.debug(f"缓存命中: {cache_key}")
# 从字典恢复对象
return _dict_to_model(self.model, cached_dict)
# 从数据库查询
async with get_db_session() as session:
stmt = select(self.model).where(self.model.id == id)
result = await session.execute(stmt)
instance = result.scalar_one_or_none()
if instance is not None:
# ✅ 在 session 内部转换为字典,此时所有字段都可安全访问
instance_dict = _model_to_dict(instance)
# 写入缓存
if use_cache:
cache = await get_cache()
await cache.set(cache_key, instance_dict)
# 从字典重建对象返回detached状态所有字段已加载
return _dict_to_model(self.model, instance_dict)
return None
async def get_by(
self,
use_cache: bool = True,
**filters: Any,
) -> T | None:
"""根据条件获取单条记录
Args:
use_cache: 是否使用缓存
**filters: 过滤条件
Returns:
模型实例或None
"""
cache_key = f"{self.model_name}:filter:{sorted(filters.items())!s}"
# 尝试从缓存获取 (缓存的是字典)
if use_cache:
cache = await get_cache()
cached_dict = await cache.get(cache_key)
if cached_dict is not None:
logger.debug(f"缓存命中: {cache_key}")
# 从字典恢复对象
return _dict_to_model(self.model, cached_dict)
# 从数据库查询
async with get_db_session() as session:
stmt = select(self.model)
for key, value in filters.items():
if hasattr(self.model, key):
stmt = stmt.where(getattr(self.model, key) == value)
result = await session.execute(stmt)
instance = result.scalar_one_or_none()
if instance is not None:
# ✅ 在 session 内部转换为字典,此时所有字段都可安全访问
instance_dict = _model_to_dict(instance)
# 写入缓存
if use_cache:
cache = await get_cache()
await cache.set(cache_key, instance_dict)
# 从字典重建对象返回detached状态所有字段已加载
return _dict_to_model(self.model, instance_dict)
return None
async def get_multi(
self,
skip: int = 0,
limit: int = 100,
use_cache: bool = True,
**filters: Any,
) -> list[T]:
"""获取多条记录
Args:
skip: 跳过的记录数
limit: 返回的最大记录数
use_cache: 是否使用缓存
**filters: 过滤条件
Returns:
模型实例列表
"""
cache_key = f"{self.model_name}:multi:{skip}:{limit}:{sorted(filters.items())!s}"
# 尝试从缓存获取 (缓存的是字典列表)
if use_cache:
cache = await get_cache()
cached_dicts = await cache.get(cache_key)
if cached_dicts is not None:
logger.debug(f"缓存命中: {cache_key}")
# 从字典列表恢复对象列表
return [_dict_to_model(self.model, d) for d in cached_dicts]
# 从数据库查询
async with get_db_session() as session:
stmt = select(self.model)
# 应用过滤条件
for key, value in filters.items():
if hasattr(self.model, key):
if isinstance(value, list | tuple | set):
stmt = stmt.where(getattr(self.model, key).in_(value))
else:
stmt = stmt.where(getattr(self.model, key) == value)
# 应用分页
stmt = stmt.offset(skip).limit(limit)
result = await session.execute(stmt)
instances = list(result.scalars().all())
# ✅ 在 session 内部转换为字典列表,此时所有字段都可安全访问
instances_dicts = [_model_to_dict(inst) for inst in instances]
# 写入缓存
if use_cache:
cache = await get_cache()
await cache.set(cache_key, instances_dicts)
# 从字典列表重建对象列表返回detached状态所有字段已加载
return [_dict_to_model(self.model, d) for d in instances_dicts]
async def create(
self,
obj_in: dict[str, Any],
use_batch: bool = False,
) -> T:
"""创建新记录
Args:
obj_in: 创建数据
use_batch: 是否使用批处理
Returns:
创建的模型实例
"""
if use_batch:
# 使用批处理
scheduler = await get_batch_scheduler()
operation = BatchOperation(
operation_type="insert",
model_class=self.model,
data=obj_in,
priority=Priority.NORMAL,
)
future = await scheduler.add_operation(operation)
await future
# 批处理返回成功,创建实例
instance = self.model(**obj_in)
return instance
else:
# 直接创建
async with get_db_session() as session:
instance = self.model(**obj_in)
session.add(instance)
await session.flush()
await session.refresh(instance)
# 注意commit在get_db_session的context manager退出时自动执行
# 但为了明确性这里不需要显式commit
# 注意create不清除缓存因为
# 1. 新记录不会影响已有的单条查询缓存get/get_by
# 2. get_multi的缓存会自然过期TTL机制
# 3. 清除所有缓存代价太大,影响性能
# 如果需要强一致性应该在查询时设置use_cache=False
return instance
async def update(
self,
id: int,
obj_in: dict[str, Any],
use_batch: bool = False,
) -> T | None:
"""更新记录
Args:
id: 记录ID
obj_in: 更新数据
use_batch: 是否使用批处理
Returns:
更新后的模型实例或None
"""
# 先获取实例
instance = await self.get(id, use_cache=False)
if instance is None:
return None
if use_batch:
# 使用批处理
scheduler = await get_batch_scheduler()
operation = BatchOperation(
operation_type="update",
model_class=self.model,
conditions={"id": id},
data=obj_in,
priority=Priority.NORMAL,
)
future = await scheduler.add_operation(operation)
await future
# 更新实例属性
for key, value in obj_in.items():
if hasattr(instance, key):
setattr(instance, key, value)
else:
# 直接更新
async with get_db_session() as session:
# 重新加载实例到当前会话
stmt = select(self.model).where(self.model.id == id)
result = await session.execute(stmt)
db_instance = result.scalar_one_or_none()
if db_instance:
for key, value in obj_in.items():
if hasattr(db_instance, key):
setattr(db_instance, key, value)
await session.flush()
await session.refresh(db_instance)
instance = db_instance
# 注意commit在get_db_session的context manager退出时自动执行
# 清除缓存
cache_key = f"{self.model_name}:id:{id}"
cache = await get_cache()
await cache.delete(cache_key)
return instance
async def delete(
self,
id: int,
use_batch: bool = False,
) -> bool:
"""删除记录
Args:
id: 记录ID
use_batch: 是否使用批处理
Returns:
是否成功删除
"""
if use_batch:
# 使用批处理
scheduler = await get_batch_scheduler()
operation = BatchOperation(
operation_type="delete",
model_class=self.model,
conditions={"id": id},
priority=Priority.NORMAL,
)
future = await scheduler.add_operation(operation)
result = await future
success = result > 0
else:
# 直接删除
async with get_db_session() as session:
stmt = delete(self.model).where(self.model.id == id)
result = await session.execute(stmt)
success = result.rowcount > 0
# 注意commit在get_db_session的context manager退出时自动执行
# 清除缓存
if success:
cache_key = f"{self.model_name}:id:{id}"
cache = await get_cache()
await cache.delete(cache_key)
return success
async def count(
self,
**filters: Any,
) -> int:
"""统计记录数
Args:
**filters: 过滤条件
Returns:
记录数量
"""
async with get_db_session() as session:
stmt = select(func.count(self.model.id))
# 应用过滤条件
for key, value in filters.items():
if hasattr(self.model, key):
if isinstance(value, list | tuple | set):
stmt = stmt.where(getattr(self.model, key).in_(value))
else:
stmt = stmt.where(getattr(self.model, key) == value)
result = await session.execute(stmt)
return result.scalar()
async def exists(
self,
**filters: Any,
) -> bool:
"""检查记录是否存在
Args:
**filters: 过滤条件
Returns:
是否存在
"""
count = await self.count(**filters)
return count > 0
async def get_or_create(
self,
defaults: dict[str, Any] | None = None,
**filters: Any,
) -> tuple[T, bool]:
"""获取或创建记录
Args:
defaults: 创建时的默认值
**filters: 查找条件
Returns:
(实例, 是否新创建)
"""
# 先尝试获取
instance = await self.get_by(use_cache=False, **filters)
if instance is not None:
return instance, False
# 创建新记录
create_data = {**filters}
if defaults:
create_data.update(defaults)
instance = await self.create(create_data)
return instance, True
async def bulk_create(
self,
objs_in: list[dict[str, Any]],
) -> list[T]:
"""批量创建记录
Args:
objs_in: 创建数据列表
Returns:
创建的模型实例列表
"""
async with get_db_session() as session:
instances = [self.model(**obj_data) for obj_data in objs_in]
session.add_all(instances)
await session.flush()
for instance in instances:
await session.refresh(instance)
# 批量创建的缓存策略:
# bulk_create通常用于批量导入场景此时清除缓存是合理的
# 因为可能创建大量记录,缓存的列表查询会明显过期
cache = await get_cache()
await cache.clear()
logger.info(f"批量创建{len(instances)}{self.model_name}记录后已清除缓存")
return instances
async def bulk_update(
self,
updates: list[tuple[int, dict[str, Any]]],
) -> int:
"""批量更新记录
Args:
updates: (id, update_data)元组列表
Returns:
更新的记录数
"""
async with get_db_session() as session:
count = 0
for id, obj_in in updates:
stmt = (
update(self.model)
.where(self.model.id == id)
.values(**obj_in)
)
result = await session.execute(stmt)
count += result.rowcount
# 清除缓存
cache_key = f"{self.model_name}:id:{id}"
cache = await get_cache()
await cache.delete(cache_key)
return count