482 lines
14 KiB
Python
482 lines
14 KiB
Python
"""高级查询API
|
||
|
||
提供复杂的查询操作:
|
||
- MongoDB风格的查询操作符
|
||
- 聚合查询
|
||
- 排序和分页
|
||
- 关联查询
|
||
"""
|
||
|
||
from typing import Any, Generic, TypeVar
|
||
|
||
from sqlalchemy import and_, asc, desc, func, or_, select
|
||
|
||
# 导入 CRUD 辅助函数以避免重复定义
|
||
from src.common.database.api.crud import _dict_to_model, _model_to_dict
|
||
from src.common.database.core.models import Base
|
||
from src.common.database.core.session import get_db_session
|
||
from src.common.database.optimization import get_cache
|
||
from src.common.logger import get_logger
|
||
|
||
logger = get_logger("database.query")
|
||
|
||
T = TypeVar("T", bound=Any)
|
||
|
||
|
||
class QueryBuilder(Generic[T]):
|
||
"""查询构建器
|
||
|
||
支持链式调用,构建复杂查询
|
||
"""
|
||
|
||
def __init__(self, model: type[T]):
|
||
"""初始化查询构建器
|
||
|
||
Args:
|
||
model: SQLAlchemy模型类
|
||
"""
|
||
self.model = model
|
||
self.model_name = model.__tablename__
|
||
self._stmt = select(model)
|
||
self._use_cache = True
|
||
self._cache_key_parts: list[str] = [self.model_name]
|
||
|
||
def filter(self, **conditions: Any) -> "QueryBuilder":
|
||
"""添加过滤条件
|
||
|
||
支持的操作符:
|
||
- 直接相等: field=value
|
||
- 大于: field__gt=value
|
||
- 小于: field__lt=value
|
||
- 大于等于: field__gte=value
|
||
- 小于等于: field__lte=value
|
||
- 不等于: field__ne=value
|
||
- 包含: field__in=[values]
|
||
- 不包含: field__nin=[values]
|
||
- 模糊匹配: field__like='%pattern%'
|
||
- 为空: field__isnull=True
|
||
|
||
Args:
|
||
**conditions: 过滤条件
|
||
|
||
Returns:
|
||
self,支持链式调用
|
||
"""
|
||
for key, value in conditions.items():
|
||
# 解析字段和操作符
|
||
if "__" in key:
|
||
field_name, operator = key.rsplit("__", 1)
|
||
else:
|
||
field_name, operator = key, "eq"
|
||
|
||
if not hasattr(self.model, field_name):
|
||
logger.warning(f"模型 {self.model_name} 没有字段 {field_name}")
|
||
continue
|
||
|
||
field = getattr(self.model, field_name)
|
||
|
||
# 应用操作符
|
||
if operator == "eq":
|
||
self._stmt = self._stmt.where(field == value)
|
||
elif operator == "gt":
|
||
self._stmt = self._stmt.where(field > value)
|
||
elif operator == "lt":
|
||
self._stmt = self._stmt.where(field < value)
|
||
elif operator == "gte":
|
||
self._stmt = self._stmt.where(field >= value)
|
||
elif operator == "lte":
|
||
self._stmt = self._stmt.where(field <= value)
|
||
elif operator == "ne":
|
||
self._stmt = self._stmt.where(field != value)
|
||
elif operator == "in":
|
||
self._stmt = self._stmt.where(field.in_(value))
|
||
elif operator == "nin":
|
||
self._stmt = self._stmt.where(~field.in_(value))
|
||
elif operator == "like":
|
||
self._stmt = self._stmt.where(field.like(value))
|
||
elif operator == "isnull":
|
||
if value:
|
||
self._stmt = self._stmt.where(field.is_(None))
|
||
else:
|
||
self._stmt = self._stmt.where(field.isnot(None))
|
||
else:
|
||
logger.warning(f"未知操作符: {operator}")
|
||
|
||
# 更新缓存键
|
||
self._cache_key_parts.append(f"filter:{sorted(conditions.items())!s}")
|
||
return self
|
||
|
||
def filter_or(self, **conditions: Any) -> "QueryBuilder":
|
||
"""添加OR过滤条件
|
||
|
||
Args:
|
||
**conditions: OR条件
|
||
|
||
Returns:
|
||
self,支持链式调用
|
||
"""
|
||
or_conditions = []
|
||
for key, value in conditions.items():
|
||
if hasattr(self.model, key):
|
||
field = getattr(self.model, key)
|
||
or_conditions.append(field == value)
|
||
|
||
if or_conditions:
|
||
self._stmt = self._stmt.where(or_(*or_conditions))
|
||
self._cache_key_parts.append(f"or:{sorted(conditions.items())!s}")
|
||
|
||
return self
|
||
|
||
def order_by(self, *fields: str) -> "QueryBuilder":
|
||
"""添加排序
|
||
|
||
Args:
|
||
*fields: 排序字段,'-'前缀表示降序
|
||
|
||
Returns:
|
||
self,支持链式调用
|
||
"""
|
||
for field_name in fields:
|
||
if field_name.startswith("-"):
|
||
field_name = field_name[1:]
|
||
if hasattr(self.model, field_name):
|
||
self._stmt = self._stmt.order_by(desc(getattr(self.model, field_name)))
|
||
else:
|
||
if hasattr(self.model, field_name):
|
||
self._stmt = self._stmt.order_by(asc(getattr(self.model, field_name)))
|
||
|
||
self._cache_key_parts.append(f"order:{','.join(fields)}")
|
||
return self
|
||
|
||
def limit(self, limit: int) -> "QueryBuilder":
|
||
"""限制结果数量
|
||
|
||
Args:
|
||
limit: 最大数量
|
||
|
||
Returns:
|
||
self,支持链式调用
|
||
"""
|
||
self._stmt = self._stmt.limit(limit)
|
||
self._cache_key_parts.append(f"limit:{limit}")
|
||
return self
|
||
|
||
def offset(self, offset: int) -> "QueryBuilder":
|
||
"""跳过指定数量
|
||
|
||
Args:
|
||
offset: 跳过数量
|
||
|
||
Returns:
|
||
self,支持链式调用
|
||
"""
|
||
self._stmt = self._stmt.offset(offset)
|
||
self._cache_key_parts.append(f"offset:{offset}")
|
||
return self
|
||
|
||
def no_cache(self) -> "QueryBuilder":
|
||
"""禁用缓存
|
||
|
||
Returns:
|
||
self,支持链式调用
|
||
"""
|
||
self._use_cache = False
|
||
return self
|
||
|
||
async def all(self, *, as_dict: bool = False) -> list[T] | list[dict[str, Any]]:
|
||
"""获取所有结果
|
||
|
||
Args:
|
||
as_dict: 为True时返回字典格式
|
||
|
||
Returns:
|
||
模型实例列表或字典列表
|
||
"""
|
||
cache_key = ":".join(self._cache_key_parts) + ":all"
|
||
|
||
# 尝试从缓存获取 (缓存的是字典列表)
|
||
if self._use_cache:
|
||
cache = await get_cache()
|
||
cached_dicts = await cache.get(cache_key)
|
||
if cached_dicts is not None:
|
||
dict_rows = [dict(row) for row in cached_dicts]
|
||
if as_dict:
|
||
return dict_rows
|
||
return [_dict_to_model(self.model, row) for row in dict_rows]
|
||
|
||
# 从数据库查询
|
||
async with get_db_session() as session:
|
||
result = await session.execute(self._stmt)
|
||
instances = list(result.scalars().all())
|
||
|
||
# 在 session 内部转换为字典列表,此时所有字段都可安全访问
|
||
instances_dicts = [_model_to_dict(inst) for inst in instances]
|
||
|
||
if self._use_cache:
|
||
cache = await get_cache()
|
||
cache_payload = [dict(row) for row in instances_dicts]
|
||
await cache.set(cache_key, cache_payload)
|
||
|
||
if as_dict:
|
||
return instances_dicts
|
||
return [_dict_to_model(self.model, row) for row in instances_dicts]
|
||
|
||
async def first(self, *, as_dict: bool = False) -> T | dict[str, Any] | None:
|
||
"""获取第一条结果
|
||
|
||
Args:
|
||
as_dict: 为True时返回字典格式
|
||
|
||
Returns:
|
||
模型实例或None
|
||
"""
|
||
cache_key = ":".join(self._cache_key_parts) + ":first"
|
||
|
||
# 尝试从缓存获取 (缓存的是字典)
|
||
if self._use_cache:
|
||
cache = await get_cache()
|
||
cached_dict = await cache.get(cache_key)
|
||
if cached_dict is not None:
|
||
row = dict(cached_dict)
|
||
if as_dict:
|
||
return row
|
||
return _dict_to_model(self.model, row)
|
||
|
||
# 从数据库查询
|
||
async with get_db_session() as session:
|
||
result = await session.execute(self._stmt)
|
||
instance = result.scalars().first()
|
||
|
||
if instance is not None:
|
||
# 在 session 内部转换为字典,此时所有字段都可安全访问
|
||
instance_dict = _model_to_dict(instance)
|
||
|
||
# 写入缓存
|
||
if self._use_cache:
|
||
cache = await get_cache()
|
||
await cache.set(cache_key, dict(instance_dict))
|
||
|
||
if as_dict:
|
||
return instance_dict
|
||
return _dict_to_model(self.model, instance_dict)
|
||
|
||
return None
|
||
|
||
async def count(self) -> int:
|
||
"""统计数量
|
||
|
||
Returns:
|
||
记录数量
|
||
"""
|
||
cache_key = ":".join(self._cache_key_parts) + ":count"
|
||
|
||
# 尝试从缓存获取
|
||
if self._use_cache:
|
||
cache = await get_cache()
|
||
cached = await cache.get(cache_key)
|
||
if cached is not None:
|
||
return cached
|
||
|
||
# 构建count查询
|
||
count_stmt = select(func.count()).select_from(self._stmt.subquery())
|
||
|
||
# 从数据库查询
|
||
async with get_db_session() as session:
|
||
result = await session.execute(count_stmt)
|
||
count = result.scalar() or 0
|
||
|
||
# 写入缓存
|
||
if self._use_cache:
|
||
cache = await get_cache()
|
||
await cache.set(cache_key, count)
|
||
|
||
return count
|
||
|
||
async def exists(self) -> bool:
|
||
"""检查是否存在
|
||
|
||
Returns:
|
||
是否存在记录
|
||
"""
|
||
count = await self.count()
|
||
return count > 0
|
||
|
||
async def paginate(
|
||
self,
|
||
page: int = 1,
|
||
page_size: int = 20,
|
||
) -> tuple[list[T], int]:
|
||
"""分页查询
|
||
|
||
Args:
|
||
page: 页码(从1开始)
|
||
page_size: 每页数量
|
||
|
||
Returns:
|
||
(结果列表, 总数量)
|
||
"""
|
||
# 计算偏移量
|
||
offset = (page - 1) * page_size
|
||
|
||
# 获取总数
|
||
total = await self.count()
|
||
|
||
# 获取当前页数据
|
||
self._stmt = self._stmt.offset(offset).limit(page_size)
|
||
self._cache_key_parts.append(f"page:{page}:{page_size}")
|
||
|
||
items = await self.all()
|
||
|
||
return items, total # type: ignore
|
||
|
||
|
||
class AggregateQuery:
|
||
"""聚合查询
|
||
|
||
提供聚合操作如sum、avg、max、min等
|
||
"""
|
||
|
||
def __init__(self, model: type[T]):
|
||
"""初始化聚合查询
|
||
|
||
Args:
|
||
model: SQLAlchemy模型类
|
||
"""
|
||
self.model = model
|
||
self.model_name = model.__tablename__
|
||
self._conditions = []
|
||
|
||
def filter(self, **conditions: Any) -> "AggregateQuery":
|
||
"""添加过滤条件
|
||
|
||
Args:
|
||
**conditions: 过滤条件
|
||
|
||
Returns:
|
||
self,支持链式调用
|
||
"""
|
||
for key, value in conditions.items():
|
||
if hasattr(self.model, key):
|
||
field = getattr(self.model, key)
|
||
self._conditions.append(field == value)
|
||
return self
|
||
|
||
async def sum(self, field: str) -> float:
|
||
"""求和
|
||
|
||
Args:
|
||
field: 字段名
|
||
|
||
Returns:
|
||
总和
|
||
"""
|
||
if not hasattr(self.model, field):
|
||
raise ValueError(f"字段 {field} 不存在")
|
||
|
||
async with get_db_session() as session:
|
||
stmt = select(func.sum(getattr(self.model, field)))
|
||
|
||
if self._conditions:
|
||
stmt = stmt.where(and_(*self._conditions))
|
||
|
||
result = await session.execute(stmt)
|
||
return result.scalar() or 0
|
||
|
||
async def avg(self, field: str) -> float:
|
||
"""求平均值
|
||
|
||
Args:
|
||
field: 字段名
|
||
|
||
Returns:
|
||
平均值
|
||
"""
|
||
if not hasattr(self.model, field):
|
||
raise ValueError(f"字段 {field} 不存在")
|
||
|
||
async with get_db_session() as session:
|
||
stmt = select(func.avg(getattr(self.model, field)))
|
||
|
||
if self._conditions:
|
||
stmt = stmt.where(and_(*self._conditions))
|
||
|
||
result = await session.execute(stmt)
|
||
return result.scalar() or 0
|
||
|
||
async def max(self, field: str) -> Any:
|
||
"""求最大值
|
||
|
||
Args:
|
||
field: 字段名
|
||
|
||
Returns:
|
||
最大值
|
||
"""
|
||
if not hasattr(self.model, field):
|
||
raise ValueError(f"字段 {field} 不存在")
|
||
|
||
async with get_db_session() as session:
|
||
stmt = select(func.max(getattr(self.model, field)))
|
||
|
||
if self._conditions:
|
||
stmt = stmt.where(and_(*self._conditions))
|
||
|
||
result = await session.execute(stmt)
|
||
return result.scalar()
|
||
|
||
async def min(self, field: str) -> Any:
|
||
"""求最小值
|
||
|
||
Args:
|
||
field: 字段名
|
||
|
||
Returns:
|
||
最小值
|
||
"""
|
||
if not hasattr(self.model, field):
|
||
raise ValueError(f"字段 {field} 不存在")
|
||
|
||
async with get_db_session() as session:
|
||
stmt = select(func.min(getattr(self.model, field)))
|
||
|
||
if self._conditions:
|
||
stmt = stmt.where(and_(*self._conditions))
|
||
|
||
result = await session.execute(stmt)
|
||
return result.scalar()
|
||
|
||
async def group_by_count(
|
||
self,
|
||
*fields: str,
|
||
) -> list[tuple[Any, ...]]:
|
||
"""分组统计
|
||
|
||
Args:
|
||
*fields: 分组字段
|
||
|
||
Returns:
|
||
[(分组值1, 分组值2, ..., 数量), ...]
|
||
"""
|
||
if not fields:
|
||
raise ValueError("至少需要一个分组字段")
|
||
|
||
group_columns = [
|
||
getattr(self.model, field_name)
|
||
for field_name in fields
|
||
if hasattr(self.model, field_name)
|
||
]
|
||
|
||
if not group_columns:
|
||
return []
|
||
|
||
async with get_db_session() as session:
|
||
stmt = select(*group_columns, func.count(self.model.id))
|
||
|
||
if self._conditions:
|
||
stmt = stmt.where(and_(*self._conditions))
|
||
|
||
stmt = stmt.group_by(*group_columns)
|
||
|
||
result = await session.execute(stmt)
|
||
return [tuple(row) for row in result.all()]
|