This commit is contained in:
雅诺狐
2025-12-08 17:44:00 +08:00
47 changed files with 525 additions and 897 deletions

View File

@@ -19,6 +19,7 @@ from src.common.database.optimization import (
Priority,
get_batch_scheduler,
get_cache,
record_preload_access,
)
from src.common.logger import get_logger
@@ -143,6 +144,16 @@ class CRUDBase(Generic[T]):
"""
cache_key = f"{self.model_name}:id:{id}"
if use_cache:
async def _preload_loader() -> dict[str, Any] | None:
async with get_db_session() as session:
stmt = select(self.model).where(self.model.id == id)
result = await session.execute(stmt)
instance = result.scalar_one_or_none()
return _model_to_dict(instance) if instance is not None else None
await record_preload_access(cache_key, loader=_preload_loader)
# 尝试从缓存获取 (缓存的是字典)
if use_cache:
cache = await get_cache()
@@ -187,6 +198,21 @@ class CRUDBase(Generic[T]):
"""
cache_key = f"{self.model_name}:filter:{sorted(filters.items())!s}"
filters_copy = dict(filters)
if use_cache:
async def _preload_loader() -> dict[str, Any] | None:
async with get_db_session() as session:
stmt = select(self.model)
for key, value in filters_copy.items():
if hasattr(self.model, key):
stmt = stmt.where(getattr(self.model, key) == value)
result = await session.execute(stmt)
instance = result.scalar_one_or_none()
return _model_to_dict(instance) if instance is not None else None
await record_preload_access(cache_key, loader=_preload_loader)
# 尝试从缓存获取 (缓存的是字典)
if use_cache:
cache = await get_cache()
@@ -239,6 +265,29 @@ class CRUDBase(Generic[T]):
"""
cache_key = f"{self.model_name}:multi:{skip}:{limit}:{sorted(filters.items())!s}"
filters_copy = dict(filters)
if use_cache:
async def _preload_loader() -> list[dict[str, Any]]:
async with get_db_session() as session:
stmt = select(self.model)
# 应用过滤条件
for key, value in filters_copy.items():
if hasattr(self.model, key):
if isinstance(value, list | tuple | set):
stmt = stmt.where(getattr(self.model, key).in_(value))
else:
stmt = stmt.where(getattr(self.model, key) == value)
# 应用分页
stmt = stmt.offset(skip).limit(limit)
result = await session.execute(stmt)
instances = list(result.scalars().all())
return [_model_to_dict(inst) for inst in instances]
await record_preload_access(cache_key, loader=_preload_loader)
# 尝试从缓存获取 (缓存的是字典列表)
if use_cache:
cache = await get_cache()

View File

@@ -16,7 +16,7 @@ from sqlalchemy import and_, asc, desc, func, or_, select
# 导入 CRUD 辅助函数以避免重复定义
from src.common.database.api.crud import _dict_to_model, _model_to_dict
from src.common.database.core.session import get_db_session
from src.common.database.optimization import get_cache
from src.common.database.optimization import get_cache, record_preload_access
from src.common.logger import get_logger
logger = get_logger("database.query")
@@ -272,6 +272,16 @@ class QueryBuilder(Generic[T]):
模型实例列表或字典列表
"""
cache_key = ":".join(self._cache_key_parts) + ":all"
stmt = self._stmt
if self._use_cache:
async def _preload_loader() -> list[dict[str, Any]]:
async with get_db_session() as session:
result = await session.execute(stmt)
instances = list(result.scalars().all())
return [_model_to_dict(inst) for inst in instances]
await record_preload_access(cache_key, loader=_preload_loader)
# 尝试从缓存获取 (缓存的是字典列表)
if self._use_cache:
@@ -310,6 +320,16 @@ class QueryBuilder(Generic[T]):
模型实例或None
"""
cache_key = ":".join(self._cache_key_parts) + ":first"
stmt = self._stmt
if self._use_cache:
async def _preload_loader() -> dict[str, Any] | None:
async with get_db_session() as session:
result = await session.execute(stmt)
instance = result.scalars().first()
return _model_to_dict(instance) if instance is not None else None
await record_preload_access(cache_key, loader=_preload_loader)
# 尝试从缓存获取 (缓存的是字典)
if self._use_cache:
@@ -348,6 +368,15 @@ class QueryBuilder(Generic[T]):
记录数量
"""
cache_key = ":".join(self._cache_key_parts) + ":count"
count_stmt = select(func.count()).select_from(self._stmt.subquery())
if self._use_cache:
async def _preload_loader() -> int:
async with get_db_session() as session:
result = await session.execute(count_stmt)
return result.scalar() or 0
await record_preload_access(cache_key, loader=_preload_loader)
# 尝试从缓存获取
if self._use_cache:
@@ -357,8 +386,6 @@ class QueryBuilder(Generic[T]):
return cached
# 构建count查询
count_stmt = select(func.count()).select_from(self._stmt.subquery())
# 从数据库查询
async with get_db_session() as session:
result = await session.execute(count_stmt)

View File

@@ -79,7 +79,7 @@ async def get_engine() -> AsyncEngine:
elif db_type == "postgresql":
await _enable_postgresql_optimizations(_engine)
logger.info(f"{db_type.upper()} 数据库引擎初始化成功")
logger.info(f"{db_type.upper()} 数据库引擎初始化成功")
return _engine
except Exception as e:
@@ -116,7 +116,7 @@ def _build_sqlite_config(config) -> tuple[str, dict]:
},
}
logger.info(f"SQLite配置: {db_path}")
logger.debug(f"SQLite配置: {db_path}")
return url, engine_kwargs
@@ -167,7 +167,7 @@ def _build_postgresql_config(config) -> tuple[str, dict]:
if connect_args:
engine_kwargs["connect_args"] = connect_args
logger.info(
logger.debug(
f"PostgreSQL配置: {config.postgresql_user}@{config.postgresql_host}:{config.postgresql_port}/{config.postgresql_database}"
)
return url, engine_kwargs
@@ -184,7 +184,7 @@ async def close_engine():
logger.info("正在关闭数据库引擎...")
await _engine.dispose()
_engine = None
logger.info("数据库引擎已关闭")
logger.info("数据库引擎已关闭")
async def _enable_sqlite_optimizations(engine: AsyncEngine):
@@ -214,8 +214,6 @@ async def _enable_sqlite_optimizations(engine: AsyncEngine):
# 临时存储使用内存
await conn.execute(text("PRAGMA temp_store = MEMORY"))
logger.info("✅ SQLite性能优化已启用 (WAL模式 + 并发优化)")
except Exception as e:
logger.warning(f"⚠️ SQLite性能优化失败: {e},将使用默认配置")
@@ -241,8 +239,6 @@ async def _enable_postgresql_optimizations(engine: AsyncEngine):
# 启用自动 EXPLAIN可选用于调试
# await conn.execute(text("SET auto_explain.log_min_duration = '1000'"))
logger.info("✅ PostgreSQL性能优化已启用")
except Exception as e:
logger.warning(f"⚠️ PostgreSQL性能优化失败: {e},将使用默认配置")

View File

@@ -31,6 +31,7 @@ from .preloader import (
DataPreloader,
close_preloader,
get_preloader,
record_preload_access,
)
from .redis_cache import RedisCache, close_redis_cache, get_redis_cache
@@ -62,5 +63,6 @@ __all__ = [
"get_cache",
"get_cache_backend_type",
"get_preloader",
"record_preload_access",
"get_redis_cache"
]

View File

@@ -13,6 +13,7 @@ from collections import defaultdict
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any
from collections import OrderedDict
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -22,6 +23,15 @@ from src.common.logger import get_logger
logger = get_logger("preloader")
# 预加载注册表(用于后台刷新热点数据)
_preload_loader_registry: OrderedDict[str, Callable[[], Awaitable[Any]]] = OrderedDict()
_registry_lock = asyncio.Lock()
_preload_task: asyncio.Task | None = None
_preload_task_lock = asyncio.Lock()
_PRELOAD_REGISTRY_LIMIT = 1024
# 默认后台预加载轮询间隔(秒)
_DEFAULT_PRELOAD_INTERVAL = 60
@dataclass
class AccessPattern:
@@ -223,16 +233,19 @@ class DataPreloader:
async def start_preload_batch(
self,
session: AsyncSession,
loaders: dict[str, Callable[[], Awaitable[Any]]],
limit: int = 100,
) -> None:
"""批量启动预加载任务
Args:
session: 数据库会话
loaders: 数据键到加载函数的映射
limit: 参与预加载的热点键数量上限
"""
preload_keys = await self.get_preload_keys()
if not loaders:
return
preload_keys = await self.get_preload_keys(limit=limit)
for key in preload_keys:
if key in loaders:
@@ -418,6 +431,91 @@ class CommonDataPreloader:
await self.preloader.preload_data(cache_key, loader)
# 预加载后台任务与注册表管理
async def _get_preload_interval() -> float:
"""获取后台预加载轮询间隔"""
try:
from src.config.config import global_config
if global_config and getattr(global_config, "database", None):
interval = getattr(global_config.database, "preload_interval", None)
if interval:
return max(5.0, float(interval))
except Exception:
# 配置可能未加载或不存在该字段,使用默认值
pass
return float(_DEFAULT_PRELOAD_INTERVAL)
async def _register_preload_loader(
cache_key: str,
loader: Callable[[], Awaitable[Any]],
) -> None:
"""注册用于热点预加载的加载函数"""
async with _registry_lock:
# move_to_end可以保持最近注册的顺序便于淘汰旧项
_preload_loader_registry[cache_key] = loader
_preload_loader_registry.move_to_end(cache_key)
# 控制注册表大小,避免无限增长
while len(_preload_loader_registry) > _PRELOAD_REGISTRY_LIMIT:
_preload_loader_registry.popitem(last=False)
async def _snapshot_loaders() -> dict[str, Callable[[], Awaitable[Any]]]:
"""获取当前注册的预加载loader快照"""
async with _registry_lock:
return dict(_preload_loader_registry)
async def _preload_worker() -> None:
"""后台周期性预加载任务"""
while True:
try:
interval = await _get_preload_interval()
loaders = await _snapshot_loaders()
if loaders:
preloader = await get_preloader()
await preloader.start_preload_batch(loaders)
await asyncio.sleep(interval)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"预加载后台任务异常: {e}")
# 避免紧急重试导致CPU占用过高
await asyncio.sleep(5)
async def _ensure_preload_worker() -> None:
"""确保后台预加载任务已启动"""
global _preload_task
async with _preload_task_lock:
if _preload_task is None or _preload_task.done():
_preload_task = asyncio.create_task(_preload_worker())
async def record_preload_access(
cache_key: str,
*,
related_keys: list[str] | None = None,
loader: Callable[[], Awaitable[Any]] | None = None,
) -> None:
"""记录访问并注册预加载loader
这个入口为上层APICRUD/Query提供记录访问模式、建立关联关系、
以及注册用于后续后台预加载的加载函数。
"""
preloader = await get_preloader()
await preloader.record_access(cache_key, related_keys)
if loader is not None:
await _register_preload_loader(cache_key, loader)
await _ensure_preload_worker()
# 全局预加载器实例
_global_preloader: DataPreloader | None = None
_preloader_lock = asyncio.Lock()
@@ -438,7 +536,22 @@ async def get_preloader() -> DataPreloader:
async def close_preloader() -> None:
"""关闭全局预加载器"""
global _global_preloader
global _preload_task
# 停止后台任务
if _preload_task is not None:
_preload_task.cancel()
try:
await _preload_task
except asyncio.CancelledError:
pass
_preload_task = None
# 清理注册表
async with _registry_lock:
_preload_loader_registry.clear()
# 清理预加载器实例
if _global_preloader is not None:
await _global_preloader.clear()
_global_preloader = None