ruff
This commit is contained in:
@@ -8,7 +8,7 @@
|
||||
python scripts/migrate_database.py --help
|
||||
python scripts/migrate_database.py --source sqlite --target postgresql
|
||||
python scripts/migrate_database.py --source postgresql --target sqlite --batch-size 5000
|
||||
|
||||
|
||||
# 交互式向导模式(推荐)
|
||||
python scripts/migrate_database.py
|
||||
|
||||
@@ -55,19 +55,21 @@ try:
|
||||
except ImportError:
|
||||
tomllib = None
|
||||
|
||||
from typing import Any, Iterable, Callable
|
||||
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime as dt
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import (
|
||||
create_engine,
|
||||
MetaData,
|
||||
Table,
|
||||
create_engine,
|
||||
inspect,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy import (
|
||||
types as sqltypes,
|
||||
)
|
||||
from sqlalchemy.engine import Engine, Connection
|
||||
from sqlalchemy.engine import Connection, Engine
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
# ====== 为了在 Windows 上更友好的输出中文,提前设置环境 ======
|
||||
@@ -320,7 +322,7 @@ def convert_value_for_target(
|
||||
"""
|
||||
# 获取目标类型的类名
|
||||
target_type_name = target_col_type.__class__.__name__.upper()
|
||||
source_type_name = source_col_type.__class__.__name__.upper()
|
||||
source_col_type.__class__.__name__.upper()
|
||||
|
||||
# 处理 None 值
|
||||
if val is None:
|
||||
@@ -500,7 +502,7 @@ def migrate_table_data(
|
||||
target_cols_by_name = {c.key: c for c in target_table.columns}
|
||||
|
||||
# 识别主键列(通常是 id),迁移时保留原始 ID 以避免重复数据
|
||||
primary_key_cols = {c.key for c in source_table.primary_key.columns}
|
||||
{c.key for c in source_table.primary_key.columns}
|
||||
|
||||
# 使用流式查询,避免一次性加载太多数据
|
||||
# 使用 text() 原始 SQL 查询,避免 SQLAlchemy 自动类型转换(如 DateTime)导致的错误
|
||||
@@ -776,7 +778,7 @@ class DatabaseMigrator:
|
||||
for table_name in self.metadata.tables:
|
||||
dependencies[table_name] = set()
|
||||
|
||||
for table_name, table in self.metadata.tables.items():
|
||||
for table_name in self.metadata.tables.keys():
|
||||
fks = inspector.get_foreign_keys(table_name)
|
||||
for fk in fks:
|
||||
# 被引用的表
|
||||
@@ -919,7 +921,7 @@ class DatabaseMigrator:
|
||||
self.stats["errors"].append(f"表 {source_table.name} 迁移失败: {e}")
|
||||
|
||||
self.stats["end_time"] = time.time()
|
||||
|
||||
|
||||
# 迁移完成后,自动修复 PostgreSQL 特有问题
|
||||
if self.target_type == "postgresql" and self.target_engine:
|
||||
fix_postgresql_boolean_columns(self.target_engine)
|
||||
@@ -927,7 +929,6 @@ class DatabaseMigrator:
|
||||
|
||||
def print_summary(self):
|
||||
"""打印迁移总结"""
|
||||
import time
|
||||
|
||||
duration = None
|
||||
if self.stats["start_time"] is not None and self.stats["end_time"] is not None:
|
||||
@@ -1262,104 +1263,104 @@ def interactive_setup() -> dict:
|
||||
|
||||
def fix_postgresql_sequences(engine: Engine):
|
||||
"""修复 PostgreSQL 序列值
|
||||
|
||||
|
||||
迁移数据后,PostgreSQL 的序列(用于自增主键)可能没有更新到正确的值,
|
||||
导致插入新记录时出现主键冲突。此函数会自动检测并重置所有序列。
|
||||
|
||||
|
||||
Args:
|
||||
engine: PostgreSQL 数据库引擎
|
||||
"""
|
||||
if engine.dialect.name != "postgresql":
|
||||
logger.info("非 PostgreSQL 数据库,跳过序列修复")
|
||||
return
|
||||
|
||||
|
||||
logger.info("正在修复 PostgreSQL 序列...")
|
||||
|
||||
|
||||
with engine.connect() as conn:
|
||||
# 获取所有带有序列的表
|
||||
result = conn.execute(text('''
|
||||
SELECT
|
||||
result = conn.execute(text("""
|
||||
SELECT
|
||||
t.table_name,
|
||||
c.column_name,
|
||||
pg_get_serial_sequence(t.table_name, c.column_name) as sequence_name
|
||||
FROM information_schema.tables t
|
||||
JOIN information_schema.columns c
|
||||
JOIN information_schema.columns c
|
||||
ON t.table_name = c.table_name AND t.table_schema = c.table_schema
|
||||
WHERE t.table_schema = 'public'
|
||||
WHERE t.table_schema = 'public'
|
||||
AND t.table_type = 'BASE TABLE'
|
||||
AND c.column_default LIKE 'nextval%'
|
||||
ORDER BY t.table_name
|
||||
'''))
|
||||
|
||||
"""))
|
||||
|
||||
sequences = result.fetchall()
|
||||
logger.info("发现 %d 个带序列的表", len(sequences))
|
||||
|
||||
|
||||
fixed_count = 0
|
||||
for table_name, column_name, seq_name in sequences:
|
||||
if seq_name:
|
||||
try:
|
||||
# 获取当前表中该列的最大值
|
||||
max_result = conn.execute(text(f'SELECT COALESCE(MAX({column_name}), 0) FROM {table_name}'))
|
||||
max_result = conn.execute(text(f"SELECT COALESCE(MAX({column_name}), 0) FROM {table_name}"))
|
||||
max_val = max_result.scalar()
|
||||
|
||||
|
||||
# 设置序列的下一个值
|
||||
next_val = max_val + 1
|
||||
conn.execute(text(f"SELECT setval('{seq_name}', {next_val}, false)"))
|
||||
conn.commit()
|
||||
|
||||
|
||||
logger.info(" ✅ %s.%s: 最大值=%d, 序列设为=%d", table_name, column_name, max_val, next_val)
|
||||
fixed_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(" ❌ %s.%s: 修复失败 - %s", table_name, column_name, e)
|
||||
|
||||
|
||||
logger.info("序列修复完成!共修复 %d 个序列", fixed_count)
|
||||
|
||||
|
||||
def fix_postgresql_boolean_columns(engine: Engine):
|
||||
"""修复 PostgreSQL 布尔列类型
|
||||
|
||||
|
||||
从 SQLite 迁移后,布尔列可能是 INTEGER 类型。此函数将其转换为 BOOLEAN。
|
||||
|
||||
|
||||
Args:
|
||||
engine: PostgreSQL 数据库引擎
|
||||
"""
|
||||
if engine.dialect.name != "postgresql":
|
||||
logger.info("非 PostgreSQL 数据库,跳过布尔列修复")
|
||||
return
|
||||
|
||||
|
||||
# 已知需要转换为 BOOLEAN 的列
|
||||
BOOLEAN_COLUMNS = {
|
||||
'messages': ['is_mentioned', 'is_emoji', 'is_picid', 'is_command',
|
||||
'is_notify', 'is_public_notice', 'should_reply', 'should_act'],
|
||||
'action_records': ['action_done', 'action_build_into_prompt'],
|
||||
"messages": ["is_mentioned", "is_emoji", "is_picid", "is_command",
|
||||
"is_notify", "is_public_notice", "should_reply", "should_act"],
|
||||
"action_records": ["action_done", "action_build_into_prompt"],
|
||||
}
|
||||
|
||||
|
||||
logger.info("正在检查并修复 PostgreSQL 布尔列...")
|
||||
|
||||
|
||||
with engine.connect() as conn:
|
||||
fixed_count = 0
|
||||
for table_name, columns in BOOLEAN_COLUMNS.items():
|
||||
for col_name in columns:
|
||||
try:
|
||||
# 检查当前类型
|
||||
result = conn.execute(text(f'''
|
||||
SELECT data_type FROM information_schema.columns
|
||||
result = conn.execute(text(f"""
|
||||
SELECT data_type FROM information_schema.columns
|
||||
WHERE table_name = '{table_name}' AND column_name = '{col_name}'
|
||||
'''))
|
||||
"""))
|
||||
row = result.fetchone()
|
||||
if row and row[0] != 'boolean':
|
||||
if row and row[0] != "boolean":
|
||||
# 需要修复
|
||||
conn.execute(text(f'''
|
||||
ALTER TABLE {table_name}
|
||||
ALTER COLUMN {col_name} TYPE BOOLEAN
|
||||
conn.execute(text(f"""
|
||||
ALTER TABLE {table_name}
|
||||
ALTER COLUMN {col_name} TYPE BOOLEAN
|
||||
USING CASE WHEN {col_name} = 0 THEN FALSE ELSE TRUE END
|
||||
'''))
|
||||
"""))
|
||||
conn.commit()
|
||||
logger.info(" ✅ %s.%s: %s -> BOOLEAN", table_name, col_name, row[0])
|
||||
fixed_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(" ⚠️ %s.%s: 检查/修复失败 - %s", table_name, col_name, e)
|
||||
|
||||
|
||||
if fixed_count > 0:
|
||||
logger.info("布尔列修复完成!共修复 %d 列", fixed_count)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user