数据库异步支持
仅仅支持还有107处待迁移
This commit is contained in:
@@ -3,16 +3,18 @@
|
||||
替换Peewee ORM,使用SQLAlchemy提供更好的连接池管理和错误恢复能力
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, create_engine, DateTime
|
||||
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, DateTime
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session, Mapped, mapped_column
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.pool import QueuePool
|
||||
import os
|
||||
import datetime
|
||||
import time
|
||||
from typing import Iterator, Optional, Any, Dict
|
||||
from typing import Iterator, Optional, Any, Dict, AsyncGenerator
|
||||
from src.common.logger import get_logger
|
||||
from contextlib import contextmanager
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
|
||||
logger = get_logger("sqlalchemy_models")
|
||||
|
||||
@@ -575,14 +577,14 @@ def get_database_url():
|
||||
# 使用Unix socket连接
|
||||
encoded_socket = quote_plus(config.mysql_unix_socket)
|
||||
return (
|
||||
f"mysql+pymysql://{encoded_user}:{encoded_password}"
|
||||
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
|
||||
f"@/{config.mysql_database}"
|
||||
f"?unix_socket={encoded_socket}&charset={config.mysql_charset}"
|
||||
)
|
||||
else:
|
||||
# 使用标准TCP连接
|
||||
return (
|
||||
f"mysql+pymysql://{encoded_user}:{encoded_password}"
|
||||
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
|
||||
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||
f"?charset={config.mysql_charset}"
|
||||
)
|
||||
@@ -597,11 +599,11 @@ def get_database_url():
|
||||
# 确保数据库目录存在
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
return f"sqlite:///{db_path}"
|
||||
return f"sqlite+aiosqlite:///{db_path}"
|
||||
|
||||
|
||||
def initialize_database():
|
||||
"""初始化数据库引擎和会话"""
|
||||
async def initialize_database():
|
||||
"""初始化异步数据库引擎和会话"""
|
||||
global _engine, _SessionLocal
|
||||
|
||||
if _engine is not None:
|
||||
@@ -654,41 +656,40 @@ def initialize_database():
|
||||
}
|
||||
)
|
||||
|
||||
_engine = create_engine(database_url, **engine_kwargs)
|
||||
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine)
|
||||
_engine = create_async_engine(database_url, **engine_kwargs)
|
||||
_SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
# 调用新的迁移函数,它会处理表的创建和列的添加
|
||||
from src.common.database.db_migration import check_and_migrate_database
|
||||
|
||||
check_and_migrate_database()
|
||||
await check_and_migrate_database()
|
||||
|
||||
logger.info(f"SQLAlchemy数据库初始化成功: {config.database_type}")
|
||||
logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}")
|
||||
return _engine, _SessionLocal
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session() -> Iterator[Session]:
|
||||
"""数据库会话上下文管理器 - 推荐使用这个而不是get_session()"""
|
||||
session: Optional[Session] = None
|
||||
@asynccontextmanager
|
||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""异步数据库会话上下文管理器"""
|
||||
session: Optional[AsyncSession] = None
|
||||
try:
|
||||
engine, SessionLocal = initialize_database()
|
||||
engine, SessionLocal = await initialize_database()
|
||||
if not SessionLocal:
|
||||
raise RuntimeError("Database session not initialized")
|
||||
session = SessionLocal()
|
||||
yield session
|
||||
# session.commit()
|
||||
except Exception:
|
||||
if session:
|
||||
session.rollback()
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
if session:
|
||||
session.close()
|
||||
await session.close()
|
||||
|
||||
|
||||
def get_engine():
|
||||
"""获取数据库引擎"""
|
||||
engine, _ = initialize_database()
|
||||
async def get_engine():
|
||||
"""获取异步数据库引擎"""
|
||||
engine, _ = await initialize_database()
|
||||
return engine
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user