style: ruff自动格式化修复 - 修复180个空白行和格式问题
This commit is contained in:
@@ -6,11 +6,9 @@
|
||||
- 智能预加载:关联数据自动预加载
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, Type, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from sqlalchemy import and_, delete, func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy import delete, func, select, update
|
||||
|
||||
from src.common.database.core.models import Base
|
||||
from src.common.database.core.session import get_db_session
|
||||
@@ -19,7 +17,6 @@ from src.common.database.optimization import (
|
||||
Priority,
|
||||
get_batch_scheduler,
|
||||
get_cache,
|
||||
get_preloader,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -30,10 +27,10 @@ T = TypeVar("T", bound=Base)
|
||||
|
||||
def _model_to_dict(instance: Base) -> dict[str, Any]:
|
||||
"""将 SQLAlchemy 模型实例转换为字典
|
||||
|
||||
|
||||
Args:
|
||||
instance: SQLAlchemy 模型实例
|
||||
|
||||
|
||||
Returns:
|
||||
字典表示,包含所有列的值
|
||||
"""
|
||||
@@ -47,13 +44,13 @@ def _model_to_dict(instance: Base) -> dict[str, Any]:
|
||||
return result
|
||||
|
||||
|
||||
def _dict_to_model(model_class: Type[T], data: dict[str, Any]) -> T:
|
||||
def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T:
|
||||
"""从字典创建 SQLAlchemy 模型实例 (detached状态)
|
||||
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy 模型类
|
||||
data: 字典数据
|
||||
|
||||
|
||||
Returns:
|
||||
模型实例 (detached, 所有字段已加载)
|
||||
"""
|
||||
@@ -66,13 +63,13 @@ def _dict_to_model(model_class: Type[T], data: dict[str, Any]) -> T:
|
||||
|
||||
class CRUDBase:
|
||||
"""基础CRUD操作类
|
||||
|
||||
|
||||
提供通用的增删改查操作,自动集成缓存和批处理
|
||||
"""
|
||||
|
||||
def __init__(self, model: Type[T]):
|
||||
def __init__(self, model: type[T]):
|
||||
"""初始化CRUD操作
|
||||
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy模型类
|
||||
"""
|
||||
@@ -83,18 +80,18 @@ class CRUDBase:
|
||||
self,
|
||||
id: int,
|
||||
use_cache: bool = True,
|
||||
) -> Optional[T]:
|
||||
) -> T | None:
|
||||
"""根据ID获取单条记录
|
||||
|
||||
|
||||
Args:
|
||||
id: 记录ID
|
||||
use_cache: 是否使用缓存
|
||||
|
||||
|
||||
Returns:
|
||||
模型实例或None
|
||||
"""
|
||||
cache_key = f"{self.model_name}:id:{id}"
|
||||
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典)
|
||||
if use_cache:
|
||||
cache = await get_cache()
|
||||
@@ -103,13 +100,13 @@ class CRUDBase:
|
||||
logger.debug(f"缓存命中: {cache_key}")
|
||||
# 从字典恢复对象
|
||||
return _dict_to_model(self.model, cached_dict)
|
||||
|
||||
|
||||
# 从数据库查询
|
||||
async with get_db_session() as session:
|
||||
stmt = select(self.model).where(self.model.id == id)
|
||||
result = await session.execute(stmt)
|
||||
instance = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if instance is not None:
|
||||
# 预加载所有字段
|
||||
for column in self.model.__table__.columns:
|
||||
@@ -117,31 +114,31 @@ class CRUDBase:
|
||||
getattr(instance, column.name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# 转换为字典并写入缓存
|
||||
if use_cache:
|
||||
instance_dict = _model_to_dict(instance)
|
||||
cache = await get_cache()
|
||||
await cache.set(cache_key, instance_dict)
|
||||
|
||||
|
||||
return instance
|
||||
|
||||
async def get_by(
|
||||
self,
|
||||
use_cache: bool = True,
|
||||
**filters: Any,
|
||||
) -> Optional[T]:
|
||||
) -> T | None:
|
||||
"""根据条件获取单条记录
|
||||
|
||||
|
||||
Args:
|
||||
use_cache: 是否使用缓存
|
||||
**filters: 过滤条件
|
||||
|
||||
|
||||
Returns:
|
||||
模型实例或None
|
||||
"""
|
||||
cache_key = f"{self.model_name}:filter:{str(sorted(filters.items()))}"
|
||||
|
||||
cache_key = f"{self.model_name}:filter:{sorted(filters.items())!s}"
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典)
|
||||
if use_cache:
|
||||
cache = await get_cache()
|
||||
@@ -150,17 +147,17 @@ class CRUDBase:
|
||||
logger.debug(f"缓存命中: {cache_key}")
|
||||
# 从字典恢复对象
|
||||
return _dict_to_model(self.model, cached_dict)
|
||||
|
||||
|
||||
# 从数据库查询
|
||||
async with get_db_session() as session:
|
||||
stmt = select(self.model)
|
||||
for key, value in filters.items():
|
||||
if hasattr(self.model, key):
|
||||
stmt = stmt.where(getattr(self.model, key) == value)
|
||||
|
||||
|
||||
result = await session.execute(stmt)
|
||||
instance = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if instance is not None:
|
||||
# 触发所有列的加载,避免 detached 后的延迟加载问题
|
||||
# 遍历所有列属性以确保它们被加载到内存中
|
||||
@@ -169,13 +166,13 @@ class CRUDBase:
|
||||
getattr(instance, column.name)
|
||||
except Exception:
|
||||
pass # 忽略访问错误
|
||||
|
||||
|
||||
# 转换为字典并写入缓存
|
||||
if use_cache:
|
||||
instance_dict = _model_to_dict(instance)
|
||||
cache = await get_cache()
|
||||
await cache.set(cache_key, instance_dict)
|
||||
|
||||
|
||||
return instance
|
||||
|
||||
async def get_multi(
|
||||
@@ -186,18 +183,18 @@ class CRUDBase:
|
||||
**filters: Any,
|
||||
) -> list[T]:
|
||||
"""获取多条记录
|
||||
|
||||
|
||||
Args:
|
||||
skip: 跳过的记录数
|
||||
limit: 返回的最大记录数
|
||||
use_cache: 是否使用缓存
|
||||
**filters: 过滤条件
|
||||
|
||||
|
||||
Returns:
|
||||
模型实例列表
|
||||
"""
|
||||
cache_key = f"{self.model_name}:multi:{skip}:{limit}:{str(sorted(filters.items()))}"
|
||||
|
||||
cache_key = f"{self.model_name}:multi:{skip}:{limit}:{sorted(filters.items())!s}"
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典列表)
|
||||
if use_cache:
|
||||
cache = await get_cache()
|
||||
@@ -206,11 +203,11 @@ class CRUDBase:
|
||||
logger.debug(f"缓存命中: {cache_key}")
|
||||
# 从字典列表恢复对象列表
|
||||
return [_dict_to_model(self.model, d) for d in cached_dicts]
|
||||
|
||||
|
||||
# 从数据库查询
|
||||
async with get_db_session() as session:
|
||||
stmt = select(self.model)
|
||||
|
||||
|
||||
# 应用过滤条件
|
||||
for key, value in filters.items():
|
||||
if hasattr(self.model, key):
|
||||
@@ -218,13 +215,13 @@ class CRUDBase:
|
||||
stmt = stmt.where(getattr(self.model, key).in_(value))
|
||||
else:
|
||||
stmt = stmt.where(getattr(self.model, key) == value)
|
||||
|
||||
|
||||
# 应用分页
|
||||
stmt = stmt.offset(skip).limit(limit)
|
||||
|
||||
|
||||
result = await session.execute(stmt)
|
||||
instances = list(result.scalars().all())
|
||||
|
||||
|
||||
# 触发所有实例的列加载,避免 detached 后的延迟加载问题
|
||||
for instance in instances:
|
||||
for column in self.model.__table__.columns:
|
||||
@@ -232,13 +229,13 @@ class CRUDBase:
|
||||
getattr(instance, column.name)
|
||||
except Exception:
|
||||
pass # 忽略访问错误
|
||||
|
||||
|
||||
# 转换为字典列表并写入缓存
|
||||
if use_cache:
|
||||
instances_dicts = [_model_to_dict(inst) for inst in instances]
|
||||
cache = await get_cache()
|
||||
await cache.set(cache_key, instances_dicts)
|
||||
|
||||
|
||||
return instances
|
||||
|
||||
async def create(
|
||||
@@ -247,11 +244,11 @@ class CRUDBase:
|
||||
use_batch: bool = False,
|
||||
) -> T:
|
||||
"""创建新记录
|
||||
|
||||
|
||||
Args:
|
||||
obj_in: 创建数据
|
||||
use_batch: 是否使用批处理
|
||||
|
||||
|
||||
Returns:
|
||||
创建的模型实例
|
||||
"""
|
||||
@@ -266,7 +263,7 @@ class CRUDBase:
|
||||
)
|
||||
future = await scheduler.add_operation(operation)
|
||||
await future
|
||||
|
||||
|
||||
# 批处理返回成功,创建实例
|
||||
instance = self.model(**obj_in)
|
||||
return instance
|
||||
@@ -284,14 +281,14 @@ class CRUDBase:
|
||||
id: int,
|
||||
obj_in: dict[str, Any],
|
||||
use_batch: bool = False,
|
||||
) -> Optional[T]:
|
||||
) -> T | None:
|
||||
"""更新记录
|
||||
|
||||
|
||||
Args:
|
||||
id: 记录ID
|
||||
obj_in: 更新数据
|
||||
use_batch: 是否使用批处理
|
||||
|
||||
|
||||
Returns:
|
||||
更新后的模型实例或None
|
||||
"""
|
||||
@@ -299,7 +296,7 @@ class CRUDBase:
|
||||
instance = await self.get(id, use_cache=False)
|
||||
if instance is None:
|
||||
return None
|
||||
|
||||
|
||||
if use_batch:
|
||||
# 使用批处理
|
||||
scheduler = await get_batch_scheduler()
|
||||
@@ -312,7 +309,7 @@ class CRUDBase:
|
||||
)
|
||||
future = await scheduler.add_operation(operation)
|
||||
await future
|
||||
|
||||
|
||||
# 更新实例属性
|
||||
for key, value in obj_in.items():
|
||||
if hasattr(instance, key):
|
||||
@@ -324,7 +321,7 @@ class CRUDBase:
|
||||
stmt = select(self.model).where(self.model.id == id)
|
||||
result = await session.execute(stmt)
|
||||
db_instance = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if db_instance:
|
||||
for key, value in obj_in.items():
|
||||
if hasattr(db_instance, key):
|
||||
@@ -332,12 +329,12 @@ class CRUDBase:
|
||||
await session.flush()
|
||||
await session.refresh(db_instance)
|
||||
instance = db_instance
|
||||
|
||||
|
||||
# 清除缓存
|
||||
cache_key = f"{self.model_name}:id:{id}"
|
||||
cache = await get_cache()
|
||||
await cache.delete(cache_key)
|
||||
|
||||
|
||||
return instance
|
||||
|
||||
async def delete(
|
||||
@@ -346,11 +343,11 @@ class CRUDBase:
|
||||
use_batch: bool = False,
|
||||
) -> bool:
|
||||
"""删除记录
|
||||
|
||||
|
||||
Args:
|
||||
id: 记录ID
|
||||
use_batch: 是否使用批处理
|
||||
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
@@ -372,13 +369,13 @@ class CRUDBase:
|
||||
stmt = delete(self.model).where(self.model.id == id)
|
||||
result = await session.execute(stmt)
|
||||
success = result.rowcount > 0
|
||||
|
||||
|
||||
# 清除缓存
|
||||
if success:
|
||||
cache_key = f"{self.model_name}:id:{id}"
|
||||
cache = await get_cache()
|
||||
await cache.delete(cache_key)
|
||||
|
||||
|
||||
return success
|
||||
|
||||
async def count(
|
||||
@@ -386,16 +383,16 @@ class CRUDBase:
|
||||
**filters: Any,
|
||||
) -> int:
|
||||
"""统计记录数
|
||||
|
||||
|
||||
Args:
|
||||
**filters: 过滤条件
|
||||
|
||||
|
||||
Returns:
|
||||
记录数量
|
||||
"""
|
||||
async with get_db_session() as session:
|
||||
stmt = select(func.count(self.model.id))
|
||||
|
||||
|
||||
# 应用过滤条件
|
||||
for key, value in filters.items():
|
||||
if hasattr(self.model, key):
|
||||
@@ -403,7 +400,7 @@ class CRUDBase:
|
||||
stmt = stmt.where(getattr(self.model, key).in_(value))
|
||||
else:
|
||||
stmt = stmt.where(getattr(self.model, key) == value)
|
||||
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar()
|
||||
|
||||
@@ -412,10 +409,10 @@ class CRUDBase:
|
||||
**filters: Any,
|
||||
) -> bool:
|
||||
"""检查记录是否存在
|
||||
|
||||
|
||||
Args:
|
||||
**filters: 过滤条件
|
||||
|
||||
|
||||
Returns:
|
||||
是否存在
|
||||
"""
|
||||
@@ -424,15 +421,15 @@ class CRUDBase:
|
||||
|
||||
async def get_or_create(
|
||||
self,
|
||||
defaults: Optional[dict[str, Any]] = None,
|
||||
defaults: dict[str, Any] | None = None,
|
||||
**filters: Any,
|
||||
) -> tuple[T, bool]:
|
||||
"""获取或创建记录
|
||||
|
||||
|
||||
Args:
|
||||
defaults: 创建时的默认值
|
||||
**filters: 查找条件
|
||||
|
||||
|
||||
Returns:
|
||||
(实例, 是否新创建)
|
||||
"""
|
||||
@@ -440,12 +437,12 @@ class CRUDBase:
|
||||
instance = await self.get_by(use_cache=False, **filters)
|
||||
if instance is not None:
|
||||
return instance, False
|
||||
|
||||
|
||||
# 创建新记录
|
||||
create_data = {**filters}
|
||||
if defaults:
|
||||
create_data.update(defaults)
|
||||
|
||||
|
||||
instance = await self.create(create_data)
|
||||
return instance, True
|
||||
|
||||
@@ -454,10 +451,10 @@ class CRUDBase:
|
||||
objs_in: list[dict[str, Any]],
|
||||
) -> list[T]:
|
||||
"""批量创建记录
|
||||
|
||||
|
||||
Args:
|
||||
objs_in: 创建数据列表
|
||||
|
||||
|
||||
Returns:
|
||||
创建的模型实例列表
|
||||
"""
|
||||
@@ -465,10 +462,10 @@ class CRUDBase:
|
||||
instances = [self.model(**obj_data) for obj_data in objs_in]
|
||||
session.add_all(instances)
|
||||
await session.flush()
|
||||
|
||||
|
||||
for instance in instances:
|
||||
await session.refresh(instance)
|
||||
|
||||
|
||||
return instances
|
||||
|
||||
async def bulk_update(
|
||||
@@ -476,10 +473,10 @@ class CRUDBase:
|
||||
updates: list[tuple[int, dict[str, Any]]],
|
||||
) -> int:
|
||||
"""批量更新记录
|
||||
|
||||
|
||||
Args:
|
||||
updates: (id, update_data)元组列表
|
||||
|
||||
|
||||
Returns:
|
||||
更新的记录数
|
||||
"""
|
||||
@@ -493,10 +490,10 @@ class CRUDBase:
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
count += result.rowcount
|
||||
|
||||
|
||||
# 清除缓存
|
||||
cache_key = f"{self.model_name}:id:{id}"
|
||||
cache = await get_cache()
|
||||
await cache.delete(cache_key)
|
||||
|
||||
|
||||
return count
|
||||
|
||||
@@ -7,19 +7,16 @@
|
||||
- 关联查询
|
||||
"""
|
||||
|
||||
from typing import Any, Generic, Optional, Sequence, Type, TypeVar
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from sqlalchemy import and_, asc, desc, func, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.engine import Row
|
||||
|
||||
from src.common.database.core.models import Base
|
||||
from src.common.database.core.session import get_db_session
|
||||
from src.common.database.optimization import get_cache, get_preloader
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 导入 CRUD 辅助函数以避免重复定义
|
||||
from src.common.database.api.crud import _dict_to_model, _model_to_dict
|
||||
from src.common.database.core.models import Base
|
||||
from src.common.database.core.session import get_db_session
|
||||
from src.common.database.optimization import get_cache
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("database.query")
|
||||
|
||||
@@ -28,13 +25,13 @@ T = TypeVar("T", bound="Base")
|
||||
|
||||
class QueryBuilder(Generic[T]):
|
||||
"""查询构建器
|
||||
|
||||
|
||||
支持链式调用,构建复杂查询
|
||||
"""
|
||||
|
||||
def __init__(self, model: Type[T]):
|
||||
def __init__(self, model: type[T]):
|
||||
"""初始化查询构建器
|
||||
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy模型类
|
||||
"""
|
||||
@@ -46,7 +43,7 @@ class QueryBuilder(Generic[T]):
|
||||
|
||||
def filter(self, **conditions: Any) -> "QueryBuilder":
|
||||
"""添加过滤条件
|
||||
|
||||
|
||||
支持的操作符:
|
||||
- 直接相等: field=value
|
||||
- 大于: field__gt=value
|
||||
@@ -58,10 +55,10 @@ class QueryBuilder(Generic[T]):
|
||||
- 不包含: field__nin=[values]
|
||||
- 模糊匹配: field__like='%pattern%'
|
||||
- 为空: field__isnull=True
|
||||
|
||||
|
||||
Args:
|
||||
**conditions: 过滤条件
|
||||
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
"""
|
||||
@@ -71,13 +68,13 @@ class QueryBuilder(Generic[T]):
|
||||
field_name, operator = key.rsplit("__", 1)
|
||||
else:
|
||||
field_name, operator = key, "eq"
|
||||
|
||||
|
||||
if not hasattr(self.model, field_name):
|
||||
logger.warning(f"模型 {self.model_name} 没有字段 {field_name}")
|
||||
continue
|
||||
|
||||
|
||||
field = getattr(self.model, field_name)
|
||||
|
||||
|
||||
# 应用操作符
|
||||
if operator == "eq":
|
||||
self._stmt = self._stmt.where(field == value)
|
||||
@@ -104,17 +101,17 @@ class QueryBuilder(Generic[T]):
|
||||
self._stmt = self._stmt.where(field.isnot(None))
|
||||
else:
|
||||
logger.warning(f"未知操作符: {operator}")
|
||||
|
||||
|
||||
# 更新缓存键
|
||||
self._cache_key_parts.append(f"filter:{str(sorted(conditions.items()))}")
|
||||
self._cache_key_parts.append(f"filter:{sorted(conditions.items())!s}")
|
||||
return self
|
||||
|
||||
def filter_or(self, **conditions: Any) -> "QueryBuilder":
|
||||
"""添加OR过滤条件
|
||||
|
||||
|
||||
Args:
|
||||
**conditions: OR条件
|
||||
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
"""
|
||||
@@ -123,19 +120,19 @@ class QueryBuilder(Generic[T]):
|
||||
if hasattr(self.model, key):
|
||||
field = getattr(self.model, key)
|
||||
or_conditions.append(field == value)
|
||||
|
||||
|
||||
if or_conditions:
|
||||
self._stmt = self._stmt.where(or_(*or_conditions))
|
||||
self._cache_key_parts.append(f"or:{str(sorted(conditions.items()))}")
|
||||
|
||||
self._cache_key_parts.append(f"or:{sorted(conditions.items())!s}")
|
||||
|
||||
return self
|
||||
|
||||
def order_by(self, *fields: str) -> "QueryBuilder":
|
||||
"""添加排序
|
||||
|
||||
|
||||
Args:
|
||||
*fields: 排序字段,'-'前缀表示降序
|
||||
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
"""
|
||||
@@ -147,16 +144,16 @@ class QueryBuilder(Generic[T]):
|
||||
else:
|
||||
if hasattr(self.model, field_name):
|
||||
self._stmt = self._stmt.order_by(asc(getattr(self.model, field_name)))
|
||||
|
||||
|
||||
self._cache_key_parts.append(f"order:{','.join(fields)}")
|
||||
return self
|
||||
|
||||
def limit(self, limit: int) -> "QueryBuilder":
|
||||
"""限制结果数量
|
||||
|
||||
|
||||
Args:
|
||||
limit: 最大数量
|
||||
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
"""
|
||||
@@ -166,10 +163,10 @@ class QueryBuilder(Generic[T]):
|
||||
|
||||
def offset(self, offset: int) -> "QueryBuilder":
|
||||
"""跳过指定数量
|
||||
|
||||
|
||||
Args:
|
||||
offset: 跳过数量
|
||||
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
"""
|
||||
@@ -179,7 +176,7 @@ class QueryBuilder(Generic[T]):
|
||||
|
||||
def no_cache(self) -> "QueryBuilder":
|
||||
"""禁用缓存
|
||||
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
"""
|
||||
@@ -188,12 +185,12 @@ class QueryBuilder(Generic[T]):
|
||||
|
||||
async def all(self) -> list[T]:
|
||||
"""获取所有结果
|
||||
|
||||
|
||||
Returns:
|
||||
模型实例列表
|
||||
"""
|
||||
cache_key = ":".join(self._cache_key_parts) + ":all"
|
||||
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典列表)
|
||||
if self._use_cache:
|
||||
cache = await get_cache()
|
||||
@@ -202,12 +199,12 @@ class QueryBuilder(Generic[T]):
|
||||
logger.debug(f"缓存命中: {cache_key}")
|
||||
# 从字典列表恢复对象列表
|
||||
return [_dict_to_model(self.model, d) for d in cached_dicts]
|
||||
|
||||
|
||||
# 从数据库查询
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(self._stmt)
|
||||
instances = list(result.scalars().all())
|
||||
|
||||
|
||||
# 预加载所有列以避免detached对象的lazy loading问题
|
||||
for instance in instances:
|
||||
for column in self.model.__table__.columns:
|
||||
@@ -215,23 +212,23 @@ class QueryBuilder(Generic[T]):
|
||||
getattr(instance, column.name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# 转换为字典列表并写入缓存
|
||||
if self._use_cache:
|
||||
instances_dicts = [_model_to_dict(inst) for inst in instances]
|
||||
cache = await get_cache()
|
||||
await cache.set(cache_key, instances_dicts)
|
||||
|
||||
|
||||
return instances
|
||||
|
||||
async def first(self) -> Optional[T]:
|
||||
async def first(self) -> T | None:
|
||||
"""获取第一个结果
|
||||
|
||||
|
||||
Returns:
|
||||
模型实例或None
|
||||
"""
|
||||
cache_key = ":".join(self._cache_key_parts) + ":first"
|
||||
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典)
|
||||
if self._use_cache:
|
||||
cache = await get_cache()
|
||||
@@ -240,12 +237,12 @@ class QueryBuilder(Generic[T]):
|
||||
logger.debug(f"缓存命中: {cache_key}")
|
||||
# 从字典恢复对象
|
||||
return _dict_to_model(self.model, cached_dict)
|
||||
|
||||
|
||||
# 从数据库查询
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(self._stmt)
|
||||
instance = result.scalars().first()
|
||||
|
||||
|
||||
# 预加载所有列以避免detached对象的lazy loading问题
|
||||
if instance is not None:
|
||||
for column in self.model.__table__.columns:
|
||||
@@ -253,23 +250,23 @@ class QueryBuilder(Generic[T]):
|
||||
getattr(instance, column.name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# 转换为字典并写入缓存
|
||||
if instance is not None and self._use_cache:
|
||||
instance_dict = _model_to_dict(instance)
|
||||
cache = await get_cache()
|
||||
await cache.set(cache_key, instance_dict)
|
||||
|
||||
|
||||
return instance
|
||||
|
||||
async def count(self) -> int:
|
||||
"""统计数量
|
||||
|
||||
|
||||
Returns:
|
||||
记录数量
|
||||
"""
|
||||
cache_key = ":".join(self._cache_key_parts) + ":count"
|
||||
|
||||
|
||||
# 尝试从缓存获取
|
||||
if self._use_cache:
|
||||
cache = await get_cache()
|
||||
@@ -277,25 +274,25 @@ class QueryBuilder(Generic[T]):
|
||||
if cached is not None:
|
||||
logger.debug(f"缓存命中: {cache_key}")
|
||||
return cached
|
||||
|
||||
|
||||
# 构建count查询
|
||||
count_stmt = select(func.count()).select_from(self._stmt.subquery())
|
||||
|
||||
|
||||
# 从数据库查询
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(count_stmt)
|
||||
count = result.scalar() or 0
|
||||
|
||||
|
||||
# 写入缓存
|
||||
if self._use_cache:
|
||||
cache = await get_cache()
|
||||
await cache.set(cache_key, count)
|
||||
|
||||
|
||||
return count
|
||||
|
||||
async def exists(self) -> bool:
|
||||
"""检查是否存在
|
||||
|
||||
|
||||
Returns:
|
||||
是否存在记录
|
||||
"""
|
||||
@@ -308,38 +305,38 @@ class QueryBuilder(Generic[T]):
|
||||
page_size: int = 20,
|
||||
) -> tuple[list[T], int]:
|
||||
"""分页查询
|
||||
|
||||
|
||||
Args:
|
||||
page: 页码(从1开始)
|
||||
page_size: 每页数量
|
||||
|
||||
|
||||
Returns:
|
||||
(结果列表, 总数量)
|
||||
"""
|
||||
# 计算偏移量
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
|
||||
# 获取总数
|
||||
total = await self.count()
|
||||
|
||||
|
||||
# 获取当前页数据
|
||||
self._stmt = self._stmt.offset(offset).limit(page_size)
|
||||
self._cache_key_parts.append(f"page:{page}:{page_size}")
|
||||
|
||||
|
||||
items = await self.all()
|
||||
|
||||
|
||||
return items, total
|
||||
|
||||
|
||||
class AggregateQuery:
|
||||
"""聚合查询
|
||||
|
||||
|
||||
提供聚合操作如sum、avg、max、min等
|
||||
"""
|
||||
|
||||
def __init__(self, model: Type[T]):
|
||||
def __init__(self, model: type[T]):
|
||||
"""初始化聚合查询
|
||||
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy模型类
|
||||
"""
|
||||
@@ -349,10 +346,10 @@ class AggregateQuery:
|
||||
|
||||
def filter(self, **conditions: Any) -> "AggregateQuery":
|
||||
"""添加过滤条件
|
||||
|
||||
|
||||
Args:
|
||||
**conditions: 过滤条件
|
||||
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
"""
|
||||
@@ -364,85 +361,85 @@ class AggregateQuery:
|
||||
|
||||
async def sum(self, field: str) -> float:
|
||||
"""求和
|
||||
|
||||
|
||||
Args:
|
||||
field: 字段名
|
||||
|
||||
|
||||
Returns:
|
||||
总和
|
||||
"""
|
||||
if not hasattr(self.model, field):
|
||||
raise ValueError(f"字段 {field} 不存在")
|
||||
|
||||
|
||||
async with get_db_session() as session:
|
||||
stmt = select(func.sum(getattr(self.model, field)))
|
||||
|
||||
|
||||
if self._conditions:
|
||||
stmt = stmt.where(and_(*self._conditions))
|
||||
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def avg(self, field: str) -> float:
|
||||
"""求平均值
|
||||
|
||||
|
||||
Args:
|
||||
field: 字段名
|
||||
|
||||
|
||||
Returns:
|
||||
平均值
|
||||
"""
|
||||
if not hasattr(self.model, field):
|
||||
raise ValueError(f"字段 {field} 不存在")
|
||||
|
||||
|
||||
async with get_db_session() as session:
|
||||
stmt = select(func.avg(getattr(self.model, field)))
|
||||
|
||||
|
||||
if self._conditions:
|
||||
stmt = stmt.where(and_(*self._conditions))
|
||||
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def max(self, field: str) -> Any:
|
||||
"""求最大值
|
||||
|
||||
|
||||
Args:
|
||||
field: 字段名
|
||||
|
||||
|
||||
Returns:
|
||||
最大值
|
||||
"""
|
||||
if not hasattr(self.model, field):
|
||||
raise ValueError(f"字段 {field} 不存在")
|
||||
|
||||
|
||||
async with get_db_session() as session:
|
||||
stmt = select(func.max(getattr(self.model, field)))
|
||||
|
||||
|
||||
if self._conditions:
|
||||
stmt = stmt.where(and_(*self._conditions))
|
||||
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar()
|
||||
|
||||
async def min(self, field: str) -> Any:
|
||||
"""求最小值
|
||||
|
||||
|
||||
Args:
|
||||
field: 字段名
|
||||
|
||||
|
||||
Returns:
|
||||
最小值
|
||||
"""
|
||||
if not hasattr(self.model, field):
|
||||
raise ValueError(f"字段 {field} 不存在")
|
||||
|
||||
|
||||
async with get_db_session() as session:
|
||||
stmt = select(func.min(getattr(self.model, field)))
|
||||
|
||||
|
||||
if self._conditions:
|
||||
stmt = stmt.where(and_(*self._conditions))
|
||||
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar()
|
||||
|
||||
@@ -451,31 +448,31 @@ class AggregateQuery:
|
||||
*fields: str,
|
||||
) -> list[tuple[Any, ...]]:
|
||||
"""分组统计
|
||||
|
||||
|
||||
Args:
|
||||
*fields: 分组字段
|
||||
|
||||
|
||||
Returns:
|
||||
[(分组值1, 分组值2, ..., 数量), ...]
|
||||
"""
|
||||
if not fields:
|
||||
raise ValueError("至少需要一个分组字段")
|
||||
|
||||
|
||||
group_columns = []
|
||||
for field_name in fields:
|
||||
if hasattr(self.model, field_name):
|
||||
group_columns.append(getattr(self.model, field_name))
|
||||
|
||||
|
||||
if not group_columns:
|
||||
return []
|
||||
|
||||
|
||||
async with get_db_session() as session:
|
||||
stmt = select(*group_columns, func.count(self.model.id))
|
||||
|
||||
|
||||
if self._conditions:
|
||||
stmt = stmt.where(and_(*self._conditions))
|
||||
|
||||
|
||||
stmt = stmt.group_by(*group_columns)
|
||||
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return [tuple(row) for row in result.all()]
|
||||
|
||||
@@ -9,9 +9,9 @@ import orjson
|
||||
from json_repair import repair_json
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import PersonInfo
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.utils.decorators import cached
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
@@ -637,18 +637,18 @@ class PersonInfoManager:
|
||||
crud = CRUDBase(PersonInfo)
|
||||
record = await crud.get_by(person_id=p_id)
|
||||
query_time = time.time()
|
||||
|
||||
|
||||
if record:
|
||||
# 更新记录
|
||||
await crud.update(record.id, {f_name: val_to_set})
|
||||
save_time = time.time()
|
||||
total_time = save_time - start_time
|
||||
|
||||
|
||||
if total_time > 0.5:
|
||||
logger.warning(
|
||||
f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) person_id={p_id}, field={f_name}"
|
||||
)
|
||||
|
||||
|
||||
# 使缓存失效
|
||||
from src.common.database.optimization.cache_manager import get_cache
|
||||
from src.common.database.utils.decorators import generate_cache_key
|
||||
@@ -657,7 +657,7 @@ class PersonInfoManager:
|
||||
await cache.delete(generate_cache_key("person_value", p_id, f_name))
|
||||
await cache.delete(generate_cache_key("person_values", p_id))
|
||||
await cache.delete(generate_cache_key("person_has_field", p_id, f_name))
|
||||
|
||||
|
||||
return True, False
|
||||
else:
|
||||
total_time = time.time() - start_time
|
||||
@@ -669,7 +669,7 @@ class PersonInfoManager:
|
||||
logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}")
|
||||
raise
|
||||
|
||||
found, needs_creation = await _db_update_async(person_id, field_name, processed_value)
|
||||
_found, needs_creation = await _db_update_async(person_id, field_name, processed_value)
|
||||
|
||||
if needs_creation:
|
||||
logger.info(f"{person_id} 不存在,将新建。")
|
||||
@@ -872,7 +872,7 @@ class PersonInfoManager:
|
||||
record = await crud.get_by(person_id=p_id)
|
||||
if record:
|
||||
await crud.delete(record.id)
|
||||
|
||||
|
||||
# 注意: 删除操作很少发生,缓存会在TTL过期后自动清除
|
||||
# 无法从person_id反向得到platform和user_id,因此无法精确清除缓存
|
||||
# 删除后的查询仍会返回正确结果(None/False)
|
||||
@@ -992,7 +992,7 @@ class PersonInfoManager:
|
||||
try:
|
||||
value = getattr(record, f_name, None)
|
||||
if value is not None and way(value):
|
||||
person_id_value = getattr(record, 'person_id', None)
|
||||
person_id_value = getattr(record, "person_id", None)
|
||||
if person_id_value:
|
||||
found_results[person_id_value] = value
|
||||
except Exception as e:
|
||||
@@ -1024,7 +1024,7 @@ class PersonInfoManager:
|
||||
"""原子性的获取或创建操作"""
|
||||
# 使用CRUD进行获取或创建
|
||||
crud = CRUDBase(PersonInfo)
|
||||
|
||||
|
||||
# 首先尝试获取现有记录
|
||||
record = await crud.get_by(person_id=p_id)
|
||||
if record:
|
||||
@@ -1070,7 +1070,7 @@ class PersonInfoManager:
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
|
||||
|
||||
record, was_created = await _db_get_or_create_async(person_id, filtered_initial_data)
|
||||
_record, was_created = await _db_get_or_create_async(person_id, filtered_initial_data)
|
||||
|
||||
if was_created:
|
||||
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。")
|
||||
|
||||
@@ -186,7 +186,7 @@ class RelationshipFetcher:
|
||||
# 查询用户关系数据
|
||||
user_id = str(await person_info_manager.get_value(person_id, "user_id"))
|
||||
platform = str(await person_info_manager.get_value(person_id, "platform"))
|
||||
|
||||
|
||||
# 使用优化后的API(带缓存)
|
||||
relationship = await get_user_relationship(
|
||||
platform=platform,
|
||||
@@ -261,7 +261,7 @@ class RelationshipFetcher:
|
||||
# 使用优化后的API(带缓存)
|
||||
# 从stream_id解析platform,或使用默认值
|
||||
platform = stream_id.split("_")[0] if "_" in stream_id else "unknown"
|
||||
|
||||
|
||||
stream, _ = await get_or_create_chat_stream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
|
||||
Reference in New Issue
Block a user