数据库异步支持

仅仅支持还有107处待迁移
This commit is contained in:
雅诺狐
2025-09-19 20:20:20 +08:00
parent 5a0a63464a
commit 4dbc651d74
6 changed files with 169 additions and 136 deletions

View File

@@ -3,16 +3,18 @@
替换Peewee ORM使用SQLAlchemy提供更好的连接池管理和错误恢复能力
"""
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, create_engine, DateTime
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, DateTime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session, Mapped, mapped_column
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.pool import QueuePool
import os
import datetime
import time
from typing import Iterator, Optional, Any, Dict
from typing import Iterator, Optional, Any, Dict, AsyncGenerator
from src.common.logger import get_logger
from contextlib import contextmanager
from contextlib import asynccontextmanager
import asyncio
logger = get_logger("sqlalchemy_models")
@@ -575,14 +577,14 @@ def get_database_url():
# 使用Unix socket连接
encoded_socket = quote_plus(config.mysql_unix_socket)
return (
f"mysql+pymysql://{encoded_user}:{encoded_password}"
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
f"@/{config.mysql_database}"
f"?unix_socket={encoded_socket}&charset={config.mysql_charset}"
)
else:
# 使用标准TCP连接
return (
f"mysql+pymysql://{encoded_user}:{encoded_password}"
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
f"?charset={config.mysql_charset}"
)
@@ -597,11 +599,11 @@ def get_database_url():
# 确保数据库目录存在
os.makedirs(os.path.dirname(db_path), exist_ok=True)
return f"sqlite:///{db_path}"
return f"sqlite+aiosqlite:///{db_path}"
def initialize_database():
"""初始化数据库引擎和会话"""
async def initialize_database():
"""初始化异步数据库引擎和会话"""
global _engine, _SessionLocal
if _engine is not None:
@@ -654,41 +656,40 @@ def initialize_database():
}
)
_engine = create_engine(database_url, **engine_kwargs)
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine)
_engine = create_async_engine(database_url, **engine_kwargs)
_SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
# 调用新的迁移函数,它会处理表的创建和列的添加
from src.common.database.db_migration import check_and_migrate_database
check_and_migrate_database()
await check_and_migrate_database()
logger.info(f"SQLAlchemy数据库初始化成功: {config.database_type}")
logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}")
return _engine, _SessionLocal
@contextmanager
def get_db_session() -> Iterator[Session]:
"""数据库会话上下文管理器 - 推荐使用这个而不是get_session()"""
session: Optional[Session] = None
@asynccontextmanager
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""异步数据库会话上下文管理器"""
session: Optional[AsyncSession] = None
try:
engine, SessionLocal = initialize_database()
engine, SessionLocal = await initialize_database()
if not SessionLocal:
raise RuntimeError("Database session not initialized")
session = SessionLocal()
yield session
# session.commit()
except Exception:
if session:
session.rollback()
await session.rollback()
raise
finally:
if session:
session.close()
await session.close()
def get_engine():
"""获取数据库引擎"""
engine, _ = initialize_database()
async def get_engine():
"""获取异步数据库引擎"""
engine, _ = await initialize_database()
return engine