#!/usr/bin/env python3 """数据库迁移脚本 支持在不同数据库之间迁移数据: - SQLite <-> MySQL - SQLite <-> PostgreSQL - MySQL <-> PostgreSQL 使用方法: python scripts/migrate_database.py --help python scripts/migrate_database.py --source sqlite --target postgresql python scripts/migrate_database.py --source mysql --target postgresql --batch-size 5000 注意事项: 1. 迁移前请备份源数据库 2. 目标数据库应该是空的或不存在的(脚本会自动创建表) 3. 迁移过程可能需要较长时间,请耐心等待 实现细节: - 使用 SQLAlchemy 进行数据库连接和元数据管理 - 采用流式迁移,避免一次性加载过多数据 - 支持 SQLite、MySQL、PostgreSQL 之间的互相迁移 """ from __future__ import annotations import argparse import logging import os import sys from getpass import getpass # ============================================================================= # 设置日志 # ============================================================================= logging.basicConfig( level=logging.INFO, format="[%(levelname)s] %(message)s", ) logger = logging.getLogger(__name__) # ============================================================================= # 导入第三方库(延迟导入以便友好报错) # ============================================================================= try: import tomllib except ImportError: tomllib = None from typing import Any, Iterable, Callable from sqlalchemy import ( create_engine, MetaData, Table, inspect, text, types as sqltypes, ) from sqlalchemy.engine import Engine, Connection from sqlalchemy.exc import SQLAlchemyError # ====== 为了在 Windows 上更友好的输出中文,提前设置环境 ====== # 有些 Windows 终端默认编码不是 UTF-8,这里做个兼容 if os.name == "nt": try: import ctypes ctypes.windll.kernel32.SetConsoleOutputCP(65001) except Exception: pass # ============================================================================= # 配置相关工具 # ============================================================================= def get_project_root() -> str: """获取项目根目录(当前脚本的上级目录)""" return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) PROJECT_ROOT = get_project_root() def load_bot_config() -> dict: """加载 config/bot_config.toml 配置文件 返回: dict: 配置字典,如果文件不存在或解析失败,则返回空字典 """ config_path = os.path.join(PROJECT_ROOT, "config", "bot_config.toml") if not os.path.exists(config_path): logger.warning("配置文件不存在: %s", config_path) return {} if tomllib is None: logger.warning("当前 Python 版本不支持 tomllib,请使用 Python 3.11+ 或手动安装 tomli") return {} try: with open(config_path, "rb") as f: config = tomllib.load(f) return config except Exception as e: logger.error("解析配置文件失败: %s", e) return {} def get_database_config_from_toml(db_type: str) -> dict | None: """从 bot_config.toml 中读取数据库配置 Args: db_type: 数据库类型,支持 "sqlite"、"mysql"、"postgresql" Returns: dict: 数据库配置字典,如果对应配置不存在则返回 None """ config_data = load_bot_config() if not config_data: return None # 兼容旧结构和新结构 # 旧结构: 顶层直接有 db_type 相关字段 # 新结构: 在 [database] 下有 db_type 相关字段 db_config = config_data.get("database", {}) if db_type == "sqlite": sqlite_path = ( db_config.get("sqlite_path") or config_data.get("sqlite_path") or "data/MaiBot.db" ) if not os.path.isabs(sqlite_path): sqlite_path = os.path.join(PROJECT_ROOT, sqlite_path) return {"path": sqlite_path} elif db_type == "mysql": return { "host": db_config.get("mysql_host") or config_data.get("mysql_host") or "localhost", "port": db_config.get("mysql_port") or config_data.get("mysql_port") or 3306, "database": db_config.get("mysql_database") or config_data.get("mysql_database") or "maibot", "user": db_config.get("mysql_user") or config_data.get("mysql_user") or "root", "password": db_config.get("mysql_password") or config_data.get("mysql_password") or "", "charset": db_config.get("mysql_charset") or config_data.get("mysql_charset") or "utf8mb4", } elif db_type == "postgresql": return { "host": db_config.get("postgresql_host") or config_data.get("postgresql_host") or "localhost", "port": db_config.get("postgresql_port") or config_data.get("postgresql_port") or 5432, "database": db_config.get("postgresql_database") or config_data.get("postgresql_database") or "maibot", "user": db_config.get("postgresql_user") or config_data.get("postgresql_user") or "postgres", "password": db_config.get("postgresql_password") or config_data.get("postgresql_password") or "", "schema": db_config.get("postgresql_schema") or config_data.get("postgresql_schema") or "public", } return None # ============================================================================= # 数据库连接相关 # ============================================================================= def create_sqlite_engine(sqlite_path: str) -> Engine: """���� SQLite ����""" if not os.path.isabs(sqlite_path): sqlite_path = os.path.join(PROJECT_ROOT, sqlite_path) # 确保目录存在 os.makedirs(os.path.dirname(sqlite_path), exist_ok=True) url = f"sqlite:///{sqlite_path}" logger.info("使用 SQLite 数据库: %s", sqlite_path) engine = create_engine( url, future=True, connect_args={ "timeout": 30, # wait a bit if the db is locked "check_same_thread": False, }, ) # Increase busy timeout to reduce "database is locked" errors on SQLite with engine.connect() as conn: conn.execute(text("PRAGMA busy_timeout=30000")) return engine def create_postgresql_engine( host: str, port: int, database: str, user: str, password: str, schema: str = "public", ) -> Engine: """创建 PostgreSQL 引擎""" # 在导入 psycopg2 之前设置环境变量,解决 Windows 编码问题 # psycopg2 在 Windows 上连接时,如果客户端编码与服务器不一致可能会有问题 os.environ.setdefault("PGCLIENTENCODING", "utf-8") # 延迟导入 psycopg2,以便友好提示 try: import psycopg2 # noqa: F401 except ImportError: logger.error("需要安装 psycopg2-binary 才能连接 PostgreSQL: pip install psycopg2-binary") raise url = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}" logger.info("使用 PostgreSQL 数据库: %s@%s:%s/%s (schema=%s)", user, host, port, database, schema) engine = create_engine(url, future=True) # 为了方便,设置 search_path with engine.connect() as conn: conn.execute(text(f"SET search_path TO {schema}")) return engine def create_engine_by_type(db_type: str, config: dict) -> Engine: """根据数据库类型创建对应的 SQLAlchemy Engine Args: db_type: 数据库类型,支持 sqlite/mysql/postgresql config: 配置字典 Returns: Engine: SQLAlchemy 引擎实例 """ db_type = db_type.lower() if db_type == "sqlite": return create_sqlite_engine(config["path"]) elif db_type == "mysql": return create_mysql_engine( host=config["host"], port=config["port"], database=config["database"], user=config["user"], password=config["password"], charset=config.get("charset", "utf8mb4"), ) elif db_type == "postgresql": return create_postgresql_engine( host=config["host"], port=config["port"], database=config["database"], user=config["user"], password=config["password"], schema=config.get("schema", "public"), ) else: raise ValueError(f"不支持的数据库类型: {db_type}") # ============================================================================= # 工具函数 # ============================================================================= def chunked_iterable(iterable: Iterable, size: int) -> Iterable[list]: """将可迭代对象分块 Args: iterable: 可迭代对象 size: 每块大小 Yields: list: 分块列表 """ chunk: list[Any] = [] for item in iterable: chunk.append(item) if len(chunk) >= size: yield chunk chunk = [] if chunk: yield chunk def get_table_row_count(conn: Connection, table: Table) -> int: """获取表的行数""" try: result = conn.execute(text(f"SELECT COUNT(*) FROM {table.name}")) return int(result.scalar() or 0) except SQLAlchemyError as e: logger.warning("获取表行数失败 %s: %s", table.name, e) return 0 def copy_table_structure(source_table: Table, target_metadata: MetaData, target_engine: Engine) -> Table: """复制表结构到目标数据库,使其结构保持一致""" target_is_sqlite = target_engine.dialect.name == "sqlite" target_is_pg = target_engine.dialect.name == "postgresql" columns = [] for c in source_table.columns: new_col = c.copy() # SQLite 不支持 nextval 等 server_default if target_is_sqlite: new_col.server_default = None # PostgreSQL 需要将部分 SQLite 特有类型转换 if target_is_pg: col_type = new_col.type # SQLite DATETIME -> 通用 DateTime if isinstance(col_type, sqltypes.DateTime) or col_type.__class__.__name__ in {"DATETIME", "DateTime"}: new_col.type = sqltypes.DateTime() # TEXT(50) 等长度受限的 TEXT 在 PG 无效,改用 String(length) elif isinstance(col_type, sqltypes.Text) and getattr(col_type, "length", None): new_col.type = sqltypes.String(length=col_type.length) columns.append(new_col) # 为避免迭代约束集合时出现 “Set changed size during iteration”,这里不复制表级约束 target_table = Table( source_table.name, target_metadata, *columns, ) target_metadata.create_all(target_engine, tables=[target_table]) return target_table def migrate_table_data( source_conn: Connection, target_conn: Connection, source_table: Table, target_table: Table, batch_size: int = 1000, ) -> tuple[int, int]: """迁移单个表的数据 Args: source_conn: 源数据库连接 target_conn: 目标数据库连接 source_table: 源表对象 target_table: 目标表对象 batch_size: 每批次处理大小 Returns: tuple[int, int]: (迁移行数, 错误数量) """ total_rows = get_table_row_count(source_conn, source_table) logger.info( "开始迁移表: %s (共 %s 行)", source_table.name, total_rows if total_rows else "未知", ) migrated_rows = 0 error_count = 0 # 使用流式查询,避免一次性加载太多数据 # 对于 SQLAlchemy 1.4/2.0 可以使用 yield_per try: select_stmt = source_table.select() result = source_conn.execute(select_stmt) except SQLAlchemyError as e: logger.error("查询表 %s 失败: %s", source_table.name, e) return 0, 1 def insert_batch(rows: list[dict]): nonlocal migrated_rows, error_count if not rows: return try: target_conn.execute(target_table.insert(), rows) migrated_rows += len(rows) logger.info(" 已迁移 %d/%s 行", migrated_rows, total_rows or "?") except SQLAlchemyError as e: logger.error("写入表 %s 失败: %s", target_table.name, e) error_count += len(rows) batch: list[dict] = [] null_char_replacements = 0 for row in result: # Use column objects to access row mapping to avoid quoted_name keys row_dict = {} for col in source_table.columns: val = row._mapping[col] if isinstance(val, str) and "\x00" in val: val = val.replace("\x00", "") null_char_replacements += 1 row_dict[col.key] = val batch.append(row_dict) if len(batch) >= batch_size: insert_batch(batch) batch = [] if batch: insert_batch(batch) logger.info( "完成迁移表: %s (成功: %d 行, 失败: %d 行)", source_table.name, migrated_rows, error_count, ) if null_char_replacements: logger.warning( "表 %s 中 %d 个字符串值包含 NUL 已被移除后写入目标库", source_table.name, null_char_replacements, ) return migrated_rows, error_count def confirm_action(prompt: str, default: bool = False) -> bool: """确认操作 Args: prompt: 提示信息 default: 默认值 Returns: bool: 用户是否确认 """ while True: if default: choice = input(f"{prompt} [Y/n]: ").strip().lower() if choice == "": return True else: choice = input(f"{prompt} [y/N]: ").strip().lower() if choice == "": return False if choice in ("y", "yes"): return True elif choice in ("n", "no"): return False else: print("请输入 y 或 n") # ============================================================================= # 迁移器实现 # ============================================================================= class DatabaseMigrator: """通用数据库迁移器""" def __init__( self, source_type: str, target_type: str, batch_size: int = 1000, source_config: dict | None = None, target_config: dict | None = None, ): """初始化迁移器 Args: source_type: 源数据库类型 target_type: 目标数据库类型 batch_size: 批量处理大小 source_config: 源数据库配置(可选,默认从配置文件读取) target_config: 目标数据库配置(可选,需要手动指定) """ self.source_type = source_type.lower() self.target_type = target_type.lower() self.batch_size = batch_size self.source_config = source_config self.target_config = target_config self._validate_database_types() self.source_engine: Any = None self.target_engine: Any = None self.metadata = MetaData() # 统计信息 self.stats = { "tables_migrated": 0, "rows_migrated": 0, "errors": [], "start_time": None, "end_time": None, } def _validate_database_types(self): """验证数据库类型""" supported_types = {"sqlite", "mysql", "postgresql"} if self.source_type not in supported_types: raise ValueError(f"不支持的源数据库类型: {self.source_type}") if self.target_type not in supported_types: raise ValueError(f"不支持的目标数据库类型: {self.target_type}") def _load_source_config(self) -> dict: """加载源数据库配置 如果初始化时提供了 source_config,则直接使用; 否则从 bot_config.toml 中读取。 """ if self.source_config: logger.info("使用传入的源数据库配置") return self.source_config logger.info("未提供源数据库配置,尝试从 bot_config.toml 读取") config = get_database_config_from_toml(self.source_type) if not config: raise ValueError("无法从配置文件中读取源数据库配置,请检查 config/bot_config.toml") logger.info("成功从配置文件读取源数据库配置") return config def _load_target_config(self) -> dict: """加载目标数据库配置 目标数据库配置必须通过初始化参数提供,或者通过命令行参数构建。 """ if not self.target_config: raise ValueError("未提供目标数据库配置,请通过命令行参数指定或在交互模式中输入") logger.info("使用传入的目标数据库配置") return self.target_config def _connect_databases(self): """连接源数据库和目标数据库""" # 源数据库配置 source_config = self._load_source_config() # 目标数据库配置 target_config = self._load_target_config() # 防止源/目标 SQLite 指向同一路径导致自我覆盖及锁 if ( self.source_type == "sqlite" and self.target_type == "sqlite" and os.path.abspath(source_config.get("path", "")) == os.path.abspath(target_config.get("path", "")) ): raise ValueError("源数据库与目标数据库不能是同一个 SQLite 文件,请为目标指定不同的路径") # 创建引擎 self.source_engine = create_engine_by_type(self.source_type, source_config) self.target_engine = create_engine_by_type(self.target_type, target_config) # 反射源数据库元数据 logger.info("正在反射源数据库元数据...") self.metadata.reflect(bind=self.source_engine) logger.info("发现 %d 张表: %s", len(self.metadata.tables), ", ".join(self.metadata.tables.keys())) def _get_tables_in_dependency_order(self) -> list[Table]: """获取按依赖顺序排序的表列表 为了避免外键约束问题,创建表时需要按照依赖顺序, 例如先创建被引用的表,再创建引用它们的表。 """ inspector = inspect(self.source_engine) # 构建依赖图:table -> set(dependent_tables) dependencies: dict[str, set[str]] = {} for table_name in self.metadata.tables: dependencies[table_name] = set() for table_name, table in self.metadata.tables.items(): fks = inspector.get_foreign_keys(table_name) for fk in fks: # 被引用的表 referred_table = fk["referred_table"] if referred_table in dependencies: dependencies[table_name].add(referred_table) # 拓扑排序 sorted_tables: list[Table] = [] visited: set[str] = set() temp_mark: set[str] = set() def visit(table_name: str): if table_name in visited: return if table_name in temp_mark: logger.warning("检测到循环依赖,表: %s", table_name) return temp_mark.add(table_name) for dep in dependencies[table_name]: visit(dep) temp_mark.remove(table_name) visited.add(table_name) sorted_tables.append(self.metadata.tables[table_name]) for table_name in dependencies: if table_name not in visited: visit(table_name) return sorted_tables def _drop_target_tables(self): """删除目标数据库中已有的表(如果有) 使用 Engine.begin() 进行连接以支持 autobegin 和 begin 兼容 SQLAlchemy 2.0 的写法 """ if self.target_engine is None: logger.warning("目标数据库引擎尚未初始化,无法删除表") return with self.target_engine.begin() as conn: inspector = inspect(conn) existing_tables = inspector.get_table_names() if not existing_tables: logger.info("目标数据库中没有已存在的表,无需删除") return logger.info("目标数据库中的当前表: %s", ", ".join(existing_tables)) if confirm_action("是否删除目标数据库中现有的表列表?此操作不可撤销", default=False): for table_name in existing_tables: try: logger.info("删除目标数据库表: %s", table_name) conn.execute(text(f"DROP TABLE IF EXISTS {table_name} CASCADE")) except SQLAlchemyError as e: logger.error("删除 %s 失败: %s", table_name, e) self.stats["errors"].append( f"删除 {table_name} 失败: {e}" ) else: logger.info("跳过删除目标数据库中的表,继续迁移过程") def migrate(self): """执行迁移操作""" import time self.stats["start_time"] = time.time() # 连接数据库 self._connect_databases() # 获取表的依赖顺序 tables = self._get_tables_in_dependency_order() logger.info("按依赖顺序迁移表: %s", ", ".join(t.name for t in tables)) # 删除目标库中已有表(可选) self._drop_target_tables() # 开始迁移 with self.source_engine.connect() as source_conn: for source_table in tables: try: # 在目标库中创建表结构 target_table = copy_table_structure(source_table, MetaData(), self.target_engine) # 每张表单独事务,避免退出上下文被自动回滚 with self.target_engine.begin() as target_conn: migrated_rows, error_count = migrate_table_data( source_conn, target_conn, source_table, target_table, batch_size=self.batch_size, ) self.stats["tables_migrated"] += 1 self.stats["rows_migrated"] += migrated_rows if error_count > 0: self.stats["errors"].append( f"表 {source_table.name} 迁移失败 {error_count} 行" ) except Exception as e: logger.error("迁移表 %s 时发生错误: %s", source_table.name, e) self.stats["errors"].append(f"表 {source_table.name} 迁移失败: {e}") self.stats["end_time"] = time.time() def print_summary(self): """打印迁移总结""" import time duration = None if self.stats["start_time"] is not None and self.stats["end_time"] is not None: duration = self.stats["end_time"] - self.stats["start_time"] print("\n" + "=" * 60) print("迁移完成!") print(f" 迁移表数量: {self.stats['tables_migrated']}") print(f" 迁移行数量: {self.stats['rows_migrated']}") if duration is not None: print(f" 总耗时: {duration:.2f} 秒") if self.stats["errors"]: print(" ⚠️ 发生错误:") for err in self.stats["errors"]: print(f" - {err}") else: print(" 没有发生错误 🎉") print("=" * 60 + "\n") def run(self): """运行迁移并打印总结""" self.migrate() self.print_summary() return self.stats # ============================================================================= # 命令行参数解析 # ============================================================================= def parse_args(): """解析命令行参数""" parser = argparse.ArgumentParser( description="数据库迁移工具 - 在 SQLite、MySQL、PostgreSQL 之间迁移数据", formatter_class=argparse.RawDescriptionHelpFormatter, epilog="""示例: # 从 SQLite 迁移到 PostgreSQL python scripts/migrate_database.py \ --source sqlite \ --target postgresql \ --target-host localhost \ --target-port 5432 \ --target-database maibot \ --target-user postgres \ --target-password your_password # 从 SQLite 迁移到 MySQL python scripts/migrate_database.py \ --source sqlite \ --target mysql \ --target-host localhost \ --target-port 3306 \ --target-database maibot \ --target-user root \ --target-password your_password # 使用交互式向导模式(推荐) python scripts/migrate_database.py python scripts/migrate_database.py --interactive """, ) # 基本参数 parser.add_argument( "--source", type=str, choices=["sqlite", "mysql", "postgresql"], help="源数据库类型(不指定时,在交互模式中选择)", ) parser.add_argument( "--target", type=str, choices=["sqlite", "mysql", "postgresql"], help="目标数据库类型(不指定时,在交互模式中选择)", ) parser.add_argument( "--batch-size", type=int, default=1000, help="批量处理大小(默认: 1000)", ) parser.add_argument( "--interactive", action="store_true", help="启用交互式向导模式(推荐:直接运行脚本或加上此参数)", ) # 源数据库参数(可选,默认从 bot_config.toml 读取) source_group = parser.add_argument_group("源数据库配置(可选,默认从 bot_config.toml 读取)") source_group.add_argument("--source-path", type=str, help="SQLite 数据库路径") source_group.add_argument("--source-host", type=str, help="MySQL/PostgreSQL 主机") source_group.add_argument("--source-port", type=int, help="MySQL/PostgreSQL 端口") source_group.add_argument("--source-database", type=str, help="数据库名") source_group.add_argument("--source-user", type=str, help="用户名") source_group.add_argument("--source-password", type=str, help="密码") # 目标数据库参数 target_group = parser.add_argument_group("目标数据库配置") target_group.add_argument("--target-path", type=str, help="SQLite 数据库路径") target_group.add_argument("--target-host", type=str, help="MySQL/PostgreSQL 主机") target_group.add_argument("--target-port", type=int, help="MySQL/PostgreSQL 端口") target_group.add_argument("--target-database", type=str, help="数据库名") target_group.add_argument("--target-user", type=str, help="用户名") target_group.add_argument("--target-password", type=str, help="密码") target_group.add_argument("--target-schema", type=str, default="public", help="PostgreSQL schema") target_group.add_argument("--target-charset", type=str, default="utf8mb4", help="MySQL 字符集") return parser.parse_args() def build_config_from_args(args, prefix: str, db_type: str) -> dict | None: """从命令行参数构建配置 Args: args: 命令行参数 prefix: 参数前缀 ("source" 或 "target") db_type: 数据库类型 Returns: 配置字典或 None """ if db_type == "sqlite": path = getattr(args, f"{prefix}_path", None) if path: return {"path": path} return None elif db_type in ("mysql", "postgresql"): host = getattr(args, f"{prefix}_host", None) if not host: return None config = { "host": host, "port": getattr(args, f"{prefix}_port") or (3306 if db_type == "mysql" else 5432), "database": getattr(args, f"{prefix}_database") or "maibot", "user": getattr(args, f"{prefix}_user") or ("root" if db_type == "mysql" else "postgres"), "password": getattr(args, f"{prefix}_password") or "", } if db_type == "mysql": config["charset"] = getattr(args, f"{prefix}_charset", "utf8mb4") elif db_type == "postgresql": config["schema"] = getattr(args, f"{prefix}_schema", "public") return config return None def _ask_choice(prompt: str, options: list[str], default_index: int | None = None) -> str: """在控制台中让用户从多个选项中选择一个""" while True: print() print(prompt) for i, opt in enumerate(options, start=1): default_mark = "" if default_index is not None and i - 1 == default_index: default_mark = " (默认)" print(f" {i}) {opt}{default_mark}") ans = input("请输入选项编号: ").strip() if not ans and default_index is not None: return options[default_index] if ans.isdigit(): idx = int(ans) if 1 <= idx <= len(options): return options[idx - 1] print("❌ 无效的选择,请重新输入。") def _ask_int(prompt: str, default: int | None = None) -> int: """在控制台中输入正整数""" while True: suffix = f" (默认 {default})" if default is not None else "" raw = input(f"{prompt}{suffix}: ").strip() if not raw and default is not None: return default try: value = int(raw) if value <= 0: raise ValueError() return value except ValueError: print("❌ 请输入一个大于 0 的整数。") def _ask_str( prompt: str, default: str | None = None, allow_empty: bool = False, is_password: bool = False, ) -> str: """在控制台中输入字符串,可选默认值/密码输入""" while True: suffix = f" (默认: {default})" if default is not None else "" full_prompt = f"{prompt}{suffix}: " raw = getpass(full_prompt) if is_password else input(full_prompt) raw = raw.strip() if not raw: if default is not None: return default if allow_empty: return "" print("❌ 输入不能为空,请重新输入。") continue return raw def interactive_setup() -> dict: """交互式向导,返回用于初始化 DatabaseMigrator 的参数字典""" print("=" * 60) print("🌟 数据库迁移向导") print("只需回答几个问题,我会帮你构造迁移配置。") print("=" * 60) db_types = ["sqlite", "mysql", "postgresql"] # 选择源数据库 source_type = _ask_choice("请选择【源数据库类型】:", db_types, default_index=0) # 选择目标数据库(不能与源相同) while True: default_idx = 2 if len(db_types) >= 3 else 0 target_type = _ask_choice("请选择【目标数据库类型】:", db_types, default_index=default_idx) if target_type != source_type: break print("❌ 目标数据库不能和源数据库相同,请重新选择。") # 批量大小 batch_size = _ask_int("请输入批量大小 batch-size", default=1000) # 源数据库配置:默认使用 bot_config.toml print() print("源数据库配置:") print(" 默认会从 config/bot_config.toml 中读取对应配置。") use_default_source = input("是否使用配置文件中的【源数据库】配置? [Y/n]: ").strip().lower() if use_default_source in ("", "y", "yes"): source_config = None # 让 DatabaseMigrator 自己去读配置 else: # 简单交互式配置源数据库 print("请手动输入源数据库连接信息:") if source_type == "sqlite": source_path = _ask_str("源 SQLite 文件路径", default="data/MaiBot.db") source_config = {"path": source_path} else: port_default = 3306 if source_type == "mysql" else 5432 user_default = "root" if source_type == "mysql" else "postgres" host = _ask_str("源数据库 host", default="localhost") port = _ask_int("源数据库 port", default=port_default) database = _ask_str("源数据库名", default="maibot") user = _ask_str("源数据库用户名", default=user_default) password = _ask_str("源数据库密码(输入时不回显)", default="", is_password=True) source_config = { "host": host, "port": port, "database": database, "user": user, "password": password, } if source_type == "mysql": source_config["charset"] = _ask_str("源数据库字符集", default="utf8mb4") elif source_type == "postgresql": source_config["schema"] = _ask_str("源数据库 schema", default="public") # 目标数据库配置(必须显式确认) print() print("目标数据库配置:") if target_type == "sqlite": target_path = _ask_str( "目标 SQLite 文件路径(若不存在会自动创建)", default="data/MaiBot.db", ) target_config = {"path": target_path} else: port_default = 3306 if target_type == "mysql" else 5432 user_default = "root" if target_type == "mysql" else "postgres" host = _ask_str("目标数据库 host", default="localhost") port = _ask_int("目标数据库 port", default=port_default) database = _ask_str("目标数据库名", default="maibot") user = _ask_str("目标数据库用户名", default=user_default) password = _ask_str("目标数据库密码(输入时不回显)", default="", is_password=True) target_config = { "host": host, "port": port, "database": database, "user": user, "password": password, } if target_type == "mysql": target_config["charset"] = _ask_str("目标数据库字符集", default="utf8mb4") elif target_type == "postgresql": target_config["schema"] = _ask_str("目标数据库 schema", default="public") print() print("=" * 60) print("迁移配置确认:") print(f" 源数据库类型: {source_type}") print(f" 目标数据库类型: {target_type}") print(f" 批量大小: {batch_size}") print("⚠️ 请确认目标数据库为空或可以被覆盖,并且已备份源数据库。") confirm = input("是否开始迁移?[Y/n]: ").strip().lower() if confirm not in ("", "y", "yes"): print("已取消迁移。") sys.exit(0) return { "source_type": source_type, "target_type": target_type, "batch_size": batch_size, "source_config": source_config, "target_config": target_config, } def main(): """主函数""" args = parse_args() # 如果没有任何参数,或者显式指定 --interactive,则进入交互模式 if args.interactive or len(sys.argv) == 1: params = interactive_setup() try: migrator = DatabaseMigrator(**params) stats = migrator.run() if stats["errors"]: sys.exit(1) return except KeyboardInterrupt: print("\n迁移被用户中断") sys.exit(130) except Exception as e: print(f"迁移失败: {e}") sys.exit(1) # 非交互模式:保持原有行为,但如果没给 source/target,就提示错误 if not args.source or not args.target: print("错误: 非交互模式下必须指定 --source 和 --target。") print("你也可以直接运行脚本或添加 --interactive 使用交互式向导。") sys.exit(2) # 构建配置 source_config = build_config_from_args(args, "source", args.source) target_config = build_config_from_args(args, "target", args.target) # 验证目标配置 if target_config is None: if args.target == "sqlite": if not args.target_path: print("错误: 目标数据库为 SQLite 时,必须指定 --target-path(或使用交互模式)") sys.exit(1) target_config = {"path": args.target_path} else: if not args.target_host: print(f"错误: 目标数据库为 {args.target} 时,必须指定 --target-host(或使用交互模式)") sys.exit(1) try: migrator = DatabaseMigrator( source_type=args.source, target_type=args.target, batch_size=args.batch_size, source_config=source_config, target_config=target_config, ) stats = migrator.run() # 如果有错误,返回非零退出码 if stats["errors"]: sys.exit(1) except KeyboardInterrupt: print("\n迁移被用户中断") sys.exit(130) except Exception as e: print(f"迁移失败: {e}") sys.exit(1) if __name__ == "__main__": main()