""" 数据库批量调度器 实现多个数据库请求的智能合并和批量处理,减少数据库连接竞争 """ import asyncio import time from collections import defaultdict, deque from collections.abc import Callable from contextlib import asynccontextmanager from dataclasses import dataclass from typing import Any, TypeVar from sqlalchemy import delete, insert, select, update from src.common.database.sqlalchemy_database_api import get_db_session from src.common.logger import get_logger logger = get_logger("db_batch_scheduler") T = TypeVar("T") @dataclass class BatchOperation: """批量操作基础类""" operation_type: str # 'select', 'insert', 'update', 'delete' model_class: Any conditions: dict[str, Any] data: dict[str, Any] | None = None callback: Callable | None = None future: asyncio.Future | None = None timestamp: float = 0.0 def __post_init__(self): if self.timestamp == 0.0: self.timestamp = time.time() @dataclass class BatchResult: """批量操作结果""" success: bool data: Any = None error: str | None = None class DatabaseBatchScheduler: """数据库批量调度器""" def __init__( self, batch_size: int = 50, max_wait_time: float = 0.1, # 100ms max_queue_size: int = 1000, ): self.batch_size = batch_size self.max_wait_time = max_wait_time self.max_queue_size = max_queue_size # 操作队列,按操作类型和模型分类 self.operation_queues: dict[str, deque] = defaultdict(deque) # 调度控制 self._scheduler_task: asyncio.Task | None = None self._is_running = False self._lock = asyncio.Lock() # 统计信息 self.stats = {"total_operations": 0, "batched_operations": 0, "cache_hits": 0, "execution_time": 0.0} # 简单的结果缓存(用于频繁的查询) self._result_cache: dict[str, tuple[Any, float]] = {} self._cache_ttl = 5.0 # 5秒缓存 async def start(self): """启动调度器""" if self._is_running: return self._is_running = True self._scheduler_task = asyncio.create_task(self._scheduler_loop()) logger.info("数据库批量调度器已启动") async def stop(self): """停止调度器""" if not self._is_running: return self._is_running = False if self._scheduler_task: self._scheduler_task.cancel() try: await self._scheduler_task except asyncio.CancelledError: pass # 处理剩余的操作 await self._flush_all_queues() logger.info("数据库批量调度器已停止") def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: dict[str, Any]) -> str: """生成缓存键""" # 简单的缓存键生成,实际可以根据需要优化 key_parts = [operation_type, model_class.__name__, str(sorted(conditions.items()))] return "|".join(key_parts) def _get_from_cache(self, cache_key: str) -> Any | None: """从缓存获取结果""" if cache_key in self._result_cache: result, timestamp = self._result_cache[cache_key] if time.time() - timestamp < self._cache_ttl: self.stats["cache_hits"] += 1 return result else: # 清理过期缓存 del self._result_cache[cache_key] return None def _set_cache(self, cache_key: str, result: Any): """设置缓存""" self._result_cache[cache_key] = (result, time.time()) async def add_operation(self, operation: BatchOperation) -> asyncio.Future: """添加操作到队列""" # 检查是否可以立即返回缓存结果 if operation.operation_type == "select": cache_key = self._generate_cache_key(operation.operation_type, operation.model_class, operation.conditions) cached_result = self._get_from_cache(cache_key) if cached_result is not None: if operation.callback: operation.callback(cached_result) future = asyncio.get_event_loop().create_future() future.set_result(cached_result) return future # 创建future用于返回结果 future = asyncio.get_event_loop().create_future() operation.future = future # 添加到队列 queue_key = f"{operation.operation_type}_{operation.model_class.__name__}" async with self._lock: if len(self.operation_queues[queue_key]) >= self.max_queue_size: # 队列满了,直接执行 await self._execute_operations([operation]) else: self.operation_queues[queue_key].append(operation) self.stats["total_operations"] += 1 return future async def _scheduler_loop(self): """调度器主循环""" while self._is_running: try: await asyncio.sleep(self.max_wait_time) await self._flush_all_queues() except asyncio.CancelledError: break except Exception as e: logger.error(f"调度器循环异常: {e}", exc_info=True) async def _flush_all_queues(self): """刷新所有队列""" async with self._lock: if not any(self.operation_queues.values()): return # 复制队列内容,避免长时间占用锁 queues_copy = {key: deque(operations) for key, operations in self.operation_queues.items()} # 清空原队列 for queue in self.operation_queues.values(): queue.clear() # 批量执行各队列的操作 for operations in queues_copy.values(): if operations: await self._execute_operations(list(operations)) async def _execute_operations(self, operations: list[BatchOperation]): """执行批量操作""" if not operations: return start_time = time.time() try: # 按操作类型分组 op_groups = defaultdict(list) for op in operations: op_groups[op.operation_type].append(op) # 为每种操作类型创建批量执行任务 tasks = [] for op_type, ops in op_groups.items(): if op_type == "select": tasks.append(self._execute_select_batch(ops)) elif op_type == "insert": tasks.append(self._execute_insert_batch(ops)) elif op_type == "update": tasks.append(self._execute_update_batch(ops)) elif op_type == "delete": tasks.append(self._execute_delete_batch(ops)) # 并发执行所有操作 results = await asyncio.gather(*tasks, return_exceptions=True) # 处理结果 for i, result in enumerate(results): operation = operations[i] if isinstance(result, Exception): if operation.future and not operation.future.done(): operation.future.set_exception(result) else: if operation.callback: try: operation.callback(result) except Exception as e: logger.warning(f"操作回调执行失败: {e}") if operation.future and not operation.future.done(): operation.future.set_result(result) # 缓存查询结果 if operation.operation_type == "select": cache_key = self._generate_cache_key( operation.operation_type, operation.model_class, operation.conditions ) self._set_cache(cache_key, result) self.stats["batched_operations"] += len(operations) except Exception as e: logger.error(f"批量操作执行失败: {e}", exc_info="") # 设置所有future的异常状态 for operation in operations: if operation.future and not operation.future.done(): operation.future.set_exception(e) finally: self.stats["execution_time"] += time.time() - start_time async def _execute_select_batch(self, operations: list[BatchOperation]): """批量执行查询操作""" # 合并相似的查询条件 merged_conditions = self._merge_select_conditions(operations) async with get_db_session() as session: results = [] for conditions, ops in merged_conditions.items(): try: # 构建查询 query = select(ops[0].model_class) for field_name, value in conditions.items(): model_attr = getattr(ops[0].model_class, field_name) if isinstance(value, list | tuple | set): query = query.where(model_attr.in_(value)) else: query = query.where(model_attr == value) # 执行查询 result = await session.execute(query) data = result.scalars().all() # 分发结果到各个操作 for op in ops: if len(conditions) == 1 and len(ops) == 1: # 单个查询,直接返回所有结果 op_result = data else: # 需要根据条件过滤结果 op_result = [ item for item in data if all(getattr(item, k) == v for k, v in op.conditions.items() if hasattr(item, k)) ] results.append(op_result) except Exception as e: logger.error(f"批量查询失败: {e}", exc_info=True) results.append([]) return results if len(results) > 1 else results[0] if results else [] async def _execute_insert_batch(self, operations: list[BatchOperation]): """批量执行插入操作""" async with get_db_session() as session: try: # 收集所有要插入的数据 all_data = [op.data for op in operations if op.data] if not all_data: return [] # 批量插入 stmt = insert(operations[0].model_class).values(all_data) result = await session.execute(stmt) await session.commit() return [result.rowcount] * len(operations) except Exception as e: await session.rollback() logger.error(f"批量插入失败: {e}", exc_info=True) return [0] * len(operations) async def _execute_update_batch(self, operations: list[BatchOperation]): """批量执行更新操作""" async with get_db_session() as session: try: results = [] for op in operations: if not op.data or not op.conditions: results.append(0) continue stmt = update(op.model_class) for field_name, value in op.conditions.items(): model_attr = getattr(op.model_class, field_name) if isinstance(value, list | tuple | set): stmt = stmt.where(model_attr.in_(value)) else: stmt = stmt.where(model_attr == value) stmt = stmt.values(**op.data) result = await session.execute(stmt) results.append(result.rowcount) await session.commit() return results except Exception as e: await session.rollback() logger.error(f"批量更新失败: {e}", exc_info=True) return [0] * len(operations) async def _execute_delete_batch(self, operations: list[BatchOperation]): """批量执行删除操作""" async with get_db_session() as session: try: results = [] for op in operations: if not op.conditions: results.append(0) continue stmt = delete(op.model_class) for field_name, value in op.conditions.items(): model_attr = getattr(op.model_class, field_name) if isinstance(value, list | tuple | set): stmt = stmt.where(model_attr.in_(value)) else: stmt = stmt.where(model_attr == value) result = await session.execute(stmt) results.append(result.rowcount) await session.commit() return results except Exception as e: await session.rollback() logger.error(f"批量删除失败: {e}", exc_info=True) return [0] * len(operations) def _merge_select_conditions(self, operations: list[BatchOperation]) -> dict[tuple, list[BatchOperation]]: """合并相似的查询条件""" merged = {} for op in operations: # 生成条件键 condition_key = tuple(sorted(op.conditions.keys())) if condition_key not in merged: merged[condition_key] = {} # 尝试合并相同字段的值 for field_name, value in op.conditions.items(): if field_name not in merged[condition_key]: merged[condition_key][field_name] = [] if isinstance(value, list | tuple | set): merged[condition_key][field_name].extend(value) else: merged[condition_key][field_name].append(value) # 记录操作 if condition_key not in merged: merged[condition_key] = {"_operations": []} if "_operations" not in merged[condition_key]: merged[condition_key]["_operations"] = [] merged[condition_key]["_operations"].append(op) # 去重并构建最终条件 final_merged = {} for condition_key, conditions in merged.items(): operations = conditions.pop("_operations") # 去重 for field_name, values in conditions.items(): conditions[field_name] = list(set(values)) final_merged[condition_key] = operations return final_merged def get_stats(self) -> dict[str, Any]: """获取统计信息""" return { **self.stats, "cache_size": len(self._result_cache), "queue_sizes": {k: len(v) for k, v in self.operation_queues.items()}, "is_running": self._is_running, } # 全局数据库批量调度器实例 db_batch_scheduler = DatabaseBatchScheduler() @asynccontextmanager async def get_batch_session(): """获取批量会话上下文管理器""" if not db_batch_scheduler._is_running: await db_batch_scheduler.start() try: yield db_batch_scheduler finally: pass # 便捷函数 async def batch_select(model_class: Any, conditions: dict[str, Any]) -> Any: """批量查询""" operation = BatchOperation(operation_type="select", model_class=model_class, conditions=conditions) return await db_batch_scheduler.add_operation(operation) async def batch_insert(model_class: Any, data: dict[str, Any]) -> int: """批量插入""" operation = BatchOperation(operation_type="insert", model_class=model_class, conditions={}, data=data) return await db_batch_scheduler.add_operation(operation) async def batch_update(model_class: Any, conditions: dict[str, Any], data: dict[str, Any]) -> int: """批量更新""" operation = BatchOperation(operation_type="update", model_class=model_class, conditions=conditions, data=data) return await db_batch_scheduler.add_operation(operation) async def batch_delete(model_class: Any, conditions: dict[str, Any]) -> int: """批量删除""" operation = BatchOperation(operation_type="delete", model_class=model_class, conditions=conditions) return await db_batch_scheduler.add_operation(operation) def get_db_batch_scheduler() -> DatabaseBatchScheduler: """获取数据库批量调度器实例""" return db_batch_scheduler