diff --git a/scripts/migrate_database.py b/scripts/migrate_database.py index 9b13cd675..b5e5c30c6 100644 --- a/scripts/migrate_database.py +++ b/scripts/migrate_database.py @@ -10,16 +10,23 @@ python scripts/migrate_database.py --help python scripts/migrate_database.py --source sqlite --target postgresql python scripts/migrate_database.py --source mysql --target postgresql --batch-size 5000 + + # 交互式向导模式(推荐) + python scripts/migrate_database.py 注意事项: 1. 迁移前请备份源数据库 2. 目标数据库应该是空的或不存在的(脚本会自动创建表) 3. 迁移过程可能需要较长时间,请耐心等待 +4. 迁移到 PostgreSQL 时,脚本会自动: + - 修复布尔列类型(SQLite INTEGER -> PostgreSQL BOOLEAN) + - 重置序列值(避免主键冲突) 实现细节: - 使用 SQLAlchemy 进行数据库连接和元数据管理 - 采用流式迁移,避免一次性加载过多数据 - 支持 SQLite、MySQL、PostgreSQL 之间的互相迁移 +- 批量插入失败时自动降级为逐行插入,最大程度保留数据 """ from __future__ import annotations @@ -52,6 +59,8 @@ except ImportError: from typing import Any, Iterable, Callable +from datetime import datetime as dt + from sqlalchemy import ( create_engine, MetaData, @@ -314,6 +323,143 @@ def get_table_row_count(conn: Connection, table: Table) -> int: return 0 +def convert_value_for_target( + val: Any, + col_name: str, + source_col_type: Any, + target_col_type: Any, + target_dialect: str, + target_col_nullable: bool = True, +) -> Any: + """转换值以适配目标数据库类型 + + 处理以下情况: + 1. 空字符串日期时间 -> None + 2. SQLite INTEGER (0/1) -> PostgreSQL BOOLEAN + 3. 字符串日期时间 -> datetime 对象 + 4. 跳过主键 id (让目标数据库自增) + 5. 对于 NOT NULL 列,提供合适的默认值 + + Args: + val: 原始值 + col_name: 列名 + source_col_type: 源列类型 + target_col_type: 目标列类型 + target_dialect: 目标数据库方言名称 + target_col_nullable: 目标列是否允许 NULL + + Returns: + 转换后的值 + """ + # 获取目标类型的类名 + target_type_name = target_col_type.__class__.__name__.upper() + source_type_name = source_col_type.__class__.__name__.upper() + + # 处理 None 值 + if val is None: + # 如果目标列不允许 NULL,提供默认值 + if not target_col_nullable: + # Boolean 类型的默认值是 False + if target_type_name == "BOOLEAN" or isinstance(target_col_type, sqltypes.Boolean): + return False + # 数值类型的默认值 + if target_type_name in ("INTEGER", "BIGINT", "SMALLINT") or isinstance(target_col_type, sqltypes.Integer): + return 0 + if target_type_name in ("FLOAT", "DOUBLE", "REAL", "NUMERIC", "DECIMAL", "DOUBLE_PRECISION") or isinstance(target_col_type, sqltypes.Float): + return 0.0 + # 日期时间类型的默认值 + if target_type_name in ("DATETIME", "TIMESTAMP") or isinstance(target_col_type, sqltypes.DateTime): + return dt.now() + # 字符串类型的默认值 + if target_type_name in ("VARCHAR", "STRING", "TEXT") or isinstance(target_col_type, (sqltypes.String, sqltypes.Text)): + return "" + # 其他类型也返回空字符串作为兜底 + return "" + return None + + # 处理 Boolean 类型转换 + # SQLite 中 Boolean 实际存储为 INTEGER (0/1) + if target_type_name == "BOOLEAN" or isinstance(target_col_type, sqltypes.Boolean): + if isinstance(val, bool): + return val + if isinstance(val, (int, float)): + return bool(val) + if isinstance(val, str): + val_lower = val.lower().strip() + if val_lower in ("true", "1", "yes"): + return True + elif val_lower in ("false", "0", "no", ""): + return False + return bool(val) if val else False + + # 处理 DateTime 类型转换 + if target_type_name in ("DATETIME", "TIMESTAMP") or isinstance(target_col_type, sqltypes.DateTime): + if isinstance(val, dt): + return val + if isinstance(val, str): + val = val.strip() + # 空字符串 -> None + if val == "": + return None + # 尝试多种日期格式 + for fmt in [ + "%Y-%m-%d %H:%M:%S.%f", + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%dT%H:%M:%S.%f", + "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%d", + ]: + try: + return dt.strptime(val, fmt) + except ValueError: + continue + # 如果都失败,尝试 fromisoformat + try: + return dt.fromisoformat(val) + except ValueError: + logger.warning("无法解析日期时间字符串 '%s' (列: %s),设为 None", val, col_name) + return None + # 如果是数值(时间戳),尝试转换 + if isinstance(val, (int, float)) and val > 0: + try: + return dt.fromtimestamp(val) + except (OSError, ValueError, OverflowError): + return None + return None + + # 处理 Float 类型 + if target_type_name == "FLOAT" or isinstance(target_col_type, sqltypes.Float): + if isinstance(val, (int, float)): + return float(val) + if isinstance(val, str): + val = val.strip() + if val == "": + return None + try: + return float(val) + except ValueError: + return None + return val + + # 处理 Integer 类型 + if target_type_name == "INTEGER" or isinstance(target_col_type, sqltypes.Integer): + if isinstance(val, int): + return val + if isinstance(val, float): + return int(val) + if isinstance(val, str): + val = val.strip() + if val == "": + return None + try: + return int(float(val)) + except ValueError: + return None + return val + + return val + + def copy_table_structure(source_table: Table, target_metadata: MetaData, target_engine: Engine) -> Table: """复制表结构到目标数据库,使其结构保持一致""" target_is_sqlite = target_engine.dialect.name == "sqlite" @@ -351,19 +497,23 @@ def copy_table_structure(source_table: Table, target_metadata: MetaData, target_ def migrate_table_data( source_conn: Connection, - target_conn: Connection, + target_engine: Engine, source_table: Table, target_table: Table, batch_size: int = 1000, + target_dialect: str = "postgresql", + row_limit: int | None = None, ) -> tuple[int, int]: """迁移单个表的数据 Args: source_conn: 源数据库连接 - target_conn: 目标数据库连接 + target_engine: 目标数据库引擎(注意:改为 engine 而不是 connection) source_table: 源表对象 target_table: 目标表对象 batch_size: 每批次处理大小 + target_dialect: 目标数据库方言 (sqlite/mysql/postgresql) + row_limit: 最大迁移行数限制,None 表示不限制 Returns: tuple[int, int]: (迁移行数, 错误数量) @@ -377,40 +527,101 @@ def migrate_table_data( migrated_rows = 0 error_count = 0 + conversion_warnings = 0 + + # 构建源列到目标列的映射 + target_cols_by_name = {c.key: c for c in target_table.columns} + + # 识别主键列(通常是 id),迁移时保留原始 ID 以避免重复数据 + primary_key_cols = {c.key for c in source_table.primary_key.columns} # 使用流式查询,避免一次性加载太多数据 - # 对于 SQLAlchemy 1.4/2.0 可以使用 yield_per + # 使用 text() 原始 SQL 查询,避免 SQLAlchemy 自动类型转换(如 DateTime)导致的错误 try: - select_stmt = source_table.select() - result = source_conn.execute(select_stmt) + # 构建原始 SQL 查询语句 + col_names = [c.key for c in source_table.columns] + if row_limit: + # 按时间或 ID 倒序取最新的 row_limit 条 + raw_sql = text(f"SELECT {', '.join(col_names)} FROM {source_table.name} ORDER BY id DESC LIMIT {row_limit}") + logger.info(" 限制迁移最新 %d 行", row_limit) + else: + raw_sql = text(f"SELECT {', '.join(col_names)} FROM {source_table.name}") + result = source_conn.execute(raw_sql) except SQLAlchemyError as e: logger.error("查询表 %s 失败: %s", source_table.name, e) return 0, 1 def insert_batch(rows: list[dict]): + """每个批次使用独立的事务,批次失败时降级为逐行插入""" nonlocal migrated_rows, error_count if not rows: return try: - target_conn.execute(target_table.insert(), rows) + # 每个批次使用独立的事务 + with target_engine.begin() as target_conn: + target_conn.execute(target_table.insert(), rows) migrated_rows += len(rows) logger.info(" 已迁移 %d/%s 行", migrated_rows, total_rows or "?") except SQLAlchemyError as e: - logger.error("写入表 %s 失败: %s", target_table.name, e) - error_count += len(rows) + # 批量插入失败,降级为逐行插入 + logger.warning("批量插入失败,降级为逐行插入 (共 %d 行): %s", len(rows), str(e)[:200]) + for row in rows: + try: + with target_engine.begin() as target_conn: + target_conn.execute(target_table.insert(), [row]) + migrated_rows += 1 + except SQLAlchemyError as row_e: + # 记录失败的行信息 + row_id = row.get("id", "unknown") + logger.error("插入行失败 (id=%s): %s", row_id, str(row_e)[:200]) + error_count += 1 + logger.info(" 逐行插入完成,已迁移 %d/%s 行", migrated_rows, total_rows or "?") batch: list[dict] = [] null_char_replacements = 0 + # 构建列名列表(用于通过索引访问原始 SQL 结果) + col_list = list(source_table.columns) + col_name_to_idx = {c.key: idx for idx, c in enumerate(col_list)} + for row in result: - # Use column objects to access row mapping to avoid quoted_name keys row_dict = {} - for col in source_table.columns: - val = row._mapping[col] + for col in col_list: + col_key = col.key + + # 保留主键列(id),确保数据一致性 + # 注意:如果目标表使用自增主键,可能需要重置序列 + + # 通过索引获取原始值(避免 SQLAlchemy 自动类型转换) + col_idx = col_name_to_idx[col_key] + val = row[col_idx] + + # 处理 NUL 字符 if isinstance(val, str) and "\x00" in val: val = val.replace("\x00", "") null_char_replacements += 1 - row_dict[col.key] = val + + # 获取目标列类型进行转换 + target_col = target_cols_by_name.get(col_key) + if target_col is not None: + try: + val = convert_value_for_target( + val=val, + col_name=col_key, + source_col_type=col.type, + target_col_type=target_col.type, + target_dialect=target_dialect, + target_col_nullable=target_col.nullable if target_col.nullable is not None else True, + ) + except Exception as e: + conversion_warnings += 1 + if conversion_warnings <= 5: + logger.warning( + "值转换异常 (表=%s, 列=%s, 值=%r): %s", + source_table.name, col_key, val, e + ) + + row_dict[col_key] = val batch.append(row_dict) if len(batch) >= batch_size: @@ -432,6 +643,12 @@ def migrate_table_data( source_table.name, null_char_replacements, ) + if conversion_warnings: + logger.warning( + "表 %s 中 %d 个值发生类型转换警告", + source_table.name, + conversion_warnings, + ) return migrated_rows, error_count @@ -479,6 +696,9 @@ class DatabaseMigrator: batch_size: int = 1000, source_config: dict | None = None, target_config: dict | None = None, + skip_tables: set | None = None, + only_tables: set | None = None, + no_create_tables: bool = False, ): """初始化迁移器 @@ -488,12 +708,18 @@ class DatabaseMigrator: batch_size: 批量处理大小 source_config: 源数据库配置(可选,默认从配置文件读取) target_config: 目标数据库配置(可选,需要手动指定) + skip_tables: 要跳过的表名集合 + only_tables: 只迁移的表名集合(设置后忽略 skip_tables) + no_create_tables: 是否跳过创建表结构(假设目标表已存在) """ self.source_type = source_type.lower() self.target_type = target_type.lower() self.batch_size = batch_size self.source_config = source_config self.target_config = target_config + self.skip_tables = skip_tables or set() + self.only_tables = only_tables or set() + self.no_create_tables = no_create_tables self._validate_database_types() @@ -659,25 +885,60 @@ class DatabaseMigrator: tables = self._get_tables_in_dependency_order() logger.info("按依赖顺序迁移表: %s", ", ".join(t.name for t in tables)) - # 删除目标库中已有表(可选) - self._drop_target_tables() + # 如果指定了 only_tables,则过滤表列表 + if self.only_tables: + tables = [t for t in tables if t.name in self.only_tables] + logger.info("只迁移指定的表: %s", ", ".join(t.name for t in tables)) + if not tables: + logger.warning("没有找到任何匹配 --only-tables 的表") + return + + # 删除目标库中已有表(可选)- 如果是增量迁移则跳过 + if not self.no_create_tables: + self._drop_target_tables() + + # 获取目标数据库方言 + target_dialect = self.target_engine.dialect.name # 开始迁移 with self.source_engine.connect() as source_conn: for source_table in tables: - try: - # 在目标库中创建表结构 - target_table = copy_table_structure(source_table, MetaData(), self.target_engine) + # 跳过指定的表(仅在未指定 only_tables 时生效) + if not self.only_tables and source_table.name in self.skip_tables: + logger.info("跳过表: %s (在 skip_tables 列表中)", source_table.name) + continue - # 每张表单独事务,避免退出上下文被自动回滚 - with self.target_engine.begin() as target_conn: - migrated_rows, error_count = migrate_table_data( - source_conn, - target_conn, - source_table, - target_table, - batch_size=self.batch_size, - ) + try: + # 在目标库中创建表结构(除非指定了 no_create_tables) + if self.no_create_tables: + # 反射目标数据库中已存在的表结构 + target_metadata = MetaData() + target_metadata.reflect(bind=self.target_engine, only=[source_table.name]) + target_table = target_metadata.tables.get(source_table.name) + if target_table is None: + logger.error("目标数据库中不存在表: %s,请先创建表结构或移除 --no-create-tables 参数", source_table.name) + self.stats["errors"].append(f"目标数据库中不存在表: {source_table.name}") + continue + logger.info("使用目标数据库中已存在的表结构: %s", source_table.name) + else: + target_table = copy_table_structure(source_table, MetaData(), self.target_engine) + + # 对 messages 表限制迁移行数(只迁移最新 1 万条) + row_limit = None + if source_table.name == "messages": + row_limit = 10000 + logger.info("messages 表将只迁移最新 %d 条记录", row_limit) + + # 每个批次使用独立事务,传入 engine 而不是 connection + migrated_rows, error_count = migrate_table_data( + source_conn, + self.target_engine, + source_table, + target_table, + batch_size=self.batch_size, + target_dialect=target_dialect, + row_limit=row_limit, + ) self.stats["tables_migrated"] += 1 self.stats["rows_migrated"] += migrated_rows @@ -691,6 +952,11 @@ class DatabaseMigrator: self.stats["errors"].append(f"表 {source_table.name} 迁移失败: {e}") self.stats["end_time"] = time.time() + + # 迁移完成后,自动修复 PostgreSQL 特有问题 + if self.target_type == "postgresql" and self.target_engine: + fix_postgresql_boolean_columns(self.target_engine) + fix_postgresql_sequences(self.target_engine) def print_summary(self): """打印迁移总结""" @@ -804,6 +1070,29 @@ def parse_args(): target_group.add_argument("--target-schema", type=str, default="public", help="PostgreSQL schema") target_group.add_argument("--target-charset", type=str, default="utf8mb4", help="MySQL 字符集") + # 跳过表参数 + parser.add_argument( + "--skip-tables", + type=str, + default="", + help="跳过迁移的表名,多个表名用逗号分隔(如: messages,logs)", + ) + + # 只迁移指定表参数 + parser.add_argument( + "--only-tables", + type=str, + default="", + help="只迁移指定的表名,多个表名用逗号分隔(如: user_relationships,maizone_schedule_status)。设置后将忽略 --skip-tables", + ) + + # 不创建表结构,假设目标表已存在 + parser.add_argument( + "--no-create-tables", + action="store_true", + help="不创建表结构,假设目标数据库中的表已存在。用于增量迁移指定表的数据", + ) + return parser.parse_args() @@ -1012,6 +1301,112 @@ def interactive_setup() -> dict: } +def fix_postgresql_sequences(engine: Engine): + """修复 PostgreSQL 序列值 + + 迁移数据后,PostgreSQL 的序列(用于自增主键)可能没有更新到正确的值, + 导致插入新记录时出现主键冲突。此函数会自动检测并重置所有序列。 + + Args: + engine: PostgreSQL 数据库引擎 + """ + if engine.dialect.name != "postgresql": + logger.info("非 PostgreSQL 数据库,跳过序列修复") + return + + logger.info("正在修复 PostgreSQL 序列...") + + with engine.connect() as conn: + # 获取所有带有序列的表 + result = conn.execute(text(''' + SELECT + t.table_name, + c.column_name, + pg_get_serial_sequence(t.table_name, c.column_name) as sequence_name + FROM information_schema.tables t + JOIN information_schema.columns c + ON t.table_name = c.table_name AND t.table_schema = c.table_schema + WHERE t.table_schema = 'public' + AND t.table_type = 'BASE TABLE' + AND c.column_default LIKE 'nextval%' + ORDER BY t.table_name + ''')) + + sequences = result.fetchall() + logger.info("发现 %d 个带序列的表", len(sequences)) + + fixed_count = 0 + for table_name, column_name, seq_name in sequences: + if seq_name: + try: + # 获取当前表中该列的最大值 + max_result = conn.execute(text(f'SELECT COALESCE(MAX({column_name}), 0) FROM {table_name}')) + max_val = max_result.scalar() + + # 设置序列的下一个值 + next_val = max_val + 1 + conn.execute(text(f"SELECT setval('{seq_name}', {next_val}, false)")) + conn.commit() + + logger.info(" ✅ %s.%s: 最大值=%d, 序列设为=%d", table_name, column_name, max_val, next_val) + fixed_count += 1 + except Exception as e: + logger.warning(" ❌ %s.%s: 修复失败 - %s", table_name, column_name, e) + + logger.info("序列修复完成!共修复 %d 个序列", fixed_count) + + +def fix_postgresql_boolean_columns(engine: Engine): + """修复 PostgreSQL 布尔列类型 + + 从 SQLite 迁移后,布尔列可能是 INTEGER 类型。此函数将其转换为 BOOLEAN。 + + Args: + engine: PostgreSQL 数据库引擎 + """ + if engine.dialect.name != "postgresql": + logger.info("非 PostgreSQL 数据库,跳过布尔列修复") + return + + # 已知需要转换为 BOOLEAN 的列 + BOOLEAN_COLUMNS = { + 'messages': ['is_mentioned', 'is_emoji', 'is_picid', 'is_command', + 'is_notify', 'is_public_notice', 'should_reply', 'should_act'], + 'action_records': ['action_done', 'action_build_into_prompt'], + } + + logger.info("正在检查并修复 PostgreSQL 布尔列...") + + with engine.connect() as conn: + fixed_count = 0 + for table_name, columns in BOOLEAN_COLUMNS.items(): + for col_name in columns: + try: + # 检查当前类型 + result = conn.execute(text(f''' + SELECT data_type FROM information_schema.columns + WHERE table_name = '{table_name}' AND column_name = '{col_name}' + ''')) + row = result.fetchone() + if row and row[0] != 'boolean': + # 需要修复 + conn.execute(text(f''' + ALTER TABLE {table_name} + ALTER COLUMN {col_name} TYPE BOOLEAN + USING CASE WHEN {col_name} = 0 THEN FALSE ELSE TRUE END + ''')) + conn.commit() + logger.info(" ✅ %s.%s: %s -> BOOLEAN", table_name, col_name, row[0]) + fixed_count += 1 + except Exception as e: + logger.warning(" ⚠️ %s.%s: 检查/修复失败 - %s", table_name, col_name, e) + + if fixed_count > 0: + logger.info("布尔列修复完成!共修复 %d 列", fixed_count) + else: + logger.info("所有布尔列类型正确,无需修复") + + def main(): """主函数""" args = parse_args() @@ -1055,12 +1450,27 @@ def main(): sys.exit(1) try: + # 解析跳过的表 + skip_tables = set() + if args.skip_tables: + skip_tables = {t.strip() for t in args.skip_tables.split(",") if t.strip()} + logger.info("将跳过以下表: %s", ", ".join(skip_tables)) + + # 解析只迁移的表 + only_tables = set() + if args.only_tables: + only_tables = {t.strip() for t in args.only_tables.split(",") if t.strip()} + logger.info("将只迁移以下表: %s", ", ".join(only_tables)) + migrator = DatabaseMigrator( source_type=args.source, target_type=args.target, batch_size=args.batch_size, source_config=source_config, target_config=target_config, + skip_tables=skip_tables, + only_tables=only_tables, + no_create_tables=args.no_create_tables, ) stats = migrator.run() diff --git a/scripts/reset_pg_sequences.py b/scripts/reset_pg_sequences.py deleted file mode 100644 index cd36091f0..000000000 --- a/scripts/reset_pg_sequences.py +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env python3 -"""重置 PostgreSQL 序列值 - -迁移数据后,PostgreSQL 的序列(用于自增主键)可能没有更新到正确的值, -导致插入新记录时出现主键冲突。此脚本会自动检测并重置所有序列。 - -使用方法: - python scripts/reset_pg_sequences.py --host localhost --port 5432 --database maibot --user postgres --password your_password -""" - -import argparse -import psycopg - - -def reset_sequences(host: str, port: int, database: str, user: str, password: str): - """重置所有序列值""" - conn_str = f"host={host} port={port} dbname={database} user={user} password={password}" - - print(f"连接到 PostgreSQL: {host}:{port}/{database}") - conn = psycopg.connect(conn_str) - conn.autocommit = True - - # 查询所有序列及其关联的表和列 - query = """ - SELECT - t.relname AS table_name, - a.attname AS column_name, - s.relname AS sequence_name - FROM pg_class s - JOIN pg_depend d ON d.objid = s.oid - JOIN pg_class t ON d.refobjid = t.oid - JOIN pg_attribute a ON (d.refobjid, d.refobjsubid) = (a.attrelid, a.attnum) - WHERE s.relkind = 'S' - """ - - cursor = conn.execute(query) - sequences = cursor.fetchall() - - print(f"发现 {len(sequences)} 个序列") - - reset_count = 0 - for table_name, col_name, seq_name in sequences: - try: - # 获取当前最大 ID - max_result = conn.execute(f'SELECT MAX("{col_name}") FROM "{table_name}"') - max_id = max_result.fetchone()[0] - - if max_id is not None: - # 重置序列 - conn.execute(f"SELECT setval('{seq_name}', {max_id}, true)") - print(f" ✓ {seq_name} -> {max_id}") - reset_count += 1 - else: - print(f" - {seq_name}: 表为空,跳过") - - except Exception as e: - print(f" ✗ {table_name}.{col_name}: {e}") - - conn.close() - print(f"\n✅ 重置完成!共重置 {reset_count} 个序列") - - -def main(): - parser = argparse.ArgumentParser(description="重置 PostgreSQL 序列值") - parser.add_argument("--host", default="localhost", help="PostgreSQL 主机") - parser.add_argument("--port", type=int, default=5432, help="PostgreSQL 端口") - parser.add_argument("--database", default="maibot", help="数据库名") - parser.add_argument("--user", default="postgres", help="用户名") - parser.add_argument("--password", required=True, help="密码") - - args = parser.parse_args() - - reset_sequences(args.host, args.port, args.database, args.user, args.password) - - -if __name__ == "__main__": - main()