rufffffff

This commit is contained in:
明天好像没什么
2025-11-01 21:10:01 +08:00
committed by Windpicker-owo
parent 05daf869d1
commit ff6dc542e1
50 changed files with 742 additions and 759 deletions

View File

@@ -11,17 +11,17 @@ from .batch_scheduler import (
AdaptiveBatchScheduler,
BatchOperation,
BatchStats,
Priority,
close_batch_scheduler,
get_batch_scheduler,
Priority,
)
from .cache_manager import (
CacheEntry,
CacheStats,
close_cache,
get_cache,
LRUCache,
MultiLevelCache,
close_cache,
get_cache,
)
from .connection_pool import (
ConnectionPoolManager,
@@ -31,36 +31,36 @@ from .connection_pool import (
)
from .preloader import (
AccessPattern,
close_preloader,
CommonDataPreloader,
DataPreloader,
close_preloader,
get_preloader,
)
__all__ = [
# Connection Pool
"ConnectionPoolManager",
"get_connection_pool_manager",
"start_connection_pool",
"stop_connection_pool",
# Cache
"MultiLevelCache",
"LRUCache",
"CacheEntry",
"CacheStats",
"get_cache",
"close_cache",
# Preloader
"DataPreloader",
"CommonDataPreloader",
"AccessPattern",
"get_preloader",
"close_preloader",
# Batch Scheduler
"AdaptiveBatchScheduler",
"BatchOperation",
"BatchStats",
"CacheEntry",
"CacheStats",
"CommonDataPreloader",
# Connection Pool
"ConnectionPoolManager",
# Preloader
"DataPreloader",
"LRUCache",
# Cache
"MultiLevelCache",
"Priority",
"get_batch_scheduler",
"close_batch_scheduler",
"close_cache",
"close_preloader",
"get_batch_scheduler",
"get_cache",
"get_connection_pool_manager",
"get_preloader",
"start_connection_pool",
"stop_connection_pool",
]

View File

@@ -10,12 +10,12 @@
import asyncio
import time
from collections import defaultdict, deque
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any, Callable, Optional, TypeVar
from typing import Any, TypeVar
from sqlalchemy import delete, insert, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from src.common.database.core.session import get_db_session
from src.common.logger import get_logger
@@ -36,22 +36,22 @@ class Priority(IntEnum):
@dataclass
class BatchOperation:
"""批量操作"""
operation_type: str # 'select', 'insert', 'update', 'delete'
model_class: type
conditions: dict[str, Any] = field(default_factory=dict)
data: Optional[dict[str, Any]] = None
callback: Optional[Callable] = None
future: Optional[asyncio.Future] = None
data: dict[str, Any] | None = None
callback: Callable | None = None
future: asyncio.Future | None = None
timestamp: float = field(default_factory=time.time)
priority: Priority = Priority.NORMAL
timeout: Optional[float] = None # 超时时间(秒)
timeout: float | None = None # 超时时间(秒)
@dataclass
class BatchStats:
"""批处理统计"""
total_operations: int = 0
batched_operations: int = 0
cache_hits: int = 0
@@ -60,7 +60,7 @@ class BatchStats:
avg_wait_time: float = 0.0
timeout_count: int = 0
error_count: int = 0
# 自适应统计
last_batch_duration: float = 0.0
last_batch_size: int = 0
@@ -69,7 +69,7 @@ class BatchStats:
class AdaptiveBatchScheduler:
"""自适应批量调度器
特性:
- 动态批次大小:根据负载自动调整
- 优先级队列:高优先级操作优先执行
@@ -87,7 +87,7 @@ class AdaptiveBatchScheduler:
cache_ttl: float = 5.0,
):
"""初始化调度器
Args:
min_batch_size: 最小批次大小
max_batch_size: 最大批次大小
@@ -104,23 +104,23 @@ class AdaptiveBatchScheduler:
self.current_wait_time = base_wait_time
self.max_queue_size = max_queue_size
self.cache_ttl = cache_ttl
# 操作队列,按优先级分类
self.operation_queues: dict[Priority, deque[BatchOperation]] = {
priority: deque() for priority in Priority
}
# 调度控制
self._scheduler_task: Optional[asyncio.Task] = None
self._scheduler_task: asyncio.Task | None = None
self._is_running = False
self._lock = asyncio.Lock()
# 统计信息
self.stats = BatchStats()
# 简单的结果缓存
self._result_cache: dict[str, tuple[Any, float]] = {}
logger.info(
f"自适应批量调度器初始化: "
f"批次大小{min_batch_size}-{max_batch_size}, "
@@ -132,7 +132,7 @@ class AdaptiveBatchScheduler:
if self._is_running:
logger.warning("调度器已在运行")
return
self._is_running = True
self._scheduler_task = asyncio.create_task(self._scheduler_loop())
logger.info("批量调度器已启动")
@@ -141,16 +141,16 @@ class AdaptiveBatchScheduler:
"""停止调度器"""
if not self._is_running:
return
self._is_running = False
if self._scheduler_task:
self._scheduler_task.cancel()
try:
await self._scheduler_task
except asyncio.CancelledError:
pass
# 处理剩余操作
await self._flush_all_queues()
logger.info("批量调度器已停止")
@@ -160,10 +160,10 @@ class AdaptiveBatchScheduler:
operation: BatchOperation,
) -> asyncio.Future:
"""添加操作到队列
Args:
operation: 批量操作
Returns:
Future对象可用于获取结果
"""
@@ -175,11 +175,11 @@ class AdaptiveBatchScheduler:
future = asyncio.get_event_loop().create_future()
future.set_result(cached_result)
return future
# 创建future
future = asyncio.get_event_loop().create_future()
operation.future = future
async with self._lock:
# 检查队列是否已满
total_queued = sum(len(q) for q in self.operation_queues.values())
@@ -191,7 +191,7 @@ class AdaptiveBatchScheduler:
# 添加到优先级队列
self.operation_queues[operation.priority].append(operation)
self.stats.total_operations += 1
return future
async def _scheduler_loop(self) -> None:
@@ -217,10 +217,10 @@ class AdaptiveBatchScheduler:
for _ in range(count):
if queue:
operations.append(queue.popleft())
if not operations:
return
# 执行批量操作
await self._execute_operations(operations)
@@ -231,10 +231,10 @@ class AdaptiveBatchScheduler:
"""执行批量操作"""
if not operations:
return
start_time = time.time()
batch_size = len(operations)
try:
# 检查超时
valid_operations = []
@@ -246,41 +246,41 @@ class AdaptiveBatchScheduler:
self.stats.timeout_count += 1
else:
valid_operations.append(op)
if not valid_operations:
return
# 按操作类型分组
op_groups = defaultdict(list)
for op in valid_operations:
key = f"{op.operation_type}_{op.model_class.__name__}"
op_groups[key].append(op)
# 执行各组操作
for group_key, ops in op_groups.items():
for ops in op_groups.values():
await self._execute_group(ops)
# 更新统计
duration = time.time() - start_time
self.stats.batched_operations += batch_size
self.stats.total_execution_time += duration
self.stats.last_batch_duration = duration
self.stats.last_batch_size = batch_size
if self.stats.batched_operations > 0:
self.stats.avg_batch_size = (
self.stats.batched_operations /
self.stats.batched_operations /
(self.stats.total_execution_time / duration)
)
logger.debug(
f"批量执行完成: {batch_size}个操作, 耗时{duration*1000:.2f}ms"
)
except Exception as e:
logger.error(f"批量操作执行失败: {e}", exc_info=True)
self.stats.error_count += 1
# 设置所有future的异常
for op in operations:
if op.future and not op.future.done():
@@ -290,9 +290,9 @@ class AdaptiveBatchScheduler:
"""执行同类操作组"""
if not operations:
return
op_type = operations[0].operation_type
try:
if op_type == "select":
await self._execute_select_batch(operations)
@@ -304,7 +304,7 @@ class AdaptiveBatchScheduler:
await self._execute_delete_batch(operations)
else:
raise ValueError(f"未知操作类型: {op_type}")
except Exception as e:
logger.error(f"执行{op_type}操作组失败: {e}", exc_info=True)
for op in operations:
@@ -323,30 +323,30 @@ class AdaptiveBatchScheduler:
stmt = select(op.model_class)
for key, value in op.conditions.items():
attr = getattr(op.model_class, key)
if isinstance(value, (list, tuple, set)):
if isinstance(value, list | tuple | set):
stmt = stmt.where(attr.in_(value))
else:
stmt = stmt.where(attr == value)
# 执行查询
result = await session.execute(stmt)
data = result.scalars().all()
# 设置结果
if op.future and not op.future.done():
op.future.set_result(data)
# 缓存结果
cache_key = self._generate_cache_key(op)
self._set_cache(cache_key, data)
# 执行回调
if op.callback:
try:
op.callback(data)
except Exception as e:
logger.warning(f"回调执行失败: {e}")
except Exception as e:
logger.error(f"查询失败: {e}", exc_info=True)
if op.future and not op.future.done():
@@ -363,23 +363,23 @@ class AdaptiveBatchScheduler:
all_data = [op.data for op in operations if op.data]
if not all_data:
return
# 批量插入
stmt = insert(operations[0].model_class).values(all_data)
result = await session.execute(stmt)
await session.execute(stmt)
await session.commit()
# 设置结果
for op in operations:
if op.future and not op.future.done():
op.future.set_result(True)
if op.callback:
try:
op.callback(True)
except Exception as e:
logger.warning(f"回调执行失败: {e}")
except Exception as e:
logger.error(f"批量插入失败: {e}", exc_info=True)
await session.rollback()
@@ -402,28 +402,28 @@ class AdaptiveBatchScheduler:
for key, value in op.conditions.items():
attr = getattr(op.model_class, key)
stmt = stmt.where(attr == value)
if op.data:
stmt = stmt.values(**op.data)
# 执行更新但不commit
result = await session.execute(stmt)
results.append((op, result.rowcount))
# 所有操作成功后一次性commit
await session.commit()
# 设置所有操作的结果
for op, rowcount in results:
if op.future and not op.future.done():
op.future.set_result(rowcount)
if op.callback:
try:
op.callback(rowcount)
except Exception as e:
logger.warning(f"回调执行失败: {e}")
except Exception as e:
logger.error(f"批量更新失败: {e}", exc_info=True)
await session.rollback()
@@ -447,25 +447,25 @@ class AdaptiveBatchScheduler:
for key, value in op.conditions.items():
attr = getattr(op.model_class, key)
stmt = stmt.where(attr == value)
# 执行删除但不commit
result = await session.execute(stmt)
results.append((op, result.rowcount))
# 所有操作成功后一次性commit
await session.commit()
# 设置所有操作的结果
for op, rowcount in results:
if op.future and not op.future.done():
op.future.set_result(rowcount)
if op.callback:
try:
op.callback(rowcount)
except Exception as e:
logger.warning(f"回调执行失败: {e}")
except Exception as e:
logger.error(f"批量删除失败: {e}", exc_info=True)
await session.rollback()
@@ -479,7 +479,7 @@ class AdaptiveBatchScheduler:
# 计算拥塞评分
total_queued = sum(len(q) for q in self.operation_queues.values())
self.stats.congestion_score = min(1.0, total_queued / self.max_queue_size)
# 根据拥塞情况调整批次大小
if self.stats.congestion_score > 0.7:
# 高拥塞,增加批次大小
@@ -493,7 +493,7 @@ class AdaptiveBatchScheduler:
self.min_batch_size,
int(self.current_batch_size * 0.9),
)
# 根据批次执行时间调整等待时间
if self.stats.last_batch_duration > 0:
if self.stats.last_batch_duration > self.current_wait_time * 2:
@@ -518,7 +518,7 @@ class AdaptiveBatchScheduler:
]
return "|".join(key_parts)
def _get_from_cache(self, cache_key: str) -> Optional[Any]:
def _get_from_cache(self, cache_key: str) -> Any | None:
"""从缓存获取结果"""
if cache_key in self._result_cache:
result, timestamp = self._result_cache[cache_key]
@@ -551,27 +551,27 @@ class AdaptiveBatchScheduler:
# 全局调度器实例
_global_scheduler: Optional[AdaptiveBatchScheduler] = None
_global_scheduler: AdaptiveBatchScheduler | None = None
_scheduler_lock = asyncio.Lock()
async def get_batch_scheduler() -> AdaptiveBatchScheduler:
"""获取全局批量调度器(单例)"""
global _global_scheduler
if _global_scheduler is None:
async with _scheduler_lock:
if _global_scheduler is None:
_global_scheduler = AdaptiveBatchScheduler()
await _global_scheduler.start()
return _global_scheduler
async def close_batch_scheduler() -> None:
"""关闭全局批量调度器"""
global _global_scheduler
if _global_scheduler is not None:
await _global_scheduler.stop()
_global_scheduler = None

View File

@@ -11,8 +11,9 @@
import asyncio
import time
from collections import OrderedDict
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Callable, Generic, Optional, TypeVar
from typing import Any, Generic, TypeVar
from src.common.logger import get_logger
@@ -24,7 +25,7 @@ T = TypeVar("T")
@dataclass
class CacheEntry(Generic[T]):
"""缓存条目
Attributes:
value: 缓存的值
created_at: 创建时间戳
@@ -42,7 +43,7 @@ class CacheEntry(Generic[T]):
@dataclass
class CacheStats:
"""缓存统计信息
Attributes:
hits: 命中次数
misses: 未命中次数
@@ -70,7 +71,7 @@ class CacheStats:
class LRUCache(Generic[T]):
"""LRU缓存实现
使用OrderedDict实现O(1)的get/set操作
"""
@@ -81,7 +82,7 @@ class LRUCache(Generic[T]):
name: str = "cache",
):
"""初始化LRU缓存
Args:
max_size: 最大缓存条目数
ttl: 过期时间(秒)
@@ -94,18 +95,18 @@ class LRUCache(Generic[T]):
self._lock = asyncio.Lock()
self._stats = CacheStats()
async def get(self, key: str) -> Optional[T]:
async def get(self, key: str) -> T | None:
"""获取缓存值
Args:
key: 缓存键
Returns:
缓存值如果不存在或已过期返回None
"""
async with self._lock:
entry = self._cache.get(key)
if entry is None:
self._stats.misses += 1
return None
@@ -125,20 +126,20 @@ class LRUCache(Generic[T]):
entry.last_accessed = now
entry.access_count += 1
self._stats.hits += 1
# 移到末尾(最近使用)
self._cache.move_to_end(key)
return entry.value
async def set(
self,
key: str,
value: T,
size: Optional[int] = None,
size: int | None = None,
) -> None:
"""设置缓存值
Args:
key: 缓存键
value: 缓存值
@@ -146,16 +147,16 @@ class LRUCache(Generic[T]):
"""
async with self._lock:
now = time.time()
# 如果键已存在,更新值
if key in self._cache:
old_entry = self._cache[key]
self._stats.total_size -= old_entry.size
# 估算大小
if size is None:
size = self._estimate_size(value)
# 创建新条目
entry = CacheEntry(
value=value,
@@ -164,7 +165,7 @@ class LRUCache(Generic[T]):
access_count=0,
size=size,
)
# 如果缓存已满,淘汰最久未使用的条目
while len(self._cache) >= self.max_size:
oldest_key, oldest_entry = self._cache.popitem(last=False)
@@ -175,7 +176,7 @@ class LRUCache(Generic[T]):
f"[{self.name}] 淘汰缓存条目: {oldest_key} "
f"(访问{oldest_entry.access_count}次)"
)
# 添加新条目
self._cache[key] = entry
self._stats.item_count += 1
@@ -183,10 +184,10 @@ class LRUCache(Generic[T]):
async def delete(self, key: str) -> bool:
"""删除缓存条目
Args:
key: 缓存键
Returns:
是否成功删除
"""
@@ -217,7 +218,7 @@ class LRUCache(Generic[T]):
def _estimate_size(self, value: Any) -> int:
"""估算数据大小(字节)
这是一个简单的估算,实际大小可能不同
"""
import sys
@@ -230,11 +231,11 @@ class LRUCache(Generic[T]):
class MultiLevelCache:
"""多级缓存管理器
实现两级缓存架构:
- L1: 高速缓存小容量短TTL
- L2: 扩展缓存大容量长TTL
查询时先查L1未命中再查L2未命中再从数据源加载
"""
@@ -246,7 +247,7 @@ class MultiLevelCache:
l2_ttl: float = 300,
):
"""初始化多级缓存
Args:
l1_max_size: L1缓存最大条目数
l1_ttl: L1缓存TTL
@@ -255,8 +256,8 @@ class MultiLevelCache:
"""
self.l1_cache: LRUCache[Any] = LRUCache(l1_max_size, l1_ttl, "L1")
self.l2_cache: LRUCache[Any] = LRUCache(l2_max_size, l2_ttl, "L2")
self._cleanup_task: Optional[asyncio.Task] = None
self._cleanup_task: asyncio.Task | None = None
logger.info(
f"多级缓存初始化: L1({l1_max_size}项/{l1_ttl}s) "
f"L2({l2_max_size}项/{l2_ttl}s)"
@@ -265,16 +266,16 @@ class MultiLevelCache:
async def get(
self,
key: str,
loader: Optional[Callable[[], Any]] = None,
) -> Optional[Any]:
loader: Callable[[], Any] | None = None,
) -> Any | None:
"""从缓存获取数据
查询顺序L1 -> L2 -> loader
Args:
key: 缓存键
loader: 数据加载函数,当缓存未命中时调用
Returns:
缓存值或加载的值如果都不存在返回None
"""
@@ -307,12 +308,12 @@ class MultiLevelCache:
self,
key: str,
value: Any,
size: Optional[int] = None,
size: int | None = None,
) -> None:
"""设置缓存值
同时写入L1和L2
Args:
key: 缓存键
value: 缓存值
@@ -323,9 +324,9 @@ class MultiLevelCache:
async def delete(self, key: str) -> None:
"""删除缓存条目
同时从L1和L2删除
Args:
key: 缓存键
"""
@@ -347,7 +348,7 @@ class MultiLevelCache:
async def start_cleanup_task(self, interval: float = 60) -> None:
"""启动定期清理任务
Args:
interval: 清理间隔(秒)
"""
@@ -387,27 +388,27 @@ class MultiLevelCache:
# 全局缓存实例
_global_cache: Optional[MultiLevelCache] = None
_global_cache: MultiLevelCache | None = None
_cache_lock = asyncio.Lock()
async def get_cache() -> MultiLevelCache:
"""获取全局缓存实例(单例)"""
global _global_cache
if _global_cache is None:
async with _cache_lock:
if _global_cache is None:
_global_cache = MultiLevelCache()
await _global_cache.start_cleanup_task()
return _global_cache
async def close_cache() -> None:
"""关闭全局缓存"""
global _global_cache
if _global_cache is not None:
await _global_cache.stop_cleanup_task()
await _global_cache.clear()

View File

@@ -150,7 +150,7 @@ class ConnectionPoolManager:
logger.debug(f"🆕 创建连接 (池大小: {len(self._connections)})")
yield connection_info.session
# 🔧 修复:正常退出时提交事务
# 这对SQLite至关重要因为SQLite没有autocommit
if connection_info and connection_info.session:
@@ -249,7 +249,7 @@ class ConnectionPoolManager:
"""获取连接池统计信息"""
total_requests = self._stats["pool_hits"] + self._stats["pool_misses"]
pool_efficiency = (self._stats["pool_hits"] / max(1, total_requests)) * 100 if total_requests > 0 else 0
return {
**self._stats,
"active_connections": len(self._connections),

View File

@@ -10,8 +10,9 @@
import asyncio
import time
from collections import defaultdict
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Optional
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -25,7 +26,7 @@ logger = get_logger("preloader")
@dataclass
class AccessPattern:
"""访问模式统计
Attributes:
key: 数据键
access_count: 访问次数
@@ -42,7 +43,7 @@ class AccessPattern:
class DataPreloader:
"""数据预加载器
通过分析访问模式,预测并预加载可能需要的数据
"""
@@ -53,7 +54,7 @@ class DataPreloader:
max_patterns: int = 1000,
):
"""初始化预加载器
Args:
decay_factor: 时间衰减因子0-1越小衰减越快
preload_threshold: 预加载阈值score超过此值时预加载
@@ -62,7 +63,7 @@ class DataPreloader:
self.decay_factor = decay_factor
self.preload_threshold = preload_threshold
self.max_patterns = max_patterns
# 访问模式跟踪
self._patterns: dict[str, AccessPattern] = {}
# 关联关系key -> [related_keys]
@@ -73,9 +74,9 @@ class DataPreloader:
self._total_accesses = 0
self._preload_count = 0
self._preload_hits = 0
self._lock = asyncio.Lock()
logger.info(
f"数据预加载器初始化: 衰减因子={decay_factor}, "
f"预加载阈值={preload_threshold}"
@@ -84,10 +85,10 @@ class DataPreloader:
async def record_access(
self,
key: str,
related_keys: Optional[list[str]] = None,
related_keys: list[str] | None = None,
) -> None:
"""记录数据访问
Args:
key: 被访问的数据键
related_keys: 关联访问的数据键列表
@@ -95,7 +96,7 @@ class DataPreloader:
async with self._lock:
self._total_accesses += 1
now = time.time()
# 更新或创建访问模式
if key in self._patterns:
pattern = self._patterns[key]
@@ -108,15 +109,15 @@ class DataPreloader:
last_access=now,
)
self._patterns[key] = pattern
# 更新热度评分(时间衰减)
pattern.score = self._calculate_score(pattern)
# 记录关联关系
if related_keys:
self._associations[key].update(related_keys)
pattern.related_keys = list(self._associations[key])
# 如果模式过多,删除评分最低的
if len(self._patterns) > self.max_patterns:
min_key = min(self._patterns, key=lambda k: self._patterns[k].score)
@@ -126,10 +127,10 @@ class DataPreloader:
async def should_preload(self, key: str) -> bool:
"""判断是否应该预加载某个数据
Args:
key: 数据键
Returns:
是否应该预加载
"""
@@ -137,18 +138,18 @@ class DataPreloader:
pattern = self._patterns.get(key)
if pattern is None:
return False
# 更新评分
pattern.score = self._calculate_score(pattern)
return pattern.score >= self.preload_threshold
async def get_preload_keys(self, limit: int = 100) -> list[str]:
"""获取应该预加载的数据键列表
Args:
limit: 最大返回数量
Returns:
按评分排序的数据键列表
"""
@@ -156,14 +157,14 @@ class DataPreloader:
# 更新所有评分
for pattern in self._patterns.values():
pattern.score = self._calculate_score(pattern)
# 按评分排序
sorted_patterns = sorted(
self._patterns.values(),
key=lambda p: p.score,
reverse=True,
)
# 返回超过阈值的键
return [
p.key for p in sorted_patterns[:limit]
@@ -172,10 +173,10 @@ class DataPreloader:
async def get_related_keys(self, key: str) -> list[str]:
"""获取关联数据键
Args:
key: 数据键
Returns:
关联数据键列表
"""
@@ -188,27 +189,27 @@ class DataPreloader:
loader: Callable[[], Awaitable[Any]],
) -> None:
"""预加载数据
Args:
key: 数据键
loader: 异步加载函数
"""
try:
cache = await get_cache()
# 检查缓存中是否已存在
if await cache.l1_cache.get(key) is not None:
return
# 加载数据
logger.debug(f"预加载数据: {key}")
data = await loader()
if data is not None:
# 写入缓存
await cache.set(key, data)
self._preload_count += 1
# 预加载关联数据
related_keys = await self.get_related_keys(key)
for related_key in related_keys[:5]: # 最多预加载5个关联项
@@ -216,7 +217,7 @@ class DataPreloader:
# 这里需要调用者提供关联数据的加载函数
# 暂时只记录,不实际加载
logger.debug(f"发现关联数据: {related_key}")
except Exception as e:
logger.error(f"预加载数据失败 {key}: {e}", exc_info=True)
@@ -226,13 +227,13 @@ class DataPreloader:
loaders: dict[str, Callable[[], Awaitable[Any]]],
) -> None:
"""批量启动预加载任务
Args:
session: 数据库会话
loaders: 数据键到加载函数的映射
"""
preload_keys = await self.get_preload_keys()
for key in preload_keys:
if key in loaders:
loader = loaders[key]
@@ -242,9 +243,9 @@ class DataPreloader:
async def record_hit(self, key: str) -> None:
"""记录预加载命中
当缓存命中的数据是预加载的,调用此方法统计
Args:
key: 数据键
"""
@@ -259,7 +260,7 @@ class DataPreloader:
if self._preload_count > 0
else 0.0
)
return {
"total_accesses": self._total_accesses,
"tracked_patterns": len(self._patterns),
@@ -278,7 +279,7 @@ class DataPreloader:
self._total_accesses = 0
self._preload_count = 0
self._preload_hits = 0
# 取消所有预加载任务
for task in self._preload_tasks:
task.cancel()
@@ -286,38 +287,38 @@ class DataPreloader:
def _calculate_score(self, pattern: AccessPattern) -> float:
"""计算热度评分
使用时间衰减的访问频率:
score = access_count * decay_factor^(time_since_last_access)
Args:
pattern: 访问模式
Returns:
热度评分
"""
now = time.time()
time_diff = now - pattern.last_access
# 时间衰减(以小时为单位)
hours_passed = time_diff / 3600
decay = self.decay_factor ** hours_passed
# 评分 = 访问次数 * 时间衰减
score = pattern.access_count * decay
return score
class CommonDataPreloader:
"""常见数据预加载器
针对特定的数据类型提供预加载策略
"""
def __init__(self, preloader: DataPreloader):
"""初始化
Args:
preloader: 基础预加载器
"""
@@ -330,16 +331,16 @@ class CommonDataPreloader:
platform: str,
) -> None:
"""预加载用户相关数据
包括:个人信息、权限、关系等
Args:
session: 数据库会话
user_id: 用户ID
platform: 平台
"""
from src.common.database.core.models import PersonInfo, UserPermissions, UserRelationships
# 预加载个人信息
await self._preload_model(
session,
@@ -347,7 +348,7 @@ class CommonDataPreloader:
PersonInfo,
{"platform": platform, "user_id": user_id},
)
# 预加载用户权限
await self._preload_model(
session,
@@ -355,7 +356,7 @@ class CommonDataPreloader:
UserPermissions,
{"platform": platform, "user_id": user_id},
)
# 预加载用户关系
await self._preload_model(
session,
@@ -371,16 +372,16 @@ class CommonDataPreloader:
limit: int = 50,
) -> None:
"""预加载聊天上下文
包括:最近消息、聊天流信息等
Args:
session: 数据库会话
stream_id: 聊天流ID
limit: 消息数量限制
"""
from src.common.database.core.models import ChatStreams, Messages
from src.common.database.core.models import ChatStreams
# 预加载聊天流信息
await self._preload_model(
session,
@@ -388,7 +389,7 @@ class CommonDataPreloader:
ChatStreams,
{"stream_id": stream_id},
)
# 预加载最近消息(这个比较复杂,暂时跳过)
# TODO: 实现消息列表的预加载
@@ -400,7 +401,7 @@ class CommonDataPreloader:
filters: dict[str, Any],
) -> None:
"""预加载模型数据
Args:
session: 数据库会话
cache_key: 缓存键
@@ -413,31 +414,31 @@ class CommonDataPreloader:
stmt = stmt.where(getattr(model_class, key) == value)
result = await session.execute(stmt)
return result.scalar_one_or_none()
await self.preloader.preload_data(cache_key, loader)
# 全局预加载器实例
_global_preloader: Optional[DataPreloader] = None
_global_preloader: DataPreloader | None = None
_preloader_lock = asyncio.Lock()
async def get_preloader() -> DataPreloader:
"""获取全局预加载器实例(单例)"""
global _global_preloader
if _global_preloader is None:
async with _preloader_lock:
if _global_preloader is None:
_global_preloader = DataPreloader()
return _global_preloader
async def close_preloader() -> None:
"""关闭全局预加载器"""
global _global_preloader
if _global_preloader is not None:
await _global_preloader.clear()
_global_preloader = None