feat(database): 完成API层、Utils层和兼容层重构 (Stage 4-6)
Stage 4: API层重构
=================
新增文件:
- api/crud.py (430行): CRUDBase泛型类,提供12个CRUD方法
* get, get_by, get_multi, create, update, delete
* count, exists, get_or_create, bulk_create, bulk_update
* 集成缓存: 自动缓存读操作,写操作清除缓存
* 集成批处理: 可选use_batch参数透明使用AdaptiveBatchScheduler
- api/query.py (461行): 高级查询构建器
* QueryBuilder: 链式调用,MongoDB风格操作符
- 操作符: __gt, __lt, __gte, __lte, __ne, __in, __nin, __like, __isnull
- 方法: filter, filter_or, order_by, limit, offset, no_cache
- 执行: all, first, count, exists, paginate
* AggregateQuery: 聚合查询
- sum, avg, max, min, group_by_count
- api/specialized.py (461行): 业务特定API
* ActionRecords: store_action_info, get_recent_actions
* Messages: get_chat_history, get_message_count, save_message
* PersonInfo: get_or_create_person, update_person_affinity
* ChatStreams: get_or_create_chat_stream, get_active_streams
* LLMUsage: record_llm_usage, get_usage_statistics
* UserRelationships: get_user_relationship, update_relationship_affinity
- 更新api/__init__.py: 导出所有API接口
Stage 5: Utils层实现
===================
新增文件:
- utils/decorators.py (320行): 数据库操作装饰器
* @retry: 自动重试失败操作,指数退避
* @timeout: 超时控制
* @cached: 自动缓存函数结果
* @measure_time: 性能测量,慢查询日志
* @transactional: 事务管理,自动提交/回滚
* @db_operation: 组合装饰器
- utils/monitoring.py (330行): 性能监控系统
* DatabaseMonitor: 单例监控器
* OperationMetrics: 操作指标 (次数、时间、错误)
* DatabaseMetrics: 全局指标
- 连接池统计
- 缓存命中率
- 批处理统计
- 预加载统计
* 便捷函数: get_monitor, record_operation, print_stats
- 更新utils/__init__.py: 导出装饰器和监控函数
Stage 6: 兼容层实现
==================
新增目录: compatibility/
- adapter.py (370行): 向后兼容适配器
* 完全兼容旧API签名: db_query, db_save, db_get, store_action_info
* 支持MongoDB风格操作符 (\, \, \)
* 内部使用新架构 (QueryBuilder + CRUDBase)
* 保持返回dict格式不变
* MODEL_MAPPING: 25个模型映射
- __init__.py: 导出兼容API
更新database/__init__.py:
- 导出核心层 (engine, session, models, migration)
- 导出优化层 (cache, preloader, batch_scheduler)
- 导出API层 (CRUD, Query, 业务API)
- 导出Utils层 (装饰器, 监控)
- 导出兼容层 (db_query, db_save等)
核心特性
========
类型安全: Generic[T]提供完整类型推断
缓存透明: 自动缓存,用户无需关心
批处理透明: 可选批处理,自动优化高频写入
链式查询: 流畅的API设计
业务封装: 常用操作封装成便捷函数
向后兼容: 兼容层保证现有代码无缝迁移
性能监控: 完整的指标收集和报告
统计数据
========
- 新增文件: 7个
- 代码行数: ~2050行
- API函数: 14个业务API + 6个装饰器
- 兼容函数: 5个 (db_query, db_save, db_get等)
下一步
======
- 更新28个文件的import语句 (从sqlalchemy_database_api迁移)
- 移动旧文件到old/目录
- 编写Stage 4-6的测试
- 集成测试验证兼容性
This commit is contained in:
458
src/common/database/api/query.py
Normal file
458
src/common/database/api/query.py
Normal file
@@ -0,0 +1,458 @@
|
||||
"""高级查询API
|
||||
|
||||
提供复杂的查询操作:
|
||||
- MongoDB风格的查询操作符
|
||||
- 聚合查询
|
||||
- 排序和分页
|
||||
- 关联查询
|
||||
"""
|
||||
|
||||
from typing import Any, Generic, Optional, Sequence, Type, TypeVar
|
||||
|
||||
from sqlalchemy import and_, asc, desc, func, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.engine import Row
|
||||
|
||||
from src.common.database.core.models import Base
|
||||
from src.common.database.core.session import get_db_session
|
||||
from src.common.database.optimization import get_cache, get_preloader
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("database.query")
|
||||
|
||||
T = TypeVar("T", bound="Base")
|
||||
|
||||
|
||||
class QueryBuilder(Generic[T]):
|
||||
"""查询构建器
|
||||
|
||||
支持链式调用,构建复杂查询
|
||||
"""
|
||||
|
||||
def __init__(self, model: Type[T]):
|
||||
"""初始化查询构建器
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy模型类
|
||||
"""
|
||||
self.model = model
|
||||
self.model_name = model.__tablename__
|
||||
self._stmt = select(model)
|
||||
self._use_cache = True
|
||||
self._cache_key_parts: list[str] = [self.model_name]
|
||||
|
||||
def filter(self, **conditions: Any) -> "QueryBuilder":
|
||||
"""添加过滤条件
|
||||
|
||||
支持的操作符:
|
||||
- 直接相等: field=value
|
||||
- 大于: field__gt=value
|
||||
- 小于: field__lt=value
|
||||
- 大于等于: field__gte=value
|
||||
- 小于等于: field__lte=value
|
||||
- 不等于: field__ne=value
|
||||
- 包含: field__in=[values]
|
||||
- 不包含: field__nin=[values]
|
||||
- 模糊匹配: field__like='%pattern%'
|
||||
- 为空: field__isnull=True
|
||||
|
||||
Args:
|
||||
**conditions: 过滤条件
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
"""
|
||||
for key, value in conditions.items():
|
||||
# 解析字段和操作符
|
||||
if "__" in key:
|
||||
field_name, operator = key.rsplit("__", 1)
|
||||
else:
|
||||
field_name, operator = key, "eq"
|
||||
|
||||
if not hasattr(self.model, field_name):
|
||||
logger.warning(f"模型 {self.model_name} 没有字段 {field_name}")
|
||||
continue
|
||||
|
||||
field = getattr(self.model, field_name)
|
||||
|
||||
# 应用操作符
|
||||
if operator == "eq":
|
||||
self._stmt = self._stmt.where(field == value)
|
||||
elif operator == "gt":
|
||||
self._stmt = self._stmt.where(field > value)
|
||||
elif operator == "lt":
|
||||
self._stmt = self._stmt.where(field < value)
|
||||
elif operator == "gte":
|
||||
self._stmt = self._stmt.where(field >= value)
|
||||
elif operator == "lte":
|
||||
self._stmt = self._stmt.where(field <= value)
|
||||
elif operator == "ne":
|
||||
self._stmt = self._stmt.where(field != value)
|
||||
elif operator == "in":
|
||||
self._stmt = self._stmt.where(field.in_(value))
|
||||
elif operator == "nin":
|
||||
self._stmt = self._stmt.where(~field.in_(value))
|
||||
elif operator == "like":
|
||||
self._stmt = self._stmt.where(field.like(value))
|
||||
elif operator == "isnull":
|
||||
if value:
|
||||
self._stmt = self._stmt.where(field.is_(None))
|
||||
else:
|
||||
self._stmt = self._stmt.where(field.isnot(None))
|
||||
else:
|
||||
logger.warning(f"未知操作符: {operator}")
|
||||
|
||||
# 更新缓存键
|
||||
self._cache_key_parts.append(f"filter:{str(sorted(conditions.items()))}")
|
||||
return self
|
||||
|
||||
def filter_or(self, **conditions: Any) -> "QueryBuilder":
|
||||
"""添加OR过滤条件
|
||||
|
||||
Args:
|
||||
**conditions: OR条件
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
"""
|
||||
or_conditions = []
|
||||
for key, value in conditions.items():
|
||||
if hasattr(self.model, key):
|
||||
field = getattr(self.model, key)
|
||||
or_conditions.append(field == value)
|
||||
|
||||
if or_conditions:
|
||||
self._stmt = self._stmt.where(or_(*or_conditions))
|
||||
self._cache_key_parts.append(f"or:{str(sorted(conditions.items()))}")
|
||||
|
||||
return self
|
||||
|
||||
def order_by(self, *fields: str) -> "QueryBuilder":
|
||||
"""添加排序
|
||||
|
||||
Args:
|
||||
*fields: 排序字段,'-'前缀表示降序
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
"""
|
||||
for field_name in fields:
|
||||
if field_name.startswith("-"):
|
||||
field_name = field_name[1:]
|
||||
if hasattr(self.model, field_name):
|
||||
self._stmt = self._stmt.order_by(desc(getattr(self.model, field_name)))
|
||||
else:
|
||||
if hasattr(self.model, field_name):
|
||||
self._stmt = self._stmt.order_by(asc(getattr(self.model, field_name)))
|
||||
|
||||
self._cache_key_parts.append(f"order:{','.join(fields)}")
|
||||
return self
|
||||
|
||||
def limit(self, limit: int) -> "QueryBuilder":
|
||||
"""限制结果数量
|
||||
|
||||
Args:
|
||||
limit: 最大数量
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
"""
|
||||
self._stmt = self._stmt.limit(limit)
|
||||
self._cache_key_parts.append(f"limit:{limit}")
|
||||
return self
|
||||
|
||||
def offset(self, offset: int) -> "QueryBuilder":
|
||||
"""跳过指定数量
|
||||
|
||||
Args:
|
||||
offset: 跳过数量
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
"""
|
||||
self._stmt = self._stmt.offset(offset)
|
||||
self._cache_key_parts.append(f"offset:{offset}")
|
||||
return self
|
||||
|
||||
def no_cache(self) -> "QueryBuilder":
|
||||
"""禁用缓存
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
"""
|
||||
self._use_cache = False
|
||||
return self
|
||||
|
||||
async def all(self) -> list[T]:
|
||||
"""获取所有结果
|
||||
|
||||
Returns:
|
||||
模型实例列表
|
||||
"""
|
||||
cache_key = ":".join(self._cache_key_parts) + ":all"
|
||||
|
||||
# 尝试从缓存获取
|
||||
if self._use_cache:
|
||||
cache = await get_cache()
|
||||
cached = await cache.get(cache_key)
|
||||
if cached is not None:
|
||||
logger.debug(f"缓存命中: {cache_key}")
|
||||
return cached
|
||||
|
||||
# 从数据库查询
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(self._stmt)
|
||||
instances = list(result.scalars().all())
|
||||
|
||||
# 写入缓存
|
||||
if self._use_cache:
|
||||
cache = await get_cache()
|
||||
await cache.set(cache_key, instances)
|
||||
|
||||
return instances
|
||||
|
||||
async def first(self) -> Optional[T]:
|
||||
"""获取第一个结果
|
||||
|
||||
Returns:
|
||||
模型实例或None
|
||||
"""
|
||||
cache_key = ":".join(self._cache_key_parts) + ":first"
|
||||
|
||||
# 尝试从缓存获取
|
||||
if self._use_cache:
|
||||
cache = await get_cache()
|
||||
cached = await cache.get(cache_key)
|
||||
if cached is not None:
|
||||
logger.debug(f"缓存命中: {cache_key}")
|
||||
return cached
|
||||
|
||||
# 从数据库查询
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(self._stmt)
|
||||
instance = result.scalars().first()
|
||||
|
||||
# 写入缓存
|
||||
if instance is not None and self._use_cache:
|
||||
cache = await get_cache()
|
||||
await cache.set(cache_key, instance)
|
||||
|
||||
return instance
|
||||
|
||||
async def count(self) -> int:
|
||||
"""统计数量
|
||||
|
||||
Returns:
|
||||
记录数量
|
||||
"""
|
||||
cache_key = ":".join(self._cache_key_parts) + ":count"
|
||||
|
||||
# 尝试从缓存获取
|
||||
if self._use_cache:
|
||||
cache = await get_cache()
|
||||
cached = await cache.get(cache_key)
|
||||
if cached is not None:
|
||||
logger.debug(f"缓存命中: {cache_key}")
|
||||
return cached
|
||||
|
||||
# 构建count查询
|
||||
count_stmt = select(func.count()).select_from(self._stmt.subquery())
|
||||
|
||||
# 从数据库查询
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(count_stmt)
|
||||
count = result.scalar() or 0
|
||||
|
||||
# 写入缓存
|
||||
if self._use_cache:
|
||||
cache = await get_cache()
|
||||
await cache.set(cache_key, count)
|
||||
|
||||
return count
|
||||
|
||||
async def exists(self) -> bool:
|
||||
"""检查是否存在
|
||||
|
||||
Returns:
|
||||
是否存在记录
|
||||
"""
|
||||
count = await self.count()
|
||||
return count > 0
|
||||
|
||||
async def paginate(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> tuple[list[T], int]:
|
||||
"""分页查询
|
||||
|
||||
Args:
|
||||
page: 页码(从1开始)
|
||||
page_size: 每页数量
|
||||
|
||||
Returns:
|
||||
(结果列表, 总数量)
|
||||
"""
|
||||
# 计算偏移量
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 获取总数
|
||||
total = await self.count()
|
||||
|
||||
# 获取当前页数据
|
||||
self._stmt = self._stmt.offset(offset).limit(page_size)
|
||||
self._cache_key_parts.append(f"page:{page}:{page_size}")
|
||||
|
||||
items = await self.all()
|
||||
|
||||
return items, total
|
||||
|
||||
|
||||
class AggregateQuery:
|
||||
"""聚合查询
|
||||
|
||||
提供聚合操作如sum、avg、max、min等
|
||||
"""
|
||||
|
||||
def __init__(self, model: Type[T]):
|
||||
"""初始化聚合查询
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy模型类
|
||||
"""
|
||||
self.model = model
|
||||
self.model_name = model.__tablename__
|
||||
self._conditions = []
|
||||
|
||||
def filter(self, **conditions: Any) -> "AggregateQuery":
|
||||
"""添加过滤条件
|
||||
|
||||
Args:
|
||||
**conditions: 过滤条件
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
"""
|
||||
for key, value in conditions.items():
|
||||
if hasattr(self.model, key):
|
||||
field = getattr(self.model, key)
|
||||
self._conditions.append(field == value)
|
||||
return self
|
||||
|
||||
async def sum(self, field: str) -> float:
|
||||
"""求和
|
||||
|
||||
Args:
|
||||
field: 字段名
|
||||
|
||||
Returns:
|
||||
总和
|
||||
"""
|
||||
if not hasattr(self.model, field):
|
||||
raise ValueError(f"字段 {field} 不存在")
|
||||
|
||||
async with get_db_session() as session:
|
||||
stmt = select(func.sum(getattr(self.model, field)))
|
||||
|
||||
if self._conditions:
|
||||
stmt = stmt.where(and_(*self._conditions))
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def avg(self, field: str) -> float:
|
||||
"""求平均值
|
||||
|
||||
Args:
|
||||
field: 字段名
|
||||
|
||||
Returns:
|
||||
平均值
|
||||
"""
|
||||
if not hasattr(self.model, field):
|
||||
raise ValueError(f"字段 {field} 不存在")
|
||||
|
||||
async with get_db_session() as session:
|
||||
stmt = select(func.avg(getattr(self.model, field)))
|
||||
|
||||
if self._conditions:
|
||||
stmt = stmt.where(and_(*self._conditions))
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def max(self, field: str) -> Any:
|
||||
"""求最大值
|
||||
|
||||
Args:
|
||||
field: 字段名
|
||||
|
||||
Returns:
|
||||
最大值
|
||||
"""
|
||||
if not hasattr(self.model, field):
|
||||
raise ValueError(f"字段 {field} 不存在")
|
||||
|
||||
async with get_db_session() as session:
|
||||
stmt = select(func.max(getattr(self.model, field)))
|
||||
|
||||
if self._conditions:
|
||||
stmt = stmt.where(and_(*self._conditions))
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar()
|
||||
|
||||
async def min(self, field: str) -> Any:
|
||||
"""求最小值
|
||||
|
||||
Args:
|
||||
field: 字段名
|
||||
|
||||
Returns:
|
||||
最小值
|
||||
"""
|
||||
if not hasattr(self.model, field):
|
||||
raise ValueError(f"字段 {field} 不存在")
|
||||
|
||||
async with get_db_session() as session:
|
||||
stmt = select(func.min(getattr(self.model, field)))
|
||||
|
||||
if self._conditions:
|
||||
stmt = stmt.where(and_(*self._conditions))
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar()
|
||||
|
||||
async def group_by_count(
|
||||
self,
|
||||
*fields: str,
|
||||
) -> list[tuple[Any, ...]]:
|
||||
"""分组统计
|
||||
|
||||
Args:
|
||||
*fields: 分组字段
|
||||
|
||||
Returns:
|
||||
[(分组值1, 分组值2, ..., 数量), ...]
|
||||
"""
|
||||
if not fields:
|
||||
raise ValueError("至少需要一个分组字段")
|
||||
|
||||
group_columns = []
|
||||
for field_name in fields:
|
||||
if hasattr(self.model, field_name):
|
||||
group_columns.append(getattr(self.model, field_name))
|
||||
|
||||
if not group_columns:
|
||||
return []
|
||||
|
||||
async with get_db_session() as session:
|
||||
stmt = select(*group_columns, func.count(self.model.id))
|
||||
|
||||
if self._conditions:
|
||||
stmt = stmt.where(and_(*self._conditions))
|
||||
|
||||
stmt = stmt.group_by(*group_columns)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return [tuple(row) for row in result.all()]
|
||||
Reference in New Issue
Block a user