rufffffff

This commit is contained in:
明天好像没什么
2025-11-01 21:10:01 +08:00
parent 08a9a2c2e8
commit cb97b2d8d3
50 changed files with 742 additions and 759 deletions

View File

@@ -10,12 +10,12 @@
import asyncio
import time
from collections import defaultdict, deque
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any, Callable, Optional, TypeVar
from typing import Any, TypeVar
from sqlalchemy import delete, insert, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from src.common.database.core.session import get_db_session
from src.common.logger import get_logger
@@ -36,22 +36,22 @@ class Priority(IntEnum):
@dataclass
class BatchOperation:
"""批量操作"""
operation_type: str # 'select', 'insert', 'update', 'delete'
model_class: type
conditions: dict[str, Any] = field(default_factory=dict)
data: Optional[dict[str, Any]] = None
callback: Optional[Callable] = None
future: Optional[asyncio.Future] = None
data: dict[str, Any] | None = None
callback: Callable | None = None
future: asyncio.Future | None = None
timestamp: float = field(default_factory=time.time)
priority: Priority = Priority.NORMAL
timeout: Optional[float] = None # 超时时间(秒)
timeout: float | None = None # 超时时间(秒)
@dataclass
class BatchStats:
"""批处理统计"""
total_operations: int = 0
batched_operations: int = 0
cache_hits: int = 0
@@ -60,7 +60,7 @@ class BatchStats:
avg_wait_time: float = 0.0
timeout_count: int = 0
error_count: int = 0
# 自适应统计
last_batch_duration: float = 0.0
last_batch_size: int = 0
@@ -69,7 +69,7 @@ class BatchStats:
class AdaptiveBatchScheduler:
"""自适应批量调度器
特性:
- 动态批次大小:根据负载自动调整
- 优先级队列:高优先级操作优先执行
@@ -87,7 +87,7 @@ class AdaptiveBatchScheduler:
cache_ttl: float = 5.0,
):
"""初始化调度器
Args:
min_batch_size: 最小批次大小
max_batch_size: 最大批次大小
@@ -104,23 +104,23 @@ class AdaptiveBatchScheduler:
self.current_wait_time = base_wait_time
self.max_queue_size = max_queue_size
self.cache_ttl = cache_ttl
# 操作队列,按优先级分类
self.operation_queues: dict[Priority, deque[BatchOperation]] = {
priority: deque() for priority in Priority
}
# 调度控制
self._scheduler_task: Optional[asyncio.Task] = None
self._scheduler_task: asyncio.Task | None = None
self._is_running = False
self._lock = asyncio.Lock()
# 统计信息
self.stats = BatchStats()
# 简单的结果缓存
self._result_cache: dict[str, tuple[Any, float]] = {}
logger.info(
f"自适应批量调度器初始化: "
f"批次大小{min_batch_size}-{max_batch_size}, "
@@ -132,7 +132,7 @@ class AdaptiveBatchScheduler:
if self._is_running:
logger.warning("调度器已在运行")
return
self._is_running = True
self._scheduler_task = asyncio.create_task(self._scheduler_loop())
logger.info("批量调度器已启动")
@@ -141,16 +141,16 @@ class AdaptiveBatchScheduler:
"""停止调度器"""
if not self._is_running:
return
self._is_running = False
if self._scheduler_task:
self._scheduler_task.cancel()
try:
await self._scheduler_task
except asyncio.CancelledError:
pass
# 处理剩余操作
await self._flush_all_queues()
logger.info("批量调度器已停止")
@@ -160,10 +160,10 @@ class AdaptiveBatchScheduler:
operation: BatchOperation,
) -> asyncio.Future:
"""添加操作到队列
Args:
operation: 批量操作
Returns:
Future对象可用于获取结果
"""
@@ -175,11 +175,11 @@ class AdaptiveBatchScheduler:
future = asyncio.get_event_loop().create_future()
future.set_result(cached_result)
return future
# 创建future
future = asyncio.get_event_loop().create_future()
operation.future = future
async with self._lock:
# 检查队列是否已满
total_queued = sum(len(q) for q in self.operation_queues.values())
@@ -191,7 +191,7 @@ class AdaptiveBatchScheduler:
# 添加到优先级队列
self.operation_queues[operation.priority].append(operation)
self.stats.total_operations += 1
return future
async def _scheduler_loop(self) -> None:
@@ -217,10 +217,10 @@ class AdaptiveBatchScheduler:
for _ in range(count):
if queue:
operations.append(queue.popleft())
if not operations:
return
# 执行批量操作
await self._execute_operations(operations)
@@ -231,10 +231,10 @@ class AdaptiveBatchScheduler:
"""执行批量操作"""
if not operations:
return
start_time = time.time()
batch_size = len(operations)
try:
# 检查超时
valid_operations = []
@@ -246,41 +246,41 @@ class AdaptiveBatchScheduler:
self.stats.timeout_count += 1
else:
valid_operations.append(op)
if not valid_operations:
return
# 按操作类型分组
op_groups = defaultdict(list)
for op in valid_operations:
key = f"{op.operation_type}_{op.model_class.__name__}"
op_groups[key].append(op)
# 执行各组操作
for group_key, ops in op_groups.items():
for ops in op_groups.values():
await self._execute_group(ops)
# 更新统计
duration = time.time() - start_time
self.stats.batched_operations += batch_size
self.stats.total_execution_time += duration
self.stats.last_batch_duration = duration
self.stats.last_batch_size = batch_size
if self.stats.batched_operations > 0:
self.stats.avg_batch_size = (
self.stats.batched_operations /
self.stats.batched_operations /
(self.stats.total_execution_time / duration)
)
logger.debug(
f"批量执行完成: {batch_size}个操作, 耗时{duration*1000:.2f}ms"
)
except Exception as e:
logger.error(f"批量操作执行失败: {e}", exc_info=True)
self.stats.error_count += 1
# 设置所有future的异常
for op in operations:
if op.future and not op.future.done():
@@ -290,9 +290,9 @@ class AdaptiveBatchScheduler:
"""执行同类操作组"""
if not operations:
return
op_type = operations[0].operation_type
try:
if op_type == "select":
await self._execute_select_batch(operations)
@@ -304,7 +304,7 @@ class AdaptiveBatchScheduler:
await self._execute_delete_batch(operations)
else:
raise ValueError(f"未知操作类型: {op_type}")
except Exception as e:
logger.error(f"执行{op_type}操作组失败: {e}", exc_info=True)
for op in operations:
@@ -323,30 +323,30 @@ class AdaptiveBatchScheduler:
stmt = select(op.model_class)
for key, value in op.conditions.items():
attr = getattr(op.model_class, key)
if isinstance(value, (list, tuple, set)):
if isinstance(value, list | tuple | set):
stmt = stmt.where(attr.in_(value))
else:
stmt = stmt.where(attr == value)
# 执行查询
result = await session.execute(stmt)
data = result.scalars().all()
# 设置结果
if op.future and not op.future.done():
op.future.set_result(data)
# 缓存结果
cache_key = self._generate_cache_key(op)
self._set_cache(cache_key, data)
# 执行回调
if op.callback:
try:
op.callback(data)
except Exception as e:
logger.warning(f"回调执行失败: {e}")
except Exception as e:
logger.error(f"查询失败: {e}", exc_info=True)
if op.future and not op.future.done():
@@ -363,23 +363,23 @@ class AdaptiveBatchScheduler:
all_data = [op.data for op in operations if op.data]
if not all_data:
return
# 批量插入
stmt = insert(operations[0].model_class).values(all_data)
result = await session.execute(stmt)
await session.execute(stmt)
await session.commit()
# 设置结果
for op in operations:
if op.future and not op.future.done():
op.future.set_result(True)
if op.callback:
try:
op.callback(True)
except Exception as e:
logger.warning(f"回调执行失败: {e}")
except Exception as e:
logger.error(f"批量插入失败: {e}", exc_info=True)
await session.rollback()
@@ -402,28 +402,28 @@ class AdaptiveBatchScheduler:
for key, value in op.conditions.items():
attr = getattr(op.model_class, key)
stmt = stmt.where(attr == value)
if op.data:
stmt = stmt.values(**op.data)
# 执行更新但不commit
result = await session.execute(stmt)
results.append((op, result.rowcount))
# 所有操作成功后一次性commit
await session.commit()
# 设置所有操作的结果
for op, rowcount in results:
if op.future and not op.future.done():
op.future.set_result(rowcount)
if op.callback:
try:
op.callback(rowcount)
except Exception as e:
logger.warning(f"回调执行失败: {e}")
except Exception as e:
logger.error(f"批量更新失败: {e}", exc_info=True)
await session.rollback()
@@ -447,25 +447,25 @@ class AdaptiveBatchScheduler:
for key, value in op.conditions.items():
attr = getattr(op.model_class, key)
stmt = stmt.where(attr == value)
# 执行删除但不commit
result = await session.execute(stmt)
results.append((op, result.rowcount))
# 所有操作成功后一次性commit
await session.commit()
# 设置所有操作的结果
for op, rowcount in results:
if op.future and not op.future.done():
op.future.set_result(rowcount)
if op.callback:
try:
op.callback(rowcount)
except Exception as e:
logger.warning(f"回调执行失败: {e}")
except Exception as e:
logger.error(f"批量删除失败: {e}", exc_info=True)
await session.rollback()
@@ -479,7 +479,7 @@ class AdaptiveBatchScheduler:
# 计算拥塞评分
total_queued = sum(len(q) for q in self.operation_queues.values())
self.stats.congestion_score = min(1.0, total_queued / self.max_queue_size)
# 根据拥塞情况调整批次大小
if self.stats.congestion_score > 0.7:
# 高拥塞,增加批次大小
@@ -493,7 +493,7 @@ class AdaptiveBatchScheduler:
self.min_batch_size,
int(self.current_batch_size * 0.9),
)
# 根据批次执行时间调整等待时间
if self.stats.last_batch_duration > 0:
if self.stats.last_batch_duration > self.current_wait_time * 2:
@@ -518,7 +518,7 @@ class AdaptiveBatchScheduler:
]
return "|".join(key_parts)
def _get_from_cache(self, cache_key: str) -> Optional[Any]:
def _get_from_cache(self, cache_key: str) -> Any | None:
"""从缓存获取结果"""
if cache_key in self._result_cache:
result, timestamp = self._result_cache[cache_key]
@@ -551,27 +551,27 @@ class AdaptiveBatchScheduler:
# 全局调度器实例
_global_scheduler: Optional[AdaptiveBatchScheduler] = None
_global_scheduler: AdaptiveBatchScheduler | None = None
_scheduler_lock = asyncio.Lock()
async def get_batch_scheduler() -> AdaptiveBatchScheduler:
"""获取全局批量调度器(单例)"""
global _global_scheduler
if _global_scheduler is None:
async with _scheduler_lock:
if _global_scheduler is None:
_global_scheduler = AdaptiveBatchScheduler()
await _global_scheduler.start()
return _global_scheduler
async def close_batch_scheduler() -> None:
"""关闭全局批量调度器"""
global _global_scheduler
if _global_scheduler is not None:
await _global_scheduler.stop()
_global_scheduler = None