diff --git a/src/common/database/config/__init__.py b/src/common/database/config/__init__.py index b23071e93..903651d74 100644 --- a/src/common/database/config/__init__.py +++ b/src/common/database/config/__init__.py @@ -1,14 +1,11 @@ """数据库配置层 职责: -- 数据库配置管理 +- 数据库配置现已集成到全局配置中 +- 通过 src.config.config.global_config.database 访问 - 优化参数配置 + +注意:此模块已废弃,配置已迁移到 global_config """ -from .database_config import DatabaseConfig, get_database_config, reset_database_config - -__all__ = [ - "DatabaseConfig", - "get_database_config", - "reset_database_config", -] +__all__ = [] diff --git a/src/common/database/config/database_config.py b/src/common/database/config/old/database_config.py similarity index 100% rename from src/common/database/config/database_config.py rename to src/common/database/config/old/database_config.py diff --git a/src/common/database/core/engine.py b/src/common/database/core/engine.py index 6201f60fd..4b8e0cc7a 100644 --- a/src/common/database/core/engine.py +++ b/src/common/database/core/engine.py @@ -4,14 +4,15 @@ """ import asyncio +import os from typing import Optional +from urllib.parse import quote_plus from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from src.common.logger import get_logger -from ..config.database_config import get_database_config from ..utils.exceptions import DatabaseInitializationError logger = get_logger("database.engine") @@ -47,21 +48,86 @@ async def get_engine() -> AsyncEngine: return _engine try: - config = get_database_config() + from src.config.config import global_config - logger.info(f"正在初始化 {config.db_type.upper()} 数据库引擎...") + config = global_config.database + db_type = config.database_type + + logger.info(f"正在初始化 {db_type.upper()} 数据库引擎...") + + # 构建数据库URL和引擎参数 + if db_type == "mysql": + # MySQL配置 + encoded_user = quote_plus(config.mysql_user) + encoded_password = quote_plus(config.mysql_password) + + if config.mysql_unix_socket: + # Unix socket连接 + encoded_socket = quote_plus(config.mysql_unix_socket) + url = ( + f"mysql+aiomysql://{encoded_user}:{encoded_password}" + f"@/{config.mysql_database}" + f"?unix_socket={encoded_socket}&charset={config.mysql_charset}" + ) + else: + # TCP连接 + url = ( + f"mysql+aiomysql://{encoded_user}:{encoded_password}" + f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" + f"?charset={config.mysql_charset}" + ) + + engine_kwargs = { + "echo": False, + "future": True, + "pool_size": config.connection_pool_size, + "max_overflow": config.connection_pool_size * 2, + "pool_timeout": config.connection_timeout, + "pool_recycle": 3600, + "pool_pre_ping": True, + "connect_args": { + "autocommit": config.mysql_autocommit, + "charset": config.mysql_charset, + "connect_timeout": config.connection_timeout, + }, + } + + logger.info( + f"MySQL配置: {config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" + ) + + else: + # SQLite配置 + if not os.path.isabs(config.sqlite_path): + ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) + db_path = os.path.join(ROOT_PATH, config.sqlite_path) + else: + db_path = config.sqlite_path + + # 确保数据库目录存在 + os.makedirs(os.path.dirname(db_path), exist_ok=True) + + url = f"sqlite+aiosqlite:///{db_path}" + + engine_kwargs = { + "echo": False, + "future": True, + "connect_args": { + "check_same_thread": False, + "timeout": 60, + }, + } + + logger.info(f"SQLite配置: {db_path}") # 创建异步引擎 - _engine = create_async_engine( - config.url, - **config.engine_kwargs - ) + _engine = create_async_engine(url, **engine_kwargs) # SQLite特定优化 - if config.db_type == "sqlite": + if db_type == "sqlite": await _enable_sqlite_optimizations(_engine) - logger.info(f"✅ {config.db_type.upper()} 数据库引擎初始化成功") + logger.info(f"✅ {db_type.upper()} 数据库引擎初始化成功") return _engine except Exception as e: diff --git a/src/common/database/core/session.py b/src/common/database/core/session.py index 4124cdf07..c269ba9c4 100644 --- a/src/common/database/core/session.py +++ b/src/common/database/core/session.py @@ -13,7 +13,6 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from src.common.logger import get_logger -from ..config.database_config import get_database_config from .engine import get_engine logger = get_logger("database.session") @@ -78,8 +77,9 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]: # 使用连接池管理器(透明复用连接) async with pool_manager.get_session(session_factory) as session: # 为SQLite设置特定的PRAGMA - config = get_database_config() - if config.db_type == "sqlite": + from src.config.config import global_config + + if global_config.database.database_type == "sqlite": try: await session.execute(text("PRAGMA busy_timeout = 60000")) await session.execute(text("PRAGMA foreign_keys = ON"))