558 lines
16 KiB
Python
558 lines
16 KiB
Python
"""智能数据预加载器
|
||
|
||
实现智能的数据预加载策略:
|
||
- 热点数据识别:基于访问频率和时间衰减
|
||
- 关联数据预取:预测性地加载相关数据
|
||
- 自适应策略:根据命中率动态调整
|
||
- 异步预加载:不阻塞主线程
|
||
"""
|
||
|
||
import asyncio
|
||
import time
|
||
from collections import OrderedDict, defaultdict
|
||
from collections.abc import Awaitable, Callable
|
||
from dataclasses import dataclass, field
|
||
from typing import Any
|
||
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from src.common.database.optimization.cache_manager import get_cache
|
||
from src.common.logger import get_logger
|
||
|
||
logger = get_logger("preloader")
|
||
|
||
# 预加载注册表(用于后台刷新热点数据)
|
||
_preload_loader_registry: OrderedDict[str, Callable[[], Awaitable[Any]]] = OrderedDict()
|
||
_registry_lock = asyncio.Lock()
|
||
_preload_task: asyncio.Task | None = None
|
||
_preload_task_lock = asyncio.Lock()
|
||
_PRELOAD_REGISTRY_LIMIT = 1024
|
||
# 默认后台预加载轮询间隔(秒)
|
||
_DEFAULT_PRELOAD_INTERVAL = 60
|
||
|
||
|
||
@dataclass
|
||
class AccessPattern:
|
||
"""访问模式统计
|
||
|
||
Attributes:
|
||
key: 数据键
|
||
access_count: 访问次数
|
||
last_access: 最后访问时间
|
||
score: 热度评分(时间衰减后的访问频率)
|
||
related_keys: 关联数据键列表
|
||
"""
|
||
key: str
|
||
access_count: int = 0
|
||
last_access: float = 0
|
||
score: float = 0
|
||
related_keys: list[str] = field(default_factory=list)
|
||
|
||
|
||
class DataPreloader:
|
||
"""数据预加载器
|
||
|
||
通过分析访问模式,预测并预加载可能需要的数据
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
decay_factor: float = 0.9,
|
||
preload_threshold: float = 0.5,
|
||
max_patterns: int = 1000,
|
||
):
|
||
"""初始化预加载器
|
||
|
||
Args:
|
||
decay_factor: 时间衰减因子(0-1),越小衰减越快
|
||
preload_threshold: 预加载阈值,score超过此值时预加载
|
||
max_patterns: 最大跟踪的访问模式数量
|
||
"""
|
||
self.decay_factor = decay_factor
|
||
self.preload_threshold = preload_threshold
|
||
self.max_patterns = max_patterns
|
||
|
||
# 访问模式跟踪
|
||
self._patterns: dict[str, AccessPattern] = {}
|
||
# 关联关系:key -> [related_keys]
|
||
self._associations: dict[str, set[str]] = defaultdict(set)
|
||
# 预加载任务
|
||
self._preload_tasks: set[asyncio.Task] = set()
|
||
# 统计信息
|
||
self._total_accesses = 0
|
||
self._preload_count = 0
|
||
self._preload_hits = 0
|
||
|
||
self._lock = asyncio.Lock()
|
||
|
||
logger.info(
|
||
f"数据预加载器初始化: 衰减因子={decay_factor}, "
|
||
f"预加载阈值={preload_threshold}"
|
||
)
|
||
|
||
async def record_access(
|
||
self,
|
||
key: str,
|
||
related_keys: list[str] | None = None,
|
||
) -> None:
|
||
"""记录数据访问
|
||
|
||
Args:
|
||
key: 被访问的数据键
|
||
related_keys: 关联访问的数据键列表
|
||
"""
|
||
async with self._lock:
|
||
self._total_accesses += 1
|
||
now = time.time()
|
||
|
||
# 更新或创建访问模式
|
||
if key in self._patterns:
|
||
pattern = self._patterns[key]
|
||
pattern.access_count += 1
|
||
pattern.last_access = now
|
||
else:
|
||
pattern = AccessPattern(
|
||
key=key,
|
||
access_count=1,
|
||
last_access=now,
|
||
)
|
||
self._patterns[key] = pattern
|
||
|
||
# 更新热度评分(时间衰减)
|
||
pattern.score = self._calculate_score(pattern)
|
||
|
||
# 记录关联关系
|
||
if related_keys:
|
||
self._associations[key].update(related_keys)
|
||
pattern.related_keys = list(self._associations[key])
|
||
|
||
# 如果模式过多,删除评分最低的
|
||
if len(self._patterns) > self.max_patterns:
|
||
min_key = min(self._patterns, key=lambda k: self._patterns[k].score)
|
||
del self._patterns[min_key]
|
||
if min_key in self._associations:
|
||
del self._associations[min_key]
|
||
|
||
async def should_preload(self, key: str) -> bool:
|
||
"""判断是否应该预加载某个数据
|
||
|
||
Args:
|
||
key: 数据键
|
||
|
||
Returns:
|
||
是否应该预加载
|
||
"""
|
||
async with self._lock:
|
||
pattern = self._patterns.get(key)
|
||
if pattern is None:
|
||
return False
|
||
|
||
# 更新评分
|
||
pattern.score = self._calculate_score(pattern)
|
||
|
||
return pattern.score >= self.preload_threshold
|
||
|
||
async def get_preload_keys(self, limit: int = 100) -> list[str]:
|
||
"""获取应该预加载的数据键列表
|
||
|
||
Args:
|
||
limit: 最大返回数量
|
||
|
||
Returns:
|
||
按评分排序的数据键列表
|
||
"""
|
||
async with self._lock:
|
||
# 更新所有评分
|
||
for pattern in self._patterns.values():
|
||
pattern.score = self._calculate_score(pattern)
|
||
|
||
# 按评分排序
|
||
sorted_patterns = sorted(
|
||
self._patterns.values(),
|
||
key=lambda p: p.score,
|
||
reverse=True,
|
||
)
|
||
|
||
# 返回超过阈值的键
|
||
return [
|
||
p.key for p in sorted_patterns[:limit]
|
||
if p.score >= self.preload_threshold
|
||
]
|
||
|
||
async def get_related_keys(self, key: str) -> list[str]:
|
||
"""获取关联数据键
|
||
|
||
Args:
|
||
key: 数据键
|
||
|
||
Returns:
|
||
关联数据键列表
|
||
"""
|
||
async with self._lock:
|
||
return list(self._associations.get(key, []))
|
||
|
||
async def preload_data(
|
||
self,
|
||
key: str,
|
||
loader: Callable[[], Awaitable[Any]],
|
||
) -> None:
|
||
"""预加载数据
|
||
|
||
Args:
|
||
key: 数据键
|
||
loader: 异步加载函数
|
||
"""
|
||
try:
|
||
cache = await get_cache()
|
||
|
||
# 检查缓存中是否已存在
|
||
if await cache.l1_cache.get(key) is not None:
|
||
return
|
||
|
||
# 加载数据
|
||
logger.debug(f"预加载数据: {key}")
|
||
data = await loader()
|
||
|
||
if data is not None:
|
||
# 写入缓存
|
||
await cache.set(key, data)
|
||
self._preload_count += 1
|
||
|
||
# 预加载关联数据
|
||
related_keys = await self.get_related_keys(key)
|
||
for related_key in related_keys[:5]: # 最多预加载5个关联项
|
||
if await cache.l1_cache.get(related_key) is None:
|
||
# 这里需要调用者提供关联数据的加载函数
|
||
# 暂时只记录,不实际加载
|
||
logger.debug(f"发现关联数据: {related_key}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"预加载数据失败 {key}: {e}")
|
||
|
||
async def start_preload_batch(
|
||
self,
|
||
loaders: dict[str, Callable[[], Awaitable[Any]]],
|
||
limit: int = 100,
|
||
) -> None:
|
||
"""批量启动预加载任务
|
||
|
||
Args:
|
||
loaders: 数据键到加载函数的映射
|
||
limit: 参与预加载的热点键数量上限
|
||
"""
|
||
if not loaders:
|
||
return
|
||
|
||
preload_keys = await self.get_preload_keys(limit=limit)
|
||
|
||
for key in preload_keys:
|
||
if key in loaders:
|
||
loader = loaders[key]
|
||
task = asyncio.create_task(self.preload_data(key, loader))
|
||
self._preload_tasks.add(task)
|
||
task.add_done_callback(self._preload_tasks.discard)
|
||
|
||
async def record_hit(self, key: str) -> None:
|
||
"""记录预加载命中
|
||
|
||
当缓存命中的数据是预加载的,调用此方法统计
|
||
|
||
Args:
|
||
key: 数据键
|
||
"""
|
||
async with self._lock:
|
||
self._preload_hits += 1
|
||
|
||
async def get_stats(self) -> dict[str, Any]:
|
||
"""获取统计信息"""
|
||
async with self._lock:
|
||
preload_hit_rate = (
|
||
self._preload_hits / self._preload_count
|
||
if self._preload_count > 0
|
||
else 0.0
|
||
)
|
||
|
||
return {
|
||
"total_accesses": self._total_accesses,
|
||
"tracked_patterns": len(self._patterns),
|
||
"associations": len(self._associations),
|
||
"preload_count": self._preload_count,
|
||
"preload_hits": self._preload_hits,
|
||
"preload_hit_rate": preload_hit_rate,
|
||
"active_tasks": len(self._preload_tasks),
|
||
}
|
||
|
||
async def clear(self) -> None:
|
||
"""清空所有统计信息"""
|
||
async with self._lock:
|
||
self._patterns.clear()
|
||
self._associations.clear()
|
||
self._total_accesses = 0
|
||
self._preload_count = 0
|
||
self._preload_hits = 0
|
||
|
||
# 取消所有预加载任务
|
||
for task in self._preload_tasks:
|
||
task.cancel()
|
||
self._preload_tasks.clear()
|
||
|
||
def _calculate_score(self, pattern: AccessPattern) -> float:
|
||
"""计算热度评分
|
||
|
||
使用时间衰减的访问频率:
|
||
score = access_count * decay_factor^(time_since_last_access)
|
||
|
||
Args:
|
||
pattern: 访问模式
|
||
|
||
Returns:
|
||
热度评分
|
||
"""
|
||
now = time.time()
|
||
time_diff = now - pattern.last_access
|
||
|
||
# 时间衰减(以小时为单位)
|
||
hours_passed = time_diff / 3600
|
||
decay = self.decay_factor ** hours_passed
|
||
|
||
# 评分 = 访问次数 * 时间衰减
|
||
score = pattern.access_count * decay
|
||
|
||
return score
|
||
|
||
|
||
class CommonDataPreloader:
|
||
"""常见数据预加载器
|
||
|
||
针对特定的数据类型提供预加载策略
|
||
"""
|
||
|
||
def __init__(self, preloader: DataPreloader):
|
||
"""初始化
|
||
|
||
Args:
|
||
preloader: 基础预加载器
|
||
"""
|
||
self.preloader = preloader
|
||
|
||
async def preload_user_data(
|
||
self,
|
||
session: AsyncSession,
|
||
user_id: str,
|
||
platform: str,
|
||
) -> None:
|
||
"""预加载用户相关数据
|
||
|
||
包括:个人信息、权限、关系等
|
||
|
||
Args:
|
||
session: 数据库会话
|
||
user_id: 用户ID
|
||
platform: 平台
|
||
"""
|
||
from src.common.database.core.models import PersonInfo, UserPermissions, UserRelationships
|
||
|
||
# 预加载个人信息
|
||
await self._preload_model(
|
||
session,
|
||
f"person:{platform}:{user_id}",
|
||
PersonInfo,
|
||
{"platform": platform, "user_id": user_id},
|
||
)
|
||
|
||
# 预加载用户权限
|
||
await self._preload_model(
|
||
session,
|
||
f"permissions:{platform}:{user_id}",
|
||
UserPermissions,
|
||
{"platform": platform, "user_id": user_id},
|
||
)
|
||
|
||
# 预加载用户关系
|
||
await self._preload_model(
|
||
session,
|
||
f"relationship:{user_id}",
|
||
UserRelationships,
|
||
{"user_id": user_id},
|
||
)
|
||
|
||
async def preload_chat_context(
|
||
self,
|
||
session: AsyncSession,
|
||
stream_id: str,
|
||
limit: int = 50,
|
||
) -> None:
|
||
"""预加载聊天上下文
|
||
|
||
包括:最近消息、聊天流信息等
|
||
|
||
Args:
|
||
session: 数据库会话
|
||
stream_id: 聊天流ID
|
||
limit: 消息数量限制
|
||
"""
|
||
from src.common.database.core.models import ChatStreams
|
||
|
||
# 预加载聊天流信息
|
||
await self._preload_model(
|
||
session,
|
||
f"stream:{stream_id}",
|
||
ChatStreams,
|
||
{"stream_id": stream_id},
|
||
)
|
||
|
||
# 预加载最近消息(这个比较复杂,暂时跳过)
|
||
# TODO: 实现消息列表的预加载
|
||
|
||
async def _preload_model(
|
||
self,
|
||
session: AsyncSession,
|
||
cache_key: str,
|
||
model_class: type,
|
||
filters: dict[str, Any],
|
||
) -> None:
|
||
"""预加载模型数据
|
||
|
||
Args:
|
||
session: 数据库会话
|
||
cache_key: 缓存键
|
||
model_class: 模型类
|
||
filters: 过滤条件
|
||
"""
|
||
async def loader():
|
||
stmt = select(model_class)
|
||
for key, value in filters.items():
|
||
stmt = stmt.where(getattr(model_class, key) == value)
|
||
result = await session.execute(stmt)
|
||
return result.scalar_one_or_none()
|
||
|
||
await self.preloader.preload_data(cache_key, loader)
|
||
|
||
|
||
# 预加载后台任务与注册表管理
|
||
async def _get_preload_interval() -> float:
|
||
"""获取后台预加载轮询间隔"""
|
||
try:
|
||
from src.config.config import global_config
|
||
|
||
if global_config and getattr(global_config, "database", None):
|
||
interval = getattr(global_config.database, "preload_interval", None)
|
||
if interval:
|
||
return max(5.0, float(interval))
|
||
except Exception:
|
||
# 配置可能未加载或不存在该字段,使用默认值
|
||
pass
|
||
return float(_DEFAULT_PRELOAD_INTERVAL)
|
||
|
||
|
||
async def _register_preload_loader(
|
||
cache_key: str,
|
||
loader: Callable[[], Awaitable[Any]],
|
||
) -> None:
|
||
"""注册用于热点预加载的加载函数"""
|
||
async with _registry_lock:
|
||
# move_to_end可以保持最近注册的顺序,便于淘汰旧项
|
||
_preload_loader_registry[cache_key] = loader
|
||
_preload_loader_registry.move_to_end(cache_key)
|
||
|
||
# 控制注册表大小,避免无限增长
|
||
while len(_preload_loader_registry) > _PRELOAD_REGISTRY_LIMIT:
|
||
_preload_loader_registry.popitem(last=False)
|
||
|
||
|
||
async def _snapshot_loaders() -> dict[str, Callable[[], Awaitable[Any]]]:
|
||
"""获取当前注册的预加载loader快照"""
|
||
async with _registry_lock:
|
||
return dict(_preload_loader_registry)
|
||
|
||
|
||
async def _preload_worker() -> None:
|
||
"""后台周期性预加载任务"""
|
||
while True:
|
||
try:
|
||
interval = await _get_preload_interval()
|
||
loaders = await _snapshot_loaders()
|
||
|
||
if loaders:
|
||
preloader = await get_preloader()
|
||
await preloader.start_preload_batch(loaders)
|
||
|
||
await asyncio.sleep(interval)
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"预加载后台任务异常: {e}")
|
||
# 避免紧急重试导致CPU占用过高
|
||
await asyncio.sleep(5)
|
||
|
||
|
||
async def _ensure_preload_worker() -> None:
|
||
"""确保后台预加载任务已启动"""
|
||
global _preload_task
|
||
|
||
async with _preload_task_lock:
|
||
if _preload_task is None or _preload_task.done():
|
||
_preload_task = asyncio.create_task(_preload_worker())
|
||
|
||
|
||
async def record_preload_access(
|
||
cache_key: str,
|
||
*,
|
||
related_keys: list[str] | None = None,
|
||
loader: Callable[[], Awaitable[Any]] | None = None,
|
||
) -> None:
|
||
"""记录访问并注册预加载loader
|
||
|
||
这个入口为上层API(CRUD/Query)提供:记录访问模式、建立关联关系、
|
||
以及注册用于后续后台预加载的加载函数。
|
||
"""
|
||
preloader = await get_preloader()
|
||
await preloader.record_access(cache_key, related_keys)
|
||
|
||
if loader is not None:
|
||
await _register_preload_loader(cache_key, loader)
|
||
await _ensure_preload_worker()
|
||
|
||
|
||
# 全局预加载器实例
|
||
_global_preloader: DataPreloader | None = None
|
||
_preloader_lock = asyncio.Lock()
|
||
|
||
|
||
async def get_preloader() -> DataPreloader:
|
||
"""获取全局预加载器实例(单例)"""
|
||
global _global_preloader
|
||
|
||
if _global_preloader is None:
|
||
async with _preloader_lock:
|
||
if _global_preloader is None:
|
||
_global_preloader = DataPreloader()
|
||
|
||
return _global_preloader
|
||
|
||
|
||
async def close_preloader() -> None:
|
||
"""关闭全局预加载器"""
|
||
global _global_preloader
|
||
global _preload_task
|
||
|
||
# 停止后台任务
|
||
if _preload_task is not None:
|
||
_preload_task.cancel()
|
||
try:
|
||
await _preload_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
_preload_task = None
|
||
|
||
# 清理注册表
|
||
async with _registry_lock:
|
||
_preload_loader_registry.clear()
|
||
|
||
# 清理预加载器实例
|
||
if _global_preloader is not None:
|
||
await _global_preloader.clear()
|
||
_global_preloader = None
|
||
logger.info("全局预加载器已关闭")
|