style: ruff自动格式化修复 - 修复180个空白行和格式问题

This commit is contained in:
Windpicker-owo
2025-11-01 17:06:40 +08:00
parent ece6a70c65
commit cabaf74194
4 changed files with 173 additions and 179 deletions

View File

@@ -7,19 +7,16 @@
- 关联查询
"""
from typing import Any, Generic, Optional, Sequence, Type, TypeVar
from typing import Any, Generic, TypeVar
from sqlalchemy import and_, asc, desc, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.engine import Row
from src.common.database.core.models import Base
from src.common.database.core.session import get_db_session
from src.common.database.optimization import get_cache, get_preloader
from src.common.logger import get_logger
# 导入 CRUD 辅助函数以避免重复定义
from src.common.database.api.crud import _dict_to_model, _model_to_dict
from src.common.database.core.models import Base
from src.common.database.core.session import get_db_session
from src.common.database.optimization import get_cache
from src.common.logger import get_logger
logger = get_logger("database.query")
@@ -28,13 +25,13 @@ T = TypeVar("T", bound="Base")
class QueryBuilder(Generic[T]):
"""查询构建器
支持链式调用,构建复杂查询
"""
def __init__(self, model: Type[T]):
def __init__(self, model: type[T]):
"""初始化查询构建器
Args:
model: SQLAlchemy模型类
"""
@@ -46,7 +43,7 @@ class QueryBuilder(Generic[T]):
def filter(self, **conditions: Any) -> "QueryBuilder":
"""添加过滤条件
支持的操作符:
- 直接相等: field=value
- 大于: field__gt=value
@@ -58,10 +55,10 @@ class QueryBuilder(Generic[T]):
- 不包含: field__nin=[values]
- 模糊匹配: field__like='%pattern%'
- 为空: field__isnull=True
Args:
**conditions: 过滤条件
Returns:
self支持链式调用
"""
@@ -71,13 +68,13 @@ class QueryBuilder(Generic[T]):
field_name, operator = key.rsplit("__", 1)
else:
field_name, operator = key, "eq"
if not hasattr(self.model, field_name):
logger.warning(f"模型 {self.model_name} 没有字段 {field_name}")
continue
field = getattr(self.model, field_name)
# 应用操作符
if operator == "eq":
self._stmt = self._stmt.where(field == value)
@@ -104,17 +101,17 @@ class QueryBuilder(Generic[T]):
self._stmt = self._stmt.where(field.isnot(None))
else:
logger.warning(f"未知操作符: {operator}")
# 更新缓存键
self._cache_key_parts.append(f"filter:{str(sorted(conditions.items()))}")
self._cache_key_parts.append(f"filter:{sorted(conditions.items())!s}")
return self
def filter_or(self, **conditions: Any) -> "QueryBuilder":
"""添加OR过滤条件
Args:
**conditions: OR条件
Returns:
self支持链式调用
"""
@@ -123,19 +120,19 @@ class QueryBuilder(Generic[T]):
if hasattr(self.model, key):
field = getattr(self.model, key)
or_conditions.append(field == value)
if or_conditions:
self._stmt = self._stmt.where(or_(*or_conditions))
self._cache_key_parts.append(f"or:{str(sorted(conditions.items()))}")
self._cache_key_parts.append(f"or:{sorted(conditions.items())!s}")
return self
def order_by(self, *fields: str) -> "QueryBuilder":
"""添加排序
Args:
*fields: 排序字段,'-'前缀表示降序
Returns:
self支持链式调用
"""
@@ -147,16 +144,16 @@ class QueryBuilder(Generic[T]):
else:
if hasattr(self.model, field_name):
self._stmt = self._stmt.order_by(asc(getattr(self.model, field_name)))
self._cache_key_parts.append(f"order:{','.join(fields)}")
return self
def limit(self, limit: int) -> "QueryBuilder":
"""限制结果数量
Args:
limit: 最大数量
Returns:
self支持链式调用
"""
@@ -166,10 +163,10 @@ class QueryBuilder(Generic[T]):
def offset(self, offset: int) -> "QueryBuilder":
"""跳过指定数量
Args:
offset: 跳过数量
Returns:
self支持链式调用
"""
@@ -179,7 +176,7 @@ class QueryBuilder(Generic[T]):
def no_cache(self) -> "QueryBuilder":
"""禁用缓存
Returns:
self支持链式调用
"""
@@ -188,12 +185,12 @@ class QueryBuilder(Generic[T]):
async def all(self) -> list[T]:
"""获取所有结果
Returns:
模型实例列表
"""
cache_key = ":".join(self._cache_key_parts) + ":all"
# 尝试从缓存获取 (缓存的是字典列表)
if self._use_cache:
cache = await get_cache()
@@ -202,12 +199,12 @@ class QueryBuilder(Generic[T]):
logger.debug(f"缓存命中: {cache_key}")
# 从字典列表恢复对象列表
return [_dict_to_model(self.model, d) for d in cached_dicts]
# 从数据库查询
async with get_db_session() as session:
result = await session.execute(self._stmt)
instances = list(result.scalars().all())
# 预加载所有列以避免detached对象的lazy loading问题
for instance in instances:
for column in self.model.__table__.columns:
@@ -215,23 +212,23 @@ class QueryBuilder(Generic[T]):
getattr(instance, column.name)
except Exception:
pass
# 转换为字典列表并写入缓存
if self._use_cache:
instances_dicts = [_model_to_dict(inst) for inst in instances]
cache = await get_cache()
await cache.set(cache_key, instances_dicts)
return instances
async def first(self) -> Optional[T]:
async def first(self) -> T | None:
"""获取第一个结果
Returns:
模型实例或None
"""
cache_key = ":".join(self._cache_key_parts) + ":first"
# 尝试从缓存获取 (缓存的是字典)
if self._use_cache:
cache = await get_cache()
@@ -240,12 +237,12 @@ class QueryBuilder(Generic[T]):
logger.debug(f"缓存命中: {cache_key}")
# 从字典恢复对象
return _dict_to_model(self.model, cached_dict)
# 从数据库查询
async with get_db_session() as session:
result = await session.execute(self._stmt)
instance = result.scalars().first()
# 预加载所有列以避免detached对象的lazy loading问题
if instance is not None:
for column in self.model.__table__.columns:
@@ -253,23 +250,23 @@ class QueryBuilder(Generic[T]):
getattr(instance, column.name)
except Exception:
pass
# 转换为字典并写入缓存
if instance is not None and self._use_cache:
instance_dict = _model_to_dict(instance)
cache = await get_cache()
await cache.set(cache_key, instance_dict)
return instance
async def count(self) -> int:
"""统计数量
Returns:
记录数量
"""
cache_key = ":".join(self._cache_key_parts) + ":count"
# 尝试从缓存获取
if self._use_cache:
cache = await get_cache()
@@ -277,25 +274,25 @@ class QueryBuilder(Generic[T]):
if cached is not None:
logger.debug(f"缓存命中: {cache_key}")
return cached
# 构建count查询
count_stmt = select(func.count()).select_from(self._stmt.subquery())
# 从数据库查询
async with get_db_session() as session:
result = await session.execute(count_stmt)
count = result.scalar() or 0
# 写入缓存
if self._use_cache:
cache = await get_cache()
await cache.set(cache_key, count)
return count
async def exists(self) -> bool:
"""检查是否存在
Returns:
是否存在记录
"""
@@ -308,38 +305,38 @@ class QueryBuilder(Generic[T]):
page_size: int = 20,
) -> tuple[list[T], int]:
"""分页查询
Args:
page: 页码从1开始
page_size: 每页数量
Returns:
(结果列表, 总数量)
"""
# 计算偏移量
offset = (page - 1) * page_size
# 获取总数
total = await self.count()
# 获取当前页数据
self._stmt = self._stmt.offset(offset).limit(page_size)
self._cache_key_parts.append(f"page:{page}:{page_size}")
items = await self.all()
return items, total
class AggregateQuery:
"""聚合查询
提供聚合操作如sum、avg、max、min等
"""
def __init__(self, model: Type[T]):
def __init__(self, model: type[T]):
"""初始化聚合查询
Args:
model: SQLAlchemy模型类
"""
@@ -349,10 +346,10 @@ class AggregateQuery:
def filter(self, **conditions: Any) -> "AggregateQuery":
"""添加过滤条件
Args:
**conditions: 过滤条件
Returns:
self支持链式调用
"""
@@ -364,85 +361,85 @@ class AggregateQuery:
async def sum(self, field: str) -> float:
"""求和
Args:
field: 字段名
Returns:
总和
"""
if not hasattr(self.model, field):
raise ValueError(f"字段 {field} 不存在")
async with get_db_session() as session:
stmt = select(func.sum(getattr(self.model, field)))
if self._conditions:
stmt = stmt.where(and_(*self._conditions))
result = await session.execute(stmt)
return result.scalar() or 0
async def avg(self, field: str) -> float:
"""求平均值
Args:
field: 字段名
Returns:
平均值
"""
if not hasattr(self.model, field):
raise ValueError(f"字段 {field} 不存在")
async with get_db_session() as session:
stmt = select(func.avg(getattr(self.model, field)))
if self._conditions:
stmt = stmt.where(and_(*self._conditions))
result = await session.execute(stmt)
return result.scalar() or 0
async def max(self, field: str) -> Any:
"""求最大值
Args:
field: 字段名
Returns:
最大值
"""
if not hasattr(self.model, field):
raise ValueError(f"字段 {field} 不存在")
async with get_db_session() as session:
stmt = select(func.max(getattr(self.model, field)))
if self._conditions:
stmt = stmt.where(and_(*self._conditions))
result = await session.execute(stmt)
return result.scalar()
async def min(self, field: str) -> Any:
"""求最小值
Args:
field: 字段名
Returns:
最小值
"""
if not hasattr(self.model, field):
raise ValueError(f"字段 {field} 不存在")
async with get_db_session() as session:
stmt = select(func.min(getattr(self.model, field)))
if self._conditions:
stmt = stmt.where(and_(*self._conditions))
result = await session.execute(stmt)
return result.scalar()
@@ -451,31 +448,31 @@ class AggregateQuery:
*fields: str,
) -> list[tuple[Any, ...]]:
"""分组统计
Args:
*fields: 分组字段
Returns:
[(分组值1, 分组值2, ..., 数量), ...]
"""
if not fields:
raise ValueError("至少需要一个分组字段")
group_columns = []
for field_name in fields:
if hasattr(self.model, field_name):
group_columns.append(getattr(self.model, field_name))
if not group_columns:
return []
async with get_db_session() as session:
stmt = select(*group_columns, func.count(self.model.id))
if self._conditions:
stmt = stmt.where(and_(*self._conditions))
stmt = stmt.group_by(*group_columns)
result = await session.execute(stmt)
return [tuple(row) for row in result.all()]