feat(database): 完成API层、Utils层和兼容层重构 (Stage 4-6)

Stage 4: API层重构
=================
新增文件:
- api/crud.py (430行): CRUDBase泛型类,提供12个CRUD方法
  * get, get_by, get_multi, create, update, delete
  * count, exists, get_or_create, bulk_create, bulk_update
  * 集成缓存: 自动缓存读操作,写操作清除缓存
  * 集成批处理: 可选use_batch参数透明使用AdaptiveBatchScheduler

- api/query.py (461行): 高级查询构建器
  * QueryBuilder: 链式调用,MongoDB风格操作符
    - 操作符: __gt, __lt, __gte, __lte, __ne, __in, __nin, __like, __isnull
    - 方法: filter, filter_or, order_by, limit, offset, no_cache
    - 执行: all, first, count, exists, paginate
  * AggregateQuery: 聚合查询
    - sum, avg, max, min, group_by_count

- api/specialized.py (461行): 业务特定API
  * ActionRecords: store_action_info, get_recent_actions
  * Messages: get_chat_history, get_message_count, save_message
  * PersonInfo: get_or_create_person, update_person_affinity
  * ChatStreams: get_or_create_chat_stream, get_active_streams
  * LLMUsage: record_llm_usage, get_usage_statistics
  * UserRelationships: get_user_relationship, update_relationship_affinity

- 更新api/__init__.py: 导出所有API接口

Stage 5: Utils层实现
===================
新增文件:
- utils/decorators.py (320行): 数据库操作装饰器
  * @retry: 自动重试失败操作,指数退避
  * @timeout: 超时控制
  * @cached: 自动缓存函数结果
  * @measure_time: 性能测量,慢查询日志
  * @transactional: 事务管理,自动提交/回滚
  * @db_operation: 组合装饰器

- utils/monitoring.py (330行): 性能监控系统
  * DatabaseMonitor: 单例监控器
  * OperationMetrics: 操作指标 (次数、时间、错误)
  * DatabaseMetrics: 全局指标
    - 连接池统计
    - 缓存命中率
    - 批处理统计
    - 预加载统计
  * 便捷函数: get_monitor, record_operation, print_stats

- 更新utils/__init__.py: 导出装饰器和监控函数

Stage 6: 兼容层实现
==================
新增目录: compatibility/
- adapter.py (370行): 向后兼容适配器
  * 完全兼容旧API签名: db_query, db_save, db_get, store_action_info
  * 支持MongoDB风格操作符 (\, \, \)
  * 内部使用新架构 (QueryBuilder + CRUDBase)
  * 保持返回dict格式不变
  * MODEL_MAPPING: 25个模型映射

- __init__.py: 导出兼容API

更新database/__init__.py:
- 导出核心层 (engine, session, models, migration)
- 导出优化层 (cache, preloader, batch_scheduler)
- 导出API层 (CRUD, Query, 业务API)
- 导出Utils层 (装饰器, 监控)
- 导出兼容层 (db_query, db_save等)

核心特性
========
 类型安全: Generic[T]提供完整类型推断
 缓存透明: 自动缓存,用户无需关心
 批处理透明: 可选批处理,自动优化高频写入
 链式查询: 流畅的API设计
 业务封装: 常用操作封装成便捷函数
 向后兼容: 兼容层保证现有代码无缝迁移
 性能监控: 完整的指标收集和报告

统计数据
========
- 新增文件: 7个
- 代码行数: ~2050行
- API函数: 14个业务API + 6个装饰器
- 兼容函数: 5个 (db_query, db_save, db_get等)

下一步
======
- 更新28个文件的import语句 (从sqlalchemy_database_api迁移)
- 移动旧文件到old/目录
- 编写Stage 4-6的测试
- 集成测试验证兼容性
This commit is contained in:
Windpicker-owo
2025-11-01 13:27:33 +08:00
parent aae84ec454
commit 61de975d73
10 changed files with 2563 additions and 5 deletions

View File

@@ -0,0 +1,458 @@
"""高级查询API
提供复杂的查询操作:
- MongoDB风格的查询操作符
- 聚合查询
- 排序和分页
- 关联查询
"""
from typing import Any, Generic, Optional, Sequence, Type, TypeVar
from sqlalchemy import and_, asc, desc, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.engine import Row
from src.common.database.core.models import Base
from src.common.database.core.session import get_db_session
from src.common.database.optimization import get_cache, get_preloader
from src.common.logger import get_logger
logger = get_logger("database.query")
T = TypeVar("T", bound="Base")
class QueryBuilder(Generic[T]):
"""查询构建器
支持链式调用,构建复杂查询
"""
def __init__(self, model: Type[T]):
"""初始化查询构建器
Args:
model: SQLAlchemy模型类
"""
self.model = model
self.model_name = model.__tablename__
self._stmt = select(model)
self._use_cache = True
self._cache_key_parts: list[str] = [self.model_name]
def filter(self, **conditions: Any) -> "QueryBuilder":
"""添加过滤条件
支持的操作符:
- 直接相等: field=value
- 大于: field__gt=value
- 小于: field__lt=value
- 大于等于: field__gte=value
- 小于等于: field__lte=value
- 不等于: field__ne=value
- 包含: field__in=[values]
- 不包含: field__nin=[values]
- 模糊匹配: field__like='%pattern%'
- 为空: field__isnull=True
Args:
**conditions: 过滤条件
Returns:
self支持链式调用
"""
for key, value in conditions.items():
# 解析字段和操作符
if "__" in key:
field_name, operator = key.rsplit("__", 1)
else:
field_name, operator = key, "eq"
if not hasattr(self.model, field_name):
logger.warning(f"模型 {self.model_name} 没有字段 {field_name}")
continue
field = getattr(self.model, field_name)
# 应用操作符
if operator == "eq":
self._stmt = self._stmt.where(field == value)
elif operator == "gt":
self._stmt = self._stmt.where(field > value)
elif operator == "lt":
self._stmt = self._stmt.where(field < value)
elif operator == "gte":
self._stmt = self._stmt.where(field >= value)
elif operator == "lte":
self._stmt = self._stmt.where(field <= value)
elif operator == "ne":
self._stmt = self._stmt.where(field != value)
elif operator == "in":
self._stmt = self._stmt.where(field.in_(value))
elif operator == "nin":
self._stmt = self._stmt.where(~field.in_(value))
elif operator == "like":
self._stmt = self._stmt.where(field.like(value))
elif operator == "isnull":
if value:
self._stmt = self._stmt.where(field.is_(None))
else:
self._stmt = self._stmt.where(field.isnot(None))
else:
logger.warning(f"未知操作符: {operator}")
# 更新缓存键
self._cache_key_parts.append(f"filter:{str(sorted(conditions.items()))}")
return self
def filter_or(self, **conditions: Any) -> "QueryBuilder":
"""添加OR过滤条件
Args:
**conditions: OR条件
Returns:
self支持链式调用
"""
or_conditions = []
for key, value in conditions.items():
if hasattr(self.model, key):
field = getattr(self.model, key)
or_conditions.append(field == value)
if or_conditions:
self._stmt = self._stmt.where(or_(*or_conditions))
self._cache_key_parts.append(f"or:{str(sorted(conditions.items()))}")
return self
def order_by(self, *fields: str) -> "QueryBuilder":
"""添加排序
Args:
*fields: 排序字段,'-'前缀表示降序
Returns:
self支持链式调用
"""
for field_name in fields:
if field_name.startswith("-"):
field_name = field_name[1:]
if hasattr(self.model, field_name):
self._stmt = self._stmt.order_by(desc(getattr(self.model, field_name)))
else:
if hasattr(self.model, field_name):
self._stmt = self._stmt.order_by(asc(getattr(self.model, field_name)))
self._cache_key_parts.append(f"order:{','.join(fields)}")
return self
def limit(self, limit: int) -> "QueryBuilder":
"""限制结果数量
Args:
limit: 最大数量
Returns:
self支持链式调用
"""
self._stmt = self._stmt.limit(limit)
self._cache_key_parts.append(f"limit:{limit}")
return self
def offset(self, offset: int) -> "QueryBuilder":
"""跳过指定数量
Args:
offset: 跳过数量
Returns:
self支持链式调用
"""
self._stmt = self._stmt.offset(offset)
self._cache_key_parts.append(f"offset:{offset}")
return self
def no_cache(self) -> "QueryBuilder":
"""禁用缓存
Returns:
self支持链式调用
"""
self._use_cache = False
return self
async def all(self) -> list[T]:
"""获取所有结果
Returns:
模型实例列表
"""
cache_key = ":".join(self._cache_key_parts) + ":all"
# 尝试从缓存获取
if self._use_cache:
cache = await get_cache()
cached = await cache.get(cache_key)
if cached is not None:
logger.debug(f"缓存命中: {cache_key}")
return cached
# 从数据库查询
async with get_db_session() as session:
result = await session.execute(self._stmt)
instances = list(result.scalars().all())
# 写入缓存
if self._use_cache:
cache = await get_cache()
await cache.set(cache_key, instances)
return instances
async def first(self) -> Optional[T]:
"""获取第一个结果
Returns:
模型实例或None
"""
cache_key = ":".join(self._cache_key_parts) + ":first"
# 尝试从缓存获取
if self._use_cache:
cache = await get_cache()
cached = await cache.get(cache_key)
if cached is not None:
logger.debug(f"缓存命中: {cache_key}")
return cached
# 从数据库查询
async with get_db_session() as session:
result = await session.execute(self._stmt)
instance = result.scalars().first()
# 写入缓存
if instance is not None and self._use_cache:
cache = await get_cache()
await cache.set(cache_key, instance)
return instance
async def count(self) -> int:
"""统计数量
Returns:
记录数量
"""
cache_key = ":".join(self._cache_key_parts) + ":count"
# 尝试从缓存获取
if self._use_cache:
cache = await get_cache()
cached = await cache.get(cache_key)
if cached is not None:
logger.debug(f"缓存命中: {cache_key}")
return cached
# 构建count查询
count_stmt = select(func.count()).select_from(self._stmt.subquery())
# 从数据库查询
async with get_db_session() as session:
result = await session.execute(count_stmt)
count = result.scalar() or 0
# 写入缓存
if self._use_cache:
cache = await get_cache()
await cache.set(cache_key, count)
return count
async def exists(self) -> bool:
"""检查是否存在
Returns:
是否存在记录
"""
count = await self.count()
return count > 0
async def paginate(
self,
page: int = 1,
page_size: int = 20,
) -> tuple[list[T], int]:
"""分页查询
Args:
page: 页码从1开始
page_size: 每页数量
Returns:
(结果列表, 总数量)
"""
# 计算偏移量
offset = (page - 1) * page_size
# 获取总数
total = await self.count()
# 获取当前页数据
self._stmt = self._stmt.offset(offset).limit(page_size)
self._cache_key_parts.append(f"page:{page}:{page_size}")
items = await self.all()
return items, total
class AggregateQuery:
"""聚合查询
提供聚合操作如sum、avg、max、min等
"""
def __init__(self, model: Type[T]):
"""初始化聚合查询
Args:
model: SQLAlchemy模型类
"""
self.model = model
self.model_name = model.__tablename__
self._conditions = []
def filter(self, **conditions: Any) -> "AggregateQuery":
"""添加过滤条件
Args:
**conditions: 过滤条件
Returns:
self支持链式调用
"""
for key, value in conditions.items():
if hasattr(self.model, key):
field = getattr(self.model, key)
self._conditions.append(field == value)
return self
async def sum(self, field: str) -> float:
"""求和
Args:
field: 字段名
Returns:
总和
"""
if not hasattr(self.model, field):
raise ValueError(f"字段 {field} 不存在")
async with get_db_session() as session:
stmt = select(func.sum(getattr(self.model, field)))
if self._conditions:
stmt = stmt.where(and_(*self._conditions))
result = await session.execute(stmt)
return result.scalar() or 0
async def avg(self, field: str) -> float:
"""求平均值
Args:
field: 字段名
Returns:
平均值
"""
if not hasattr(self.model, field):
raise ValueError(f"字段 {field} 不存在")
async with get_db_session() as session:
stmt = select(func.avg(getattr(self.model, field)))
if self._conditions:
stmt = stmt.where(and_(*self._conditions))
result = await session.execute(stmt)
return result.scalar() or 0
async def max(self, field: str) -> Any:
"""求最大值
Args:
field: 字段名
Returns:
最大值
"""
if not hasattr(self.model, field):
raise ValueError(f"字段 {field} 不存在")
async with get_db_session() as session:
stmt = select(func.max(getattr(self.model, field)))
if self._conditions:
stmt = stmt.where(and_(*self._conditions))
result = await session.execute(stmt)
return result.scalar()
async def min(self, field: str) -> Any:
"""求最小值
Args:
field: 字段名
Returns:
最小值
"""
if not hasattr(self.model, field):
raise ValueError(f"字段 {field} 不存在")
async with get_db_session() as session:
stmt = select(func.min(getattr(self.model, field)))
if self._conditions:
stmt = stmt.where(and_(*self._conditions))
result = await session.execute(stmt)
return result.scalar()
async def group_by_count(
self,
*fields: str,
) -> list[tuple[Any, ...]]:
"""分组统计
Args:
*fields: 分组字段
Returns:
[(分组值1, 分组值2, ..., 数量), ...]
"""
if not fields:
raise ValueError("至少需要一个分组字段")
group_columns = []
for field_name in fields:
if hasattr(self.model, field_name):
group_columns.append(getattr(self.model, field_name))
if not group_columns:
return []
async with get_db_session() as session:
stmt = select(*group_columns, func.count(self.model.id))
if self._conditions:
stmt = stmt.where(and_(*self._conditions))
stmt = stmt.group_by(*group_columns)
result = await session.execute(stmt)
return [tuple(row) for row in result.all()]