依旧修pyright喵~
This commit is contained in:
@@ -46,6 +46,7 @@ async def get_engine() -> AsyncEngine:
|
||||
if _engine_lock is None:
|
||||
_engine_lock = asyncio.Lock()
|
||||
|
||||
assert _engine_lock is not None
|
||||
# 使用锁保护初始化过程
|
||||
async with _engine_lock:
|
||||
# 双重检查锁定模式
|
||||
@@ -55,6 +56,7 @@ async def get_engine() -> AsyncEngine:
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
assert global_config is not None
|
||||
config = global_config.database
|
||||
db_type = config.database_type
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ def get_string_field(max_length=255, **kwargs):
|
||||
"""
|
||||
from src.config.config import global_config
|
||||
|
||||
assert global_config is not None
|
||||
db_type = global_config.database.database_type
|
||||
|
||||
# MySQL 索引需要指定长度的 VARCHAR
|
||||
|
||||
@@ -75,6 +75,7 @@ async def _apply_session_settings(session: AsyncSession, db_type: str) -> None:
|
||||
# 可以设置 schema 搜索路径等
|
||||
from src.config.config import global_config
|
||||
|
||||
assert global_config is not None
|
||||
schema = global_config.database.postgresql_schema
|
||||
if schema and schema != "public":
|
||||
await session.execute(text(f"SET search_path TO {schema}"))
|
||||
@@ -114,6 +115,7 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
# 获取数据库类型并应用特定设置
|
||||
from src.config.config import global_config
|
||||
|
||||
assert global_config is not None
|
||||
await _apply_session_settings(session, global_config.database.database_type)
|
||||
|
||||
yield session
|
||||
@@ -142,6 +144,7 @@ async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
|
||||
# 应用数据库特定设置
|
||||
from src.config.config import global_config
|
||||
|
||||
assert global_config is not None
|
||||
await _apply_session_settings(session, global_config.database.database_type)
|
||||
|
||||
yield session
|
||||
|
||||
Reference in New Issue
Block a user