This commit is contained in:
雅诺狐
2025-10-05 16:35:59 +08:00
52 changed files with 566 additions and 1186 deletions

View File

@@ -5,9 +5,8 @@
import asyncio
import time
import weakref
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional, Set
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
@@ -69,7 +68,7 @@ class ConnectionPoolManager:
self.max_idle = max_idle
# 连接池
self._connections: Set[ConnectionInfo] = set()
self._connections: set[ConnectionInfo] = set()
self._lock = asyncio.Lock()
# 统计信息
@@ -83,7 +82,7 @@ class ConnectionPoolManager:
}
# 后台清理任务
self._cleanup_task: Optional[asyncio.Task] = None
self._cleanup_task: asyncio.Task | None = None
self._should_cleanup = False
logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})")
@@ -144,7 +143,7 @@ class ConnectionPoolManager:
yield connection_info.session
except Exception as e:
except Exception:
# 发生错误时回滚连接
if connection_info and connection_info.session:
try:
@@ -157,7 +156,7 @@ class ConnectionPoolManager:
if connection_info:
connection_info.mark_released()
async def _get_reusable_connection(self, session_factory: async_sessionmaker[AsyncSession]) -> Optional[ConnectionInfo]:
async def _get_reusable_connection(self, session_factory: async_sessionmaker[AsyncSession]) -> ConnectionInfo | None:
"""获取可复用的连接"""
async with self._lock:
# 清理过期连接
@@ -231,7 +230,7 @@ class ConnectionPoolManager:
self._connections.clear()
logger.info("所有连接已关闭")
def get_stats(self) -> Dict[str, Any]:
def get_stats(self) -> dict[str, Any]:
"""获取连接池统计信息"""
return {
**self._stats,
@@ -244,7 +243,7 @@ class ConnectionPoolManager:
# 全局连接池管理器实例
_connection_pool_manager: Optional[ConnectionPoolManager] = None
_connection_pool_manager: ConnectionPoolManager | None = None
def get_connection_pool_manager() -> ConnectionPoolManager:
@@ -266,4 +265,4 @@ async def stop_connection_pool():
global _connection_pool_manager
if _connection_pool_manager:
await _connection_pool_manager.stop()
_connection_pool_manager = None
_connection_pool_manager = None

View File

@@ -2,15 +2,16 @@ import os
from rich.traceback import install
from src.common.database.connection_pool_manager import start_connection_pool, stop_connection_pool
# 数据库批量调度器和连接池
from src.common.database.db_batch_scheduler import get_db_batch_scheduler
# SQLAlchemy相关导入
from src.common.database.sqlalchemy_init import initialize_database_compat
from src.common.database.sqlalchemy_models import get_db_session, get_engine
from src.common.logger import get_logger
# 数据库批量调度器和连接池
from src.common.database.db_batch_scheduler import get_db_batch_scheduler
from src.common.database.connection_pool_manager import start_connection_pool, stop_connection_pool
install(extra_lines=3)
_sql_engine = None

View File

@@ -6,19 +6,19 @@
import asyncio
import time
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar
from collections.abc import Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, TypeVar
from sqlalchemy import select, delete, insert, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import delete, insert, select, update
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.logger import get_logger
logger = get_logger("db_batch_scheduler")
T = TypeVar('T')
T = TypeVar("T")
@dataclass
@@ -26,10 +26,10 @@ class BatchOperation:
"""批量操作基础类"""
operation_type: str # 'select', 'insert', 'update', 'delete'
model_class: Any
conditions: Dict[str, Any]
data: Optional[Dict[str, Any]] = None
callback: Optional[Callable] = None
future: Optional[asyncio.Future] = None
conditions: dict[str, Any]
data: dict[str, Any] | None = None
callback: Callable | None = None
future: asyncio.Future | None = None
timestamp: float = 0.0
def __post_init__(self):
@@ -42,7 +42,7 @@ class BatchResult:
"""批量操作结果"""
success: bool
data: Any = None
error: Optional[str] = None
error: str | None = None
class DatabaseBatchScheduler:
@@ -57,23 +57,23 @@ class DatabaseBatchScheduler:
self.max_queue_size = max_queue_size
# 操作队列,按操作类型和模型分类
self.operation_queues: Dict[str, deque] = defaultdict(deque)
self.operation_queues: dict[str, deque] = defaultdict(deque)
# 调度控制
self._scheduler_task: Optional[asyncio.Task] = None
self._scheduler_task: asyncio.Task | None = None
self._is_running = bool = False
self._lock = asyncio.Lock()
# 统计信息
self.stats = {
'total_operations': 0,
'batched_operations': 0,
'cache_hits': 0,
'execution_time': 0.0
"total_operations": 0,
"batched_operations": 0,
"cache_hits": 0,
"execution_time": 0.0
}
# 简单的结果缓存(用于频繁的查询)
self._result_cache: Dict[str, Tuple[Any, float]] = {}
self._result_cache: dict[str, tuple[Any, float]] = {}
self._cache_ttl = 5.0 # 5秒缓存
async def start(self):
@@ -102,7 +102,7 @@ class DatabaseBatchScheduler:
await self._flush_all_queues()
logger.info("数据库批量调度器已停止")
def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: Dict[str, Any]) -> str:
def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: dict[str, Any]) -> str:
"""生成缓存键"""
# 简单的缓存键生成,实际可以根据需要优化
key_parts = [
@@ -112,12 +112,12 @@ class DatabaseBatchScheduler:
]
return "|".join(key_parts)
def _get_from_cache(self, cache_key: str) -> Optional[Any]:
def _get_from_cache(self, cache_key: str) -> Any | None:
"""从缓存获取结果"""
if cache_key in self._result_cache:
result, timestamp = self._result_cache[cache_key]
if time.time() - timestamp < self._cache_ttl:
self.stats['cache_hits'] += 1
self.stats["cache_hits"] += 1
return result
else:
# 清理过期缓存
@@ -131,7 +131,7 @@ class DatabaseBatchScheduler:
async def add_operation(self, operation: BatchOperation) -> asyncio.Future:
"""添加操作到队列"""
# 检查是否可以立即返回缓存结果
if operation.operation_type == 'select':
if operation.operation_type == "select":
cache_key = self._generate_cache_key(
operation.operation_type,
operation.model_class,
@@ -158,7 +158,7 @@ class DatabaseBatchScheduler:
await self._execute_operations([operation])
else:
self.operation_queues[queue_key].append(operation)
self.stats['total_operations'] += 1
self.stats["total_operations"] += 1
return future
@@ -193,7 +193,7 @@ class DatabaseBatchScheduler:
if operations:
await self._execute_operations(list(operations))
async def _execute_operations(self, operations: List[BatchOperation]):
async def _execute_operations(self, operations: list[BatchOperation]):
"""执行批量操作"""
if not operations:
return
@@ -209,13 +209,13 @@ class DatabaseBatchScheduler:
# 为每种操作类型创建批量执行任务
tasks = []
for op_type, ops in op_groups.items():
if op_type == 'select':
if op_type == "select":
tasks.append(self._execute_select_batch(ops))
elif op_type == 'insert':
elif op_type == "insert":
tasks.append(self._execute_insert_batch(ops))
elif op_type == 'update':
elif op_type == "update":
tasks.append(self._execute_update_batch(ops))
elif op_type == 'delete':
elif op_type == "delete":
tasks.append(self._execute_delete_batch(ops))
# 并发执行所有操作
@@ -238,7 +238,7 @@ class DatabaseBatchScheduler:
operation.future.set_result(result)
# 缓存查询结果
if operation.operation_type == 'select':
if operation.operation_type == "select":
cache_key = self._generate_cache_key(
operation.operation_type,
operation.model_class,
@@ -246,7 +246,7 @@ class DatabaseBatchScheduler:
)
self._set_cache(cache_key, result)
self.stats['batched_operations'] += len(operations)
self.stats["batched_operations"] += len(operations)
except Exception as e:
logger.error(f"批量操作执行失败: {e}", exc_info="")
@@ -255,9 +255,9 @@ class DatabaseBatchScheduler:
if operation.future and not operation.future.done():
operation.future.set_exception(e)
finally:
self.stats['execution_time'] += time.time() - start_time
self.stats["execution_time"] += time.time() - start_time
async def _execute_select_batch(self, operations: List[BatchOperation]):
async def _execute_select_batch(self, operations: list[BatchOperation]):
"""批量执行查询操作"""
# 合并相似的查询条件
merged_conditions = self._merge_select_conditions(operations)
@@ -302,7 +302,7 @@ class DatabaseBatchScheduler:
return results if len(results) > 1 else results[0] if results else []
async def _execute_insert_batch(self, operations: List[BatchOperation]):
async def _execute_insert_batch(self, operations: list[BatchOperation]):
"""批量执行插入操作"""
async with get_db_session() as session:
try:
@@ -323,7 +323,7 @@ class DatabaseBatchScheduler:
logger.error(f"批量插入失败: {e}", exc_info=True)
return [0] * len(operations)
async def _execute_update_batch(self, operations: List[BatchOperation]):
async def _execute_update_batch(self, operations: list[BatchOperation]):
"""批量执行更新操作"""
async with get_db_session() as session:
try:
@@ -353,7 +353,7 @@ class DatabaseBatchScheduler:
logger.error(f"批量更新失败: {e}", exc_info=True)
return [0] * len(operations)
async def _execute_delete_batch(self, operations: List[BatchOperation]):
async def _execute_delete_batch(self, operations: list[BatchOperation]):
"""批量执行删除操作"""
async with get_db_session() as session:
try:
@@ -382,7 +382,7 @@ class DatabaseBatchScheduler:
logger.error(f"批量删除失败: {e}", exc_info=True)
return [0] * len(operations)
def _merge_select_conditions(self, operations: List[BatchOperation]) -> Dict[Tuple, List[BatchOperation]]:
def _merge_select_conditions(self, operations: list[BatchOperation]) -> dict[tuple, list[BatchOperation]]:
"""合并相似的查询条件"""
merged = {}
@@ -405,15 +405,15 @@ class DatabaseBatchScheduler:
# 记录操作
if condition_key not in merged:
merged[condition_key] = {'_operations': []}
if '_operations' not in merged[condition_key]:
merged[condition_key]['_operations'] = []
merged[condition_key]['_operations'].append(op)
merged[condition_key] = {"_operations": []}
if "_operations" not in merged[condition_key]:
merged[condition_key]["_operations"] = []
merged[condition_key]["_operations"].append(op)
# 去重并构建最终条件
final_merged = {}
for condition_key, conditions in merged.items():
operations = conditions.pop('_operations')
operations = conditions.pop("_operations")
# 去重
for field_name, values in conditions.items():
@@ -423,13 +423,13 @@ class DatabaseBatchScheduler:
return final_merged
def get_stats(self) -> Dict[str, Any]:
def get_stats(self) -> dict[str, Any]:
"""获取统计信息"""
return {
**self.stats,
'cache_size': len(self._result_cache),
'queue_sizes': {k: len(v) for k, v in self.operation_queues.items()},
'is_running': self._is_running
"cache_size": len(self._result_cache),
"queue_sizes": {k: len(v) for k, v in self.operation_queues.items()},
"is_running": self._is_running
}
@@ -450,20 +450,20 @@ async def get_batch_session():
# 便捷函数
async def batch_select(model_class: Any, conditions: Dict[str, Any]) -> Any:
async def batch_select(model_class: Any, conditions: dict[str, Any]) -> Any:
"""批量查询"""
operation = BatchOperation(
operation_type='select',
operation_type="select",
model_class=model_class,
conditions=conditions
)
return await db_batch_scheduler.add_operation(operation)
async def batch_insert(model_class: Any, data: Dict[str, Any]) -> int:
async def batch_insert(model_class: Any, data: dict[str, Any]) -> int:
"""批量插入"""
operation = BatchOperation(
operation_type='insert',
operation_type="insert",
model_class=model_class,
conditions={},
data=data
@@ -471,10 +471,10 @@ async def batch_insert(model_class: Any, data: Dict[str, Any]) -> int:
return await db_batch_scheduler.add_operation(operation)
async def batch_update(model_class: Any, conditions: Dict[str, Any], data: Dict[str, Any]) -> int:
async def batch_update(model_class: Any, conditions: dict[str, Any], data: dict[str, Any]) -> int:
"""批量更新"""
operation = BatchOperation(
operation_type='update',
operation_type="update",
model_class=model_class,
conditions=conditions,
data=data
@@ -482,10 +482,10 @@ async def batch_update(model_class: Any, conditions: Dict[str, Any], data: Dict[
return await db_batch_scheduler.add_operation(operation)
async def batch_delete(model_class: Any, conditions: Dict[str, Any]) -> int:
async def batch_delete(model_class: Any, conditions: dict[str, Any]) -> int:
"""批量删除"""
operation = BatchOperation(
operation_type='delete',
operation_type="delete",
model_class=model_class,
conditions=conditions
)
@@ -494,4 +494,4 @@ async def batch_delete(model_class: Any, conditions: Dict[str, Any]) -> int:
def get_db_batch_scheduler() -> DatabaseBatchScheduler:
"""获取数据库批量调度器实例"""
return db_batch_scheduler
return db_batch_scheduler

View File

@@ -15,8 +15,8 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Mapped, mapped_column
from src.common.logger import get_logger
from src.common.database.connection_pool_manager import get_connection_pool_manager
from src.common.logger import get_logger
logger = get_logger("sqlalchemy_models")

View File

@@ -1,13 +1,13 @@
# 使用基于时间戳的文件处理器,简单的轮转份数限制
import logging
import tarfile
import threading
import time
import tarfile
from collections.abc import Callable
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Optional, Dict
from typing import Any
import orjson
import structlog
@@ -18,15 +18,15 @@ LOG_DIR = Path("logs")
LOG_DIR.mkdir(exist_ok=True)
# 全局handler实例避免重复创建可能为None表示禁用文件日志
_file_handler: Optional[logging.Handler] = None
_console_handler: Optional[logging.Handler] = None
_file_handler: logging.Handler | None = None
_console_handler: logging.Handler | None = None
# 动态 logger 元数据注册表 (name -> {alias:str|None, color:str|None})
_LOGGER_META_LOCK = threading.Lock()
_LOGGER_META: Dict[str, Dict[str, Optional[str]]] = {}
_LOGGER_META: dict[str, dict[str, str | None]] = {}
def _normalize_color(color: Optional[str]) -> Optional[str]:
def _normalize_color(color: str | None) -> str | None:
"""接受 ANSI 码 / #RRGGBB / rgb(r,g,b) / 颜色名(直接返回) -> ANSI 码.
不做复杂解析,只支持 #RRGGBB 转 24bit ANSI。
"""
@@ -49,13 +49,13 @@ def _normalize_color(color: Optional[str]) -> Optional[str]:
nums = color[color.find("(") + 1 : -1].split(",")
r, g, b = (int(x) for x in nums[:3])
return f"\033[38;2;{r};{g};{b}m"
except Exception: # noqa: BLE001
except Exception:
return None
# 其他情况直接返回假设是短ANSI或名称控制台渲染器不做翻译仅输出
return color
def _register_logger_meta(name: str, *, alias: Optional[str] = None, color: Optional[str] = None):
def _register_logger_meta(name: str, *, alias: str | None = None, color: str | None = None):
"""注册/更新 logger 元数据。"""
if not name:
return
@@ -67,7 +67,7 @@ def _register_logger_meta(name: str, *, alias: Optional[str] = None, color: Opti
meta["color"] = _normalize_color(color)
def get_logger_meta(name: str) -> Dict[str, Optional[str]]:
def get_logger_meta(name: str) -> dict[str, str | None]:
with _LOGGER_META_LOCK:
return _LOGGER_META.get(name, {"alias": None, "color": None}).copy()
@@ -170,7 +170,7 @@ class TimestampedFileHandler(logging.Handler):
try:
self._compress_stale_logs()
self._cleanup_old_files()
except Exception as e: # noqa: BLE001
except Exception as e:
print(f"[日志轮转] 轮转过程出错: {e}")
def _compress_stale_logs(self): # sourcery skip: extract-method
@@ -184,12 +184,12 @@ class TimestampedFileHandler(logging.Handler):
continue
# 压缩
try:
with tarfile.open(tar_path, "w:gz") as tf: # noqa: SIM117
with tarfile.open(tar_path, "w:gz") as tf:
tf.add(f, arcname=f.name)
f.unlink(missing_ok=True)
except Exception as e: # noqa: BLE001
except Exception as e:
print(f"[日志压缩] 压缩 {f.name} 失败: {e}")
except Exception as e: # noqa: BLE001
except Exception as e:
print(f"[日志压缩] 过程出错: {e}")
def _cleanup_old_files(self):
@@ -206,9 +206,9 @@ class TimestampedFileHandler(logging.Handler):
mtime = datetime.fromtimestamp(f.stat().st_mtime)
if mtime < cutoff:
f.unlink(missing_ok=True)
except Exception as e: # noqa: BLE001
except Exception as e:
print(f"[日志清理] 删除 {f} 失败: {e}")
except Exception as e: # noqa: BLE001
except Exception as e:
print(f"[日志清理] 清理过程出错: {e}")
def emit(self, record):
@@ -850,7 +850,7 @@ class ModuleColoredConsoleRenderer:
if logger_name:
# 获取别名,如果没有别名则使用原名称
# 若上面条件不成立需要再次获取 meta
if 'meta' not in locals():
if "meta" not in locals():
meta = get_logger_meta(logger_name)
display_name = meta.get("alias") or DEFAULT_MODULE_ALIASES.get(logger_name, logger_name)
@@ -1066,7 +1066,7 @@ raw_logger: structlog.stdlib.BoundLogger = structlog.get_logger()
binds: dict[str, Callable] = {}
def get_logger(name: str | None, *, color: Optional[str] = None, alias: Optional[str] = None) -> structlog.stdlib.BoundLogger:
def get_logger(name: str | None, *, color: str | None = None, alias: str | None = None) -> structlog.stdlib.BoundLogger:
"""获取/创建 structlog logger。
新增:
@@ -1132,10 +1132,10 @@ def cleanup_old_logs():
tar_path = f.with_suffix(f.suffix + ".tar.gz")
if tar_path.exists():
continue
with tarfile.open(tar_path, "w:gz") as tf: # noqa: SIM117
with tarfile.open(tar_path, "w:gz") as tf:
tf.add(f, arcname=f.name)
f.unlink(missing_ok=True)
except Exception as e: # noqa: BLE001
except Exception as e:
logger = get_logger("logger")
logger.warning(f"周期压缩日志时出错: {e}")
@@ -1152,7 +1152,7 @@ def cleanup_old_logs():
log_file.unlink(missing_ok=True)
deleted_count += 1
deleted_size += size
except Exception as e: # noqa: BLE001
except Exception as e:
logger = get_logger("logger")
logger.warning(f"清理日志文件 {log_file} 时出错: {e}")
if deleted_count:
@@ -1160,7 +1160,7 @@ def cleanup_old_logs():
logger.info(
f"清理 {deleted_count} 个过期日志 (≈{deleted_size / 1024 / 1024:.2f}MB), 保留策略={retention_days}"
)
except Exception as e: # noqa: BLE001
except Exception as e:
logger = get_logger("logger")
logger.error(f"清理旧日志文件时出错: {e}")
@@ -1183,7 +1183,7 @@ def start_log_cleanup_task():
while True:
try:
cleanup_old_logs()
except Exception as e: # noqa: BLE001
except Exception as e:
print(f"[日志任务] 执行清理出错: {e}")
# 再次等待到下一个午夜
time.sleep(max(1, seconds_until_next_midnight()))