refactor(database): 阶段一 - 创建新架构基础
- 创建分层目录结构 (core/api/optimization/config/utils) - 实现核心层: engine.py, session.py - 实现配置层: database_config.py - 实现工具层: exceptions.py - 迁移连接池管理器到优化层 - 添加详细的重构计划文档
This commit is contained in:
1475
docs/database_refactoring_plan.md
Normal file
1475
docs/database_refactoring_plan.md
Normal file
File diff suppressed because it is too large
Load Diff
9
src/common/database/api/__init__.py
Normal file
9
src/common/database/api/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
"""数据库API层
|
||||||
|
|
||||||
|
职责:
|
||||||
|
- CRUD操作
|
||||||
|
- 查询构建
|
||||||
|
- 特殊业务操作
|
||||||
|
"""
|
||||||
|
|
||||||
|
__all__ = []
|
||||||
14
src/common/database/config/__init__.py
Normal file
14
src/common/database/config/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""数据库配置层
|
||||||
|
|
||||||
|
职责:
|
||||||
|
- 数据库配置管理
|
||||||
|
- 优化参数配置
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .database_config import DatabaseConfig, get_database_config, reset_database_config
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DatabaseConfig",
|
||||||
|
"get_database_config",
|
||||||
|
"reset_database_config",
|
||||||
|
]
|
||||||
149
src/common/database/config/database_config.py
Normal file
149
src/common/database/config/database_config.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
"""数据库配置管理
|
||||||
|
|
||||||
|
统一管理数据库连接配置
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional
|
||||||
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("database_config")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatabaseConfig:
|
||||||
|
"""数据库配置"""
|
||||||
|
|
||||||
|
# 基础配置
|
||||||
|
db_type: str # "sqlite" 或 "mysql"
|
||||||
|
url: str # 数据库连接URL
|
||||||
|
|
||||||
|
# 引擎配置
|
||||||
|
engine_kwargs: dict[str, Any]
|
||||||
|
|
||||||
|
# SQLite特定配置
|
||||||
|
sqlite_path: Optional[str] = None
|
||||||
|
|
||||||
|
# MySQL特定配置
|
||||||
|
mysql_host: Optional[str] = None
|
||||||
|
mysql_port: Optional[int] = None
|
||||||
|
mysql_user: Optional[str] = None
|
||||||
|
mysql_password: Optional[str] = None
|
||||||
|
mysql_database: Optional[str] = None
|
||||||
|
mysql_charset: str = "utf8mb4"
|
||||||
|
mysql_unix_socket: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
_database_config: Optional[DatabaseConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_database_config() -> DatabaseConfig:
|
||||||
|
"""获取数据库配置
|
||||||
|
|
||||||
|
从全局配置中读取数据库设置并构建配置对象
|
||||||
|
"""
|
||||||
|
global _database_config
|
||||||
|
|
||||||
|
if _database_config is not None:
|
||||||
|
return _database_config
|
||||||
|
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
config = global_config.database
|
||||||
|
|
||||||
|
# 构建数据库URL
|
||||||
|
if config.database_type == "mysql":
|
||||||
|
# MySQL配置
|
||||||
|
encoded_user = quote_plus(config.mysql_user)
|
||||||
|
encoded_password = quote_plus(config.mysql_password)
|
||||||
|
|
||||||
|
if config.mysql_unix_socket:
|
||||||
|
# Unix socket连接
|
||||||
|
encoded_socket = quote_plus(config.mysql_unix_socket)
|
||||||
|
url = (
|
||||||
|
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
|
||||||
|
f"@/{config.mysql_database}"
|
||||||
|
f"?unix_socket={encoded_socket}&charset={config.mysql_charset}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# TCP连接
|
||||||
|
url = (
|
||||||
|
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
|
||||||
|
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||||
|
f"?charset={config.mysql_charset}"
|
||||||
|
)
|
||||||
|
|
||||||
|
engine_kwargs = {
|
||||||
|
"echo": False,
|
||||||
|
"future": True,
|
||||||
|
"pool_size": config.connection_pool_size,
|
||||||
|
"max_overflow": config.connection_pool_size * 2,
|
||||||
|
"pool_timeout": config.connection_timeout,
|
||||||
|
"pool_recycle": 3600,
|
||||||
|
"pool_pre_ping": True,
|
||||||
|
"connect_args": {
|
||||||
|
"autocommit": config.mysql_autocommit,
|
||||||
|
"charset": config.mysql_charset,
|
||||||
|
"connect_timeout": config.connection_timeout,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_database_config = DatabaseConfig(
|
||||||
|
db_type="mysql",
|
||||||
|
url=url,
|
||||||
|
engine_kwargs=engine_kwargs,
|
||||||
|
mysql_host=config.mysql_host,
|
||||||
|
mysql_port=config.mysql_port,
|
||||||
|
mysql_user=config.mysql_user,
|
||||||
|
mysql_password=config.mysql_password,
|
||||||
|
mysql_database=config.mysql_database,
|
||||||
|
mysql_charset=config.mysql_charset,
|
||||||
|
mysql_unix_socket=config.mysql_unix_socket,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"MySQL配置已加载: "
|
||||||
|
f"{config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# SQLite配置
|
||||||
|
if not os.path.isabs(config.sqlite_path):
|
||||||
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||||
|
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
|
||||||
|
else:
|
||||||
|
db_path = config.sqlite_path
|
||||||
|
|
||||||
|
# 确保数据库目录存在
|
||||||
|
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||||
|
|
||||||
|
url = f"sqlite+aiosqlite:///{db_path}"
|
||||||
|
|
||||||
|
engine_kwargs = {
|
||||||
|
"echo": False,
|
||||||
|
"future": True,
|
||||||
|
"connect_args": {
|
||||||
|
"check_same_thread": False,
|
||||||
|
"timeout": 60,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_database_config = DatabaseConfig(
|
||||||
|
db_type="sqlite",
|
||||||
|
url=url,
|
||||||
|
engine_kwargs=engine_kwargs,
|
||||||
|
sqlite_path=db_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"SQLite配置已加载: {db_path}")
|
||||||
|
|
||||||
|
return _database_config
|
||||||
|
|
||||||
|
|
||||||
|
def reset_database_config():
|
||||||
|
"""重置数据库配置(用于测试)"""
|
||||||
|
global _database_config
|
||||||
|
_database_config = None
|
||||||
21
src/common/database/core/__init__.py
Normal file
21
src/common/database/core/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""数据库核心层
|
||||||
|
|
||||||
|
职责:
|
||||||
|
- 数据库引擎管理
|
||||||
|
- 会话管理
|
||||||
|
- 模型定义
|
||||||
|
- 数据库迁移
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .engine import close_engine, get_engine, get_engine_info
|
||||||
|
from .session import get_db_session, get_db_session_direct, get_session_factory, reset_session_factory
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_engine",
|
||||||
|
"close_engine",
|
||||||
|
"get_engine_info",
|
||||||
|
"get_db_session",
|
||||||
|
"get_db_session_direct",
|
||||||
|
"get_session_factory",
|
||||||
|
"reset_session_factory",
|
||||||
|
]
|
||||||
141
src/common/database/core/engine.py
Normal file
141
src/common/database/core/engine.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""数据库引擎管理
|
||||||
|
|
||||||
|
单一职责:创建和管理SQLAlchemy异步引擎
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
from ..config.database_config import get_database_config
|
||||||
|
from ..utils.exceptions import DatabaseInitializationError
|
||||||
|
|
||||||
|
logger = get_logger("database.engine")
|
||||||
|
|
||||||
|
# 全局引擎实例
|
||||||
|
_engine: Optional[AsyncEngine] = None
|
||||||
|
_engine_lock: Optional[asyncio.Lock] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_engine() -> AsyncEngine:
|
||||||
|
"""获取全局数据库引擎(单例模式)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncEngine: SQLAlchemy异步引擎
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DatabaseInitializationError: 引擎初始化失败
|
||||||
|
"""
|
||||||
|
global _engine, _engine_lock
|
||||||
|
|
||||||
|
# 快速路径:引擎已初始化
|
||||||
|
if _engine is not None:
|
||||||
|
return _engine
|
||||||
|
|
||||||
|
# 延迟创建锁(避免在导入时创建)
|
||||||
|
if _engine_lock is None:
|
||||||
|
_engine_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# 使用锁保护初始化过程
|
||||||
|
async with _engine_lock:
|
||||||
|
# 双重检查锁定模式
|
||||||
|
if _engine is not None:
|
||||||
|
return _engine
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = get_database_config()
|
||||||
|
|
||||||
|
logger.info(f"正在初始化 {config.db_type.upper()} 数据库引擎...")
|
||||||
|
|
||||||
|
# 创建异步引擎
|
||||||
|
_engine = create_async_engine(
|
||||||
|
config.url,
|
||||||
|
**config.engine_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# SQLite特定优化
|
||||||
|
if config.db_type == "sqlite":
|
||||||
|
await _enable_sqlite_optimizations(_engine)
|
||||||
|
|
||||||
|
logger.info(f"✅ {config.db_type.upper()} 数据库引擎初始化成功")
|
||||||
|
return _engine
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 数据库引擎初始化失败: {e}", exc_info=True)
|
||||||
|
raise DatabaseInitializationError(f"引擎初始化失败: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
async def close_engine():
|
||||||
|
"""关闭数据库引擎
|
||||||
|
|
||||||
|
释放所有连接池资源
|
||||||
|
"""
|
||||||
|
global _engine
|
||||||
|
|
||||||
|
if _engine is not None:
|
||||||
|
logger.info("正在关闭数据库引擎...")
|
||||||
|
await _engine.dispose()
|
||||||
|
_engine = None
|
||||||
|
logger.info("✅ 数据库引擎已关闭")
|
||||||
|
|
||||||
|
|
||||||
|
async def _enable_sqlite_optimizations(engine: AsyncEngine):
|
||||||
|
"""启用SQLite性能优化
|
||||||
|
|
||||||
|
优化项:
|
||||||
|
- WAL模式:提高并发性能
|
||||||
|
- NORMAL同步:平衡性能和安全性
|
||||||
|
- 启用外键约束
|
||||||
|
- 设置busy_timeout:避免锁定错误
|
||||||
|
|
||||||
|
Args:
|
||||||
|
engine: SQLAlchemy异步引擎
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
# 启用WAL模式
|
||||||
|
await conn.execute(text("PRAGMA journal_mode = WAL"))
|
||||||
|
# 设置适中的同步级别
|
||||||
|
await conn.execute(text("PRAGMA synchronous = NORMAL"))
|
||||||
|
# 启用外键约束
|
||||||
|
await conn.execute(text("PRAGMA foreign_keys = ON"))
|
||||||
|
# 设置busy_timeout,避免锁定错误
|
||||||
|
await conn.execute(text("PRAGMA busy_timeout = 60000"))
|
||||||
|
# 设置缓存大小(10MB)
|
||||||
|
await conn.execute(text("PRAGMA cache_size = -10000"))
|
||||||
|
# 临时存储使用内存
|
||||||
|
await conn.execute(text("PRAGMA temp_store = MEMORY"))
|
||||||
|
|
||||||
|
logger.info("✅ SQLite性能优化已启用 (WAL模式 + 并发优化)")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"⚠️ SQLite性能优化失败: {e},将使用默认配置")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_engine_info() -> dict:
|
||||||
|
"""获取引擎信息(用于监控和调试)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 引擎信息字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
engine = await get_engine()
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"name": engine.name,
|
||||||
|
"driver": engine.driver,
|
||||||
|
"url": str(engine.url).replace(str(engine.url.password or ""), "***"),
|
||||||
|
"pool_size": getattr(engine.pool, "size", lambda: None)(),
|
||||||
|
"pool_checked_out": getattr(engine.pool, "checked_out", lambda: 0)(),
|
||||||
|
"pool_overflow": getattr(engine.pool, "overflow", lambda: 0)(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取引擎信息失败: {e}")
|
||||||
|
return {}
|
||||||
118
src/common/database/core/session.py
Normal file
118
src/common/database/core/session.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""数据库会话管理
|
||||||
|
|
||||||
|
单一职责:提供数据库会话工厂和上下文管理器
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
from ..config.database_config import get_database_config
|
||||||
|
from .engine import get_engine
|
||||||
|
|
||||||
|
logger = get_logger("database.session")
|
||||||
|
|
||||||
|
# 全局会话工厂
|
||||||
|
_session_factory: Optional[async_sessionmaker] = None
|
||||||
|
_factory_lock: Optional[asyncio.Lock] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session_factory() -> async_sessionmaker:
|
||||||
|
"""获取会话工厂(单例模式)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
async_sessionmaker: SQLAlchemy异步会话工厂
|
||||||
|
"""
|
||||||
|
global _session_factory, _factory_lock
|
||||||
|
|
||||||
|
# 快速路径
|
||||||
|
if _session_factory is not None:
|
||||||
|
return _session_factory
|
||||||
|
|
||||||
|
# 延迟创建锁
|
||||||
|
if _factory_lock is None:
|
||||||
|
_factory_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async with _factory_lock:
|
||||||
|
# 双重检查
|
||||||
|
if _session_factory is not None:
|
||||||
|
return _session_factory
|
||||||
|
|
||||||
|
engine = await get_engine()
|
||||||
|
_session_factory = async_sessionmaker(
|
||||||
|
bind=engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False, # 避免在commit后访问属性时重新查询
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("会话工厂已创建")
|
||||||
|
return _session_factory
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""获取数据库会话上下文管理器
|
||||||
|
|
||||||
|
这是数据库操作的主要入口点,通过连接池管理器提供透明的连接复用。
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
async with get_db_session() as session:
|
||||||
|
result = await session.execute(select(User))
|
||||||
|
users = result.scalars().all()
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
AsyncSession: SQLAlchemy异步会话对象
|
||||||
|
"""
|
||||||
|
# 延迟导入避免循环依赖
|
||||||
|
from ..optimization.connection_pool import get_connection_pool_manager
|
||||||
|
|
||||||
|
session_factory = await get_session_factory()
|
||||||
|
pool_manager = get_connection_pool_manager()
|
||||||
|
|
||||||
|
# 使用连接池管理器(透明复用连接)
|
||||||
|
async with pool_manager.get_session(session_factory) as session:
|
||||||
|
# 为SQLite设置特定的PRAGMA
|
||||||
|
config = get_database_config()
|
||||||
|
if config.db_type == "sqlite":
|
||||||
|
try:
|
||||||
|
await session.execute(text("PRAGMA busy_timeout = 60000"))
|
||||||
|
await session.execute(text("PRAGMA foreign_keys = ON"))
|
||||||
|
except Exception:
|
||||||
|
# 复用连接时PRAGMA可能已设置,忽略错误
|
||||||
|
pass
|
||||||
|
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""获取数据库会话(直接模式,不使用连接池)
|
||||||
|
|
||||||
|
用于特殊场景,如需要完全独立的连接时。
|
||||||
|
一般情况下应使用 get_db_session()。
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
AsyncSession: SQLAlchemy异步会话对象
|
||||||
|
"""
|
||||||
|
session_factory = await get_session_factory()
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def reset_session_factory():
|
||||||
|
"""重置会话工厂(用于测试)"""
|
||||||
|
global _session_factory
|
||||||
|
_session_factory = None
|
||||||
22
src/common/database/optimization/__init__.py
Normal file
22
src/common/database/optimization/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""数据库优化层
|
||||||
|
|
||||||
|
职责:
|
||||||
|
- 连接池管理
|
||||||
|
- 批量调度
|
||||||
|
- 多级缓存
|
||||||
|
- 数据预加载
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .connection_pool import (
|
||||||
|
ConnectionPoolManager,
|
||||||
|
get_connection_pool_manager,
|
||||||
|
start_connection_pool,
|
||||||
|
stop_connection_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ConnectionPoolManager",
|
||||||
|
"get_connection_pool_manager",
|
||||||
|
"start_connection_pool",
|
||||||
|
"stop_connection_pool",
|
||||||
|
]
|
||||||
274
src/common/database/optimization/connection_pool.py
Normal file
274
src/common/database/optimization/connection_pool.py
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
"""
|
||||||
|
透明连接复用管理器
|
||||||
|
|
||||||
|
在不改变原有API的情况下,实现数据库连接的智能复用
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("database.connection_pool")
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionInfo:
|
||||||
|
"""连接信息包装器"""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession, created_at: float):
|
||||||
|
self.session = session
|
||||||
|
self.created_at = created_at
|
||||||
|
self.last_used = created_at
|
||||||
|
self.in_use = False
|
||||||
|
self.ref_count = 0
|
||||||
|
|
||||||
|
def mark_used(self):
|
||||||
|
"""标记连接被使用"""
|
||||||
|
self.last_used = time.time()
|
||||||
|
self.in_use = True
|
||||||
|
self.ref_count += 1
|
||||||
|
|
||||||
|
def mark_released(self):
|
||||||
|
"""标记连接被释放"""
|
||||||
|
self.in_use = False
|
||||||
|
self.ref_count = max(0, self.ref_count - 1)
|
||||||
|
|
||||||
|
def is_expired(self, max_lifetime: float = 300.0, max_idle: float = 60.0) -> bool:
|
||||||
|
"""检查连接是否过期"""
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# 检查总生命周期
|
||||||
|
if current_time - self.created_at > max_lifetime:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 检查空闲时间
|
||||||
|
if not self.in_use and current_time - self.last_used > max_idle:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""关闭连接"""
|
||||||
|
try:
|
||||||
|
# 使用 shield 保护 close 操作,确保即使任务被取消也能完成关闭
|
||||||
|
from typing import cast
|
||||||
|
await cast(asyncio.Future, asyncio.shield(self.session.close()))
|
||||||
|
logger.debug("连接已关闭")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# 这是一个预期的行为,例如在流式聊天中断时
|
||||||
|
logger.debug("关闭连接时任务被取消")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"关闭连接时出错: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionPoolManager:
|
||||||
|
"""透明的连接池管理器"""
|
||||||
|
|
||||||
|
def __init__(self, max_pool_size: int = 10, max_lifetime: float = 300.0, max_idle: float = 60.0):
|
||||||
|
self.max_pool_size = max_pool_size
|
||||||
|
self.max_lifetime = max_lifetime
|
||||||
|
self.max_idle = max_idle
|
||||||
|
|
||||||
|
# 连接池
|
||||||
|
self._connections: set[ConnectionInfo] = set()
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
|
self._stats = {
|
||||||
|
"total_created": 0,
|
||||||
|
"total_reused": 0,
|
||||||
|
"total_expired": 0,
|
||||||
|
"active_connections": 0,
|
||||||
|
"pool_hits": 0,
|
||||||
|
"pool_misses": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 后台清理任务
|
||||||
|
self._cleanup_task: asyncio.Task | None = None
|
||||||
|
self._should_cleanup = False
|
||||||
|
|
||||||
|
logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})")
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""启动连接池管理器"""
|
||||||
|
if self._cleanup_task is None:
|
||||||
|
self._should_cleanup = True
|
||||||
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||||
|
logger.info("✅ 连接池管理器已启动")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""停止连接池管理器"""
|
||||||
|
self._should_cleanup = False
|
||||||
|
|
||||||
|
if self._cleanup_task:
|
||||||
|
self._cleanup_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._cleanup_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._cleanup_task = None
|
||||||
|
|
||||||
|
# 关闭所有连接
|
||||||
|
await self._close_all_connections()
|
||||||
|
logger.info("✅ 连接池管理器已停止")
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_session(self, session_factory: async_sessionmaker[AsyncSession]):
|
||||||
|
"""
|
||||||
|
获取数据库会话的透明包装器
|
||||||
|
如果有可用连接则复用,否则创建新连接
|
||||||
|
"""
|
||||||
|
connection_info = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 尝试获取现有连接
|
||||||
|
connection_info = await self._get_reusable_connection(session_factory)
|
||||||
|
|
||||||
|
if connection_info:
|
||||||
|
# 复用现有连接
|
||||||
|
connection_info.mark_used()
|
||||||
|
self._stats["total_reused"] += 1
|
||||||
|
self._stats["pool_hits"] += 1
|
||||||
|
logger.debug(f"♻️ 复用连接 (池大小: {len(self._connections)})")
|
||||||
|
else:
|
||||||
|
# 创建新连接
|
||||||
|
session = session_factory()
|
||||||
|
connection_info = ConnectionInfo(session, time.time())
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
self._connections.add(connection_info)
|
||||||
|
|
||||||
|
connection_info.mark_used()
|
||||||
|
self._stats["total_created"] += 1
|
||||||
|
self._stats["pool_misses"] += 1
|
||||||
|
logger.debug(f"🆕 创建连接 (池大小: {len(self._connections)})")
|
||||||
|
|
||||||
|
yield connection_info.session
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# 发生错误时回滚连接
|
||||||
|
if connection_info and connection_info.session:
|
||||||
|
try:
|
||||||
|
await connection_info.session.rollback()
|
||||||
|
except Exception as rollback_error:
|
||||||
|
logger.warning(f"回滚连接时出错: {rollback_error}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# 释放连接回池中
|
||||||
|
if connection_info:
|
||||||
|
connection_info.mark_released()
|
||||||
|
|
||||||
|
async def _get_reusable_connection(
|
||||||
|
self, session_factory: async_sessionmaker[AsyncSession]
|
||||||
|
) -> ConnectionInfo | None:
|
||||||
|
"""获取可复用的连接"""
|
||||||
|
async with self._lock:
|
||||||
|
# 清理过期连接
|
||||||
|
await self._cleanup_expired_connections_locked()
|
||||||
|
|
||||||
|
# 查找可复用的连接
|
||||||
|
for connection_info in list(self._connections):
|
||||||
|
if not connection_info.in_use and not connection_info.is_expired(self.max_lifetime, self.max_idle):
|
||||||
|
# 验证连接是否仍然有效
|
||||||
|
try:
|
||||||
|
# 执行一个简单的查询来验证连接
|
||||||
|
await connection_info.session.execute(text("SELECT 1"))
|
||||||
|
return connection_info
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"连接验证失败,将移除: {e}")
|
||||||
|
await connection_info.close()
|
||||||
|
self._connections.remove(connection_info)
|
||||||
|
self._stats["total_expired"] += 1
|
||||||
|
|
||||||
|
# 检查是否可以创建新连接
|
||||||
|
if len(self._connections) >= self.max_pool_size:
|
||||||
|
logger.warning(f"⚠️ 连接池已满 ({len(self._connections)}/{self.max_pool_size})")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _cleanup_expired_connections_locked(self):
|
||||||
|
"""清理过期连接(需要在锁内调用)"""
|
||||||
|
expired_connections = [
|
||||||
|
connection_info for connection_info in list(self._connections)
|
||||||
|
if connection_info.is_expired(self.max_lifetime, self.max_idle) and not connection_info.in_use
|
||||||
|
]
|
||||||
|
|
||||||
|
for connection_info in expired_connections:
|
||||||
|
await connection_info.close()
|
||||||
|
self._connections.remove(connection_info)
|
||||||
|
self._stats["total_expired"] += 1
|
||||||
|
|
||||||
|
if expired_connections:
|
||||||
|
logger.debug(f"🧹 清理了 {len(expired_connections)} 个过期连接")
|
||||||
|
|
||||||
|
async def _cleanup_loop(self):
|
||||||
|
"""后台清理循环"""
|
||||||
|
while self._should_cleanup:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(30.0) # 每30秒清理一次
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
await self._cleanup_expired_connections_locked()
|
||||||
|
|
||||||
|
# 更新统计信息
|
||||||
|
self._stats["active_connections"] = len(self._connections)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"连接池清理循环出错: {e}")
|
||||||
|
await asyncio.sleep(10.0)
|
||||||
|
|
||||||
|
async def _close_all_connections(self):
|
||||||
|
"""关闭所有连接"""
|
||||||
|
async with self._lock:
|
||||||
|
for connection_info in list(self._connections):
|
||||||
|
await connection_info.close()
|
||||||
|
|
||||||
|
self._connections.clear()
|
||||||
|
logger.info("所有连接已关闭")
|
||||||
|
|
||||||
|
def get_stats(self) -> dict[str, Any]:
|
||||||
|
"""获取连接池统计信息"""
|
||||||
|
total_requests = self._stats["pool_hits"] + self._stats["pool_misses"]
|
||||||
|
pool_efficiency = (self._stats["pool_hits"] / max(1, total_requests)) * 100 if total_requests > 0 else 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
**self._stats,
|
||||||
|
"active_connections": len(self._connections),
|
||||||
|
"max_pool_size": self.max_pool_size,
|
||||||
|
"pool_efficiency": f"{pool_efficiency:.2f}%",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局连接池管理器实例
|
||||||
|
_connection_pool_manager: ConnectionPoolManager | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_connection_pool_manager() -> ConnectionPoolManager:
|
||||||
|
"""获取全局连接池管理器实例"""
|
||||||
|
global _connection_pool_manager
|
||||||
|
if _connection_pool_manager is None:
|
||||||
|
_connection_pool_manager = ConnectionPoolManager()
|
||||||
|
return _connection_pool_manager
|
||||||
|
|
||||||
|
|
||||||
|
async def start_connection_pool():
|
||||||
|
"""启动连接池"""
|
||||||
|
manager = get_connection_pool_manager()
|
||||||
|
await manager.start()
|
||||||
|
|
||||||
|
|
||||||
|
async def stop_connection_pool():
|
||||||
|
"""停止连接池"""
|
||||||
|
global _connection_pool_manager
|
||||||
|
if _connection_pool_manager:
|
||||||
|
await _connection_pool_manager.stop()
|
||||||
|
_connection_pool_manager = None
|
||||||
31
src/common/database/utils/__init__.py
Normal file
31
src/common/database/utils/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""数据库工具层
|
||||||
|
|
||||||
|
职责:
|
||||||
|
- 异常定义
|
||||||
|
- 装饰器工具
|
||||||
|
- 性能监控
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .exceptions import (
|
||||||
|
BatchSchedulerError,
|
||||||
|
CacheError,
|
||||||
|
ConnectionPoolError,
|
||||||
|
DatabaseConnectionError,
|
||||||
|
DatabaseError,
|
||||||
|
DatabaseInitializationError,
|
||||||
|
DatabaseMigrationError,
|
||||||
|
DatabaseQueryError,
|
||||||
|
DatabaseTransactionError,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DatabaseError",
|
||||||
|
"DatabaseInitializationError",
|
||||||
|
"DatabaseConnectionError",
|
||||||
|
"DatabaseQueryError",
|
||||||
|
"DatabaseTransactionError",
|
||||||
|
"DatabaseMigrationError",
|
||||||
|
"CacheError",
|
||||||
|
"BatchSchedulerError",
|
||||||
|
"ConnectionPoolError",
|
||||||
|
]
|
||||||
49
src/common/database/utils/exceptions.py
Normal file
49
src/common/database/utils/exceptions.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""数据库异常定义
|
||||||
|
|
||||||
|
提供统一的异常体系,便于错误处理和调试
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseError(Exception):
|
||||||
|
"""数据库基础异常"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseInitializationError(DatabaseError):
|
||||||
|
"""数据库初始化异常"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseConnectionError(DatabaseError):
|
||||||
|
"""数据库连接异常"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseQueryError(DatabaseError):
|
||||||
|
"""数据库查询异常"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseTransactionError(DatabaseError):
|
||||||
|
"""数据库事务异常"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseMigrationError(DatabaseError):
|
||||||
|
"""数据库迁移异常"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CacheError(DatabaseError):
|
||||||
|
"""缓存异常"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BatchSchedulerError(DatabaseError):
|
||||||
|
"""批量调度器异常"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionPoolError(DatabaseError):
|
||||||
|
"""连接池异常"""
|
||||||
|
pass
|
||||||
Reference in New Issue
Block a user