rufffffff
This commit is contained in:
@@ -19,7 +19,6 @@ from .models import (
|
||||
ChatStreams,
|
||||
Emoji,
|
||||
Expression,
|
||||
get_string_field,
|
||||
GraphEdges,
|
||||
GraphNodes,
|
||||
ImageDescriptions,
|
||||
@@ -37,30 +36,17 @@ from .models import (
|
||||
UserPermissions,
|
||||
UserRelationships,
|
||||
Videos,
|
||||
get_string_field,
|
||||
)
|
||||
from .session import get_db_session, get_db_session_direct, get_session_factory, reset_session_factory
|
||||
|
||||
__all__ = [
|
||||
# Engine
|
||||
"get_engine",
|
||||
"close_engine",
|
||||
"get_engine_info",
|
||||
# Session
|
||||
"get_db_session",
|
||||
"get_db_session_direct",
|
||||
"get_session_factory",
|
||||
"reset_session_factory",
|
||||
# Migration
|
||||
"check_and_migrate_database",
|
||||
"create_all_tables",
|
||||
"drop_all_tables",
|
||||
# Models - Base
|
||||
"Base",
|
||||
"get_string_field",
|
||||
# Models - Tables (按字母顺序)
|
||||
"ActionRecords",
|
||||
"AntiInjectionStats",
|
||||
"BanUser",
|
||||
# Models - Base
|
||||
"Base",
|
||||
"BotPersonalityInterests",
|
||||
"CacheEntries",
|
||||
"ChatStreams",
|
||||
@@ -83,4 +69,18 @@ __all__ = [
|
||||
"UserPermissions",
|
||||
"UserRelationships",
|
||||
"Videos",
|
||||
# Migration
|
||||
"check_and_migrate_database",
|
||||
"close_engine",
|
||||
"create_all_tables",
|
||||
"drop_all_tables",
|
||||
# Session
|
||||
"get_db_session",
|
||||
"get_db_session_direct",
|
||||
# Engine
|
||||
"get_engine",
|
||||
"get_engine_info",
|
||||
"get_session_factory",
|
||||
"get_string_field",
|
||||
"reset_session_factory",
|
||||
]
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Optional
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from sqlalchemy import text
|
||||
@@ -18,49 +17,49 @@ from ..utils.exceptions import DatabaseInitializationError
|
||||
logger = get_logger("database.engine")
|
||||
|
||||
# 全局引擎实例
|
||||
_engine: Optional[AsyncEngine] = None
|
||||
_engine_lock: Optional[asyncio.Lock] = None
|
||||
_engine: AsyncEngine | None = None
|
||||
_engine_lock: asyncio.Lock | None = None
|
||||
|
||||
|
||||
async def get_engine() -> AsyncEngine:
|
||||
"""获取全局数据库引擎(单例模式)
|
||||
|
||||
|
||||
Returns:
|
||||
AsyncEngine: SQLAlchemy异步引擎
|
||||
|
||||
|
||||
Raises:
|
||||
DatabaseInitializationError: 引擎初始化失败
|
||||
"""
|
||||
global _engine, _engine_lock
|
||||
|
||||
|
||||
# 快速路径:引擎已初始化
|
||||
if _engine is not None:
|
||||
return _engine
|
||||
|
||||
|
||||
# 延迟创建锁(避免在导入时创建)
|
||||
if _engine_lock is None:
|
||||
_engine_lock = asyncio.Lock()
|
||||
|
||||
|
||||
# 使用锁保护初始化过程
|
||||
async with _engine_lock:
|
||||
# 双重检查锁定模式
|
||||
if _engine is not None:
|
||||
return _engine
|
||||
|
||||
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
config = global_config.database
|
||||
db_type = config.database_type
|
||||
|
||||
|
||||
logger.info(f"正在初始化 {db_type.upper()} 数据库引擎...")
|
||||
|
||||
|
||||
# 构建数据库URL和引擎参数
|
||||
if db_type == "mysql":
|
||||
# MySQL配置
|
||||
encoded_user = quote_plus(config.mysql_user)
|
||||
encoded_password = quote_plus(config.mysql_password)
|
||||
|
||||
|
||||
if config.mysql_unix_socket:
|
||||
# Unix socket连接
|
||||
encoded_socket = quote_plus(config.mysql_unix_socket)
|
||||
@@ -76,7 +75,7 @@ async def get_engine() -> AsyncEngine:
|
||||
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||
f"?charset={config.mysql_charset}"
|
||||
)
|
||||
|
||||
|
||||
engine_kwargs = {
|
||||
"echo": False,
|
||||
"future": True,
|
||||
@@ -91,11 +90,11 @@ async def get_engine() -> AsyncEngine:
|
||||
"connect_timeout": config.connection_timeout,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
logger.info(
|
||||
f"MySQL配置: {config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
# SQLite配置
|
||||
if not os.path.isabs(config.sqlite_path):
|
||||
@@ -103,12 +102,12 @@ async def get_engine() -> AsyncEngine:
|
||||
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
|
||||
else:
|
||||
db_path = config.sqlite_path
|
||||
|
||||
|
||||
# 确保数据库目录存在
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
|
||||
url = f"sqlite+aiosqlite:///{db_path}"
|
||||
|
||||
|
||||
engine_kwargs = {
|
||||
"echo": False,
|
||||
"future": True,
|
||||
@@ -117,19 +116,19 @@ async def get_engine() -> AsyncEngine:
|
||||
"timeout": 60,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
logger.info(f"SQLite配置: {db_path}")
|
||||
|
||||
|
||||
# 创建异步引擎
|
||||
_engine = create_async_engine(url, **engine_kwargs)
|
||||
|
||||
|
||||
# SQLite特定优化
|
||||
if db_type == "sqlite":
|
||||
await _enable_sqlite_optimizations(_engine)
|
||||
|
||||
|
||||
logger.info(f"✅ {db_type.upper()} 数据库引擎初始化成功")
|
||||
return _engine
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 数据库引擎初始化失败: {e}", exc_info=True)
|
||||
raise DatabaseInitializationError(f"引擎初始化失败: {e}") from e
|
||||
@@ -137,11 +136,11 @@ async def get_engine() -> AsyncEngine:
|
||||
|
||||
async def close_engine():
|
||||
"""关闭数据库引擎
|
||||
|
||||
|
||||
释放所有连接池资源
|
||||
"""
|
||||
global _engine
|
||||
|
||||
|
||||
if _engine is not None:
|
||||
logger.info("正在关闭数据库引擎...")
|
||||
await _engine.dispose()
|
||||
@@ -151,13 +150,13 @@ async def close_engine():
|
||||
|
||||
async def _enable_sqlite_optimizations(engine: AsyncEngine):
|
||||
"""启用SQLite性能优化
|
||||
|
||||
|
||||
优化项:
|
||||
- WAL模式:提高并发性能
|
||||
- NORMAL同步:平衡性能和安全性
|
||||
- 启用外键约束
|
||||
- 设置busy_timeout:避免锁定错误
|
||||
|
||||
|
||||
Args:
|
||||
engine: SQLAlchemy异步引擎
|
||||
"""
|
||||
@@ -175,22 +174,22 @@ async def _enable_sqlite_optimizations(engine: AsyncEngine):
|
||||
await conn.execute(text("PRAGMA cache_size = -10000"))
|
||||
# 临时存储使用内存
|
||||
await conn.execute(text("PRAGMA temp_store = MEMORY"))
|
||||
|
||||
|
||||
logger.info("✅ SQLite性能优化已启用 (WAL模式 + 并发优化)")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ SQLite性能优化失败: {e},将使用默认配置")
|
||||
|
||||
|
||||
async def get_engine_info() -> dict:
|
||||
"""获取引擎信息(用于监控和调试)
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 引擎信息字典
|
||||
"""
|
||||
try:
|
||||
engine = await get_engine()
|
||||
|
||||
|
||||
info = {
|
||||
"name": engine.name,
|
||||
"driver": engine.driver,
|
||||
@@ -199,9 +198,9 @@ async def get_engine_info() -> dict:
|
||||
"pool_checked_out": getattr(engine.pool, "checked_out", lambda: 0)(),
|
||||
"pool_overflow": getattr(engine.pool, "overflow", lambda: 0)(),
|
||||
}
|
||||
|
||||
|
||||
return info
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取引擎信息失败: {e}")
|
||||
return {}
|
||||
|
||||
@@ -20,15 +20,15 @@ logger = get_logger("db_migration")
|
||||
|
||||
async def check_and_migrate_database(existing_engine=None):
|
||||
"""异步检查数据库结构并自动迁移
|
||||
|
||||
|
||||
自动执行以下操作:
|
||||
- 创建不存在的表
|
||||
- 为现有表添加缺失的列
|
||||
- 为现有表创建缺失的索引
|
||||
|
||||
|
||||
Args:
|
||||
existing_engine: 可选的已存在的数据库引擎。如果提供,将使用该引擎;否则获取全局引擎
|
||||
|
||||
|
||||
Note:
|
||||
此函数是幂等的,可以安全地多次调用
|
||||
"""
|
||||
@@ -65,7 +65,7 @@ async def check_and_migrate_database(existing_engine=None):
|
||||
for table in tables_to_create:
|
||||
logger.info(f"表 '{table.name}' 创建成功。")
|
||||
db_table_names.add(table.name) # 将新创建的表添加到集合中
|
||||
|
||||
|
||||
# 提交表创建事务
|
||||
await connection.commit()
|
||||
except Exception as e:
|
||||
@@ -191,40 +191,40 @@ async def check_and_migrate_database(existing_engine=None):
|
||||
|
||||
async def create_all_tables(existing_engine=None):
|
||||
"""创建所有表(不进行迁移检查)
|
||||
|
||||
|
||||
直接创建所有在 Base.metadata 中定义的表。
|
||||
如果表已存在,将被跳过。
|
||||
|
||||
|
||||
Args:
|
||||
existing_engine: 可选的已存在的数据库引擎
|
||||
|
||||
|
||||
Note:
|
||||
生产环境建议使用 check_and_migrate_database()
|
||||
"""
|
||||
logger.info("正在创建所有数据库表...")
|
||||
engine = existing_engine if existing_engine is not None else await get_engine()
|
||||
|
||||
|
||||
async with engine.begin() as connection:
|
||||
await connection.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
logger.info("数据库表创建完成。")
|
||||
|
||||
|
||||
async def drop_all_tables(existing_engine=None):
|
||||
"""删除所有表(危险操作!)
|
||||
|
||||
|
||||
删除所有在 Base.metadata 中定义的表。
|
||||
|
||||
|
||||
Args:
|
||||
existing_engine: 可选的已存在的数据库引擎
|
||||
|
||||
|
||||
Warning:
|
||||
此操作将删除所有数据,不可恢复!仅用于测试环境!
|
||||
"""
|
||||
logger.warning("⚠️ 正在删除所有数据库表...")
|
||||
engine = existing_engine if existing_engine is not None else await get_engine()
|
||||
|
||||
|
||||
async with engine.begin() as connection:
|
||||
await connection.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
logger.warning("所有数据库表已删除。")
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
@@ -18,38 +17,38 @@ from .engine import get_engine
|
||||
logger = get_logger("database.session")
|
||||
|
||||
# 全局会话工厂
|
||||
_session_factory: Optional[async_sessionmaker] = None
|
||||
_factory_lock: Optional[asyncio.Lock] = None
|
||||
_session_factory: async_sessionmaker | None = None
|
||||
_factory_lock: asyncio.Lock | None = None
|
||||
|
||||
|
||||
async def get_session_factory() -> async_sessionmaker:
|
||||
"""获取会话工厂(单例模式)
|
||||
|
||||
|
||||
Returns:
|
||||
async_sessionmaker: SQLAlchemy异步会话工厂
|
||||
"""
|
||||
global _session_factory, _factory_lock
|
||||
|
||||
|
||||
# 快速路径
|
||||
if _session_factory is not None:
|
||||
return _session_factory
|
||||
|
||||
|
||||
# 延迟创建锁
|
||||
if _factory_lock is None:
|
||||
_factory_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async with _factory_lock:
|
||||
# 双重检查
|
||||
if _session_factory is not None:
|
||||
return _session_factory
|
||||
|
||||
|
||||
engine = await get_engine()
|
||||
_session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False, # 避免在commit后访问属性时重新查询
|
||||
)
|
||||
|
||||
|
||||
logger.debug("会话工厂已创建")
|
||||
return _session_factory
|
||||
|
||||
@@ -57,28 +56,28 @@ async def get_session_factory() -> async_sessionmaker:
|
||||
@asynccontextmanager
|
||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""获取数据库会话上下文管理器
|
||||
|
||||
|
||||
这是数据库操作的主要入口点,通过连接池管理器提供透明的连接复用。
|
||||
|
||||
|
||||
使用示例:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(select(User))
|
||||
users = result.scalars().all()
|
||||
|
||||
|
||||
Yields:
|
||||
AsyncSession: SQLAlchemy异步会话对象
|
||||
"""
|
||||
# 延迟导入避免循环依赖
|
||||
from ..optimization.connection_pool import get_connection_pool_manager
|
||||
|
||||
|
||||
session_factory = await get_session_factory()
|
||||
pool_manager = get_connection_pool_manager()
|
||||
|
||||
|
||||
# 使用连接池管理器(透明复用连接)
|
||||
async with pool_manager.get_session(session_factory) as session:
|
||||
# 为SQLite设置特定的PRAGMA
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
if global_config.database.database_type == "sqlite":
|
||||
try:
|
||||
await session.execute(text("PRAGMA busy_timeout = 60000"))
|
||||
@@ -86,22 +85,22 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
except Exception:
|
||||
# 复用连接时PRAGMA可能已设置,忽略错误
|
||||
pass
|
||||
|
||||
|
||||
yield session
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""获取数据库会话(直接模式,不使用连接池)
|
||||
|
||||
|
||||
用于特殊场景,如需要完全独立的连接时。
|
||||
一般情况下应使用 get_db_session()。
|
||||
|
||||
|
||||
Yields:
|
||||
AsyncSession: SQLAlchemy异步会话对象
|
||||
"""
|
||||
session_factory = await get_session_factory()
|
||||
|
||||
|
||||
async with session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
|
||||
Reference in New Issue
Block a user