feat(database): 优化消息查询和计数逻辑,增加安全限制以防内存暴涨
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
重构后的数据库模块,提供:
|
||||
- 核心层:引擎、会话、模型、迁移
|
||||
- 优化层:缓存、预加载、批处理
|
||||
- 优化层:缓存、批处理
|
||||
- API层:CRUD、查询构建器、业务API
|
||||
- Utils层:装饰器、监控
|
||||
- 兼容层:向后兼容的API
|
||||
@@ -51,11 +51,9 @@ from src.common.database.core import (
|
||||
# ===== 优化层 =====
|
||||
from src.common.database.optimization import (
|
||||
AdaptiveBatchScheduler,
|
||||
DataPreloader,
|
||||
MultiLevelCache,
|
||||
get_batch_scheduler,
|
||||
get_cache,
|
||||
get_preloader,
|
||||
)
|
||||
|
||||
# ===== Utils层 =====
|
||||
@@ -83,7 +81,6 @@ __all__ = [
|
||||
"Base",
|
||||
# API层 - 基础类
|
||||
"CRUDBase",
|
||||
"DataPreloader",
|
||||
# 优化层
|
||||
"MultiLevelCache",
|
||||
"QueryBuilder",
|
||||
@@ -103,7 +100,6 @@ __all__ = [
|
||||
"get_message_count",
|
||||
"get_monitor",
|
||||
"get_or_create_person",
|
||||
"get_preloader",
|
||||
"get_recent_actions",
|
||||
"get_session_factory",
|
||||
"get_usage_statistics",
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
提供通用的数据库CRUD操作,集成优化层功能:
|
||||
- 自动缓存:查询结果自动缓存
|
||||
- 批量处理:写操作自动批处理
|
||||
- 智能预加载:关联数据自动预加载
|
||||
"""
|
||||
|
||||
import operator
|
||||
@@ -19,7 +18,6 @@ from src.common.database.optimization import (
|
||||
Priority,
|
||||
get_batch_scheduler,
|
||||
get_cache,
|
||||
record_preload_access,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -144,16 +142,6 @@ class CRUDBase(Generic[T]):
|
||||
"""
|
||||
cache_key = f"{self.model_name}:id:{id}"
|
||||
|
||||
if use_cache:
|
||||
async def _preload_loader() -> dict[str, Any] | None:
|
||||
async with get_db_session() as session:
|
||||
stmt = select(self.model).where(self.model.id == id)
|
||||
result = await session.execute(stmt)
|
||||
instance = result.scalar_one_or_none()
|
||||
return _model_to_dict(instance) if instance is not None else None
|
||||
|
||||
await record_preload_access(cache_key, loader=_preload_loader)
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典)
|
||||
if use_cache:
|
||||
cache = await get_cache()
|
||||
@@ -198,21 +186,6 @@ class CRUDBase(Generic[T]):
|
||||
"""
|
||||
cache_key = f"{self.model_name}:filter:{sorted(filters.items())!s}"
|
||||
|
||||
filters_copy = dict(filters)
|
||||
if use_cache:
|
||||
async def _preload_loader() -> dict[str, Any] | None:
|
||||
async with get_db_session() as session:
|
||||
stmt = select(self.model)
|
||||
for key, value in filters_copy.items():
|
||||
if hasattr(self.model, key):
|
||||
stmt = stmt.where(getattr(self.model, key) == value)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
instance = result.scalar_one_or_none()
|
||||
return _model_to_dict(instance) if instance is not None else None
|
||||
|
||||
await record_preload_access(cache_key, loader=_preload_loader)
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典)
|
||||
if use_cache:
|
||||
cache = await get_cache()
|
||||
@@ -265,29 +238,6 @@ class CRUDBase(Generic[T]):
|
||||
"""
|
||||
cache_key = f"{self.model_name}:multi:{skip}:{limit}:{sorted(filters.items())!s}"
|
||||
|
||||
filters_copy = dict(filters)
|
||||
if use_cache:
|
||||
async def _preload_loader() -> list[dict[str, Any]]:
|
||||
async with get_db_session() as session:
|
||||
stmt = select(self.model)
|
||||
|
||||
# 应用过滤条件
|
||||
for key, value in filters_copy.items():
|
||||
if hasattr(self.model, key):
|
||||
if isinstance(value, list | tuple | set):
|
||||
stmt = stmt.where(getattr(self.model, key).in_(value))
|
||||
else:
|
||||
stmt = stmt.where(getattr(self.model, key) == value)
|
||||
|
||||
# 应用分页
|
||||
stmt = stmt.offset(skip).limit(limit)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
instances = list(result.scalars().all())
|
||||
return [_model_to_dict(inst) for inst in instances]
|
||||
|
||||
await record_preload_access(cache_key, loader=_preload_loader)
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典列表)
|
||||
if use_cache:
|
||||
cache = await get_cache()
|
||||
|
||||
@@ -16,7 +16,7 @@ from sqlalchemy import and_, asc, desc, func, or_, select
|
||||
# 导入 CRUD 辅助函数以避免重复定义
|
||||
from src.common.database.api.crud import _dict_to_model, _model_to_dict
|
||||
from src.common.database.core.session import get_db_session
|
||||
from src.common.database.optimization import get_cache, record_preload_access
|
||||
from src.common.database.optimization import get_cache
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("database.query")
|
||||
@@ -272,16 +272,6 @@ class QueryBuilder(Generic[T]):
|
||||
模型实例列表或字典列表
|
||||
"""
|
||||
cache_key = ":".join(self._cache_key_parts) + ":all"
|
||||
stmt = self._stmt
|
||||
|
||||
if self._use_cache:
|
||||
async def _preload_loader() -> list[dict[str, Any]]:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(stmt)
|
||||
instances = list(result.scalars().all())
|
||||
return [_model_to_dict(inst) for inst in instances]
|
||||
|
||||
await record_preload_access(cache_key, loader=_preload_loader)
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典列表)
|
||||
if self._use_cache:
|
||||
@@ -320,16 +310,6 @@ class QueryBuilder(Generic[T]):
|
||||
模型实例或None
|
||||
"""
|
||||
cache_key = ":".join(self._cache_key_parts) + ":first"
|
||||
stmt = self._stmt
|
||||
|
||||
if self._use_cache:
|
||||
async def _preload_loader() -> dict[str, Any] | None:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(stmt)
|
||||
instance = result.scalars().first()
|
||||
return _model_to_dict(instance) if instance is not None else None
|
||||
|
||||
await record_preload_access(cache_key, loader=_preload_loader)
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典)
|
||||
if self._use_cache:
|
||||
@@ -370,14 +350,6 @@ class QueryBuilder(Generic[T]):
|
||||
cache_key = ":".join(self._cache_key_parts) + ":count"
|
||||
count_stmt = select(func.count()).select_from(self._stmt.subquery())
|
||||
|
||||
if self._use_cache:
|
||||
async def _preload_loader() -> int:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(count_stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
await record_preload_access(cache_key, loader=_preload_loader)
|
||||
|
||||
# 尝试从缓存获取
|
||||
if self._use_cache:
|
||||
cache = await get_cache()
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
职责:
|
||||
- 批量调度
|
||||
- 多级缓存(内存缓存 + Redis缓存)
|
||||
- 数据预加载
|
||||
"""
|
||||
|
||||
from .batch_scheduler import (
|
||||
@@ -25,18 +24,9 @@ from .cache_manager import (
|
||||
get_cache,
|
||||
get_cache_backend_type,
|
||||
)
|
||||
from .preloader import (
|
||||
AccessPattern,
|
||||
CommonDataPreloader,
|
||||
DataPreloader,
|
||||
close_preloader,
|
||||
get_preloader,
|
||||
record_preload_access,
|
||||
)
|
||||
from .redis_cache import RedisCache, close_redis_cache, get_redis_cache
|
||||
|
||||
__all__ = [
|
||||
"AccessPattern",
|
||||
# Batch Scheduler
|
||||
"AdaptiveBatchScheduler",
|
||||
"BaseCacheStats",
|
||||
@@ -46,9 +36,6 @@ __all__ = [
|
||||
"CacheBackend",
|
||||
"CacheEntry",
|
||||
"CacheStats",
|
||||
"CommonDataPreloader",
|
||||
# Preloader
|
||||
"DataPreloader",
|
||||
"LRUCache",
|
||||
# Memory Cache
|
||||
"MultiLevelCache",
|
||||
@@ -57,12 +44,9 @@ __all__ = [
|
||||
"RedisCache",
|
||||
"close_batch_scheduler",
|
||||
"close_cache",
|
||||
"close_preloader",
|
||||
"close_redis_cache",
|
||||
"get_batch_scheduler",
|
||||
"get_cache",
|
||||
"get_cache_backend_type",
|
||||
"get_preloader",
|
||||
"record_preload_access",
|
||||
"get_redis_cache"
|
||||
]
|
||||
|
||||
@@ -64,10 +64,6 @@ class DatabaseMetrics:
|
||||
batch_items_total: int = 0
|
||||
batch_avg_size: float = 0.0
|
||||
|
||||
# 预加载统计
|
||||
preload_operations: int = 0
|
||||
preload_hits: int = 0
|
||||
|
||||
@property
|
||||
def cache_hit_rate(self) -> float:
|
||||
"""缓存命中率"""
|
||||
@@ -152,12 +148,6 @@ class DatabaseMonitor:
|
||||
self._metrics.batch_items_total / self._metrics.batch_operations
|
||||
)
|
||||
|
||||
def record_preload_operation(self, hit: bool = False):
|
||||
"""记录预加载操作"""
|
||||
self._metrics.preload_operations += 1
|
||||
if hit:
|
||||
self._metrics.preload_hits += 1
|
||||
|
||||
def get_metrics(self) -> DatabaseMetrics:
|
||||
"""获取指标"""
|
||||
return self._metrics
|
||||
@@ -196,15 +186,6 @@ class DatabaseMonitor:
|
||||
"total_items": metrics.batch_items_total,
|
||||
"avg_size": f"{metrics.batch_avg_size:.1f}",
|
||||
},
|
||||
"preload": {
|
||||
"operations": metrics.preload_operations,
|
||||
"hits": metrics.preload_hits,
|
||||
"hit_rate": (
|
||||
f"{metrics.preload_hits / metrics.preload_operations:.2%}"
|
||||
if metrics.preload_operations > 0
|
||||
else "N/A"
|
||||
),
|
||||
},
|
||||
"overall": {
|
||||
"error_rate": f"{metrics.error_rate:.2%}",
|
||||
},
|
||||
@@ -261,15 +242,6 @@ class DatabaseMonitor:
|
||||
f"平均大小={batch['avg_size']}"
|
||||
)
|
||||
|
||||
# 预加载统计
|
||||
logger.info("\n预加载:")
|
||||
preload = summary["preload"]
|
||||
logger.info(
|
||||
f" 操作={preload['operations']}, "
|
||||
f"命中={preload['hits']}, "
|
||||
f"命中率={preload['hit_rate']}"
|
||||
)
|
||||
|
||||
# 整体统计
|
||||
logger.info("\n整体:")
|
||||
overall = summary["overall"]
|
||||
|
||||
Reference in New Issue
Block a user