feat: 更新机器人配置并添加数据库迁移脚本
- 将bot_config_template.toml中的版本升级至7.9.0 - 增强数据库配置选项以支持PostgreSQL - 引入一个新脚本,用于在SQLite、MySQL和PostgreSQL之间迁移数据 - 实现一个方言适配器,用于处理特定于数据库的行为和配置
This commit is contained in:
@@ -5,8 +5,22 @@
|
||||
- 会话管理
|
||||
- 模型定义
|
||||
- 数据库迁移
|
||||
- 方言适配
|
||||
|
||||
支持的数据库:
|
||||
- SQLite (默认)
|
||||
- MySQL
|
||||
- PostgreSQL
|
||||
"""
|
||||
|
||||
from .dialect_adapter import (
|
||||
DatabaseDialect,
|
||||
DialectAdapter,
|
||||
DialectConfig,
|
||||
get_dialect_adapter,
|
||||
get_indexed_string_field,
|
||||
get_text_field,
|
||||
)
|
||||
from .engine import close_engine, get_engine, get_engine_info
|
||||
from .migration import check_and_migrate_database, create_all_tables, drop_all_tables
|
||||
from .models import (
|
||||
@@ -50,6 +64,10 @@ __all__ = [
|
||||
"BotPersonalityInterests",
|
||||
"CacheEntries",
|
||||
"ChatStreams",
|
||||
# Dialect Adapter
|
||||
"DatabaseDialect",
|
||||
"DialectAdapter",
|
||||
"DialectConfig",
|
||||
"Emoji",
|
||||
"Expression",
|
||||
"GraphEdges",
|
||||
@@ -77,10 +95,13 @@ __all__ = [
|
||||
# Session
|
||||
"get_db_session",
|
||||
"get_db_session_direct",
|
||||
"get_dialect_adapter",
|
||||
# Engine
|
||||
"get_engine",
|
||||
"get_engine_info",
|
||||
"get_indexed_string_field",
|
||||
"get_session_factory",
|
||||
"get_string_field",
|
||||
"get_text_field",
|
||||
"reset_session_factory",
|
||||
]
|
||||
|
||||
230
src/common/database/core/dialect_adapter.py
Normal file
230
src/common/database/core/dialect_adapter.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""数据库方言适配器
|
||||
|
||||
提供跨数据库兼容性支持,处理不同数据库之间的差异:
|
||||
- SQLite: 轻量级本地数据库
|
||||
- MySQL: 高性能关系型数据库
|
||||
- PostgreSQL: 功能丰富的开源数据库
|
||||
|
||||
主要职责:
|
||||
1. 提供数据库特定的类型映射
|
||||
2. 处理方言特定的查询语法
|
||||
3. 提供数据库特定的优化配置
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import String, Text
|
||||
from sqlalchemy.types import TypeEngine
|
||||
|
||||
|
||||
class DatabaseDialect(Enum):
|
||||
"""数据库方言枚举"""
|
||||
|
||||
SQLITE = "sqlite"
|
||||
MYSQL = "mysql"
|
||||
POSTGRESQL = "postgresql"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DialectConfig:
|
||||
"""方言配置"""
|
||||
|
||||
dialect: DatabaseDialect
|
||||
# 连接验证查询
|
||||
ping_query: str
|
||||
# 是否支持 RETURNING 子句
|
||||
supports_returning: bool
|
||||
# 是否支持原生 JSON 类型
|
||||
supports_native_json: bool
|
||||
# 是否支持数组类型
|
||||
supports_arrays: bool
|
||||
# 是否需要指定字符串长度用于索引
|
||||
requires_length_for_index: bool
|
||||
# 默认字符串长度(用于索引列)
|
||||
default_string_length: int
|
||||
# 事务隔离级别
|
||||
isolation_level: str
|
||||
# 额外的引擎参数
|
||||
engine_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# 预定义的方言配置
|
||||
DIALECT_CONFIGS: dict[DatabaseDialect, DialectConfig] = {
|
||||
DatabaseDialect.SQLITE: DialectConfig(
|
||||
dialect=DatabaseDialect.SQLITE,
|
||||
ping_query="SELECT 1",
|
||||
supports_returning=True, # SQLite 3.35+ 支持
|
||||
supports_native_json=False,
|
||||
supports_arrays=False,
|
||||
requires_length_for_index=False,
|
||||
default_string_length=255,
|
||||
isolation_level="SERIALIZABLE",
|
||||
engine_kwargs={
|
||||
"connect_args": {
|
||||
"check_same_thread": False,
|
||||
"timeout": 60,
|
||||
}
|
||||
},
|
||||
),
|
||||
DatabaseDialect.MYSQL: DialectConfig(
|
||||
dialect=DatabaseDialect.MYSQL,
|
||||
ping_query="SELECT 1",
|
||||
supports_returning=False, # MySQL 8.0.21+ 有限支持
|
||||
supports_native_json=True, # MySQL 5.7+
|
||||
supports_arrays=False,
|
||||
requires_length_for_index=True, # MySQL 索引需要指定长度
|
||||
default_string_length=255,
|
||||
isolation_level="READ COMMITTED",
|
||||
engine_kwargs={
|
||||
"pool_pre_ping": True,
|
||||
"pool_recycle": 3600,
|
||||
},
|
||||
),
|
||||
DatabaseDialect.POSTGRESQL: DialectConfig(
|
||||
dialect=DatabaseDialect.POSTGRESQL,
|
||||
ping_query="SELECT 1",
|
||||
supports_returning=True,
|
||||
supports_native_json=True,
|
||||
supports_arrays=True,
|
||||
requires_length_for_index=False,
|
||||
default_string_length=255,
|
||||
isolation_level="READ COMMITTED",
|
||||
engine_kwargs={
|
||||
"pool_pre_ping": True,
|
||||
"pool_recycle": 3600,
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class DialectAdapter:
|
||||
"""数据库方言适配器
|
||||
|
||||
根据当前配置的数据库类型,提供相应的类型映射和查询支持
|
||||
"""
|
||||
|
||||
_current_dialect: DatabaseDialect | None = None
|
||||
_config: DialectConfig | None = None
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, db_type: str) -> None:
|
||||
"""初始化适配器
|
||||
|
||||
Args:
|
||||
db_type: 数据库类型字符串 ("sqlite", "mysql", "postgresql")
|
||||
"""
|
||||
try:
|
||||
cls._current_dialect = DatabaseDialect(db_type.lower())
|
||||
cls._config = DIALECT_CONFIGS[cls._current_dialect]
|
||||
except ValueError:
|
||||
raise ValueError(f"不支持的数据库类型: {db_type},支持的类型: sqlite, mysql, postgresql")
|
||||
|
||||
@classmethod
|
||||
def get_dialect(cls) -> DatabaseDialect:
|
||||
"""获取当前数据库方言"""
|
||||
if cls._current_dialect is None:
|
||||
# 延迟初始化:从配置获取
|
||||
from src.config.config import global_config
|
||||
|
||||
if global_config is None:
|
||||
raise RuntimeError("配置尚未初始化,无法获取数据库方言")
|
||||
cls.initialize(global_config.database.database_type)
|
||||
return cls._current_dialect # type: ignore
|
||||
|
||||
@classmethod
|
||||
def get_config(cls) -> DialectConfig:
|
||||
"""获取当前方言配置"""
|
||||
if cls._config is None:
|
||||
cls.get_dialect() # 触发初始化
|
||||
return cls._config # type: ignore
|
||||
|
||||
@classmethod
|
||||
def get_string_type(cls, max_length: int = 255, indexed: bool = False) -> TypeEngine:
|
||||
"""获取适合当前数据库的字符串类型
|
||||
|
||||
Args:
|
||||
max_length: 最大长度
|
||||
indexed: 是否用于索引
|
||||
|
||||
Returns:
|
||||
SQLAlchemy 类型
|
||||
"""
|
||||
config = cls.get_config()
|
||||
|
||||
# MySQL 索引列需要指定长度
|
||||
if config.requires_length_for_index and indexed:
|
||||
return String(max_length)
|
||||
|
||||
# SQLite 和 PostgreSQL 可以使用 Text
|
||||
if config.dialect in (DatabaseDialect.SQLITE, DatabaseDialect.POSTGRESQL):
|
||||
return Text() if not indexed else String(max_length)
|
||||
|
||||
# MySQL 使用 VARCHAR
|
||||
return String(max_length)
|
||||
|
||||
@classmethod
|
||||
def get_ping_query(cls) -> str:
|
||||
"""获取连接验证查询"""
|
||||
return cls.get_config().ping_query
|
||||
|
||||
@classmethod
|
||||
def supports_returning(cls) -> bool:
|
||||
"""是否支持 RETURNING 子句"""
|
||||
return cls.get_config().supports_returning
|
||||
|
||||
@classmethod
|
||||
def supports_native_json(cls) -> bool:
|
||||
"""是否支持原生 JSON 类型"""
|
||||
return cls.get_config().supports_native_json
|
||||
|
||||
@classmethod
|
||||
def get_engine_kwargs(cls) -> dict[str, Any]:
|
||||
"""获取引擎额外参数"""
|
||||
return cls.get_config().engine_kwargs.copy()
|
||||
|
||||
@classmethod
|
||||
def is_sqlite(cls) -> bool:
|
||||
"""是否为 SQLite"""
|
||||
return cls.get_dialect() == DatabaseDialect.SQLITE
|
||||
|
||||
@classmethod
|
||||
def is_mysql(cls) -> bool:
|
||||
"""是否为 MySQL"""
|
||||
return cls.get_dialect() == DatabaseDialect.MYSQL
|
||||
|
||||
@classmethod
|
||||
def is_postgresql(cls) -> bool:
|
||||
"""是否为 PostgreSQL"""
|
||||
return cls.get_dialect() == DatabaseDialect.POSTGRESQL
|
||||
|
||||
|
||||
def get_dialect_adapter() -> type[DialectAdapter]:
|
||||
"""获取方言适配器类"""
|
||||
return DialectAdapter
|
||||
|
||||
|
||||
def get_indexed_string_field(max_length: int = 255) -> TypeEngine:
|
||||
"""获取用于索引的字符串字段类型
|
||||
|
||||
这是一个便捷函数,用于在模型定义中获取适合当前数据库的字符串类型
|
||||
|
||||
Args:
|
||||
max_length: 最大长度(对于 MySQL 是必需的)
|
||||
|
||||
Returns:
|
||||
SQLAlchemy 类型
|
||||
"""
|
||||
return DialectAdapter.get_string_type(max_length, indexed=True)
|
||||
|
||||
|
||||
def get_text_field() -> TypeEngine:
|
||||
"""获取文本字段类型
|
||||
|
||||
用于不需要索引的大文本字段
|
||||
|
||||
Returns:
|
||||
SQLAlchemy Text 类型
|
||||
"""
|
||||
return Text()
|
||||
@@ -1,6 +1,11 @@
|
||||
"""数据库引擎管理
|
||||
|
||||
单一职责:创建和管理SQLAlchemy异步引擎
|
||||
|
||||
支持的数据库类型:
|
||||
- SQLite: 轻量级本地数据库,使用 aiosqlite 驱动
|
||||
- MySQL: 高性能关系型数据库,使用 aiomysql 驱动
|
||||
- PostgreSQL: 功能丰富的开源数据库,使用 asyncpg 驱动
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -13,6 +18,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..utils.exceptions import DatabaseInitializationError
|
||||
from .dialect_adapter import DialectAdapter
|
||||
|
||||
logger = get_logger("database.engine")
|
||||
|
||||
@@ -52,79 +58,27 @@ async def get_engine() -> AsyncEngine:
|
||||
config = global_config.database
|
||||
db_type = config.database_type
|
||||
|
||||
# 初始化方言适配器
|
||||
DialectAdapter.initialize(db_type)
|
||||
|
||||
logger.info(f"正在初始化 {db_type.upper()} 数据库引擎...")
|
||||
|
||||
# 构建数据库URL和引擎参数
|
||||
# 根据数据库类型构建URL和引擎参数
|
||||
if db_type == "mysql":
|
||||
# MySQL配置
|
||||
encoded_user = quote_plus(config.mysql_user)
|
||||
encoded_password = quote_plus(config.mysql_password)
|
||||
|
||||
if config.mysql_unix_socket:
|
||||
# Unix socket连接
|
||||
encoded_socket = quote_plus(config.mysql_unix_socket)
|
||||
url = (
|
||||
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
|
||||
f"@/{config.mysql_database}"
|
||||
f"?unix_socket={encoded_socket}&charset={config.mysql_charset}"
|
||||
)
|
||||
else:
|
||||
# TCP连接
|
||||
url = (
|
||||
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
|
||||
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||
f"?charset={config.mysql_charset}"
|
||||
)
|
||||
|
||||
engine_kwargs = {
|
||||
"echo": False,
|
||||
"future": True,
|
||||
"pool_size": config.connection_pool_size,
|
||||
"max_overflow": config.connection_pool_size * 2,
|
||||
"pool_timeout": config.connection_timeout,
|
||||
"pool_recycle": 3600,
|
||||
"pool_pre_ping": True,
|
||||
"connect_args": {
|
||||
"autocommit": config.mysql_autocommit,
|
||||
"charset": config.mysql_charset,
|
||||
"connect_timeout": config.connection_timeout,
|
||||
},
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"MySQL配置: {config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||
)
|
||||
|
||||
url, engine_kwargs = _build_mysql_config(config)
|
||||
elif db_type == "postgresql":
|
||||
url, engine_kwargs = _build_postgresql_config(config)
|
||||
else:
|
||||
# SQLite配置
|
||||
if not os.path.isabs(config.sqlite_path):
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
|
||||
else:
|
||||
db_path = config.sqlite_path
|
||||
|
||||
# 确保数据库目录存在
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
url = f"sqlite+aiosqlite:///{db_path}"
|
||||
|
||||
engine_kwargs = {
|
||||
"echo": False,
|
||||
"future": True,
|
||||
"connect_args": {
|
||||
"check_same_thread": False,
|
||||
"timeout": 60,
|
||||
},
|
||||
}
|
||||
|
||||
logger.info(f"SQLite配置: {db_path}")
|
||||
url, engine_kwargs = _build_sqlite_config(config)
|
||||
|
||||
# 创建异步引擎
|
||||
_engine = create_async_engine(url, **engine_kwargs)
|
||||
|
||||
# SQLite特定优化
|
||||
# 数据库特定优化
|
||||
if db_type == "sqlite":
|
||||
await _enable_sqlite_optimizations(_engine)
|
||||
elif db_type == "postgresql":
|
||||
await _enable_postgresql_optimizations(_engine)
|
||||
|
||||
logger.info(f"✅ {db_type.upper()} 数据库引擎初始化成功")
|
||||
return _engine
|
||||
@@ -134,6 +88,141 @@ async def get_engine() -> AsyncEngine:
|
||||
raise DatabaseInitializationError(f"引擎初始化失败: {e}") from e
|
||||
|
||||
|
||||
def _build_sqlite_config(config) -> tuple[str, dict]:
|
||||
"""构建 SQLite 配置
|
||||
|
||||
Args:
|
||||
config: 数据库配置对象
|
||||
|
||||
Returns:
|
||||
(url, engine_kwargs) 元组
|
||||
"""
|
||||
if not os.path.isabs(config.sqlite_path):
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
|
||||
else:
|
||||
db_path = config.sqlite_path
|
||||
|
||||
# 确保数据库目录存在
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
url = f"sqlite+aiosqlite:///{db_path}"
|
||||
|
||||
engine_kwargs = {
|
||||
"echo": False,
|
||||
"future": True,
|
||||
"connect_args": {
|
||||
"check_same_thread": False,
|
||||
"timeout": 60,
|
||||
},
|
||||
}
|
||||
|
||||
logger.info(f"SQLite配置: {db_path}")
|
||||
return url, engine_kwargs
|
||||
|
||||
|
||||
def _build_mysql_config(config) -> tuple[str, dict]:
|
||||
"""构建 MySQL 配置
|
||||
|
||||
Args:
|
||||
config: 数据库配置对象
|
||||
|
||||
Returns:
|
||||
(url, engine_kwargs) 元组
|
||||
"""
|
||||
encoded_user = quote_plus(config.mysql_user)
|
||||
encoded_password = quote_plus(config.mysql_password)
|
||||
|
||||
if config.mysql_unix_socket:
|
||||
# Unix socket连接
|
||||
encoded_socket = quote_plus(config.mysql_unix_socket)
|
||||
url = (
|
||||
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
|
||||
f"@/{config.mysql_database}"
|
||||
f"?unix_socket={encoded_socket}&charset={config.mysql_charset}"
|
||||
)
|
||||
else:
|
||||
# TCP连接
|
||||
url = (
|
||||
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
|
||||
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||
f"?charset={config.mysql_charset}"
|
||||
)
|
||||
|
||||
engine_kwargs = {
|
||||
"echo": False,
|
||||
"future": True,
|
||||
"pool_size": config.connection_pool_size,
|
||||
"max_overflow": config.connection_pool_size * 2,
|
||||
"pool_timeout": config.connection_timeout,
|
||||
"pool_recycle": 3600,
|
||||
"pool_pre_ping": True,
|
||||
"connect_args": {
|
||||
"autocommit": config.mysql_autocommit,
|
||||
"charset": config.mysql_charset,
|
||||
"connect_timeout": config.connection_timeout,
|
||||
},
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"MySQL配置: {config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||
)
|
||||
return url, engine_kwargs
|
||||
|
||||
|
||||
def _build_postgresql_config(config) -> tuple[str, dict]:
|
||||
"""构建 PostgreSQL 配置
|
||||
|
||||
Args:
|
||||
config: 数据库配置对象
|
||||
|
||||
Returns:
|
||||
(url, engine_kwargs) 元组
|
||||
"""
|
||||
encoded_user = quote_plus(config.postgresql_user)
|
||||
encoded_password = quote_plus(config.postgresql_password)
|
||||
|
||||
# 构建基本 URL
|
||||
url = (
|
||||
f"postgresql+asyncpg://{encoded_user}:{encoded_password}"
|
||||
f"@{config.postgresql_host}:{config.postgresql_port}/{config.postgresql_database}"
|
||||
)
|
||||
|
||||
# SSL 配置
|
||||
connect_args = {}
|
||||
if config.postgresql_ssl_mode != "disable":
|
||||
ssl_config = {"ssl": config.postgresql_ssl_mode}
|
||||
if config.postgresql_ssl_ca:
|
||||
ssl_config["ssl_ca"] = config.postgresql_ssl_ca
|
||||
if config.postgresql_ssl_cert:
|
||||
ssl_config["ssl_cert"] = config.postgresql_ssl_cert
|
||||
if config.postgresql_ssl_key:
|
||||
ssl_config["ssl_key"] = config.postgresql_ssl_key
|
||||
connect_args.update(ssl_config)
|
||||
|
||||
# 设置 schema(如果不是 public)
|
||||
if config.postgresql_schema and config.postgresql_schema != "public":
|
||||
connect_args["server_settings"] = {"search_path": config.postgresql_schema}
|
||||
|
||||
engine_kwargs = {
|
||||
"echo": False,
|
||||
"future": True,
|
||||
"pool_size": config.connection_pool_size,
|
||||
"max_overflow": config.connection_pool_size * 2,
|
||||
"pool_timeout": config.connection_timeout,
|
||||
"pool_recycle": 3600,
|
||||
"pool_pre_ping": True,
|
||||
}
|
||||
|
||||
if connect_args:
|
||||
engine_kwargs["connect_args"] = connect_args
|
||||
|
||||
logger.info(
|
||||
f"PostgreSQL配置: {config.postgresql_user}@{config.postgresql_host}:{config.postgresql_port}/{config.postgresql_database}"
|
||||
)
|
||||
return url, engine_kwargs
|
||||
|
||||
|
||||
async def close_engine():
|
||||
"""关闭数据库引擎
|
||||
|
||||
@@ -181,6 +270,33 @@ async def _enable_sqlite_optimizations(engine: AsyncEngine):
|
||||
logger.warning(f"⚠️ SQLite性能优化失败: {e},将使用默认配置")
|
||||
|
||||
|
||||
async def _enable_postgresql_optimizations(engine: AsyncEngine):
|
||||
"""启用PostgreSQL性能优化
|
||||
|
||||
优化项:
|
||||
- 设置合适的 work_mem
|
||||
- 启用 JIT 编译(如果可用)
|
||||
- 设置合适的 statement_timeout
|
||||
|
||||
Args:
|
||||
engine: SQLAlchemy异步引擎
|
||||
"""
|
||||
try:
|
||||
async with engine.begin() as conn:
|
||||
# 设置会话级别的参数
|
||||
# work_mem: 排序和哈希操作的内存(64MB)
|
||||
await conn.execute(text("SET work_mem = '64MB'"))
|
||||
# 设置语句超时(5分钟)
|
||||
await conn.execute(text("SET statement_timeout = '300000'"))
|
||||
# 启用自动 EXPLAIN(可选,用于调试)
|
||||
# await conn.execute(text("SET auto_explain.log_min_duration = '1000'"))
|
||||
|
||||
logger.info("✅ PostgreSQL性能优化已启用")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ PostgreSQL性能优化失败: {e},将使用默认配置")
|
||||
|
||||
|
||||
async def get_engine_info() -> dict:
|
||||
"""获取引擎信息(用于监控和调试)
|
||||
|
||||
|
||||
@@ -99,12 +99,17 @@ async def check_and_migrate_database(existing_engine=None):
|
||||
|
||||
def add_columns_sync(conn):
|
||||
dialect = conn.dialect
|
||||
compiler = dialect.ddl_compiler(dialect, None)
|
||||
|
||||
|
||||
for column_name in missing_columns:
|
||||
column = table.c[column_name]
|
||||
column_type = compiler.get_column_specification(column)
|
||||
sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type}"
|
||||
|
||||
# 获取列类型的 SQL 表示
|
||||
# 使用 compile 方法获取正确的类型字符串
|
||||
type_compiler = dialect.type_compiler(dialect)
|
||||
column_type_sql = column.type.compile(dialect=dialect)
|
||||
|
||||
# 构建 ALTER TABLE 语句
|
||||
sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type_sql}"
|
||||
|
||||
if column.default:
|
||||
# 手动处理不同方言的默认值
|
||||
@@ -114,26 +119,18 @@ async def check_and_migrate_database(existing_engine=None):
|
||||
):
|
||||
# SQLite 将布尔值存储为 0 或 1
|
||||
default_value = "1" if default_arg else "0"
|
||||
elif hasattr(compiler, "render_literal_value"):
|
||||
try:
|
||||
# 尝试使用 render_literal_value
|
||||
default_value = compiler.render_literal_value(
|
||||
default_arg, column.type
|
||||
)
|
||||
except AttributeError:
|
||||
# 如果失败,则回退到简单的字符串转换
|
||||
default_value = (
|
||||
f"'{default_arg}'"
|
||||
if isinstance(default_arg, str)
|
||||
else str(default_arg)
|
||||
)
|
||||
elif dialect.name == "mysql" and isinstance(default_arg, bool):
|
||||
# MySQL 也使用 1/0 表示布尔值
|
||||
default_value = "1" if default_arg else "0"
|
||||
elif isinstance(default_arg, bool):
|
||||
# PostgreSQL 使用 TRUE/FALSE
|
||||
default_value = "TRUE" if default_arg else "FALSE"
|
||||
elif isinstance(default_arg, str):
|
||||
default_value = f"'{default_arg}'"
|
||||
elif default_arg is None:
|
||||
default_value = "NULL"
|
||||
else:
|
||||
# 对于没有 render_literal_value 的旧版或特定方言
|
||||
default_value = (
|
||||
f"'{default_arg}'"
|
||||
if isinstance(default_arg, str)
|
||||
else str(default_arg)
|
||||
)
|
||||
default_value = str(default_arg)
|
||||
|
||||
sql += f" DEFAULT {default_value}"
|
||||
|
||||
|
||||
@@ -3,6 +3,11 @@
|
||||
本文件只包含纯模型定义,使用SQLAlchemy 2.0的Mapped类型注解风格。
|
||||
引擎和会话管理已移至core/engine.py和core/session.py。
|
||||
|
||||
支持的数据库类型:
|
||||
- SQLite: 使用 Text 类型
|
||||
- MySQL: 使用 VARCHAR(max_length) 用于索引字段
|
||||
- PostgreSQL: 使用 Text 类型(PostgreSQL 的 Text 类型性能与 VARCHAR 相当)
|
||||
|
||||
所有模型使用统一的类型注解风格:
|
||||
field_name: Mapped[PyType] = mapped_column(Type, ...)
|
||||
|
||||
@@ -20,16 +25,34 @@ from sqlalchemy.orm import Mapped, mapped_column
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
# MySQL兼容的字段类型辅助函数
|
||||
# 数据库兼容的字段类型辅助函数
|
||||
def get_string_field(max_length=255, **kwargs):
|
||||
"""
|
||||
根据数据库类型返回合适的字符串字段
|
||||
MySQL需要指定长度的VARCHAR用于索引,SQLite可以使用Text
|
||||
根据数据库类型返回合适的字符串字段类型
|
||||
|
||||
对于需要索引的字段:
|
||||
- MySQL: 必须使用 VARCHAR(max_length),因为索引需要指定长度
|
||||
- PostgreSQL: 可以使用 Text,但为了兼容性使用 VARCHAR
|
||||
- SQLite: 可以使用 Text,无长度限制
|
||||
|
||||
Args:
|
||||
max_length: 最大长度(对于 MySQL 是必需的)
|
||||
**kwargs: 传递给 String/Text 的额外参数
|
||||
|
||||
Returns:
|
||||
SQLAlchemy 类型
|
||||
"""
|
||||
from src.config.config import global_config
|
||||
|
||||
if global_config.database.database_type == "mysql":
|
||||
db_type = global_config.database.database_type
|
||||
|
||||
# MySQL 索引需要指定长度的 VARCHAR
|
||||
if db_type == "mysql":
|
||||
return String(max_length, **kwargs)
|
||||
# PostgreSQL 可以使用 Text,但为了跨数据库迁移兼容性,使用 VARCHAR
|
||||
elif db_type == "postgresql":
|
||||
return String(max_length, **kwargs)
|
||||
# SQLite 使用 Text(无长度限制)
|
||||
else:
|
||||
return Text(**kwargs)
|
||||
|
||||
@@ -477,7 +500,7 @@ class BanUser(Base):
|
||||
__tablename__ = "ban_users"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
platform: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
platform: Mapped[str] = mapped_column(get_string_field(50), nullable=False) # 使用有限长度,以便创建索引
|
||||
user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
|
||||
violation_num: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
|
||||
reason: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
"""数据库会话管理
|
||||
|
||||
单一职责:提供数据库会话工厂和上下文管理器
|
||||
|
||||
支持的数据库类型:
|
||||
- SQLite: 设置 PRAGMA 参数优化并发
|
||||
- MySQL: 无特殊会话设置
|
||||
- PostgreSQL: 可选设置 schema 搜索路径
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -53,12 +58,43 @@ async def get_session_factory() -> async_sessionmaker:
|
||||
return _session_factory
|
||||
|
||||
|
||||
async def _apply_session_settings(session: AsyncSession, db_type: str) -> None:
|
||||
"""应用数据库特定的会话设置
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
db_type: 数据库类型
|
||||
"""
|
||||
try:
|
||||
if db_type == "sqlite":
|
||||
# SQLite 特定的 PRAGMA 设置
|
||||
await session.execute(text("PRAGMA busy_timeout = 60000"))
|
||||
await session.execute(text("PRAGMA foreign_keys = ON"))
|
||||
elif db_type == "postgresql":
|
||||
# PostgreSQL 特定设置(如果需要)
|
||||
# 可以设置 schema 搜索路径等
|
||||
from src.config.config import global_config
|
||||
|
||||
schema = global_config.database.postgresql_schema
|
||||
if schema and schema != "public":
|
||||
await session.execute(text(f"SET search_path TO {schema}"))
|
||||
# MySQL 通常不需要会话级别的特殊设置
|
||||
except Exception:
|
||||
# 复用连接时设置可能已存在,忽略错误
|
||||
pass
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""获取数据库会话上下文管理器
|
||||
|
||||
这是数据库操作的主要入口点,通过连接池管理器提供透明的连接复用。
|
||||
|
||||
支持的数据库:
|
||||
- SQLite: 自动设置 busy_timeout 和外键约束
|
||||
- MySQL: 直接使用,无特殊设置
|
||||
- PostgreSQL: 支持自定义 schema
|
||||
|
||||
使用示例:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(select(User))
|
||||
@@ -75,16 +111,10 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
|
||||
# 使用连接池管理器(透明复用连接)
|
||||
async with pool_manager.get_session(session_factory) as session:
|
||||
# 为SQLite设置特定的PRAGMA
|
||||
# 获取数据库类型并应用特定设置
|
||||
from src.config.config import global_config
|
||||
|
||||
if global_config.database.database_type == "sqlite":
|
||||
try:
|
||||
await session.execute(text("PRAGMA busy_timeout = 60000"))
|
||||
await session.execute(text("PRAGMA foreign_keys = ON"))
|
||||
except Exception:
|
||||
# 复用连接时PRAGMA可能已设置,忽略错误
|
||||
pass
|
||||
await _apply_session_settings(session, global_config.database.database_type)
|
||||
|
||||
yield session
|
||||
|
||||
@@ -103,6 +133,11 @@ async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
|
||||
|
||||
async with session_factory() as session:
|
||||
try:
|
||||
# 应用数据库特定设置
|
||||
from src.config.config import global_config
|
||||
|
||||
await _apply_session_settings(session, global_config.database.database_type)
|
||||
|
||||
yield session
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
|
||||
Reference in New Issue
Block a user