启用数据库预加载器,清理日志
This commit is contained in:
@@ -21,6 +21,7 @@ from src.common.database.optimization import (
|
||||
Priority,
|
||||
get_batch_scheduler,
|
||||
get_cache,
|
||||
record_preload_access,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -145,6 +146,16 @@ class CRUDBase(Generic[T]):
|
||||
"""
|
||||
cache_key = f"{self.model_name}:id:{id}"
|
||||
|
||||
if use_cache:
|
||||
async def _preload_loader() -> dict[str, Any] | None:
|
||||
async with get_db_session() as session:
|
||||
stmt = select(self.model).where(self.model.id == id)
|
||||
result = await session.execute(stmt)
|
||||
instance = result.scalar_one_or_none()
|
||||
return _model_to_dict(instance) if instance is not None else None
|
||||
|
||||
await record_preload_access(cache_key, loader=_preload_loader)
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典)
|
||||
if use_cache:
|
||||
cache = await get_cache()
|
||||
@@ -189,6 +200,21 @@ class CRUDBase(Generic[T]):
|
||||
"""
|
||||
cache_key = f"{self.model_name}:filter:{sorted(filters.items())!s}"
|
||||
|
||||
filters_copy = dict(filters)
|
||||
if use_cache:
|
||||
async def _preload_loader() -> dict[str, Any] | None:
|
||||
async with get_db_session() as session:
|
||||
stmt = select(self.model)
|
||||
for key, value in filters_copy.items():
|
||||
if hasattr(self.model, key):
|
||||
stmt = stmt.where(getattr(self.model, key) == value)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
instance = result.scalar_one_or_none()
|
||||
return _model_to_dict(instance) if instance is not None else None
|
||||
|
||||
await record_preload_access(cache_key, loader=_preload_loader)
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典)
|
||||
if use_cache:
|
||||
cache = await get_cache()
|
||||
@@ -241,6 +267,29 @@ class CRUDBase(Generic[T]):
|
||||
"""
|
||||
cache_key = f"{self.model_name}:multi:{skip}:{limit}:{sorted(filters.items())!s}"
|
||||
|
||||
filters_copy = dict(filters)
|
||||
if use_cache:
|
||||
async def _preload_loader() -> list[dict[str, Any]]:
|
||||
async with get_db_session() as session:
|
||||
stmt = select(self.model)
|
||||
|
||||
# 应用过滤条件
|
||||
for key, value in filters_copy.items():
|
||||
if hasattr(self.model, key):
|
||||
if isinstance(value, list | tuple | set):
|
||||
stmt = stmt.where(getattr(self.model, key).in_(value))
|
||||
else:
|
||||
stmt = stmt.where(getattr(self.model, key) == value)
|
||||
|
||||
# 应用分页
|
||||
stmt = stmt.offset(skip).limit(limit)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
instances = list(result.scalars().all())
|
||||
return [_model_to_dict(inst) for inst in instances]
|
||||
|
||||
await record_preload_access(cache_key, loader=_preload_loader)
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典列表)
|
||||
if use_cache:
|
||||
cache = await get_cache()
|
||||
|
||||
@@ -17,7 +17,7 @@ from sqlalchemy import and_, asc, desc, func, or_, select
|
||||
from src.common.database.api.crud import _dict_to_model, _model_to_dict
|
||||
from src.common.database.core.models import Base
|
||||
from src.common.database.core.session import get_db_session
|
||||
from src.common.database.optimization import get_cache
|
||||
from src.common.database.optimization import get_cache, record_preload_access
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("database.query")
|
||||
@@ -273,6 +273,16 @@ class QueryBuilder(Generic[T]):
|
||||
模型实例列表或字典列表
|
||||
"""
|
||||
cache_key = ":".join(self._cache_key_parts) + ":all"
|
||||
stmt = self._stmt
|
||||
|
||||
if self._use_cache:
|
||||
async def _preload_loader() -> list[dict[str, Any]]:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(stmt)
|
||||
instances = list(result.scalars().all())
|
||||
return [_model_to_dict(inst) for inst in instances]
|
||||
|
||||
await record_preload_access(cache_key, loader=_preload_loader)
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典列表)
|
||||
if self._use_cache:
|
||||
@@ -311,6 +321,16 @@ class QueryBuilder(Generic[T]):
|
||||
模型实例或None
|
||||
"""
|
||||
cache_key = ":".join(self._cache_key_parts) + ":first"
|
||||
stmt = self._stmt
|
||||
|
||||
if self._use_cache:
|
||||
async def _preload_loader() -> dict[str, Any] | None:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(stmt)
|
||||
instance = result.scalars().first()
|
||||
return _model_to_dict(instance) if instance is not None else None
|
||||
|
||||
await record_preload_access(cache_key, loader=_preload_loader)
|
||||
|
||||
# 尝试从缓存获取 (缓存的是字典)
|
||||
if self._use_cache:
|
||||
@@ -349,6 +369,15 @@ class QueryBuilder(Generic[T]):
|
||||
记录数量
|
||||
"""
|
||||
cache_key = ":".join(self._cache_key_parts) + ":count"
|
||||
count_stmt = select(func.count()).select_from(self._stmt.subquery())
|
||||
|
||||
if self._use_cache:
|
||||
async def _preload_loader() -> int:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(count_stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
await record_preload_access(cache_key, loader=_preload_loader)
|
||||
|
||||
# 尝试从缓存获取
|
||||
if self._use_cache:
|
||||
@@ -358,8 +387,6 @@ class QueryBuilder(Generic[T]):
|
||||
return cached
|
||||
|
||||
# 构建count查询
|
||||
count_stmt = select(func.count()).select_from(self._stmt.subquery())
|
||||
|
||||
# 从数据库查询
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(count_stmt)
|
||||
|
||||
@@ -79,7 +79,7 @@ async def get_engine() -> AsyncEngine:
|
||||
elif db_type == "postgresql":
|
||||
await _enable_postgresql_optimizations(_engine)
|
||||
|
||||
logger.info(f"✅ {db_type.upper()} 数据库引擎初始化成功")
|
||||
logger.info(f"{db_type.upper()} 数据库引擎初始化成功")
|
||||
return _engine
|
||||
|
||||
except Exception as e:
|
||||
@@ -116,7 +116,7 @@ def _build_sqlite_config(config) -> tuple[str, dict]:
|
||||
},
|
||||
}
|
||||
|
||||
logger.info(f"SQLite配置: {db_path}")
|
||||
logger.debug(f"SQLite配置: {db_path}")
|
||||
return url, engine_kwargs
|
||||
|
||||
|
||||
@@ -167,7 +167,7 @@ def _build_postgresql_config(config) -> tuple[str, dict]:
|
||||
if connect_args:
|
||||
engine_kwargs["connect_args"] = connect_args
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"PostgreSQL配置: {config.postgresql_user}@{config.postgresql_host}:{config.postgresql_port}/{config.postgresql_database}"
|
||||
)
|
||||
return url, engine_kwargs
|
||||
@@ -184,7 +184,7 @@ async def close_engine():
|
||||
logger.info("正在关闭数据库引擎...")
|
||||
await _engine.dispose()
|
||||
_engine = None
|
||||
logger.info("✅ 数据库引擎已关闭")
|
||||
logger.info("数据库引擎已关闭")
|
||||
|
||||
|
||||
async def _enable_sqlite_optimizations(engine: AsyncEngine):
|
||||
@@ -214,8 +214,6 @@ async def _enable_sqlite_optimizations(engine: AsyncEngine):
|
||||
# 临时存储使用内存
|
||||
await conn.execute(text("PRAGMA temp_store = MEMORY"))
|
||||
|
||||
logger.info("✅ SQLite性能优化已启用 (WAL模式 + 并发优化)")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ SQLite性能优化失败: {e},将使用默认配置")
|
||||
|
||||
@@ -241,8 +239,6 @@ async def _enable_postgresql_optimizations(engine: AsyncEngine):
|
||||
# 启用自动 EXPLAIN(可选,用于调试)
|
||||
# await conn.execute(text("SET auto_explain.log_min_duration = '1000'"))
|
||||
|
||||
logger.info("✅ PostgreSQL性能优化已启用")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ PostgreSQL性能优化失败: {e},将使用默认配置")
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from .preloader import (
|
||||
DataPreloader,
|
||||
close_preloader,
|
||||
get_preloader,
|
||||
record_preload_access,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -51,4 +52,5 @@ __all__ = [
|
||||
"get_batch_scheduler",
|
||||
"get_cache",
|
||||
"get_preloader",
|
||||
"record_preload_access",
|
||||
]
|
||||
|
||||
@@ -13,6 +13,7 @@ from collections import defaultdict
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from collections import OrderedDict
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -22,6 +23,15 @@ from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("preloader")
|
||||
|
||||
# 预加载注册表(用于后台刷新热点数据)
|
||||
_preload_loader_registry: OrderedDict[str, Callable[[], Awaitable[Any]]] = OrderedDict()
|
||||
_registry_lock = asyncio.Lock()
|
||||
_preload_task: asyncio.Task | None = None
|
||||
_preload_task_lock = asyncio.Lock()
|
||||
_PRELOAD_REGISTRY_LIMIT = 1024
|
||||
# 默认后台预加载轮询间隔(秒)
|
||||
_DEFAULT_PRELOAD_INTERVAL = 60
|
||||
|
||||
|
||||
@dataclass
|
||||
class AccessPattern:
|
||||
@@ -223,16 +233,19 @@ class DataPreloader:
|
||||
|
||||
async def start_preload_batch(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
loaders: dict[str, Callable[[], Awaitable[Any]]],
|
||||
limit: int = 100,
|
||||
) -> None:
|
||||
"""批量启动预加载任务
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
loaders: 数据键到加载函数的映射
|
||||
limit: 参与预加载的热点键数量上限
|
||||
"""
|
||||
preload_keys = await self.get_preload_keys()
|
||||
if not loaders:
|
||||
return
|
||||
|
||||
preload_keys = await self.get_preload_keys(limit=limit)
|
||||
|
||||
for key in preload_keys:
|
||||
if key in loaders:
|
||||
@@ -418,6 +431,91 @@ class CommonDataPreloader:
|
||||
await self.preloader.preload_data(cache_key, loader)
|
||||
|
||||
|
||||
# 预加载后台任务与注册表管理
|
||||
async def _get_preload_interval() -> float:
|
||||
"""获取后台预加载轮询间隔"""
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
if global_config and getattr(global_config, "database", None):
|
||||
interval = getattr(global_config.database, "preload_interval", None)
|
||||
if interval:
|
||||
return max(5.0, float(interval))
|
||||
except Exception:
|
||||
# 配置可能未加载或不存在该字段,使用默认值
|
||||
pass
|
||||
return float(_DEFAULT_PRELOAD_INTERVAL)
|
||||
|
||||
|
||||
async def _register_preload_loader(
|
||||
cache_key: str,
|
||||
loader: Callable[[], Awaitable[Any]],
|
||||
) -> None:
|
||||
"""注册用于热点预加载的加载函数"""
|
||||
async with _registry_lock:
|
||||
# move_to_end可以保持最近注册的顺序,便于淘汰旧项
|
||||
_preload_loader_registry[cache_key] = loader
|
||||
_preload_loader_registry.move_to_end(cache_key)
|
||||
|
||||
# 控制注册表大小,避免无限增长
|
||||
while len(_preload_loader_registry) > _PRELOAD_REGISTRY_LIMIT:
|
||||
_preload_loader_registry.popitem(last=False)
|
||||
|
||||
|
||||
async def _snapshot_loaders() -> dict[str, Callable[[], Awaitable[Any]]]:
|
||||
"""获取当前注册的预加载loader快照"""
|
||||
async with _registry_lock:
|
||||
return dict(_preload_loader_registry)
|
||||
|
||||
|
||||
async def _preload_worker() -> None:
|
||||
"""后台周期性预加载任务"""
|
||||
while True:
|
||||
try:
|
||||
interval = await _get_preload_interval()
|
||||
loaders = await _snapshot_loaders()
|
||||
|
||||
if loaders:
|
||||
preloader = await get_preloader()
|
||||
await preloader.start_preload_batch(loaders)
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"预加载后台任务异常: {e}")
|
||||
# 避免紧急重试导致CPU占用过高
|
||||
await asyncio.sleep(5)
|
||||
|
||||
|
||||
async def _ensure_preload_worker() -> None:
|
||||
"""确保后台预加载任务已启动"""
|
||||
global _preload_task
|
||||
|
||||
async with _preload_task_lock:
|
||||
if _preload_task is None or _preload_task.done():
|
||||
_preload_task = asyncio.create_task(_preload_worker())
|
||||
|
||||
|
||||
async def record_preload_access(
|
||||
cache_key: str,
|
||||
*,
|
||||
related_keys: list[str] | None = None,
|
||||
loader: Callable[[], Awaitable[Any]] | None = None,
|
||||
) -> None:
|
||||
"""记录访问并注册预加载loader
|
||||
|
||||
这个入口为上层API(CRUD/Query)提供:记录访问模式、建立关联关系、
|
||||
以及注册用于后续后台预加载的加载函数。
|
||||
"""
|
||||
preloader = await get_preloader()
|
||||
await preloader.record_access(cache_key, related_keys)
|
||||
|
||||
if loader is not None:
|
||||
await _register_preload_loader(cache_key, loader)
|
||||
await _ensure_preload_worker()
|
||||
|
||||
|
||||
# 全局预加载器实例
|
||||
_global_preloader: DataPreloader | None = None
|
||||
_preloader_lock = asyncio.Lock()
|
||||
@@ -438,7 +536,22 @@ async def get_preloader() -> DataPreloader:
|
||||
async def close_preloader() -> None:
|
||||
"""关闭全局预加载器"""
|
||||
global _global_preloader
|
||||
global _preload_task
|
||||
|
||||
# 停止后台任务
|
||||
if _preload_task is not None:
|
||||
_preload_task.cancel()
|
||||
try:
|
||||
await _preload_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
_preload_task = None
|
||||
|
||||
# 清理注册表
|
||||
async with _registry_lock:
|
||||
_preload_loader_registry.clear()
|
||||
|
||||
# 清理预加载器实例
|
||||
if _global_preloader is not None:
|
||||
await _global_preloader.clear()
|
||||
_global_preloader = None
|
||||
|
||||
Reference in New Issue
Block a user