style: ruff自动格式化修复 - 修复180个空白行和格式问题

This commit is contained in:
Windpicker-owo
2025-11-01 17:06:40 +08:00
parent ece6a70c65
commit cabaf74194
4 changed files with 173 additions and 179 deletions

View File

@@ -6,11 +6,9 @@
- 智能预加载:关联数据自动预加载
"""
from typing import Any, Optional, Type, TypeVar
from typing import Any, TypeVar
from sqlalchemy import and_, delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.inspection import inspect
from sqlalchemy import delete, func, select, update
from src.common.database.core.models import Base
from src.common.database.core.session import get_db_session
@@ -19,7 +17,6 @@ from src.common.database.optimization import (
Priority,
get_batch_scheduler,
get_cache,
get_preloader,
)
from src.common.logger import get_logger
@@ -30,10 +27,10 @@ T = TypeVar("T", bound=Base)
def _model_to_dict(instance: Base) -> dict[str, Any]:
"""将 SQLAlchemy 模型实例转换为字典
Args:
instance: SQLAlchemy 模型实例
Returns:
字典表示,包含所有列的值
"""
@@ -47,13 +44,13 @@ def _model_to_dict(instance: Base) -> dict[str, Any]:
return result
def _dict_to_model(model_class: Type[T], data: dict[str, Any]) -> T:
def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T:
"""从字典创建 SQLAlchemy 模型实例 (detached状态)
Args:
model_class: SQLAlchemy 模型类
data: 字典数据
Returns:
模型实例 (detached, 所有字段已加载)
"""
@@ -66,13 +63,13 @@ def _dict_to_model(model_class: Type[T], data: dict[str, Any]) -> T:
class CRUDBase:
"""基础CRUD操作类
提供通用的增删改查操作,自动集成缓存和批处理
"""
def __init__(self, model: Type[T]):
def __init__(self, model: type[T]):
"""初始化CRUD操作
Args:
model: SQLAlchemy模型类
"""
@@ -83,18 +80,18 @@ class CRUDBase:
self,
id: int,
use_cache: bool = True,
) -> Optional[T]:
) -> T | None:
"""根据ID获取单条记录
Args:
id: 记录ID
use_cache: 是否使用缓存
Returns:
模型实例或None
"""
cache_key = f"{self.model_name}:id:{id}"
# 尝试从缓存获取 (缓存的是字典)
if use_cache:
cache = await get_cache()
@@ -103,13 +100,13 @@ class CRUDBase:
logger.debug(f"缓存命中: {cache_key}")
# 从字典恢复对象
return _dict_to_model(self.model, cached_dict)
# 从数据库查询
async with get_db_session() as session:
stmt = select(self.model).where(self.model.id == id)
result = await session.execute(stmt)
instance = result.scalar_one_or_none()
if instance is not None:
# 预加载所有字段
for column in self.model.__table__.columns:
@@ -117,31 +114,31 @@ class CRUDBase:
getattr(instance, column.name)
except Exception:
pass
# 转换为字典并写入缓存
if use_cache:
instance_dict = _model_to_dict(instance)
cache = await get_cache()
await cache.set(cache_key, instance_dict)
return instance
async def get_by(
self,
use_cache: bool = True,
**filters: Any,
) -> Optional[T]:
) -> T | None:
"""根据条件获取单条记录
Args:
use_cache: 是否使用缓存
**filters: 过滤条件
Returns:
模型实例或None
"""
cache_key = f"{self.model_name}:filter:{str(sorted(filters.items()))}"
cache_key = f"{self.model_name}:filter:{sorted(filters.items())!s}"
# 尝试从缓存获取 (缓存的是字典)
if use_cache:
cache = await get_cache()
@@ -150,17 +147,17 @@ class CRUDBase:
logger.debug(f"缓存命中: {cache_key}")
# 从字典恢复对象
return _dict_to_model(self.model, cached_dict)
# 从数据库查询
async with get_db_session() as session:
stmt = select(self.model)
for key, value in filters.items():
if hasattr(self.model, key):
stmt = stmt.where(getattr(self.model, key) == value)
result = await session.execute(stmt)
instance = result.scalar_one_or_none()
if instance is not None:
# 触发所有列的加载,避免 detached 后的延迟加载问题
# 遍历所有列属性以确保它们被加载到内存中
@@ -169,13 +166,13 @@ class CRUDBase:
getattr(instance, column.name)
except Exception:
pass # 忽略访问错误
# 转换为字典并写入缓存
if use_cache:
instance_dict = _model_to_dict(instance)
cache = await get_cache()
await cache.set(cache_key, instance_dict)
return instance
async def get_multi(
@@ -186,18 +183,18 @@ class CRUDBase:
**filters: Any,
) -> list[T]:
"""获取多条记录
Args:
skip: 跳过的记录数
limit: 返回的最大记录数
use_cache: 是否使用缓存
**filters: 过滤条件
Returns:
模型实例列表
"""
cache_key = f"{self.model_name}:multi:{skip}:{limit}:{str(sorted(filters.items()))}"
cache_key = f"{self.model_name}:multi:{skip}:{limit}:{sorted(filters.items())!s}"
# 尝试从缓存获取 (缓存的是字典列表)
if use_cache:
cache = await get_cache()
@@ -206,11 +203,11 @@ class CRUDBase:
logger.debug(f"缓存命中: {cache_key}")
# 从字典列表恢复对象列表
return [_dict_to_model(self.model, d) for d in cached_dicts]
# 从数据库查询
async with get_db_session() as session:
stmt = select(self.model)
# 应用过滤条件
for key, value in filters.items():
if hasattr(self.model, key):
@@ -218,13 +215,13 @@ class CRUDBase:
stmt = stmt.where(getattr(self.model, key).in_(value))
else:
stmt = stmt.where(getattr(self.model, key) == value)
# 应用分页
stmt = stmt.offset(skip).limit(limit)
result = await session.execute(stmt)
instances = list(result.scalars().all())
# 触发所有实例的列加载,避免 detached 后的延迟加载问题
for instance in instances:
for column in self.model.__table__.columns:
@@ -232,13 +229,13 @@ class CRUDBase:
getattr(instance, column.name)
except Exception:
pass # 忽略访问错误
# 转换为字典列表并写入缓存
if use_cache:
instances_dicts = [_model_to_dict(inst) for inst in instances]
cache = await get_cache()
await cache.set(cache_key, instances_dicts)
return instances
async def create(
@@ -247,11 +244,11 @@ class CRUDBase:
use_batch: bool = False,
) -> T:
"""创建新记录
Args:
obj_in: 创建数据
use_batch: 是否使用批处理
Returns:
创建的模型实例
"""
@@ -266,7 +263,7 @@ class CRUDBase:
)
future = await scheduler.add_operation(operation)
await future
# 批处理返回成功,创建实例
instance = self.model(**obj_in)
return instance
@@ -284,14 +281,14 @@ class CRUDBase:
id: int,
obj_in: dict[str, Any],
use_batch: bool = False,
) -> Optional[T]:
) -> T | None:
"""更新记录
Args:
id: 记录ID
obj_in: 更新数据
use_batch: 是否使用批处理
Returns:
更新后的模型实例或None
"""
@@ -299,7 +296,7 @@ class CRUDBase:
instance = await self.get(id, use_cache=False)
if instance is None:
return None
if use_batch:
# 使用批处理
scheduler = await get_batch_scheduler()
@@ -312,7 +309,7 @@ class CRUDBase:
)
future = await scheduler.add_operation(operation)
await future
# 更新实例属性
for key, value in obj_in.items():
if hasattr(instance, key):
@@ -324,7 +321,7 @@ class CRUDBase:
stmt = select(self.model).where(self.model.id == id)
result = await session.execute(stmt)
db_instance = result.scalar_one_or_none()
if db_instance:
for key, value in obj_in.items():
if hasattr(db_instance, key):
@@ -332,12 +329,12 @@ class CRUDBase:
await session.flush()
await session.refresh(db_instance)
instance = db_instance
# 清除缓存
cache_key = f"{self.model_name}:id:{id}"
cache = await get_cache()
await cache.delete(cache_key)
return instance
async def delete(
@@ -346,11 +343,11 @@ class CRUDBase:
use_batch: bool = False,
) -> bool:
"""删除记录
Args:
id: 记录ID
use_batch: 是否使用批处理
Returns:
是否成功删除
"""
@@ -372,13 +369,13 @@ class CRUDBase:
stmt = delete(self.model).where(self.model.id == id)
result = await session.execute(stmt)
success = result.rowcount > 0
# 清除缓存
if success:
cache_key = f"{self.model_name}:id:{id}"
cache = await get_cache()
await cache.delete(cache_key)
return success
async def count(
@@ -386,16 +383,16 @@ class CRUDBase:
**filters: Any,
) -> int:
"""统计记录数
Args:
**filters: 过滤条件
Returns:
记录数量
"""
async with get_db_session() as session:
stmt = select(func.count(self.model.id))
# 应用过滤条件
for key, value in filters.items():
if hasattr(self.model, key):
@@ -403,7 +400,7 @@ class CRUDBase:
stmt = stmt.where(getattr(self.model, key).in_(value))
else:
stmt = stmt.where(getattr(self.model, key) == value)
result = await session.execute(stmt)
return result.scalar()
@@ -412,10 +409,10 @@ class CRUDBase:
**filters: Any,
) -> bool:
"""检查记录是否存在
Args:
**filters: 过滤条件
Returns:
是否存在
"""
@@ -424,15 +421,15 @@ class CRUDBase:
async def get_or_create(
self,
defaults: Optional[dict[str, Any]] = None,
defaults: dict[str, Any] | None = None,
**filters: Any,
) -> tuple[T, bool]:
"""获取或创建记录
Args:
defaults: 创建时的默认值
**filters: 查找条件
Returns:
(实例, 是否新创建)
"""
@@ -440,12 +437,12 @@ class CRUDBase:
instance = await self.get_by(use_cache=False, **filters)
if instance is not None:
return instance, False
# 创建新记录
create_data = {**filters}
if defaults:
create_data.update(defaults)
instance = await self.create(create_data)
return instance, True
@@ -454,10 +451,10 @@ class CRUDBase:
objs_in: list[dict[str, Any]],
) -> list[T]:
"""批量创建记录
Args:
objs_in: 创建数据列表
Returns:
创建的模型实例列表
"""
@@ -465,10 +462,10 @@ class CRUDBase:
instances = [self.model(**obj_data) for obj_data in objs_in]
session.add_all(instances)
await session.flush()
for instance in instances:
await session.refresh(instance)
return instances
async def bulk_update(
@@ -476,10 +473,10 @@ class CRUDBase:
updates: list[tuple[int, dict[str, Any]]],
) -> int:
"""批量更新记录
Args:
updates: (id, update_data)元组列表
Returns:
更新的记录数
"""
@@ -493,10 +490,10 @@ class CRUDBase:
)
result = await session.execute(stmt)
count += result.rowcount
# 清除缓存
cache_key = f"{self.model_name}:id:{id}"
cache = await get_cache()
await cache.delete(cache_key)
return count