Files
Mofox-Core/scripts/migrate_database.py

1451 lines
52 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""数据库迁移脚本
支持在不同数据库之间迁移数据:
- SQLite <-> PostgreSQL
使用方法:
python scripts/migrate_database.py --help
python scripts/migrate_database.py --source sqlite --target postgresql
python scripts/migrate_database.py --source postgresql --target sqlite --batch-size 5000
# 交互式向导模式(推荐)
python scripts/migrate_database.py
注意事项:
1. 迁移前请备份源数据库
2. 目标数据库应该是空的或不存在的(脚本会自动创建表)
3. 迁移过程可能需要较长时间,请耐心等待
4. 迁移到 PostgreSQL 时,脚本会自动:
- 修复布尔列类型SQLite INTEGER -> PostgreSQL BOOLEAN
- 重置序列值(避免主键冲突)
实现细节:
- 使用 SQLAlchemy 进行数据库连接和元数据管理
- 采用流式迁移,避免一次性加载过多数据
- 支持 SQLite、PostgreSQL 之间的互相迁移
- 批量插入失败时自动降级为逐行插入,最大程度保留数据
"""
from __future__ import annotations
import argparse
import logging
import os
import sys
from getpass import getpass
# =============================================================================
# 设置日志
# =============================================================================
logging.basicConfig(
level=logging.INFO,
format="[%(levelname)s] %(message)s",
)
logger = logging.getLogger(__name__)
# =============================================================================
# 导入第三方库(延迟导入以便友好报错)
# =============================================================================
try:
import tomllib
except ImportError:
tomllib = None
from typing import Any, Iterable, Callable
from datetime import datetime as dt
from sqlalchemy import (
create_engine,
MetaData,
Table,
inspect,
text,
types as sqltypes,
)
from sqlalchemy.engine import Engine, Connection
from sqlalchemy.exc import SQLAlchemyError
# ====== 为了在 Windows 上更友好的输出中文,提前设置环境 ======
# 有些 Windows 终端默认编码不是 UTF-8这里做个兼容
if os.name == "nt":
try:
import ctypes
ctypes.windll.kernel32.SetConsoleOutputCP(65001)
except Exception:
pass
# =============================================================================
# 配置相关工具
# =============================================================================
def get_project_root() -> str:
"""获取项目根目录(当前脚本的上级目录)"""
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
PROJECT_ROOT = get_project_root()
def load_bot_config() -> dict:
"""加载 config/bot_config.toml 配置文件
返回:
dict: 配置字典,如果文件不存在或解析失败,则返回空字典
"""
config_path = os.path.join(PROJECT_ROOT, "config", "bot_config.toml")
if not os.path.exists(config_path):
logger.warning("配置文件不存在: %s", config_path)
return {}
if tomllib is None:
logger.warning("当前 Python 版本不支持 tomllib请使用 Python 3.11+ 或手动安装 tomli")
return {}
try:
with open(config_path, "rb") as f:
config = tomllib.load(f)
return config
except Exception as e:
logger.error("解析配置文件失败: %s", e)
return {}
def get_database_config_from_toml(db_type: str) -> dict | None:
"""从 bot_config.toml 中读取数据库配置
Args:
db_type: 数据库类型,支持 "sqlite""postgresql"
Returns:
dict: 数据库配置字典,如果对应配置不存在则返回 None
"""
config_data = load_bot_config()
if not config_data:
return None
# 兼容旧结构和新结构
# 旧结构: 顶层直接有 db_type 相关字段
# 新结构: 在 [database] 下有 db_type 相关字段
db_config = config_data.get("database", {})
if db_type == "sqlite":
sqlite_path = (
db_config.get("sqlite_path")
or config_data.get("sqlite_path")
or "data/MaiBot.db"
)
if not os.path.isabs(sqlite_path):
sqlite_path = os.path.join(PROJECT_ROOT, sqlite_path)
return {"path": sqlite_path}
elif db_type == "postgresql":
return {
"host": db_config.get("postgresql_host")
or config_data.get("postgresql_host")
or "localhost",
"port": db_config.get("postgresql_port")
or config_data.get("postgresql_port")
or 5432,
"database": db_config.get("postgresql_database")
or config_data.get("postgresql_database")
or "maibot",
"user": db_config.get("postgresql_user")
or config_data.get("postgresql_user")
or "postgres",
"password": db_config.get("postgresql_password")
or config_data.get("postgresql_password")
or "",
"schema": db_config.get("postgresql_schema")
or config_data.get("postgresql_schema")
or "public",
}
return None
# =============================================================================
# 数据库连接相关
# =============================================================================
def create_sqlite_engine(sqlite_path: str) -> Engine:
"""<EFBFBD><EFBFBD><EFBFBD><EFBFBD> SQLite <20><><EFBFBD><EFBFBD>"""
if not os.path.isabs(sqlite_path):
sqlite_path = os.path.join(PROJECT_ROOT, sqlite_path)
# 确保目录存在
os.makedirs(os.path.dirname(sqlite_path), exist_ok=True)
url = f"sqlite:///{sqlite_path}"
logger.info("使用 SQLite 数据库: %s", sqlite_path)
engine = create_engine(
url,
future=True,
connect_args={
"timeout": 30, # wait a bit if the db is locked
"check_same_thread": False,
},
)
# Increase busy timeout to reduce "database is locked" errors on SQLite
with engine.connect() as conn:
conn.execute(text("PRAGMA busy_timeout=30000"))
return engine
def create_postgresql_engine(
host: str,
port: int,
database: str,
user: str,
password: str,
schema: str = "public",
) -> Engine:
"""创建 PostgreSQL 引擎"""
# 在导入 psycopg2 之前设置环境变量,解决 Windows 编码问题
# psycopg2 在 Windows 上连接时,如果客户端编码与服务器不一致可能会有问题
os.environ.setdefault("PGCLIENTENCODING", "utf-8")
# 延迟导入 psycopg2以便友好提示
try:
import psycopg2 # noqa: F401
except ImportError:
logger.error("需要安装 psycopg2-binary 才能连接 PostgreSQL: pip install psycopg2-binary")
raise
url = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
logger.info("使用 PostgreSQL 数据库: %s@%s:%s/%s (schema=%s)", user, host, port, database, schema)
engine = create_engine(url, future=True)
# 为了方便,设置 search_path
with engine.connect() as conn:
conn.execute(text(f"SET search_path TO {schema}"))
return engine
def create_engine_by_type(db_type: str, config: dict) -> Engine:
"""根据数据库类型创建对应的 SQLAlchemy Engine
Args:
db_type: 数据库类型,支持 sqlite/postgresql
config: 配置字典
Returns:
Engine: SQLAlchemy 引擎实例
"""
db_type = db_type.lower()
if db_type == "sqlite":
return create_sqlite_engine(config["path"])
elif db_type == "postgresql":
return create_postgresql_engine(
host=config["host"],
port=config["port"],
database=config["database"],
user=config["user"],
password=config["password"],
schema=config.get("schema", "public"),
)
else:
raise ValueError(f"不支持的数据库类型: {db_type}")
# =============================================================================
# 工具函数
# =============================================================================
def chunked_iterable(iterable: Iterable, size: int) -> Iterable[list]:
"""将可迭代对象分块
Args:
iterable: 可迭代对象
size: 每块大小
Yields:
list: 分块列表
"""
chunk: list[Any] = []
for item in iterable:
chunk.append(item)
if len(chunk) >= size:
yield chunk
chunk = []
if chunk:
yield chunk
def get_table_row_count(conn: Connection, table: Table) -> int:
"""获取表的行数"""
try:
result = conn.execute(text(f"SELECT COUNT(*) FROM {table.name}"))
return int(result.scalar() or 0)
except SQLAlchemyError as e:
logger.warning("获取表行数失败 %s: %s", table.name, e)
return 0
def convert_value_for_target(
val: Any,
col_name: str,
source_col_type: Any,
target_col_type: Any,
target_dialect: str,
target_col_nullable: bool = True,
) -> Any:
"""转换值以适配目标数据库类型
处理以下情况:
1. 空字符串日期时间 -> None
2. SQLite INTEGER (0/1) -> PostgreSQL BOOLEAN
3. 字符串日期时间 -> datetime 对象
4. 跳过主键 id (让目标数据库自增)
5. 对于 NOT NULL 列,提供合适的默认值
Args:
val: 原始值
col_name: 列名
source_col_type: 源列类型
target_col_type: 目标列类型
target_dialect: 目标数据库方言名称
target_col_nullable: 目标列是否允许 NULL
Returns:
转换后的值
"""
# 获取目标类型的类名
target_type_name = target_col_type.__class__.__name__.upper()
source_type_name = source_col_type.__class__.__name__.upper()
# 处理 None 值
if val is None:
# 如果目标列不允许 NULL提供默认值
if not target_col_nullable:
# Boolean 类型的默认值是 False
if target_type_name == "BOOLEAN" or isinstance(target_col_type, sqltypes.Boolean):
return False
# 数值类型的默认值
if target_type_name in ("INTEGER", "BIGINT", "SMALLINT") or isinstance(target_col_type, sqltypes.Integer):
return 0
if target_type_name in ("FLOAT", "DOUBLE", "REAL", "NUMERIC", "DECIMAL", "DOUBLE_PRECISION") or isinstance(target_col_type, sqltypes.Float):
return 0.0
# 日期时间类型的默认值
if target_type_name in ("DATETIME", "TIMESTAMP") or isinstance(target_col_type, sqltypes.DateTime):
return dt.now()
# 字符串类型的默认值
if target_type_name in ("VARCHAR", "STRING", "TEXT") or isinstance(target_col_type, (sqltypes.String, sqltypes.Text)):
return ""
# 其他类型也返回空字符串作为兜底
return ""
return None
# 处理 Boolean 类型转换
# SQLite 中 Boolean 实际存储为 INTEGER (0/1)
if target_type_name == "BOOLEAN" or isinstance(target_col_type, sqltypes.Boolean):
if isinstance(val, bool):
return val
if isinstance(val, (int, float)):
return bool(val)
if isinstance(val, str):
val_lower = val.lower().strip()
if val_lower in ("true", "1", "yes"):
return True
elif val_lower in ("false", "0", "no", ""):
return False
return bool(val) if val else False
# 处理 DateTime 类型转换
if target_type_name in ("DATETIME", "TIMESTAMP") or isinstance(target_col_type, sqltypes.DateTime):
if isinstance(val, dt):
return val
if isinstance(val, str):
val = val.strip()
# 空字符串 -> None
if val == "":
return None
# 尝试多种日期格式
for fmt in [
"%Y-%m-%d %H:%M:%S.%f",
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%dT%H:%M:%S.%f",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%d",
]:
try:
return dt.strptime(val, fmt)
except ValueError:
continue
# 如果都失败,尝试 fromisoformat
try:
return dt.fromisoformat(val)
except ValueError:
logger.warning("无法解析日期时间字符串 '%s' (列: %s),设为 None", val, col_name)
return None
# 如果是数值(时间戳),尝试转换
if isinstance(val, (int, float)) and val > 0:
try:
return dt.fromtimestamp(val)
except (OSError, ValueError, OverflowError):
return None
return None
# 处理 Float 类型
if target_type_name == "FLOAT" or isinstance(target_col_type, sqltypes.Float):
if isinstance(val, (int, float)):
return float(val)
if isinstance(val, str):
val = val.strip()
if val == "":
return None
try:
return float(val)
except ValueError:
return None
return val
# 处理 Integer 类型
if target_type_name == "INTEGER" or isinstance(target_col_type, sqltypes.Integer):
if isinstance(val, int):
return val
if isinstance(val, float):
return int(val)
if isinstance(val, str):
val = val.strip()
if val == "":
return None
try:
return int(float(val))
except ValueError:
return None
return val
return val
def copy_table_structure(source_table: Table, target_metadata: MetaData, target_engine: Engine) -> Table:
"""复制表结构到目标数据库,使其结构保持一致"""
target_is_sqlite = target_engine.dialect.name == "sqlite"
target_is_pg = target_engine.dialect.name == "postgresql"
columns = []
for c in source_table.columns:
new_col = c.copy()
# SQLite 不支持 nextval 等 server_default
if target_is_sqlite:
new_col.server_default = None
# PostgreSQL 需要将部分 SQLite 特有类型转换
if target_is_pg:
col_type = new_col.type
# SQLite DATETIME -> 通用 DateTime
if isinstance(col_type, sqltypes.DateTime) or col_type.__class__.__name__ in {"DATETIME", "DateTime"}:
new_col.type = sqltypes.DateTime()
# TEXT(50) 等长度受限的 TEXT 在 PG 无效,改用 String(length)
elif isinstance(col_type, sqltypes.Text) and getattr(col_type, "length", None):
new_col.type = sqltypes.String(length=col_type.length)
columns.append(new_col)
# 为避免迭代约束集合时出现 “Set changed size during iteration”这里不复制表级约束
target_table = Table(
source_table.name,
target_metadata,
*columns,
)
target_metadata.create_all(target_engine, tables=[target_table])
return target_table
def migrate_table_data(
source_conn: Connection,
target_engine: Engine,
source_table: Table,
target_table: Table,
batch_size: int = 1000,
target_dialect: str = "postgresql",
row_limit: int | None = None,
) -> tuple[int, int]:
"""迁移单个表的数据
Args:
source_conn: 源数据库连接
target_engine: 目标数据库引擎(注意:改为 engine 而不是 connection
source_table: 源表对象
target_table: 目标表对象
batch_size: 每批次处理大小
target_dialect: 目标数据库方言 (sqlite/postgresql)
row_limit: 最大迁移行数限制None 表示不限制
Returns:
tuple[int, int]: (迁移行数, 错误数量)
"""
total_rows = get_table_row_count(source_conn, source_table)
logger.info(
"开始迁移表: %s (共 %s 行)",
source_table.name,
total_rows if total_rows else "未知",
)
migrated_rows = 0
error_count = 0
conversion_warnings = 0
# 构建源列到目标列的映射
target_cols_by_name = {c.key: c for c in target_table.columns}
# 识别主键列(通常是 id迁移时保留原始 ID 以避免重复数据
primary_key_cols = {c.key for c in source_table.primary_key.columns}
# 使用流式查询,避免一次性加载太多数据
# 使用 text() 原始 SQL 查询,避免 SQLAlchemy 自动类型转换(如 DateTime导致的错误
try:
# 构建原始 SQL 查询语句
col_names = [c.key for c in source_table.columns]
if row_limit:
# 按时间或 ID 倒序取最新的 row_limit 条
raw_sql = text(f"SELECT {', '.join(col_names)} FROM {source_table.name} ORDER BY id DESC LIMIT {row_limit}")
logger.info(" 限制迁移最新 %d", row_limit)
else:
raw_sql = text(f"SELECT {', '.join(col_names)} FROM {source_table.name}")
result = source_conn.execute(raw_sql)
except SQLAlchemyError as e:
logger.error("查询表 %s 失败: %s", source_table.name, e)
return 0, 1
def insert_batch(rows: list[dict]):
"""每个批次使用独立的事务,批次失败时降级为逐行插入"""
nonlocal migrated_rows, error_count
if not rows:
return
try:
# 每个批次使用独立的事务
with target_engine.begin() as target_conn:
target_conn.execute(target_table.insert(), rows)
migrated_rows += len(rows)
logger.info(" 已迁移 %d/%s", migrated_rows, total_rows or "?")
except SQLAlchemyError as e:
# 批量插入失败,降级为逐行插入
logger.warning("批量插入失败,降级为逐行插入 (共 %d 行): %s", len(rows), str(e)[:200])
for row in rows:
try:
with target_engine.begin() as target_conn:
target_conn.execute(target_table.insert(), [row])
migrated_rows += 1
except SQLAlchemyError as row_e:
# 记录失败的行信息
row_id = row.get("id", "unknown")
logger.error("插入行失败 (id=%s): %s", row_id, str(row_e)[:200])
error_count += 1
logger.info(" 逐行插入完成,已迁移 %d/%s", migrated_rows, total_rows or "?")
batch: list[dict] = []
null_char_replacements = 0
# 构建列名列表(用于通过索引访问原始 SQL 结果)
col_list = list(source_table.columns)
col_name_to_idx = {c.key: idx for idx, c in enumerate(col_list)}
for row in result:
row_dict = {}
for col in col_list:
col_key = col.key
# 保留主键列id确保数据一致性
# 注意:如果目标表使用自增主键,可能需要重置序列
# 通过索引获取原始值(避免 SQLAlchemy 自动类型转换)
col_idx = col_name_to_idx[col_key]
val = row[col_idx]
# 处理 NUL 字符
if isinstance(val, str) and "\x00" in val:
val = val.replace("\x00", "")
null_char_replacements += 1
# 获取目标列类型进行转换
target_col = target_cols_by_name.get(col_key)
if target_col is not None:
try:
val = convert_value_for_target(
val=val,
col_name=col_key,
source_col_type=col.type,
target_col_type=target_col.type,
target_dialect=target_dialect,
target_col_nullable=target_col.nullable if target_col.nullable is not None else True,
)
except Exception as e:
conversion_warnings += 1
if conversion_warnings <= 5:
logger.warning(
"值转换异常 (表=%s, 列=%s, 值=%r): %s",
source_table.name, col_key, val, e
)
row_dict[col_key] = val
batch.append(row_dict)
if len(batch) >= batch_size:
insert_batch(batch)
batch = []
if batch:
insert_batch(batch)
logger.info(
"完成迁移表: %s (成功: %d 行, 失败: %d 行)",
source_table.name,
migrated_rows,
error_count,
)
if null_char_replacements:
logger.warning(
"%s%d 个字符串值包含 NUL 已被移除后写入目标库",
source_table.name,
null_char_replacements,
)
if conversion_warnings:
logger.warning(
"%s%d 个值发生类型转换警告",
source_table.name,
conversion_warnings,
)
return migrated_rows, error_count
def confirm_action(prompt: str, default: bool = False) -> bool:
"""确认操作
Args:
prompt: 提示信息
default: 默认值
Returns:
bool: 用户是否确认
"""
while True:
if default:
choice = input(f"{prompt} [Y/n]: ").strip().lower()
if choice == "":
return True
else:
choice = input(f"{prompt} [y/N]: ").strip().lower()
if choice == "":
return False
if choice in ("y", "yes"):
return True
elif choice in ("n", "no"):
return False
else:
print("请输入 y 或 n")
# =============================================================================
# 迁移器实现
# =============================================================================
class DatabaseMigrator:
"""通用数据库迁移器"""
def __init__(
self,
source_type: str,
target_type: str,
batch_size: int = 1000,
source_config: dict | None = None,
target_config: dict | None = None,
skip_tables: set | None = None,
only_tables: set | None = None,
no_create_tables: bool = False,
):
"""初始化迁移器
Args:
source_type: 源数据库类型
target_type: 目标数据库类型
batch_size: 批量处理大小
source_config: 源数据库配置(可选,默认从配置文件读取)
target_config: 目标数据库配置(可选,需要手动指定)
skip_tables: 要跳过的表名集合
only_tables: 只迁移的表名集合(设置后忽略 skip_tables
no_create_tables: 是否跳过创建表结构(假设目标表已存在)
"""
self.source_type = source_type.lower()
self.target_type = target_type.lower()
self.batch_size = batch_size
self.source_config = source_config
self.target_config = target_config
self.skip_tables = skip_tables or set()
self.only_tables = only_tables or set()
self.no_create_tables = no_create_tables
self._validate_database_types()
self.source_engine: Any = None
self.target_engine: Any = None
self.metadata = MetaData()
# 统计信息
self.stats = {
"tables_migrated": 0,
"rows_migrated": 0,
"errors": [],
"start_time": None,
"end_time": None,
}
def _validate_database_types(self):
"""验证数据库类型"""
supported_types = {"sqlite", "postgresql"}
if self.source_type not in supported_types:
raise ValueError(f"不支持的源数据库类型: {self.source_type}")
if self.target_type not in supported_types:
raise ValueError(f"不支持的目标数据库类型: {self.target_type}")
def _load_source_config(self) -> dict:
"""加载源数据库配置
如果初始化时提供了 source_config则直接使用
否则从 bot_config.toml 中读取。
"""
if self.source_config:
logger.info("使用传入的源数据库配置")
return self.source_config
logger.info("未提供源数据库配置,尝试从 bot_config.toml 读取")
config = get_database_config_from_toml(self.source_type)
if not config:
raise ValueError("无法从配置文件中读取源数据库配置,请检查 config/bot_config.toml")
logger.info("成功从配置文件读取源数据库配置")
return config
def _load_target_config(self) -> dict:
"""加载目标数据库配置
目标数据库配置必须通过初始化参数提供,或者通过命令行参数构建。
"""
if not self.target_config:
raise ValueError("未提供目标数据库配置,请通过命令行参数指定或在交互模式中输入")
logger.info("使用传入的目标数据库配置")
return self.target_config
def _connect_databases(self):
"""连接源数据库和目标数据库"""
# 源数据库配置
source_config = self._load_source_config()
# 目标数据库配置
target_config = self._load_target_config()
# 防止源/目标 SQLite 指向同一路径导致自我覆盖及锁
if (
self.source_type == "sqlite"
and self.target_type == "sqlite"
and os.path.abspath(source_config.get("path", "")) == os.path.abspath(target_config.get("path", ""))
):
raise ValueError("源数据库与目标数据库不能是同一个 SQLite 文件,请为目标指定不同的路径")
# 创建引擎
self.source_engine = create_engine_by_type(self.source_type, source_config)
self.target_engine = create_engine_by_type(self.target_type, target_config)
# 反射源数据库元数据
logger.info("正在反射源数据库元数据...")
self.metadata.reflect(bind=self.source_engine)
logger.info("发现 %d 张表: %s", len(self.metadata.tables), ", ".join(self.metadata.tables.keys()))
def _get_tables_in_dependency_order(self) -> list[Table]:
"""获取按依赖顺序排序的表列表
为了避免外键约束问题,创建表时需要按照依赖顺序,
例如先创建被引用的表,再创建引用它们的表。
"""
inspector = inspect(self.source_engine)
# 构建依赖图table -> set(dependent_tables)
dependencies: dict[str, set[str]] = {}
for table_name in self.metadata.tables:
dependencies[table_name] = set()
for table_name, table in self.metadata.tables.items():
fks = inspector.get_foreign_keys(table_name)
for fk in fks:
# 被引用的表
referred_table = fk["referred_table"]
if referred_table in dependencies:
dependencies[table_name].add(referred_table)
# 拓扑排序
sorted_tables: list[Table] = []
visited: set[str] = set()
temp_mark: set[str] = set()
def visit(table_name: str):
if table_name in visited:
return
if table_name in temp_mark:
logger.warning("检测到循环依赖,表: %s", table_name)
return
temp_mark.add(table_name)
for dep in dependencies[table_name]:
visit(dep)
temp_mark.remove(table_name)
visited.add(table_name)
sorted_tables.append(self.metadata.tables[table_name])
for table_name in dependencies:
if table_name not in visited:
visit(table_name)
return sorted_tables
def _drop_target_tables(self):
"""删除目标数据库中已有的表(如果有)
使用 Engine.begin() 进行连接以支持 autobegin 和 begin 兼容 SQLAlchemy 2.0 的写法
"""
if self.target_engine is None:
logger.warning("目标数据库引擎尚未初始化,无法删除表")
return
with self.target_engine.begin() as conn:
inspector = inspect(conn)
existing_tables = inspector.get_table_names()
if not existing_tables:
logger.info("目标数据库中没有已存在的表,无需删除")
return
logger.info("目标数据库中的当前表: %s", ", ".join(existing_tables))
if confirm_action("是否删除目标数据库中现有的表列表?此操作不可撤销", default=False):
for table_name in existing_tables:
try:
logger.info("删除目标数据库表: %s", table_name)
conn.execute(text(f"DROP TABLE IF EXISTS {table_name} CASCADE"))
except SQLAlchemyError as e:
logger.error("删除 %s 失败: %s", table_name, e)
self.stats["errors"].append(
f"删除 {table_name} 失败: {e}"
)
else:
logger.info("跳过删除目标数据库中的表,继续迁移过程")
def migrate(self):
"""执行迁移操作"""
import time
self.stats["start_time"] = time.time()
# 连接数据库
self._connect_databases()
# 获取表的依赖顺序
tables = self._get_tables_in_dependency_order()
logger.info("按依赖顺序迁移表: %s", ", ".join(t.name for t in tables))
# 如果指定了 only_tables则过滤表列表
if self.only_tables:
tables = [t for t in tables if t.name in self.only_tables]
logger.info("只迁移指定的表: %s", ", ".join(t.name for t in tables))
if not tables:
logger.warning("没有找到任何匹配 --only-tables 的表")
return
# 删除目标库中已有表(可选)- 如果是增量迁移则跳过
if not self.no_create_tables:
self._drop_target_tables()
# 获取目标数据库方言
target_dialect = self.target_engine.dialect.name
# 开始迁移
with self.source_engine.connect() as source_conn:
for source_table in tables:
# 跳过指定的表(仅在未指定 only_tables 时生效)
if not self.only_tables and source_table.name in self.skip_tables:
logger.info("跳过表: %s (在 skip_tables 列表中)", source_table.name)
continue
try:
# 在目标库中创建表结构(除非指定了 no_create_tables
if self.no_create_tables:
# 反射目标数据库中已存在的表结构
target_metadata = MetaData()
target_metadata.reflect(bind=self.target_engine, only=[source_table.name])
target_table = target_metadata.tables.get(source_table.name)
if target_table is None:
logger.error("目标数据库中不存在表: %s,请先创建表结构或移除 --no-create-tables 参数", source_table.name)
self.stats["errors"].append(f"目标数据库中不存在表: {source_table.name}")
continue
logger.info("使用目标数据库中已存在的表结构: %s", source_table.name)
else:
target_table = copy_table_structure(source_table, MetaData(), self.target_engine)
# 对 messages 表限制迁移行数(只迁移最新 1 万条)
row_limit = None
if source_table.name == "messages":
row_limit = 10000
logger.info("messages 表将只迁移最新 %d 条记录", row_limit)
# 每个批次使用独立事务,传入 engine 而不是 connection
migrated_rows, error_count = migrate_table_data(
source_conn,
self.target_engine,
source_table,
target_table,
batch_size=self.batch_size,
target_dialect=target_dialect,
row_limit=row_limit,
)
self.stats["tables_migrated"] += 1
self.stats["rows_migrated"] += migrated_rows
if error_count > 0:
self.stats["errors"].append(
f"{source_table.name} 迁移失败 {error_count}"
)
except Exception as e:
logger.error("迁移表 %s 时发生错误: %s", source_table.name, e)
self.stats["errors"].append(f"{source_table.name} 迁移失败: {e}")
self.stats["end_time"] = time.time()
# 迁移完成后,自动修复 PostgreSQL 特有问题
if self.target_type == "postgresql" and self.target_engine:
fix_postgresql_boolean_columns(self.target_engine)
fix_postgresql_sequences(self.target_engine)
def print_summary(self):
"""打印迁移总结"""
import time
duration = None
if self.stats["start_time"] is not None and self.stats["end_time"] is not None:
duration = self.stats["end_time"] - self.stats["start_time"]
print("\n" + "=" * 60)
print("迁移完成!")
print(f" 迁移表数量: {self.stats['tables_migrated']}")
print(f" 迁移行数量: {self.stats['rows_migrated']}")
if duration is not None:
print(f" 总耗时: {duration:.2f}")
if self.stats["errors"]:
print(" ⚠️ 发生错误:")
for err in self.stats["errors"]:
print(f" - {err}")
else:
print(" 没有发生错误 🎉")
print("=" * 60 + "\n")
def run(self):
"""运行迁移并打印总结"""
self.migrate()
self.print_summary()
return self.stats
# =============================================================================
# 命令行参数解析
# =============================================================================
def parse_args():
"""解析命令行参数"""
parser = argparse.ArgumentParser(
description="数据库迁移工具 - 在 SQLite、PostgreSQL 之间迁移数据",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""示例:
# 从 SQLite 迁移到 PostgreSQL
python scripts/migrate_database.py \
--source sqlite \
--target postgresql \
--target-host localhost \
--target-port 5432 \
--target-database maibot \
--target-user postgres \
--target-password your_password
# 从 PostgreSQL 迁移到 SQLite
python scripts/migrate_database.py \
--source postgresql \
--source-host localhost \
--source-port 5432 \
--source-database maibot \
--source-user postgres \
--source-password your_password \
--target sqlite \
--target-path data/MaiBot_backup.db
# 使用交互式向导模式(推荐)
python scripts/migrate_database.py
python scripts/migrate_database.py --interactive
""",
)
# 基本参数
parser.add_argument(
"--source",
type=str,
choices=["sqlite", "postgresql"],
help="源数据库类型(不指定时,在交互模式中选择)",
)
parser.add_argument(
"--target",
type=str,
choices=["sqlite", "postgresql"],
help="目标数据库类型(不指定时,在交互模式中选择)",
)
parser.add_argument(
"--batch-size",
type=int,
default=1000,
help="批量处理大小(默认: 1000",
)
parser.add_argument(
"--interactive",
action="store_true",
help="启用交互式向导模式(推荐:直接运行脚本或加上此参数)",
)
# 源数据库参数(可选,默认从 bot_config.toml 读取)
source_group = parser.add_argument_group("源数据库配置(可选,默认从 bot_config.toml 读取)")
source_group.add_argument("--source-path", type=str, help="SQLite 数据库路径")
source_group.add_argument("--source-host", type=str, help="PostgreSQL 主机")
source_group.add_argument("--source-port", type=int, help="PostgreSQL 端口")
source_group.add_argument("--source-database", type=str, help="数据库名")
source_group.add_argument("--source-user", type=str, help="用户名")
source_group.add_argument("--source-password", type=str, help="密码")
# 目标数据库参数
target_group = parser.add_argument_group("目标数据库配置")
target_group.add_argument("--target-path", type=str, help="SQLite 数据库路径")
target_group.add_argument("--target-host", type=str, help="PostgreSQL 主机")
target_group.add_argument("--target-port", type=int, help="PostgreSQL 端口")
target_group.add_argument("--target-database", type=str, help="数据库名")
target_group.add_argument("--target-user", type=str, help="用户名")
target_group.add_argument("--target-password", type=str, help="密码")
target_group.add_argument("--target-schema", type=str, default="public", help="PostgreSQL schema")
# 跳过表参数
parser.add_argument(
"--skip-tables",
type=str,
default="",
help="跳过迁移的表名,多个表名用逗号分隔(如: messages,logs",
)
# 只迁移指定表参数
parser.add_argument(
"--only-tables",
type=str,
default="",
help="只迁移指定的表名,多个表名用逗号分隔(如: user_relationships,maizone_schedule_status。设置后将忽略 --skip-tables",
)
# 不创建表结构,假设目标表已存在
parser.add_argument(
"--no-create-tables",
action="store_true",
help="不创建表结构,假设目标数据库中的表已存在。用于增量迁移指定表的数据",
)
return parser.parse_args()
def build_config_from_args(args, prefix: str, db_type: str) -> dict | None:
"""从命令行参数构建配置
Args:
args: 命令行参数
prefix: 参数前缀 ("source""target")
db_type: 数据库类型
Returns:
配置字典或 None
"""
if db_type == "sqlite":
path = getattr(args, f"{prefix}_path", None)
if path:
return {"path": path}
return None
elif db_type == "postgresql":
host = getattr(args, f"{prefix}_host", None)
if not host:
return None
config = {
"host": host,
"port": getattr(args, f"{prefix}_port") or 5432,
"database": getattr(args, f"{prefix}_database") or "maibot",
"user": getattr(args, f"{prefix}_user") or "postgres",
"password": getattr(args, f"{prefix}_password") or "",
"schema": getattr(args, f"{prefix}_schema", "public"),
}
return config
return None
def _ask_choice(prompt: str, options: list[str], default_index: int | None = None) -> str:
"""在控制台中让用户从多个选项中选择一个"""
while True:
print()
print(prompt)
for i, opt in enumerate(options, start=1):
default_mark = ""
if default_index is not None and i - 1 == default_index:
default_mark = " (默认)"
print(f" {i}) {opt}{default_mark}")
ans = input("请输入选项编号: ").strip()
if not ans and default_index is not None:
return options[default_index]
if ans.isdigit():
idx = int(ans)
if 1 <= idx <= len(options):
return options[idx - 1]
print("❌ 无效的选择,请重新输入。")
def _ask_int(prompt: str, default: int | None = None) -> int:
"""在控制台中输入正整数"""
while True:
suffix = f" (默认 {default})" if default is not None else ""
raw = input(f"{prompt}{suffix}: ").strip()
if not raw and default is not None:
return default
try:
value = int(raw)
if value <= 0:
raise ValueError()
return value
except ValueError:
print("❌ 请输入一个大于 0 的整数。")
def _ask_str(
prompt: str,
default: str | None = None,
allow_empty: bool = False,
is_password: bool = False,
) -> str:
"""在控制台中输入字符串,可选默认值/密码输入"""
while True:
suffix = f" (默认: {default})" if default is not None else ""
full_prompt = f"{prompt}{suffix}: "
raw = getpass(full_prompt) if is_password else input(full_prompt)
raw = raw.strip()
if not raw:
if default is not None:
return default
if allow_empty:
return ""
print("❌ 输入不能为空,请重新输入。")
continue
return raw
def interactive_setup() -> dict:
"""交互式向导,返回用于初始化 DatabaseMigrator 的参数字典"""
print("=" * 60)
print("🌟 数据库迁移向导")
print("只需回答几个问题,我会帮你构造迁移配置。")
print("=" * 60)
db_types = ["sqlite", "postgresql"]
# 选择源数据库
source_type = _ask_choice("请选择【源数据库类型】:", db_types, default_index=0)
# 选择目标数据库(不能与源相同)
while True:
default_idx = 1 if len(db_types) >= 2 else 0
target_type = _ask_choice("请选择【目标数据库类型】:", db_types, default_index=default_idx)
if target_type != source_type:
break
print("❌ 目标数据库不能和源数据库相同,请重新选择。")
# 批量大小
batch_size = _ask_int("请输入批量大小 batch-size", default=1000)
# 源数据库配置:默认使用 bot_config.toml
print()
print("源数据库配置:")
print(" 默认会从 config/bot_config.toml 中读取对应配置。")
use_default_source = input("是否使用配置文件中的【源数据库】配置? [Y/n]: ").strip().lower()
if use_default_source in ("", "y", "yes"):
source_config = None # 让 DatabaseMigrator 自己去读配置
else:
# 简单交互式配置源数据库
print("请手动输入源数据库连接信息:")
if source_type == "sqlite":
source_path = _ask_str("源 SQLite 文件路径", default="data/MaiBot.db")
source_config = {"path": source_path}
else:
port_default = 5432
user_default = "postgres"
host = _ask_str("源数据库 host", default="localhost")
port = _ask_int("源数据库 port", default=port_default)
database = _ask_str("源数据库名", default="maibot")
user = _ask_str("源数据库用户名", default=user_default)
password = _ask_str("源数据库密码(输入时不回显)", default="", is_password=True)
source_config = {
"host": host,
"port": port,
"database": database,
"user": user,
"password": password,
}
if source_type == "postgresql":
source_config["schema"] = _ask_str("源数据库 schema", default="public")
# 目标数据库配置(必须显式确认)
print()
print("目标数据库配置:")
if target_type == "sqlite":
target_path = _ask_str(
"目标 SQLite 文件路径(若不存在会自动创建)",
default="data/MaiBot.db",
)
target_config = {"path": target_path}
else:
port_default = 5432
user_default = "postgres"
host = _ask_str("目标数据库 host", default="localhost")
port = _ask_int("目标数据库 port", default=port_default)
database = _ask_str("目标数据库名", default="maibot")
user = _ask_str("目标数据库用户名", default=user_default)
password = _ask_str("目标数据库密码(输入时不回显)", default="", is_password=True)
target_config = {
"host": host,
"port": port,
"database": database,
"user": user,
"password": password,
}
if target_type == "postgresql":
target_config["schema"] = _ask_str("目标数据库 schema", default="public")
print()
print("=" * 60)
print("迁移配置确认:")
print(f" 源数据库类型: {source_type}")
print(f" 目标数据库类型: {target_type}")
print(f" 批量大小: {batch_size}")
print("⚠️ 请确认目标数据库为空或可以被覆盖,并且已备份源数据库。")
confirm = input("是否开始迁移?[Y/n]: ").strip().lower()
if confirm not in ("", "y", "yes"):
print("已取消迁移。")
sys.exit(0)
return {
"source_type": source_type,
"target_type": target_type,
"batch_size": batch_size,
"source_config": source_config,
"target_config": target_config,
}
def fix_postgresql_sequences(engine: Engine):
"""修复 PostgreSQL 序列值
迁移数据后PostgreSQL 的序列(用于自增主键)可能没有更新到正确的值,
导致插入新记录时出现主键冲突。此函数会自动检测并重置所有序列。
Args:
engine: PostgreSQL 数据库引擎
"""
if engine.dialect.name != "postgresql":
logger.info("非 PostgreSQL 数据库,跳过序列修复")
return
logger.info("正在修复 PostgreSQL 序列...")
with engine.connect() as conn:
# 获取所有带有序列的表
result = conn.execute(text('''
SELECT
t.table_name,
c.column_name,
pg_get_serial_sequence(t.table_name, c.column_name) as sequence_name
FROM information_schema.tables t
JOIN information_schema.columns c
ON t.table_name = c.table_name AND t.table_schema = c.table_schema
WHERE t.table_schema = 'public'
AND t.table_type = 'BASE TABLE'
AND c.column_default LIKE 'nextval%'
ORDER BY t.table_name
'''))
sequences = result.fetchall()
logger.info("发现 %d 个带序列的表", len(sequences))
fixed_count = 0
for table_name, column_name, seq_name in sequences:
if seq_name:
try:
# 获取当前表中该列的最大值
max_result = conn.execute(text(f'SELECT COALESCE(MAX({column_name}), 0) FROM {table_name}'))
max_val = max_result.scalar()
# 设置序列的下一个值
next_val = max_val + 1
conn.execute(text(f"SELECT setval('{seq_name}', {next_val}, false)"))
conn.commit()
logger.info("%s.%s: 最大值=%d, 序列设为=%d", table_name, column_name, max_val, next_val)
fixed_count += 1
except Exception as e:
logger.warning("%s.%s: 修复失败 - %s", table_name, column_name, e)
logger.info("序列修复完成!共修复 %d 个序列", fixed_count)
def fix_postgresql_boolean_columns(engine: Engine):
"""修复 PostgreSQL 布尔列类型
从 SQLite 迁移后,布尔列可能是 INTEGER 类型。此函数将其转换为 BOOLEAN。
Args:
engine: PostgreSQL 数据库引擎
"""
if engine.dialect.name != "postgresql":
logger.info("非 PostgreSQL 数据库,跳过布尔列修复")
return
# 已知需要转换为 BOOLEAN 的列
BOOLEAN_COLUMNS = {
'messages': ['is_mentioned', 'is_emoji', 'is_picid', 'is_command',
'is_notify', 'is_public_notice', 'should_reply', 'should_act'],
'action_records': ['action_done', 'action_build_into_prompt'],
}
logger.info("正在检查并修复 PostgreSQL 布尔列...")
with engine.connect() as conn:
fixed_count = 0
for table_name, columns in BOOLEAN_COLUMNS.items():
for col_name in columns:
try:
# 检查当前类型
result = conn.execute(text(f'''
SELECT data_type FROM information_schema.columns
WHERE table_name = '{table_name}' AND column_name = '{col_name}'
'''))
row = result.fetchone()
if row and row[0] != 'boolean':
# 需要修复
conn.execute(text(f'''
ALTER TABLE {table_name}
ALTER COLUMN {col_name} TYPE BOOLEAN
USING CASE WHEN {col_name} = 0 THEN FALSE ELSE TRUE END
'''))
conn.commit()
logger.info("%s.%s: %s -> BOOLEAN", table_name, col_name, row[0])
fixed_count += 1
except Exception as e:
logger.warning(" ⚠️ %s.%s: 检查/修复失败 - %s", table_name, col_name, e)
if fixed_count > 0:
logger.info("布尔列修复完成!共修复 %d", fixed_count)
else:
logger.info("所有布尔列类型正确,无需修复")
def main():
"""主函数"""
args = parse_args()
# 如果没有任何参数,或者显式指定 --interactive则进入交互模式
if args.interactive or len(sys.argv) == 1:
params = interactive_setup()
try:
migrator = DatabaseMigrator(**params)
stats = migrator.run()
if stats["errors"]:
sys.exit(1)
return
except KeyboardInterrupt:
print("\n迁移被用户中断")
sys.exit(130)
except Exception as e:
print(f"迁移失败: {e}")
sys.exit(1)
# 非交互模式:保持原有行为,但如果没给 source/target就提示错误
if not args.source or not args.target:
print("错误: 非交互模式下必须指定 --source 和 --target。")
print("你也可以直接运行脚本或添加 --interactive 使用交互式向导。")
sys.exit(2)
# 构建配置
source_config = build_config_from_args(args, "source", args.source)
target_config = build_config_from_args(args, "target", args.target)
# 验证目标配置
if target_config is None:
if args.target == "sqlite":
if not args.target_path:
print("错误: 目标数据库为 SQLite 时,必须指定 --target-path或使用交互模式")
sys.exit(1)
target_config = {"path": args.target_path}
else:
if not args.target_host:
print(f"错误: 目标数据库为 {args.target} 时,必须指定 --target-host或使用交互模式")
sys.exit(1)
try:
# 解析跳过的表
skip_tables = set()
if args.skip_tables:
skip_tables = {t.strip() for t in args.skip_tables.split(",") if t.strip()}
logger.info("将跳过以下表: %s", ", ".join(skip_tables))
# 解析只迁移的表
only_tables = set()
if args.only_tables:
only_tables = {t.strip() for t in args.only_tables.split(",") if t.strip()}
logger.info("将只迁移以下表: %s", ", ".join(only_tables))
migrator = DatabaseMigrator(
source_type=args.source,
target_type=args.target,
batch_size=args.batch_size,
source_config=source_config,
target_config=target_config,
skip_tables=skip_tables,
only_tables=only_tables,
no_create_tables=args.no_create_tables,
)
stats = migrator.run()
# 如果有错误,返回非零退出码
if stats["errors"]:
sys.exit(1)
except KeyboardInterrupt:
print("\n迁移被用户中断")
sys.exit(130)
except Exception as e:
print(f"迁移失败: {e}")
sys.exit(1)
if __name__ == "__main__":
main()