refactor(database): 阶段一 - 创建新架构基础

- 创建分层目录结构 (core/api/optimization/config/utils)
- 实现核心层: engine.py, session.py
- 实现配置层: database_config.py
- 实现工具层: exceptions.py
- 迁移连接池管理器到优化层
- 添加详细的重构计划文档
This commit is contained in:
Windpicker-owo
2025-11-01 12:35:39 +08:00
parent 5b1cbb49b0
commit fbe6fb759d
11 changed files with 2303 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,9 @@
"""数据库API层
职责:
- CRUD操作
- 查询构建
- 特殊业务操作
"""
__all__ = []

View File

@@ -0,0 +1,14 @@
"""数据库配置层
职责:
- 数据库配置管理
- 优化参数配置
"""
from .database_config import DatabaseConfig, get_database_config, reset_database_config
__all__ = [
"DatabaseConfig",
"get_database_config",
"reset_database_config",
]

View File

@@ -0,0 +1,149 @@
"""数据库配置管理
统一管理数据库连接配置
"""
import os
from dataclasses import dataclass
from typing import Any, Optional
from urllib.parse import quote_plus
from src.common.logger import get_logger
logger = get_logger("database_config")
@dataclass
class DatabaseConfig:
"""数据库配置"""
# 基础配置
db_type: str # "sqlite" 或 "mysql"
url: str # 数据库连接URL
# 引擎配置
engine_kwargs: dict[str, Any]
# SQLite特定配置
sqlite_path: Optional[str] = None
# MySQL特定配置
mysql_host: Optional[str] = None
mysql_port: Optional[int] = None
mysql_user: Optional[str] = None
mysql_password: Optional[str] = None
mysql_database: Optional[str] = None
mysql_charset: str = "utf8mb4"
mysql_unix_socket: Optional[str] = None
_database_config: Optional[DatabaseConfig] = None
def get_database_config() -> DatabaseConfig:
"""获取数据库配置
从全局配置中读取数据库设置并构建配置对象
"""
global _database_config
if _database_config is not None:
return _database_config
from src.config.config import global_config
config = global_config.database
# 构建数据库URL
if config.database_type == "mysql":
# MySQL配置
encoded_user = quote_plus(config.mysql_user)
encoded_password = quote_plus(config.mysql_password)
if config.mysql_unix_socket:
# Unix socket连接
encoded_socket = quote_plus(config.mysql_unix_socket)
url = (
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
f"@/{config.mysql_database}"
f"?unix_socket={encoded_socket}&charset={config.mysql_charset}"
)
else:
# TCP连接
url = (
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
f"?charset={config.mysql_charset}"
)
engine_kwargs = {
"echo": False,
"future": True,
"pool_size": config.connection_pool_size,
"max_overflow": config.connection_pool_size * 2,
"pool_timeout": config.connection_timeout,
"pool_recycle": 3600,
"pool_pre_ping": True,
"connect_args": {
"autocommit": config.mysql_autocommit,
"charset": config.mysql_charset,
"connect_timeout": config.connection_timeout,
},
}
_database_config = DatabaseConfig(
db_type="mysql",
url=url,
engine_kwargs=engine_kwargs,
mysql_host=config.mysql_host,
mysql_port=config.mysql_port,
mysql_user=config.mysql_user,
mysql_password=config.mysql_password,
mysql_database=config.mysql_database,
mysql_charset=config.mysql_charset,
mysql_unix_socket=config.mysql_unix_socket,
)
logger.info(
f"MySQL配置已加载: "
f"{config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
)
else:
# SQLite配置
if not os.path.isabs(config.sqlite_path):
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
else:
db_path = config.sqlite_path
# 确保数据库目录存在
os.makedirs(os.path.dirname(db_path), exist_ok=True)
url = f"sqlite+aiosqlite:///{db_path}"
engine_kwargs = {
"echo": False,
"future": True,
"connect_args": {
"check_same_thread": False,
"timeout": 60,
},
}
_database_config = DatabaseConfig(
db_type="sqlite",
url=url,
engine_kwargs=engine_kwargs,
sqlite_path=db_path,
)
logger.info(f"SQLite配置已加载: {db_path}")
return _database_config
def reset_database_config():
"""重置数据库配置(用于测试)"""
global _database_config
_database_config = None

View File

@@ -0,0 +1,21 @@
"""数据库核心层
职责:
- 数据库引擎管理
- 会话管理
- 模型定义
- 数据库迁移
"""
from .engine import close_engine, get_engine, get_engine_info
from .session import get_db_session, get_db_session_direct, get_session_factory, reset_session_factory
__all__ = [
"get_engine",
"close_engine",
"get_engine_info",
"get_db_session",
"get_db_session_direct",
"get_session_factory",
"reset_session_factory",
]

View File

@@ -0,0 +1,141 @@
"""数据库引擎管理
单一职责创建和管理SQLAlchemy异步引擎
"""
import asyncio
from typing import Optional
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from src.common.logger import get_logger
from ..config.database_config import get_database_config
from ..utils.exceptions import DatabaseInitializationError
logger = get_logger("database.engine")
# 全局引擎实例
_engine: Optional[AsyncEngine] = None
_engine_lock: Optional[asyncio.Lock] = None
async def get_engine() -> AsyncEngine:
"""获取全局数据库引擎(单例模式)
Returns:
AsyncEngine: SQLAlchemy异步引擎
Raises:
DatabaseInitializationError: 引擎初始化失败
"""
global _engine, _engine_lock
# 快速路径:引擎已初始化
if _engine is not None:
return _engine
# 延迟创建锁(避免在导入时创建)
if _engine_lock is None:
_engine_lock = asyncio.Lock()
# 使用锁保护初始化过程
async with _engine_lock:
# 双重检查锁定模式
if _engine is not None:
return _engine
try:
config = get_database_config()
logger.info(f"正在初始化 {config.db_type.upper()} 数据库引擎...")
# 创建异步引擎
_engine = create_async_engine(
config.url,
**config.engine_kwargs
)
# SQLite特定优化
if config.db_type == "sqlite":
await _enable_sqlite_optimizations(_engine)
logger.info(f"{config.db_type.upper()} 数据库引擎初始化成功")
return _engine
except Exception as e:
logger.error(f"❌ 数据库引擎初始化失败: {e}", exc_info=True)
raise DatabaseInitializationError(f"引擎初始化失败: {e}") from e
async def close_engine():
"""关闭数据库引擎
释放所有连接池资源
"""
global _engine
if _engine is not None:
logger.info("正在关闭数据库引擎...")
await _engine.dispose()
_engine = None
logger.info("✅ 数据库引擎已关闭")
async def _enable_sqlite_optimizations(engine: AsyncEngine):
"""启用SQLite性能优化
优化项:
- WAL模式提高并发性能
- NORMAL同步平衡性能和安全性
- 启用外键约束
- 设置busy_timeout避免锁定错误
Args:
engine: SQLAlchemy异步引擎
"""
try:
async with engine.begin() as conn:
# 启用WAL模式
await conn.execute(text("PRAGMA journal_mode = WAL"))
# 设置适中的同步级别
await conn.execute(text("PRAGMA synchronous = NORMAL"))
# 启用外键约束
await conn.execute(text("PRAGMA foreign_keys = ON"))
# 设置busy_timeout避免锁定错误
await conn.execute(text("PRAGMA busy_timeout = 60000"))
# 设置缓存大小10MB
await conn.execute(text("PRAGMA cache_size = -10000"))
# 临时存储使用内存
await conn.execute(text("PRAGMA temp_store = MEMORY"))
logger.info("✅ SQLite性能优化已启用 (WAL模式 + 并发优化)")
except Exception as e:
logger.warning(f"⚠️ SQLite性能优化失败: {e},将使用默认配置")
async def get_engine_info() -> dict:
"""获取引擎信息(用于监控和调试)
Returns:
dict: 引擎信息字典
"""
try:
engine = await get_engine()
info = {
"name": engine.name,
"driver": engine.driver,
"url": str(engine.url).replace(str(engine.url.password or ""), "***"),
"pool_size": getattr(engine.pool, "size", lambda: None)(),
"pool_checked_out": getattr(engine.pool, "checked_out", lambda: 0)(),
"pool_overflow": getattr(engine.pool, "overflow", lambda: 0)(),
}
return info
except Exception as e:
logger.error(f"获取引擎信息失败: {e}")
return {}

View File

@@ -0,0 +1,118 @@
"""数据库会话管理
单一职责:提供数据库会话工厂和上下文管理器
"""
import asyncio
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Optional
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from src.common.logger import get_logger
from ..config.database_config import get_database_config
from .engine import get_engine
logger = get_logger("database.session")
# 全局会话工厂
_session_factory: Optional[async_sessionmaker] = None
_factory_lock: Optional[asyncio.Lock] = None
async def get_session_factory() -> async_sessionmaker:
"""获取会话工厂(单例模式)
Returns:
async_sessionmaker: SQLAlchemy异步会话工厂
"""
global _session_factory, _factory_lock
# 快速路径
if _session_factory is not None:
return _session_factory
# 延迟创建锁
if _factory_lock is None:
_factory_lock = asyncio.Lock()
async with _factory_lock:
# 双重检查
if _session_factory is not None:
return _session_factory
engine = await get_engine()
_session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False, # 避免在commit后访问属性时重新查询
)
logger.debug("会话工厂已创建")
return _session_factory
@asynccontextmanager
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""获取数据库会话上下文管理器
这是数据库操作的主要入口点,通过连接池管理器提供透明的连接复用。
使用示例:
async with get_db_session() as session:
result = await session.execute(select(User))
users = result.scalars().all()
Yields:
AsyncSession: SQLAlchemy异步会话对象
"""
# 延迟导入避免循环依赖
from ..optimization.connection_pool import get_connection_pool_manager
session_factory = await get_session_factory()
pool_manager = get_connection_pool_manager()
# 使用连接池管理器(透明复用连接)
async with pool_manager.get_session(session_factory) as session:
# 为SQLite设置特定的PRAGMA
config = get_database_config()
if config.db_type == "sqlite":
try:
await session.execute(text("PRAGMA busy_timeout = 60000"))
await session.execute(text("PRAGMA foreign_keys = ON"))
except Exception:
# 复用连接时PRAGMA可能已设置忽略错误
pass
yield session
@asynccontextmanager
async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
"""获取数据库会话(直接模式,不使用连接池)
用于特殊场景,如需要完全独立的连接时。
一般情况下应使用 get_db_session()。
Yields:
AsyncSession: SQLAlchemy异步会话对象
"""
session_factory = await get_session_factory()
async with session_factory() as session:
try:
yield session
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def reset_session_factory():
"""重置会话工厂(用于测试)"""
global _session_factory
_session_factory = None

View File

@@ -0,0 +1,22 @@
"""数据库优化层
职责:
- 连接池管理
- 批量调度
- 多级缓存
- 数据预加载
"""
from .connection_pool import (
ConnectionPoolManager,
get_connection_pool_manager,
start_connection_pool,
stop_connection_pool,
)
__all__ = [
"ConnectionPoolManager",
"get_connection_pool_manager",
"start_connection_pool",
"stop_connection_pool",
]

View File

@@ -0,0 +1,274 @@
"""
透明连接复用管理器
在不改变原有API的情况下实现数据库连接的智能复用
"""
import asyncio
import time
from contextlib import asynccontextmanager
from typing import Any
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from src.common.logger import get_logger
logger = get_logger("database.connection_pool")
class ConnectionInfo:
"""连接信息包装器"""
def __init__(self, session: AsyncSession, created_at: float):
self.session = session
self.created_at = created_at
self.last_used = created_at
self.in_use = False
self.ref_count = 0
def mark_used(self):
"""标记连接被使用"""
self.last_used = time.time()
self.in_use = True
self.ref_count += 1
def mark_released(self):
"""标记连接被释放"""
self.in_use = False
self.ref_count = max(0, self.ref_count - 1)
def is_expired(self, max_lifetime: float = 300.0, max_idle: float = 60.0) -> bool:
"""检查连接是否过期"""
current_time = time.time()
# 检查总生命周期
if current_time - self.created_at > max_lifetime:
return True
# 检查空闲时间
if not self.in_use and current_time - self.last_used > max_idle:
return True
return False
async def close(self):
"""关闭连接"""
try:
# 使用 shield 保护 close 操作,确保即使任务被取消也能完成关闭
from typing import cast
await cast(asyncio.Future, asyncio.shield(self.session.close()))
logger.debug("连接已关闭")
except asyncio.CancelledError:
# 这是一个预期的行为,例如在流式聊天中断时
logger.debug("关闭连接时任务被取消")
raise
except Exception as e:
logger.warning(f"关闭连接时出错: {e}")
class ConnectionPoolManager:
"""透明的连接池管理器"""
def __init__(self, max_pool_size: int = 10, max_lifetime: float = 300.0, max_idle: float = 60.0):
self.max_pool_size = max_pool_size
self.max_lifetime = max_lifetime
self.max_idle = max_idle
# 连接池
self._connections: set[ConnectionInfo] = set()
self._lock = asyncio.Lock()
# 统计信息
self._stats = {
"total_created": 0,
"total_reused": 0,
"total_expired": 0,
"active_connections": 0,
"pool_hits": 0,
"pool_misses": 0,
}
# 后台清理任务
self._cleanup_task: asyncio.Task | None = None
self._should_cleanup = False
logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})")
async def start(self):
"""启动连接池管理器"""
if self._cleanup_task is None:
self._should_cleanup = True
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info("✅ 连接池管理器已启动")
async def stop(self):
"""停止连接池管理器"""
self._should_cleanup = False
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
# 关闭所有连接
await self._close_all_connections()
logger.info("✅ 连接池管理器已停止")
@asynccontextmanager
async def get_session(self, session_factory: async_sessionmaker[AsyncSession]):
"""
获取数据库会话的透明包装器
如果有可用连接则复用,否则创建新连接
"""
connection_info = None
try:
# 尝试获取现有连接
connection_info = await self._get_reusable_connection(session_factory)
if connection_info:
# 复用现有连接
connection_info.mark_used()
self._stats["total_reused"] += 1
self._stats["pool_hits"] += 1
logger.debug(f"♻️ 复用连接 (池大小: {len(self._connections)})")
else:
# 创建新连接
session = session_factory()
connection_info = ConnectionInfo(session, time.time())
async with self._lock:
self._connections.add(connection_info)
connection_info.mark_used()
self._stats["total_created"] += 1
self._stats["pool_misses"] += 1
logger.debug(f"🆕 创建连接 (池大小: {len(self._connections)})")
yield connection_info.session
except Exception:
# 发生错误时回滚连接
if connection_info and connection_info.session:
try:
await connection_info.session.rollback()
except Exception as rollback_error:
logger.warning(f"回滚连接时出错: {rollback_error}")
raise
finally:
# 释放连接回池中
if connection_info:
connection_info.mark_released()
async def _get_reusable_connection(
self, session_factory: async_sessionmaker[AsyncSession]
) -> ConnectionInfo | None:
"""获取可复用的连接"""
async with self._lock:
# 清理过期连接
await self._cleanup_expired_connections_locked()
# 查找可复用的连接
for connection_info in list(self._connections):
if not connection_info.in_use and not connection_info.is_expired(self.max_lifetime, self.max_idle):
# 验证连接是否仍然有效
try:
# 执行一个简单的查询来验证连接
await connection_info.session.execute(text("SELECT 1"))
return connection_info
except Exception as e:
logger.debug(f"连接验证失败,将移除: {e}")
await connection_info.close()
self._connections.remove(connection_info)
self._stats["total_expired"] += 1
# 检查是否可以创建新连接
if len(self._connections) >= self.max_pool_size:
logger.warning(f"⚠️ 连接池已满 ({len(self._connections)}/{self.max_pool_size})")
return None
return None
async def _cleanup_expired_connections_locked(self):
"""清理过期连接(需要在锁内调用)"""
expired_connections = [
connection_info for connection_info in list(self._connections)
if connection_info.is_expired(self.max_lifetime, self.max_idle) and not connection_info.in_use
]
for connection_info in expired_connections:
await connection_info.close()
self._connections.remove(connection_info)
self._stats["total_expired"] += 1
if expired_connections:
logger.debug(f"🧹 清理了 {len(expired_connections)} 个过期连接")
async def _cleanup_loop(self):
"""后台清理循环"""
while self._should_cleanup:
try:
await asyncio.sleep(30.0) # 每30秒清理一次
async with self._lock:
await self._cleanup_expired_connections_locked()
# 更新统计信息
self._stats["active_connections"] = len(self._connections)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"连接池清理循环出错: {e}")
await asyncio.sleep(10.0)
async def _close_all_connections(self):
"""关闭所有连接"""
async with self._lock:
for connection_info in list(self._connections):
await connection_info.close()
self._connections.clear()
logger.info("所有连接已关闭")
def get_stats(self) -> dict[str, Any]:
"""获取连接池统计信息"""
total_requests = self._stats["pool_hits"] + self._stats["pool_misses"]
pool_efficiency = (self._stats["pool_hits"] / max(1, total_requests)) * 100 if total_requests > 0 else 0
return {
**self._stats,
"active_connections": len(self._connections),
"max_pool_size": self.max_pool_size,
"pool_efficiency": f"{pool_efficiency:.2f}%",
}
# 全局连接池管理器实例
_connection_pool_manager: ConnectionPoolManager | None = None
def get_connection_pool_manager() -> ConnectionPoolManager:
"""获取全局连接池管理器实例"""
global _connection_pool_manager
if _connection_pool_manager is None:
_connection_pool_manager = ConnectionPoolManager()
return _connection_pool_manager
async def start_connection_pool():
"""启动连接池"""
manager = get_connection_pool_manager()
await manager.start()
async def stop_connection_pool():
"""停止连接池"""
global _connection_pool_manager
if _connection_pool_manager:
await _connection_pool_manager.stop()
_connection_pool_manager = None

View File

@@ -0,0 +1,31 @@
"""数据库工具层
职责:
- 异常定义
- 装饰器工具
- 性能监控
"""
from .exceptions import (
BatchSchedulerError,
CacheError,
ConnectionPoolError,
DatabaseConnectionError,
DatabaseError,
DatabaseInitializationError,
DatabaseMigrationError,
DatabaseQueryError,
DatabaseTransactionError,
)
__all__ = [
"DatabaseError",
"DatabaseInitializationError",
"DatabaseConnectionError",
"DatabaseQueryError",
"DatabaseTransactionError",
"DatabaseMigrationError",
"CacheError",
"BatchSchedulerError",
"ConnectionPoolError",
]

View File

@@ -0,0 +1,49 @@
"""数据库异常定义
提供统一的异常体系,便于错误处理和调试
"""
class DatabaseError(Exception):
"""数据库基础异常"""
pass
class DatabaseInitializationError(DatabaseError):
"""数据库初始化异常"""
pass
class DatabaseConnectionError(DatabaseError):
"""数据库连接异常"""
pass
class DatabaseQueryError(DatabaseError):
"""数据库查询异常"""
pass
class DatabaseTransactionError(DatabaseError):
"""数据库事务异常"""
pass
class DatabaseMigrationError(DatabaseError):
"""数据库迁移异常"""
pass
class CacheError(DatabaseError):
"""缓存异常"""
pass
class BatchSchedulerError(DatabaseError):
"""批量调度器异常"""
pass
class ConnectionPoolError(DatabaseError):
"""连接池异常"""
pass