rufffffff

This commit is contained in:
明天好像没什么
2025-11-01 21:10:01 +08:00
parent 08a9a2c2e8
commit cb97b2d8d3
50 changed files with 742 additions and 759 deletions

View File

@@ -10,8 +10,9 @@
import asyncio
import time
from collections import defaultdict
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Optional
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -25,7 +26,7 @@ logger = get_logger("preloader")
@dataclass
class AccessPattern:
"""访问模式统计
Attributes:
key: 数据键
access_count: 访问次数
@@ -42,7 +43,7 @@ class AccessPattern:
class DataPreloader:
"""数据预加载器
通过分析访问模式,预测并预加载可能需要的数据
"""
@@ -53,7 +54,7 @@ class DataPreloader:
max_patterns: int = 1000,
):
"""初始化预加载器
Args:
decay_factor: 时间衰减因子0-1越小衰减越快
preload_threshold: 预加载阈值score超过此值时预加载
@@ -62,7 +63,7 @@ class DataPreloader:
self.decay_factor = decay_factor
self.preload_threshold = preload_threshold
self.max_patterns = max_patterns
# 访问模式跟踪
self._patterns: dict[str, AccessPattern] = {}
# 关联关系key -> [related_keys]
@@ -73,9 +74,9 @@ class DataPreloader:
self._total_accesses = 0
self._preload_count = 0
self._preload_hits = 0
self._lock = asyncio.Lock()
logger.info(
f"数据预加载器初始化: 衰减因子={decay_factor}, "
f"预加载阈值={preload_threshold}"
@@ -84,10 +85,10 @@ class DataPreloader:
async def record_access(
self,
key: str,
related_keys: Optional[list[str]] = None,
related_keys: list[str] | None = None,
) -> None:
"""记录数据访问
Args:
key: 被访问的数据键
related_keys: 关联访问的数据键列表
@@ -95,7 +96,7 @@ class DataPreloader:
async with self._lock:
self._total_accesses += 1
now = time.time()
# 更新或创建访问模式
if key in self._patterns:
pattern = self._patterns[key]
@@ -108,15 +109,15 @@ class DataPreloader:
last_access=now,
)
self._patterns[key] = pattern
# 更新热度评分(时间衰减)
pattern.score = self._calculate_score(pattern)
# 记录关联关系
if related_keys:
self._associations[key].update(related_keys)
pattern.related_keys = list(self._associations[key])
# 如果模式过多,删除评分最低的
if len(self._patterns) > self.max_patterns:
min_key = min(self._patterns, key=lambda k: self._patterns[k].score)
@@ -126,10 +127,10 @@ class DataPreloader:
async def should_preload(self, key: str) -> bool:
"""判断是否应该预加载某个数据
Args:
key: 数据键
Returns:
是否应该预加载
"""
@@ -137,18 +138,18 @@ class DataPreloader:
pattern = self._patterns.get(key)
if pattern is None:
return False
# 更新评分
pattern.score = self._calculate_score(pattern)
return pattern.score >= self.preload_threshold
async def get_preload_keys(self, limit: int = 100) -> list[str]:
"""获取应该预加载的数据键列表
Args:
limit: 最大返回数量
Returns:
按评分排序的数据键列表
"""
@@ -156,14 +157,14 @@ class DataPreloader:
# 更新所有评分
for pattern in self._patterns.values():
pattern.score = self._calculate_score(pattern)
# 按评分排序
sorted_patterns = sorted(
self._patterns.values(),
key=lambda p: p.score,
reverse=True,
)
# 返回超过阈值的键
return [
p.key for p in sorted_patterns[:limit]
@@ -172,10 +173,10 @@ class DataPreloader:
async def get_related_keys(self, key: str) -> list[str]:
"""获取关联数据键
Args:
key: 数据键
Returns:
关联数据键列表
"""
@@ -188,27 +189,27 @@ class DataPreloader:
loader: Callable[[], Awaitable[Any]],
) -> None:
"""预加载数据
Args:
key: 数据键
loader: 异步加载函数
"""
try:
cache = await get_cache()
# 检查缓存中是否已存在
if await cache.l1_cache.get(key) is not None:
return
# 加载数据
logger.debug(f"预加载数据: {key}")
data = await loader()
if data is not None:
# 写入缓存
await cache.set(key, data)
self._preload_count += 1
# 预加载关联数据
related_keys = await self.get_related_keys(key)
for related_key in related_keys[:5]: # 最多预加载5个关联项
@@ -216,7 +217,7 @@ class DataPreloader:
# 这里需要调用者提供关联数据的加载函数
# 暂时只记录,不实际加载
logger.debug(f"发现关联数据: {related_key}")
except Exception as e:
logger.error(f"预加载数据失败 {key}: {e}", exc_info=True)
@@ -226,13 +227,13 @@ class DataPreloader:
loaders: dict[str, Callable[[], Awaitable[Any]]],
) -> None:
"""批量启动预加载任务
Args:
session: 数据库会话
loaders: 数据键到加载函数的映射
"""
preload_keys = await self.get_preload_keys()
for key in preload_keys:
if key in loaders:
loader = loaders[key]
@@ -242,9 +243,9 @@ class DataPreloader:
async def record_hit(self, key: str) -> None:
"""记录预加载命中
当缓存命中的数据是预加载的,调用此方法统计
Args:
key: 数据键
"""
@@ -259,7 +260,7 @@ class DataPreloader:
if self._preload_count > 0
else 0.0
)
return {
"total_accesses": self._total_accesses,
"tracked_patterns": len(self._patterns),
@@ -278,7 +279,7 @@ class DataPreloader:
self._total_accesses = 0
self._preload_count = 0
self._preload_hits = 0
# 取消所有预加载任务
for task in self._preload_tasks:
task.cancel()
@@ -286,38 +287,38 @@ class DataPreloader:
def _calculate_score(self, pattern: AccessPattern) -> float:
"""计算热度评分
使用时间衰减的访问频率:
score = access_count * decay_factor^(time_since_last_access)
Args:
pattern: 访问模式
Returns:
热度评分
"""
now = time.time()
time_diff = now - pattern.last_access
# 时间衰减(以小时为单位)
hours_passed = time_diff / 3600
decay = self.decay_factor ** hours_passed
# 评分 = 访问次数 * 时间衰减
score = pattern.access_count * decay
return score
class CommonDataPreloader:
"""常见数据预加载器
针对特定的数据类型提供预加载策略
"""
def __init__(self, preloader: DataPreloader):
"""初始化
Args:
preloader: 基础预加载器
"""
@@ -330,16 +331,16 @@ class CommonDataPreloader:
platform: str,
) -> None:
"""预加载用户相关数据
包括:个人信息、权限、关系等
Args:
session: 数据库会话
user_id: 用户ID
platform: 平台
"""
from src.common.database.core.models import PersonInfo, UserPermissions, UserRelationships
# 预加载个人信息
await self._preload_model(
session,
@@ -347,7 +348,7 @@ class CommonDataPreloader:
PersonInfo,
{"platform": platform, "user_id": user_id},
)
# 预加载用户权限
await self._preload_model(
session,
@@ -355,7 +356,7 @@ class CommonDataPreloader:
UserPermissions,
{"platform": platform, "user_id": user_id},
)
# 预加载用户关系
await self._preload_model(
session,
@@ -371,16 +372,16 @@ class CommonDataPreloader:
limit: int = 50,
) -> None:
"""预加载聊天上下文
包括:最近消息、聊天流信息等
Args:
session: 数据库会话
stream_id: 聊天流ID
limit: 消息数量限制
"""
from src.common.database.core.models import ChatStreams, Messages
from src.common.database.core.models import ChatStreams
# 预加载聊天流信息
await self._preload_model(
session,
@@ -388,7 +389,7 @@ class CommonDataPreloader:
ChatStreams,
{"stream_id": stream_id},
)
# 预加载最近消息(这个比较复杂,暂时跳过)
# TODO: 实现消息列表的预加载
@@ -400,7 +401,7 @@ class CommonDataPreloader:
filters: dict[str, Any],
) -> None:
"""预加载模型数据
Args:
session: 数据库会话
cache_key: 缓存键
@@ -413,31 +414,31 @@ class CommonDataPreloader:
stmt = stmt.where(getattr(model_class, key) == value)
result = await session.execute(stmt)
return result.scalar_one_or_none()
await self.preloader.preload_data(cache_key, loader)
# 全局预加载器实例
_global_preloader: Optional[DataPreloader] = None
_global_preloader: DataPreloader | None = None
_preloader_lock = asyncio.Lock()
async def get_preloader() -> DataPreloader:
"""获取全局预加载器实例(单例)"""
global _global_preloader
if _global_preloader is None:
async with _preloader_lock:
if _global_preloader is None:
_global_preloader = DataPreloader()
return _global_preloader
async def close_preloader() -> None:
"""关闭全局预加载器"""
global _global_preloader
if _global_preloader is not None:
await _global_preloader.clear()
_global_preloader = None