Files
Mofox-Core/src/common/database/core/session.py

167 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""数据库会话管理
单一职责:提供数据库会话工厂和上下文管理器
支持的数据库类型:
- SQLite: 设置 PRAGMA 参数优化并发
- MySQL: 无特殊会话设置
- PostgreSQL: 可选设置 schema 搜索路径
"""
import asyncio
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from src.common.logger import get_logger
from .engine import get_engine
logger = get_logger("database.session")
# 全局会话工厂
_session_factory: async_sessionmaker | None = None
_factory_lock: asyncio.Lock | None = None
async def get_session_factory() -> async_sessionmaker:
"""获取会话工厂(单例模式)
Returns:
async_sessionmaker: SQLAlchemy异步会话工厂
"""
global _session_factory, _factory_lock
# 快速路径
if _session_factory is not None:
return _session_factory
# 延迟创建锁
if _factory_lock is None:
_factory_lock = asyncio.Lock()
async with _factory_lock:
# 双重检查
if _session_factory is not None:
return _session_factory
engine = await get_engine()
_session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False, # 避免在commit后访问属性时重新查询
)
logger.debug("会话工厂已创建")
return _session_factory
async def _apply_session_settings(session: AsyncSession, db_type: str) -> None:
"""应用数据库特定的会话设置
Args:
session: 数据库会话
db_type: 数据库类型
"""
try:
if db_type == "sqlite":
# SQLite 特定的 PRAGMA 设置
await session.execute(text("PRAGMA busy_timeout = 60000"))
await session.execute(text("PRAGMA foreign_keys = ON"))
elif db_type == "postgresql":
# PostgreSQL 特定设置(如果需要)
# 可以设置 schema 搜索路径等
from src.config.config import global_config
schema = global_config.database.postgresql_schema
if schema and schema != "public":
await session.execute(text(f"SET search_path TO {schema}"))
# MySQL 通常不需要会话级别的特殊设置
except Exception:
# 复用连接时设置可能已存在,忽略错误
pass
@asynccontextmanager
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""获取数据库会话上下文管理器
这是数据库操作的主要入口点,通过连接池管理器提供透明的连接复用。
支持的数据库:
- SQLite: 自动设置 busy_timeout 和外键约束
- MySQL: 直接使用,无特殊设置
- PostgreSQL: 支持自定义 schema
使用示例:
async with get_db_session() as session:
result = await session.execute(select(User))
users = result.scalars().all()
Yields:
AsyncSession: SQLAlchemy异步会话对象
"""
# 延迟导入避免循环依赖
from ..optimization.connection_pool import get_connection_pool_manager
session_factory = await get_session_factory()
pool_manager = get_connection_pool_manager()
# 使用连接池管理器(透明复用连接)
async with pool_manager.get_session(session_factory) as session:
# 获取数据库类型并应用特定设置
from src.config.config import global_config
await _apply_session_settings(session, global_config.database.database_type)
yield session
@asynccontextmanager
async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
"""获取数据库会话(直接模式,不使用连接池)
用于特殊场景,如需要完全独立的连接时。
一般情况下应使用 get_db_session()。
事务管理说明:
- 正常退出时自动提交事务
- 发生异常时自动回滚事务
- 如果用户代码已手动调用 commit/rollback再次调用是安全的
- 适用于所有数据库类型SQLite, MySQL, PostgreSQL
Yields:
AsyncSession: SQLAlchemy异步会话对象
"""
session_factory = await get_session_factory()
async with session_factory() as session:
try:
# 应用数据库特定设置
from src.config.config import global_config
await _apply_session_settings(session, global_config.database.database_type)
yield session
# 正常退出时提交事务
# 这对所有数据库都很重要,因为 SQLAlchemy 默认不是 autocommit 模式
# 检查事务是否活动,避免在已回滚的事务上提交
if session.is_active:
await session.commit()
except Exception:
# 检查是否需要回滚(事务是否活动)
if session.is_active:
await session.rollback()
raise
finally:
await session.close()
async def reset_session_factory():
"""重置会话工厂(用于测试)"""
global _session_factory
_session_factory = None