fix: 更新版本号至 0.13.0,增强数据库迁移功能,注册通知事件处理
This commit is contained in:
@@ -84,11 +84,12 @@ async def check_and_migrate_database(existing_engine=None):
|
||||
|
||||
try:
|
||||
# 检查并添加缺失的列
|
||||
db_columns = await connection.run_sync(
|
||||
db_columns_info = await connection.run_sync(
|
||||
lambda conn: {
|
||||
col["name"] for col in inspector.get_columns(table_name)
|
||||
col["name"]: col for col in inspector.get_columns(table_name)
|
||||
}
|
||||
)
|
||||
db_columns = set(db_columns_info.keys())
|
||||
model_columns = {col.name for col in table.c}
|
||||
missing_columns = model_columns - db_columns
|
||||
|
||||
@@ -144,7 +145,12 @@ async def check_and_migrate_database(existing_engine=None):
|
||||
# 提交列添加事务
|
||||
await connection.commit()
|
||||
else:
|
||||
logger.info(f"表 '{table_name}' 的列结构一致。")
|
||||
logger.debug(f"表 '{table_name}' 的列结构一致。")
|
||||
|
||||
# 3. 检查并修复列类型不匹配(仅 PostgreSQL)
|
||||
await _check_and_fix_column_types(
|
||||
connection, inspector, table_name, table, db_columns_info
|
||||
)
|
||||
|
||||
# 检查并创建缺失的索引
|
||||
db_indexes = await connection.run_sync(
|
||||
@@ -225,3 +231,126 @@ async def drop_all_tables(existing_engine=None):
|
||||
await connection.run_sync(Base.metadata.drop_all)
|
||||
|
||||
logger.warning("所有数据库表已删除。")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 列类型修复辅助函数
|
||||
# =============================================================================
|
||||
|
||||
# 已知需要修复的列类型映射
|
||||
# 格式: {(表名, 列名): (期望的Python类型类别, PostgreSQL USING 子句)}
|
||||
# Python类型类别: "boolean", "integer", "float", "string"
|
||||
_BOOLEAN_USING_CLAUSE = (
|
||||
"boolean",
|
||||
"USING CASE WHEN {column} IS NULL THEN FALSE "
|
||||
"WHEN {column} = 0 THEN FALSE ELSE TRUE END"
|
||||
)
|
||||
|
||||
_COLUMN_TYPE_FIXES = {
|
||||
# messages 表的布尔列
|
||||
("messages", "is_public_notice"): _BOOLEAN_USING_CLAUSE,
|
||||
("messages", "should_reply"): _BOOLEAN_USING_CLAUSE,
|
||||
("messages", "should_act"): _BOOLEAN_USING_CLAUSE,
|
||||
("messages", "is_mentioned"): _BOOLEAN_USING_CLAUSE,
|
||||
("messages", "is_emoji"): _BOOLEAN_USING_CLAUSE,
|
||||
("messages", "is_picid"): _BOOLEAN_USING_CLAUSE,
|
||||
("messages", "is_command"): _BOOLEAN_USING_CLAUSE,
|
||||
("messages", "is_notify"): _BOOLEAN_USING_CLAUSE,
|
||||
}
|
||||
|
||||
|
||||
def _get_expected_pg_type(python_type_category: str) -> str:
|
||||
"""获取期望的 PostgreSQL 类型名称"""
|
||||
mapping = {
|
||||
"boolean": "boolean",
|
||||
"integer": "integer",
|
||||
"float": "double precision",
|
||||
"string": "text",
|
||||
}
|
||||
return mapping.get(python_type_category, "text")
|
||||
|
||||
|
||||
def _normalize_pg_type(type_name: str) -> str:
|
||||
"""标准化 PostgreSQL 类型名称用于比较"""
|
||||
type_name = type_name.lower().strip()
|
||||
# 处理常见的别名
|
||||
aliases = {
|
||||
"bool": "boolean",
|
||||
"int": "integer",
|
||||
"int4": "integer",
|
||||
"int8": "bigint",
|
||||
"float8": "double precision",
|
||||
"float4": "real",
|
||||
"numeric": "numeric",
|
||||
"decimal": "numeric",
|
||||
}
|
||||
return aliases.get(type_name, type_name)
|
||||
|
||||
|
||||
async def _check_and_fix_column_types(connection, inspector, table_name, table, db_columns_info):
|
||||
"""检查并修复列类型不匹配的问题(仅 PostgreSQL)
|
||||
|
||||
Args:
|
||||
connection: 数据库连接
|
||||
inspector: SQLAlchemy inspector
|
||||
table_name: 表名
|
||||
table: SQLAlchemy Table 对象
|
||||
db_columns_info: 数据库中列的信息字典
|
||||
"""
|
||||
# 获取数据库方言
|
||||
def get_dialect_name(conn):
|
||||
return conn.dialect.name
|
||||
|
||||
dialect_name = await connection.run_sync(get_dialect_name)
|
||||
|
||||
# 目前只处理 PostgreSQL
|
||||
if dialect_name != "postgresql":
|
||||
return
|
||||
|
||||
for (fix_table, fix_column), (expected_type_category, using_clause) in _COLUMN_TYPE_FIXES.items():
|
||||
if fix_table != table_name:
|
||||
continue
|
||||
|
||||
if fix_column not in db_columns_info:
|
||||
continue
|
||||
|
||||
col_info = db_columns_info[fix_column]
|
||||
current_type = _normalize_pg_type(str(col_info.get("type", "")))
|
||||
expected_type = _get_expected_pg_type(expected_type_category)
|
||||
|
||||
# 如果类型已经正确,跳过
|
||||
if current_type == expected_type:
|
||||
continue
|
||||
|
||||
# 检查是否需要修复:如果当前是 numeric 但期望是 boolean
|
||||
if current_type == "numeric" and expected_type == "boolean":
|
||||
logger.warning(
|
||||
f"发现列类型不匹配: {table_name}.{fix_column} "
|
||||
f"(当前: {current_type}, 期望: {expected_type})"
|
||||
)
|
||||
|
||||
# PostgreSQL 需要先删除默认值,再修改类型,最后重新设置默认值
|
||||
using_sql = using_clause.format(column=fix_column)
|
||||
drop_default_sql = f"ALTER TABLE {table_name} ALTER COLUMN {fix_column} DROP DEFAULT"
|
||||
alter_type_sql = f"ALTER TABLE {table_name} ALTER COLUMN {fix_column} TYPE BOOLEAN {using_sql}"
|
||||
set_default_sql = f"ALTER TABLE {table_name} ALTER COLUMN {fix_column} SET DEFAULT FALSE"
|
||||
|
||||
try:
|
||||
def execute_alter(conn):
|
||||
# 步骤 1: 删除默认值
|
||||
try:
|
||||
conn.execute(text(drop_default_sql))
|
||||
except Exception:
|
||||
pass # 如果没有默认值,忽略错误
|
||||
# 步骤 2: 修改类型
|
||||
conn.execute(text(alter_type_sql))
|
||||
# 步骤 3: 重新设置默认值
|
||||
conn.execute(text(set_default_sql))
|
||||
|
||||
await connection.run_sync(execute_alter)
|
||||
await connection.commit()
|
||||
logger.info(f"成功修复列类型: {table_name}.{fix_column} -> BOOLEAN")
|
||||
except Exception as e:
|
||||
logger.error(f"修复列类型失败 {table_name}.{fix_column}: {e}")
|
||||
await connection.rollback()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user