feat(database): 优化消息查询和计数逻辑,增加安全限制以防内存暴涨

This commit is contained in:
Windpicker-owo
2025-12-09 17:35:23 +08:00
parent 5d6c70d8ad
commit fa9b0b3d7e
8 changed files with 126 additions and 168 deletions

View File

@@ -2,7 +2,7 @@
重构后的数据库模块,提供:
- 核心层:引擎、会话、模型、迁移
- 优化层:缓存、预加载、批处理
- 优化层:缓存、批处理
- API层CRUD、查询构建器、业务API
- Utils层装饰器、监控
- 兼容层向后兼容的API
@@ -51,11 +51,9 @@ from src.common.database.core import (
# ===== 优化层 =====
from src.common.database.optimization import (
AdaptiveBatchScheduler,
DataPreloader,
MultiLevelCache,
get_batch_scheduler,
get_cache,
get_preloader,
)
# ===== Utils层 =====
@@ -83,7 +81,6 @@ __all__ = [
"Base",
# API层 - 基础类
"CRUDBase",
"DataPreloader",
# 优化层
"MultiLevelCache",
"QueryBuilder",
@@ -103,7 +100,6 @@ __all__ = [
"get_message_count",
"get_monitor",
"get_or_create_person",
"get_preloader",
"get_recent_actions",
"get_session_factory",
"get_usage_statistics",

View File

@@ -3,7 +3,6 @@
提供通用的数据库CRUD操作集成优化层功能
- 自动缓存:查询结果自动缓存
- 批量处理:写操作自动批处理
- 智能预加载:关联数据自动预加载
"""
import operator
@@ -19,7 +18,6 @@ from src.common.database.optimization import (
Priority,
get_batch_scheduler,
get_cache,
record_preload_access,
)
from src.common.logger import get_logger
@@ -144,16 +142,6 @@ class CRUDBase(Generic[T]):
"""
cache_key = f"{self.model_name}:id:{id}"
if use_cache:
async def _preload_loader() -> dict[str, Any] | None:
async with get_db_session() as session:
stmt = select(self.model).where(self.model.id == id)
result = await session.execute(stmt)
instance = result.scalar_one_or_none()
return _model_to_dict(instance) if instance is not None else None
await record_preload_access(cache_key, loader=_preload_loader)
# 尝试从缓存获取 (缓存的是字典)
if use_cache:
cache = await get_cache()
@@ -198,21 +186,6 @@ class CRUDBase(Generic[T]):
"""
cache_key = f"{self.model_name}:filter:{sorted(filters.items())!s}"
filters_copy = dict(filters)
if use_cache:
async def _preload_loader() -> dict[str, Any] | None:
async with get_db_session() as session:
stmt = select(self.model)
for key, value in filters_copy.items():
if hasattr(self.model, key):
stmt = stmt.where(getattr(self.model, key) == value)
result = await session.execute(stmt)
instance = result.scalar_one_or_none()
return _model_to_dict(instance) if instance is not None else None
await record_preload_access(cache_key, loader=_preload_loader)
# 尝试从缓存获取 (缓存的是字典)
if use_cache:
cache = await get_cache()
@@ -265,29 +238,6 @@ class CRUDBase(Generic[T]):
"""
cache_key = f"{self.model_name}:multi:{skip}:{limit}:{sorted(filters.items())!s}"
filters_copy = dict(filters)
if use_cache:
async def _preload_loader() -> list[dict[str, Any]]:
async with get_db_session() as session:
stmt = select(self.model)
# 应用过滤条件
for key, value in filters_copy.items():
if hasattr(self.model, key):
if isinstance(value, list | tuple | set):
stmt = stmt.where(getattr(self.model, key).in_(value))
else:
stmt = stmt.where(getattr(self.model, key) == value)
# 应用分页
stmt = stmt.offset(skip).limit(limit)
result = await session.execute(stmt)
instances = list(result.scalars().all())
return [_model_to_dict(inst) for inst in instances]
await record_preload_access(cache_key, loader=_preload_loader)
# 尝试从缓存获取 (缓存的是字典列表)
if use_cache:
cache = await get_cache()

View File

@@ -16,7 +16,7 @@ from sqlalchemy import and_, asc, desc, func, or_, select
# 导入 CRUD 辅助函数以避免重复定义
from src.common.database.api.crud import _dict_to_model, _model_to_dict
from src.common.database.core.session import get_db_session
from src.common.database.optimization import get_cache, record_preload_access
from src.common.database.optimization import get_cache
from src.common.logger import get_logger
logger = get_logger("database.query")
@@ -272,16 +272,6 @@ class QueryBuilder(Generic[T]):
模型实例列表或字典列表
"""
cache_key = ":".join(self._cache_key_parts) + ":all"
stmt = self._stmt
if self._use_cache:
async def _preload_loader() -> list[dict[str, Any]]:
async with get_db_session() as session:
result = await session.execute(stmt)
instances = list(result.scalars().all())
return [_model_to_dict(inst) for inst in instances]
await record_preload_access(cache_key, loader=_preload_loader)
# 尝试从缓存获取 (缓存的是字典列表)
if self._use_cache:
@@ -320,16 +310,6 @@ class QueryBuilder(Generic[T]):
模型实例或None
"""
cache_key = ":".join(self._cache_key_parts) + ":first"
stmt = self._stmt
if self._use_cache:
async def _preload_loader() -> dict[str, Any] | None:
async with get_db_session() as session:
result = await session.execute(stmt)
instance = result.scalars().first()
return _model_to_dict(instance) if instance is not None else None
await record_preload_access(cache_key, loader=_preload_loader)
# 尝试从缓存获取 (缓存的是字典)
if self._use_cache:
@@ -370,14 +350,6 @@ class QueryBuilder(Generic[T]):
cache_key = ":".join(self._cache_key_parts) + ":count"
count_stmt = select(func.count()).select_from(self._stmt.subquery())
if self._use_cache:
async def _preload_loader() -> int:
async with get_db_session() as session:
result = await session.execute(count_stmt)
return result.scalar() or 0
await record_preload_access(cache_key, loader=_preload_loader)
# 尝试从缓存获取
if self._use_cache:
cache = await get_cache()

View File

@@ -3,7 +3,6 @@
职责:
- 批量调度
- 多级缓存(内存缓存 + Redis缓存
- 数据预加载
"""
from .batch_scheduler import (
@@ -25,18 +24,9 @@ from .cache_manager import (
get_cache,
get_cache_backend_type,
)
from .preloader import (
AccessPattern,
CommonDataPreloader,
DataPreloader,
close_preloader,
get_preloader,
record_preload_access,
)
from .redis_cache import RedisCache, close_redis_cache, get_redis_cache
__all__ = [
"AccessPattern",
# Batch Scheduler
"AdaptiveBatchScheduler",
"BaseCacheStats",
@@ -46,9 +36,6 @@ __all__ = [
"CacheBackend",
"CacheEntry",
"CacheStats",
"CommonDataPreloader",
# Preloader
"DataPreloader",
"LRUCache",
# Memory Cache
"MultiLevelCache",
@@ -57,12 +44,9 @@ __all__ = [
"RedisCache",
"close_batch_scheduler",
"close_cache",
"close_preloader",
"close_redis_cache",
"get_batch_scheduler",
"get_cache",
"get_cache_backend_type",
"get_preloader",
"record_preload_access",
"get_redis_cache"
]

View File

@@ -64,10 +64,6 @@ class DatabaseMetrics:
batch_items_total: int = 0
batch_avg_size: float = 0.0
# 预加载统计
preload_operations: int = 0
preload_hits: int = 0
@property
def cache_hit_rate(self) -> float:
"""缓存命中率"""
@@ -152,12 +148,6 @@ class DatabaseMonitor:
self._metrics.batch_items_total / self._metrics.batch_operations
)
def record_preload_operation(self, hit: bool = False):
"""记录预加载操作"""
self._metrics.preload_operations += 1
if hit:
self._metrics.preload_hits += 1
def get_metrics(self) -> DatabaseMetrics:
"""获取指标"""
return self._metrics
@@ -196,15 +186,6 @@ class DatabaseMonitor:
"total_items": metrics.batch_items_total,
"avg_size": f"{metrics.batch_avg_size:.1f}",
},
"preload": {
"operations": metrics.preload_operations,
"hits": metrics.preload_hits,
"hit_rate": (
f"{metrics.preload_hits / metrics.preload_operations:.2%}"
if metrics.preload_operations > 0
else "N/A"
),
},
"overall": {
"error_rate": f"{metrics.error_rate:.2%}",
},
@@ -261,15 +242,6 @@ class DatabaseMonitor:
f"平均大小={batch['avg_size']}"
)
# 预加载统计
logger.info("\n预加载:")
preload = summary["preload"]
logger.info(
f" 操作={preload['operations']}, "
f"命中={preload['hits']}, "
f"命中率={preload['hit_rate']}"
)
# 整体统计
logger.info("\n整体:")
overall = summary["overall"]