rufffffff
This commit is contained in:
@@ -10,8 +10,9 @@
|
||||
import asyncio
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Awaitable, Callable, Optional
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -25,7 +26,7 @@ logger = get_logger("preloader")
|
||||
@dataclass
|
||||
class AccessPattern:
|
||||
"""访问模式统计
|
||||
|
||||
|
||||
Attributes:
|
||||
key: 数据键
|
||||
access_count: 访问次数
|
||||
@@ -42,7 +43,7 @@ class AccessPattern:
|
||||
|
||||
class DataPreloader:
|
||||
"""数据预加载器
|
||||
|
||||
|
||||
通过分析访问模式,预测并预加载可能需要的数据
|
||||
"""
|
||||
|
||||
@@ -53,7 +54,7 @@ class DataPreloader:
|
||||
max_patterns: int = 1000,
|
||||
):
|
||||
"""初始化预加载器
|
||||
|
||||
|
||||
Args:
|
||||
decay_factor: 时间衰减因子(0-1),越小衰减越快
|
||||
preload_threshold: 预加载阈值,score超过此值时预加载
|
||||
@@ -62,7 +63,7 @@ class DataPreloader:
|
||||
self.decay_factor = decay_factor
|
||||
self.preload_threshold = preload_threshold
|
||||
self.max_patterns = max_patterns
|
||||
|
||||
|
||||
# 访问模式跟踪
|
||||
self._patterns: dict[str, AccessPattern] = {}
|
||||
# 关联关系:key -> [related_keys]
|
||||
@@ -73,9 +74,9 @@ class DataPreloader:
|
||||
self._total_accesses = 0
|
||||
self._preload_count = 0
|
||||
self._preload_hits = 0
|
||||
|
||||
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
|
||||
logger.info(
|
||||
f"数据预加载器初始化: 衰减因子={decay_factor}, "
|
||||
f"预加载阈值={preload_threshold}"
|
||||
@@ -84,10 +85,10 @@ class DataPreloader:
|
||||
async def record_access(
|
||||
self,
|
||||
key: str,
|
||||
related_keys: Optional[list[str]] = None,
|
||||
related_keys: list[str] | None = None,
|
||||
) -> None:
|
||||
"""记录数据访问
|
||||
|
||||
|
||||
Args:
|
||||
key: 被访问的数据键
|
||||
related_keys: 关联访问的数据键列表
|
||||
@@ -95,7 +96,7 @@ class DataPreloader:
|
||||
async with self._lock:
|
||||
self._total_accesses += 1
|
||||
now = time.time()
|
||||
|
||||
|
||||
# 更新或创建访问模式
|
||||
if key in self._patterns:
|
||||
pattern = self._patterns[key]
|
||||
@@ -108,15 +109,15 @@ class DataPreloader:
|
||||
last_access=now,
|
||||
)
|
||||
self._patterns[key] = pattern
|
||||
|
||||
|
||||
# 更新热度评分(时间衰减)
|
||||
pattern.score = self._calculate_score(pattern)
|
||||
|
||||
|
||||
# 记录关联关系
|
||||
if related_keys:
|
||||
self._associations[key].update(related_keys)
|
||||
pattern.related_keys = list(self._associations[key])
|
||||
|
||||
|
||||
# 如果模式过多,删除评分最低的
|
||||
if len(self._patterns) > self.max_patterns:
|
||||
min_key = min(self._patterns, key=lambda k: self._patterns[k].score)
|
||||
@@ -126,10 +127,10 @@ class DataPreloader:
|
||||
|
||||
async def should_preload(self, key: str) -> bool:
|
||||
"""判断是否应该预加载某个数据
|
||||
|
||||
|
||||
Args:
|
||||
key: 数据键
|
||||
|
||||
|
||||
Returns:
|
||||
是否应该预加载
|
||||
"""
|
||||
@@ -137,18 +138,18 @@ class DataPreloader:
|
||||
pattern = self._patterns.get(key)
|
||||
if pattern is None:
|
||||
return False
|
||||
|
||||
|
||||
# 更新评分
|
||||
pattern.score = self._calculate_score(pattern)
|
||||
|
||||
|
||||
return pattern.score >= self.preload_threshold
|
||||
|
||||
async def get_preload_keys(self, limit: int = 100) -> list[str]:
|
||||
"""获取应该预加载的数据键列表
|
||||
|
||||
|
||||
Args:
|
||||
limit: 最大返回数量
|
||||
|
||||
|
||||
Returns:
|
||||
按评分排序的数据键列表
|
||||
"""
|
||||
@@ -156,14 +157,14 @@ class DataPreloader:
|
||||
# 更新所有评分
|
||||
for pattern in self._patterns.values():
|
||||
pattern.score = self._calculate_score(pattern)
|
||||
|
||||
|
||||
# 按评分排序
|
||||
sorted_patterns = sorted(
|
||||
self._patterns.values(),
|
||||
key=lambda p: p.score,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
|
||||
# 返回超过阈值的键
|
||||
return [
|
||||
p.key for p in sorted_patterns[:limit]
|
||||
@@ -172,10 +173,10 @@ class DataPreloader:
|
||||
|
||||
async def get_related_keys(self, key: str) -> list[str]:
|
||||
"""获取关联数据键
|
||||
|
||||
|
||||
Args:
|
||||
key: 数据键
|
||||
|
||||
|
||||
Returns:
|
||||
关联数据键列表
|
||||
"""
|
||||
@@ -188,27 +189,27 @@ class DataPreloader:
|
||||
loader: Callable[[], Awaitable[Any]],
|
||||
) -> None:
|
||||
"""预加载数据
|
||||
|
||||
|
||||
Args:
|
||||
key: 数据键
|
||||
loader: 异步加载函数
|
||||
"""
|
||||
try:
|
||||
cache = await get_cache()
|
||||
|
||||
|
||||
# 检查缓存中是否已存在
|
||||
if await cache.l1_cache.get(key) is not None:
|
||||
return
|
||||
|
||||
|
||||
# 加载数据
|
||||
logger.debug(f"预加载数据: {key}")
|
||||
data = await loader()
|
||||
|
||||
|
||||
if data is not None:
|
||||
# 写入缓存
|
||||
await cache.set(key, data)
|
||||
self._preload_count += 1
|
||||
|
||||
|
||||
# 预加载关联数据
|
||||
related_keys = await self.get_related_keys(key)
|
||||
for related_key in related_keys[:5]: # 最多预加载5个关联项
|
||||
@@ -216,7 +217,7 @@ class DataPreloader:
|
||||
# 这里需要调用者提供关联数据的加载函数
|
||||
# 暂时只记录,不实际加载
|
||||
logger.debug(f"发现关联数据: {related_key}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预加载数据失败 {key}: {e}", exc_info=True)
|
||||
|
||||
@@ -226,13 +227,13 @@ class DataPreloader:
|
||||
loaders: dict[str, Callable[[], Awaitable[Any]]],
|
||||
) -> None:
|
||||
"""批量启动预加载任务
|
||||
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
loaders: 数据键到加载函数的映射
|
||||
"""
|
||||
preload_keys = await self.get_preload_keys()
|
||||
|
||||
|
||||
for key in preload_keys:
|
||||
if key in loaders:
|
||||
loader = loaders[key]
|
||||
@@ -242,9 +243,9 @@ class DataPreloader:
|
||||
|
||||
async def record_hit(self, key: str) -> None:
|
||||
"""记录预加载命中
|
||||
|
||||
|
||||
当缓存命中的数据是预加载的,调用此方法统计
|
||||
|
||||
|
||||
Args:
|
||||
key: 数据键
|
||||
"""
|
||||
@@ -259,7 +260,7 @@ class DataPreloader:
|
||||
if self._preload_count > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"total_accesses": self._total_accesses,
|
||||
"tracked_patterns": len(self._patterns),
|
||||
@@ -278,7 +279,7 @@ class DataPreloader:
|
||||
self._total_accesses = 0
|
||||
self._preload_count = 0
|
||||
self._preload_hits = 0
|
||||
|
||||
|
||||
# 取消所有预加载任务
|
||||
for task in self._preload_tasks:
|
||||
task.cancel()
|
||||
@@ -286,38 +287,38 @@ class DataPreloader:
|
||||
|
||||
def _calculate_score(self, pattern: AccessPattern) -> float:
|
||||
"""计算热度评分
|
||||
|
||||
|
||||
使用时间衰减的访问频率:
|
||||
score = access_count * decay_factor^(time_since_last_access)
|
||||
|
||||
|
||||
Args:
|
||||
pattern: 访问模式
|
||||
|
||||
|
||||
Returns:
|
||||
热度评分
|
||||
"""
|
||||
now = time.time()
|
||||
time_diff = now - pattern.last_access
|
||||
|
||||
|
||||
# 时间衰减(以小时为单位)
|
||||
hours_passed = time_diff / 3600
|
||||
decay = self.decay_factor ** hours_passed
|
||||
|
||||
|
||||
# 评分 = 访问次数 * 时间衰减
|
||||
score = pattern.access_count * decay
|
||||
|
||||
|
||||
return score
|
||||
|
||||
|
||||
class CommonDataPreloader:
|
||||
"""常见数据预加载器
|
||||
|
||||
|
||||
针对特定的数据类型提供预加载策略
|
||||
"""
|
||||
|
||||
def __init__(self, preloader: DataPreloader):
|
||||
"""初始化
|
||||
|
||||
|
||||
Args:
|
||||
preloader: 基础预加载器
|
||||
"""
|
||||
@@ -330,16 +331,16 @@ class CommonDataPreloader:
|
||||
platform: str,
|
||||
) -> None:
|
||||
"""预加载用户相关数据
|
||||
|
||||
|
||||
包括:个人信息、权限、关系等
|
||||
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
user_id: 用户ID
|
||||
platform: 平台
|
||||
"""
|
||||
from src.common.database.core.models import PersonInfo, UserPermissions, UserRelationships
|
||||
|
||||
|
||||
# 预加载个人信息
|
||||
await self._preload_model(
|
||||
session,
|
||||
@@ -347,7 +348,7 @@ class CommonDataPreloader:
|
||||
PersonInfo,
|
||||
{"platform": platform, "user_id": user_id},
|
||||
)
|
||||
|
||||
|
||||
# 预加载用户权限
|
||||
await self._preload_model(
|
||||
session,
|
||||
@@ -355,7 +356,7 @@ class CommonDataPreloader:
|
||||
UserPermissions,
|
||||
{"platform": platform, "user_id": user_id},
|
||||
)
|
||||
|
||||
|
||||
# 预加载用户关系
|
||||
await self._preload_model(
|
||||
session,
|
||||
@@ -371,16 +372,16 @@ class CommonDataPreloader:
|
||||
limit: int = 50,
|
||||
) -> None:
|
||||
"""预加载聊天上下文
|
||||
|
||||
|
||||
包括:最近消息、聊天流信息等
|
||||
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
stream_id: 聊天流ID
|
||||
limit: 消息数量限制
|
||||
"""
|
||||
from src.common.database.core.models import ChatStreams, Messages
|
||||
|
||||
from src.common.database.core.models import ChatStreams
|
||||
|
||||
# 预加载聊天流信息
|
||||
await self._preload_model(
|
||||
session,
|
||||
@@ -388,7 +389,7 @@ class CommonDataPreloader:
|
||||
ChatStreams,
|
||||
{"stream_id": stream_id},
|
||||
)
|
||||
|
||||
|
||||
# 预加载最近消息(这个比较复杂,暂时跳过)
|
||||
# TODO: 实现消息列表的预加载
|
||||
|
||||
@@ -400,7 +401,7 @@ class CommonDataPreloader:
|
||||
filters: dict[str, Any],
|
||||
) -> None:
|
||||
"""预加载模型数据
|
||||
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
cache_key: 缓存键
|
||||
@@ -413,31 +414,31 @@ class CommonDataPreloader:
|
||||
stmt = stmt.where(getattr(model_class, key) == value)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
await self.preloader.preload_data(cache_key, loader)
|
||||
|
||||
|
||||
# 全局预加载器实例
|
||||
_global_preloader: Optional[DataPreloader] = None
|
||||
_global_preloader: DataPreloader | None = None
|
||||
_preloader_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_preloader() -> DataPreloader:
|
||||
"""获取全局预加载器实例(单例)"""
|
||||
global _global_preloader
|
||||
|
||||
|
||||
if _global_preloader is None:
|
||||
async with _preloader_lock:
|
||||
if _global_preloader is None:
|
||||
_global_preloader = DataPreloader()
|
||||
|
||||
|
||||
return _global_preloader
|
||||
|
||||
|
||||
async def close_preloader() -> None:
|
||||
"""关闭全局预加载器"""
|
||||
global _global_preloader
|
||||
|
||||
|
||||
if _global_preloader is not None:
|
||||
await _global_preloader.clear()
|
||||
_global_preloader = None
|
||||
|
||||
Reference in New Issue
Block a user