Merge branch 'feature/kfc' of https://github.com/MoFox-Studio/MoFox-Core into feature/kfc
This commit is contained in:
@@ -10,16 +10,23 @@
|
|||||||
python scripts/migrate_database.py --help
|
python scripts/migrate_database.py --help
|
||||||
python scripts/migrate_database.py --source sqlite --target postgresql
|
python scripts/migrate_database.py --source sqlite --target postgresql
|
||||||
python scripts/migrate_database.py --source mysql --target postgresql --batch-size 5000
|
python scripts/migrate_database.py --source mysql --target postgresql --batch-size 5000
|
||||||
|
|
||||||
|
# 交互式向导模式(推荐)
|
||||||
|
python scripts/migrate_database.py
|
||||||
|
|
||||||
注意事项:
|
注意事项:
|
||||||
1. 迁移前请备份源数据库
|
1. 迁移前请备份源数据库
|
||||||
2. 目标数据库应该是空的或不存在的(脚本会自动创建表)
|
2. 目标数据库应该是空的或不存在的(脚本会自动创建表)
|
||||||
3. 迁移过程可能需要较长时间,请耐心等待
|
3. 迁移过程可能需要较长时间,请耐心等待
|
||||||
|
4. 迁移到 PostgreSQL 时,脚本会自动:
|
||||||
|
- 修复布尔列类型(SQLite INTEGER -> PostgreSQL BOOLEAN)
|
||||||
|
- 重置序列值(避免主键冲突)
|
||||||
|
|
||||||
实现细节:
|
实现细节:
|
||||||
- 使用 SQLAlchemy 进行数据库连接和元数据管理
|
- 使用 SQLAlchemy 进行数据库连接和元数据管理
|
||||||
- 采用流式迁移,避免一次性加载过多数据
|
- 采用流式迁移,避免一次性加载过多数据
|
||||||
- 支持 SQLite、MySQL、PostgreSQL 之间的互相迁移
|
- 支持 SQLite、MySQL、PostgreSQL 之间的互相迁移
|
||||||
|
- 批量插入失败时自动降级为逐行插入,最大程度保留数据
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -52,6 +59,8 @@ except ImportError:
|
|||||||
|
|
||||||
from typing import Any, Iterable, Callable
|
from typing import Any, Iterable, Callable
|
||||||
|
|
||||||
|
from datetime import datetime as dt
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
create_engine,
|
create_engine,
|
||||||
MetaData,
|
MetaData,
|
||||||
@@ -314,6 +323,143 @@ def get_table_row_count(conn: Connection, table: Table) -> int:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def convert_value_for_target(
|
||||||
|
val: Any,
|
||||||
|
col_name: str,
|
||||||
|
source_col_type: Any,
|
||||||
|
target_col_type: Any,
|
||||||
|
target_dialect: str,
|
||||||
|
target_col_nullable: bool = True,
|
||||||
|
) -> Any:
|
||||||
|
"""转换值以适配目标数据库类型
|
||||||
|
|
||||||
|
处理以下情况:
|
||||||
|
1. 空字符串日期时间 -> None
|
||||||
|
2. SQLite INTEGER (0/1) -> PostgreSQL BOOLEAN
|
||||||
|
3. 字符串日期时间 -> datetime 对象
|
||||||
|
4. 跳过主键 id (让目标数据库自增)
|
||||||
|
5. 对于 NOT NULL 列,提供合适的默认值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
val: 原始值
|
||||||
|
col_name: 列名
|
||||||
|
source_col_type: 源列类型
|
||||||
|
target_col_type: 目标列类型
|
||||||
|
target_dialect: 目标数据库方言名称
|
||||||
|
target_col_nullable: 目标列是否允许 NULL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
转换后的值
|
||||||
|
"""
|
||||||
|
# 获取目标类型的类名
|
||||||
|
target_type_name = target_col_type.__class__.__name__.upper()
|
||||||
|
source_type_name = source_col_type.__class__.__name__.upper()
|
||||||
|
|
||||||
|
# 处理 None 值
|
||||||
|
if val is None:
|
||||||
|
# 如果目标列不允许 NULL,提供默认值
|
||||||
|
if not target_col_nullable:
|
||||||
|
# Boolean 类型的默认值是 False
|
||||||
|
if target_type_name == "BOOLEAN" or isinstance(target_col_type, sqltypes.Boolean):
|
||||||
|
return False
|
||||||
|
# 数值类型的默认值
|
||||||
|
if target_type_name in ("INTEGER", "BIGINT", "SMALLINT") or isinstance(target_col_type, sqltypes.Integer):
|
||||||
|
return 0
|
||||||
|
if target_type_name in ("FLOAT", "DOUBLE", "REAL", "NUMERIC", "DECIMAL", "DOUBLE_PRECISION") or isinstance(target_col_type, sqltypes.Float):
|
||||||
|
return 0.0
|
||||||
|
# 日期时间类型的默认值
|
||||||
|
if target_type_name in ("DATETIME", "TIMESTAMP") or isinstance(target_col_type, sqltypes.DateTime):
|
||||||
|
return dt.now()
|
||||||
|
# 字符串类型的默认值
|
||||||
|
if target_type_name in ("VARCHAR", "STRING", "TEXT") or isinstance(target_col_type, (sqltypes.String, sqltypes.Text)):
|
||||||
|
return ""
|
||||||
|
# 其他类型也返回空字符串作为兜底
|
||||||
|
return ""
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 处理 Boolean 类型转换
|
||||||
|
# SQLite 中 Boolean 实际存储为 INTEGER (0/1)
|
||||||
|
if target_type_name == "BOOLEAN" or isinstance(target_col_type, sqltypes.Boolean):
|
||||||
|
if isinstance(val, bool):
|
||||||
|
return val
|
||||||
|
if isinstance(val, (int, float)):
|
||||||
|
return bool(val)
|
||||||
|
if isinstance(val, str):
|
||||||
|
val_lower = val.lower().strip()
|
||||||
|
if val_lower in ("true", "1", "yes"):
|
||||||
|
return True
|
||||||
|
elif val_lower in ("false", "0", "no", ""):
|
||||||
|
return False
|
||||||
|
return bool(val) if val else False
|
||||||
|
|
||||||
|
# 处理 DateTime 类型转换
|
||||||
|
if target_type_name in ("DATETIME", "TIMESTAMP") or isinstance(target_col_type, sqltypes.DateTime):
|
||||||
|
if isinstance(val, dt):
|
||||||
|
return val
|
||||||
|
if isinstance(val, str):
|
||||||
|
val = val.strip()
|
||||||
|
# 空字符串 -> None
|
||||||
|
if val == "":
|
||||||
|
return None
|
||||||
|
# 尝试多种日期格式
|
||||||
|
for fmt in [
|
||||||
|
"%Y-%m-%d %H:%M:%S.%f",
|
||||||
|
"%Y-%m-%d %H:%M:%S",
|
||||||
|
"%Y-%m-%dT%H:%M:%S.%f",
|
||||||
|
"%Y-%m-%dT%H:%M:%S",
|
||||||
|
"%Y-%m-%d",
|
||||||
|
]:
|
||||||
|
try:
|
||||||
|
return dt.strptime(val, fmt)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
# 如果都失败,尝试 fromisoformat
|
||||||
|
try:
|
||||||
|
return dt.fromisoformat(val)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("无法解析日期时间字符串 '%s' (列: %s),设为 None", val, col_name)
|
||||||
|
return None
|
||||||
|
# 如果是数值(时间戳),尝试转换
|
||||||
|
if isinstance(val, (int, float)) and val > 0:
|
||||||
|
try:
|
||||||
|
return dt.fromtimestamp(val)
|
||||||
|
except (OSError, ValueError, OverflowError):
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 处理 Float 类型
|
||||||
|
if target_type_name == "FLOAT" or isinstance(target_col_type, sqltypes.Float):
|
||||||
|
if isinstance(val, (int, float)):
|
||||||
|
return float(val)
|
||||||
|
if isinstance(val, str):
|
||||||
|
val = val.strip()
|
||||||
|
if val == "":
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return float(val)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return val
|
||||||
|
|
||||||
|
# 处理 Integer 类型
|
||||||
|
if target_type_name == "INTEGER" or isinstance(target_col_type, sqltypes.Integer):
|
||||||
|
if isinstance(val, int):
|
||||||
|
return val
|
||||||
|
if isinstance(val, float):
|
||||||
|
return int(val)
|
||||||
|
if isinstance(val, str):
|
||||||
|
val = val.strip()
|
||||||
|
if val == "":
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return int(float(val))
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return val
|
||||||
|
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
def copy_table_structure(source_table: Table, target_metadata: MetaData, target_engine: Engine) -> Table:
|
def copy_table_structure(source_table: Table, target_metadata: MetaData, target_engine: Engine) -> Table:
|
||||||
"""复制表结构到目标数据库,使其结构保持一致"""
|
"""复制表结构到目标数据库,使其结构保持一致"""
|
||||||
target_is_sqlite = target_engine.dialect.name == "sqlite"
|
target_is_sqlite = target_engine.dialect.name == "sqlite"
|
||||||
@@ -351,19 +497,23 @@ def copy_table_structure(source_table: Table, target_metadata: MetaData, target_
|
|||||||
|
|
||||||
def migrate_table_data(
|
def migrate_table_data(
|
||||||
source_conn: Connection,
|
source_conn: Connection,
|
||||||
target_conn: Connection,
|
target_engine: Engine,
|
||||||
source_table: Table,
|
source_table: Table,
|
||||||
target_table: Table,
|
target_table: Table,
|
||||||
batch_size: int = 1000,
|
batch_size: int = 1000,
|
||||||
|
target_dialect: str = "postgresql",
|
||||||
|
row_limit: int | None = None,
|
||||||
) -> tuple[int, int]:
|
) -> tuple[int, int]:
|
||||||
"""迁移单个表的数据
|
"""迁移单个表的数据
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source_conn: 源数据库连接
|
source_conn: 源数据库连接
|
||||||
target_conn: 目标数据库连接
|
target_engine: 目标数据库引擎(注意:改为 engine 而不是 connection)
|
||||||
source_table: 源表对象
|
source_table: 源表对象
|
||||||
target_table: 目标表对象
|
target_table: 目标表对象
|
||||||
batch_size: 每批次处理大小
|
batch_size: 每批次处理大小
|
||||||
|
target_dialect: 目标数据库方言 (sqlite/mysql/postgresql)
|
||||||
|
row_limit: 最大迁移行数限制,None 表示不限制
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[int, int]: (迁移行数, 错误数量)
|
tuple[int, int]: (迁移行数, 错误数量)
|
||||||
@@ -377,40 +527,101 @@ def migrate_table_data(
|
|||||||
|
|
||||||
migrated_rows = 0
|
migrated_rows = 0
|
||||||
error_count = 0
|
error_count = 0
|
||||||
|
conversion_warnings = 0
|
||||||
|
|
||||||
|
# 构建源列到目标列的映射
|
||||||
|
target_cols_by_name = {c.key: c for c in target_table.columns}
|
||||||
|
|
||||||
|
# 识别主键列(通常是 id),迁移时保留原始 ID 以避免重复数据
|
||||||
|
primary_key_cols = {c.key for c in source_table.primary_key.columns}
|
||||||
|
|
||||||
# 使用流式查询,避免一次性加载太多数据
|
# 使用流式查询,避免一次性加载太多数据
|
||||||
# 对于 SQLAlchemy 1.4/2.0 可以使用 yield_per
|
# 使用 text() 原始 SQL 查询,避免 SQLAlchemy 自动类型转换(如 DateTime)导致的错误
|
||||||
try:
|
try:
|
||||||
select_stmt = source_table.select()
|
# 构建原始 SQL 查询语句
|
||||||
result = source_conn.execute(select_stmt)
|
col_names = [c.key for c in source_table.columns]
|
||||||
|
if row_limit:
|
||||||
|
# 按时间或 ID 倒序取最新的 row_limit 条
|
||||||
|
raw_sql = text(f"SELECT {', '.join(col_names)} FROM {source_table.name} ORDER BY id DESC LIMIT {row_limit}")
|
||||||
|
logger.info(" 限制迁移最新 %d 行", row_limit)
|
||||||
|
else:
|
||||||
|
raw_sql = text(f"SELECT {', '.join(col_names)} FROM {source_table.name}")
|
||||||
|
result = source_conn.execute(raw_sql)
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
logger.error("查询表 %s 失败: %s", source_table.name, e)
|
logger.error("查询表 %s 失败: %s", source_table.name, e)
|
||||||
return 0, 1
|
return 0, 1
|
||||||
|
|
||||||
def insert_batch(rows: list[dict]):
|
def insert_batch(rows: list[dict]):
|
||||||
|
"""每个批次使用独立的事务,批次失败时降级为逐行插入"""
|
||||||
nonlocal migrated_rows, error_count
|
nonlocal migrated_rows, error_count
|
||||||
if not rows:
|
if not rows:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
target_conn.execute(target_table.insert(), rows)
|
# 每个批次使用独立的事务
|
||||||
|
with target_engine.begin() as target_conn:
|
||||||
|
target_conn.execute(target_table.insert(), rows)
|
||||||
migrated_rows += len(rows)
|
migrated_rows += len(rows)
|
||||||
logger.info(" 已迁移 %d/%s 行", migrated_rows, total_rows or "?")
|
logger.info(" 已迁移 %d/%s 行", migrated_rows, total_rows or "?")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
logger.error("写入表 %s 失败: %s", target_table.name, e)
|
# 批量插入失败,降级为逐行插入
|
||||||
error_count += len(rows)
|
logger.warning("批量插入失败,降级为逐行插入 (共 %d 行): %s", len(rows), str(e)[:200])
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
with target_engine.begin() as target_conn:
|
||||||
|
target_conn.execute(target_table.insert(), [row])
|
||||||
|
migrated_rows += 1
|
||||||
|
except SQLAlchemyError as row_e:
|
||||||
|
# 记录失败的行信息
|
||||||
|
row_id = row.get("id", "unknown")
|
||||||
|
logger.error("插入行失败 (id=%s): %s", row_id, str(row_e)[:200])
|
||||||
|
error_count += 1
|
||||||
|
logger.info(" 逐行插入完成,已迁移 %d/%s 行", migrated_rows, total_rows or "?")
|
||||||
|
|
||||||
batch: list[dict] = []
|
batch: list[dict] = []
|
||||||
null_char_replacements = 0
|
null_char_replacements = 0
|
||||||
|
|
||||||
|
# 构建列名列表(用于通过索引访问原始 SQL 结果)
|
||||||
|
col_list = list(source_table.columns)
|
||||||
|
col_name_to_idx = {c.key: idx for idx, c in enumerate(col_list)}
|
||||||
|
|
||||||
for row in result:
|
for row in result:
|
||||||
# Use column objects to access row mapping to avoid quoted_name keys
|
|
||||||
row_dict = {}
|
row_dict = {}
|
||||||
for col in source_table.columns:
|
for col in col_list:
|
||||||
val = row._mapping[col]
|
col_key = col.key
|
||||||
|
|
||||||
|
# 保留主键列(id),确保数据一致性
|
||||||
|
# 注意:如果目标表使用自增主键,可能需要重置序列
|
||||||
|
|
||||||
|
# 通过索引获取原始值(避免 SQLAlchemy 自动类型转换)
|
||||||
|
col_idx = col_name_to_idx[col_key]
|
||||||
|
val = row[col_idx]
|
||||||
|
|
||||||
|
# 处理 NUL 字符
|
||||||
if isinstance(val, str) and "\x00" in val:
|
if isinstance(val, str) and "\x00" in val:
|
||||||
val = val.replace("\x00", "")
|
val = val.replace("\x00", "")
|
||||||
null_char_replacements += 1
|
null_char_replacements += 1
|
||||||
row_dict[col.key] = val
|
|
||||||
|
# 获取目标列类型进行转换
|
||||||
|
target_col = target_cols_by_name.get(col_key)
|
||||||
|
if target_col is not None:
|
||||||
|
try:
|
||||||
|
val = convert_value_for_target(
|
||||||
|
val=val,
|
||||||
|
col_name=col_key,
|
||||||
|
source_col_type=col.type,
|
||||||
|
target_col_type=target_col.type,
|
||||||
|
target_dialect=target_dialect,
|
||||||
|
target_col_nullable=target_col.nullable if target_col.nullable is not None else True,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
conversion_warnings += 1
|
||||||
|
if conversion_warnings <= 5:
|
||||||
|
logger.warning(
|
||||||
|
"值转换异常 (表=%s, 列=%s, 值=%r): %s",
|
||||||
|
source_table.name, col_key, val, e
|
||||||
|
)
|
||||||
|
|
||||||
|
row_dict[col_key] = val
|
||||||
|
|
||||||
batch.append(row_dict)
|
batch.append(row_dict)
|
||||||
if len(batch) >= batch_size:
|
if len(batch) >= batch_size:
|
||||||
@@ -432,6 +643,12 @@ def migrate_table_data(
|
|||||||
source_table.name,
|
source_table.name,
|
||||||
null_char_replacements,
|
null_char_replacements,
|
||||||
)
|
)
|
||||||
|
if conversion_warnings:
|
||||||
|
logger.warning(
|
||||||
|
"表 %s 中 %d 个值发生类型转换警告",
|
||||||
|
source_table.name,
|
||||||
|
conversion_warnings,
|
||||||
|
)
|
||||||
|
|
||||||
return migrated_rows, error_count
|
return migrated_rows, error_count
|
||||||
|
|
||||||
@@ -479,6 +696,9 @@ class DatabaseMigrator:
|
|||||||
batch_size: int = 1000,
|
batch_size: int = 1000,
|
||||||
source_config: dict | None = None,
|
source_config: dict | None = None,
|
||||||
target_config: dict | None = None,
|
target_config: dict | None = None,
|
||||||
|
skip_tables: set | None = None,
|
||||||
|
only_tables: set | None = None,
|
||||||
|
no_create_tables: bool = False,
|
||||||
):
|
):
|
||||||
"""初始化迁移器
|
"""初始化迁移器
|
||||||
|
|
||||||
@@ -488,12 +708,18 @@ class DatabaseMigrator:
|
|||||||
batch_size: 批量处理大小
|
batch_size: 批量处理大小
|
||||||
source_config: 源数据库配置(可选,默认从配置文件读取)
|
source_config: 源数据库配置(可选,默认从配置文件读取)
|
||||||
target_config: 目标数据库配置(可选,需要手动指定)
|
target_config: 目标数据库配置(可选,需要手动指定)
|
||||||
|
skip_tables: 要跳过的表名集合
|
||||||
|
only_tables: 只迁移的表名集合(设置后忽略 skip_tables)
|
||||||
|
no_create_tables: 是否跳过创建表结构(假设目标表已存在)
|
||||||
"""
|
"""
|
||||||
self.source_type = source_type.lower()
|
self.source_type = source_type.lower()
|
||||||
self.target_type = target_type.lower()
|
self.target_type = target_type.lower()
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.source_config = source_config
|
self.source_config = source_config
|
||||||
self.target_config = target_config
|
self.target_config = target_config
|
||||||
|
self.skip_tables = skip_tables or set()
|
||||||
|
self.only_tables = only_tables or set()
|
||||||
|
self.no_create_tables = no_create_tables
|
||||||
|
|
||||||
self._validate_database_types()
|
self._validate_database_types()
|
||||||
|
|
||||||
@@ -659,25 +885,60 @@ class DatabaseMigrator:
|
|||||||
tables = self._get_tables_in_dependency_order()
|
tables = self._get_tables_in_dependency_order()
|
||||||
logger.info("按依赖顺序迁移表: %s", ", ".join(t.name for t in tables))
|
logger.info("按依赖顺序迁移表: %s", ", ".join(t.name for t in tables))
|
||||||
|
|
||||||
# 删除目标库中已有表(可选)
|
# 如果指定了 only_tables,则过滤表列表
|
||||||
self._drop_target_tables()
|
if self.only_tables:
|
||||||
|
tables = [t for t in tables if t.name in self.only_tables]
|
||||||
|
logger.info("只迁移指定的表: %s", ", ".join(t.name for t in tables))
|
||||||
|
if not tables:
|
||||||
|
logger.warning("没有找到任何匹配 --only-tables 的表")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 删除目标库中已有表(可选)- 如果是增量迁移则跳过
|
||||||
|
if not self.no_create_tables:
|
||||||
|
self._drop_target_tables()
|
||||||
|
|
||||||
|
# 获取目标数据库方言
|
||||||
|
target_dialect = self.target_engine.dialect.name
|
||||||
|
|
||||||
# 开始迁移
|
# 开始迁移
|
||||||
with self.source_engine.connect() as source_conn:
|
with self.source_engine.connect() as source_conn:
|
||||||
for source_table in tables:
|
for source_table in tables:
|
||||||
try:
|
# 跳过指定的表(仅在未指定 only_tables 时生效)
|
||||||
# 在目标库中创建表结构
|
if not self.only_tables and source_table.name in self.skip_tables:
|
||||||
target_table = copy_table_structure(source_table, MetaData(), self.target_engine)
|
logger.info("跳过表: %s (在 skip_tables 列表中)", source_table.name)
|
||||||
|
continue
|
||||||
|
|
||||||
# 每张表单独事务,避免退出上下文被自动回滚
|
try:
|
||||||
with self.target_engine.begin() as target_conn:
|
# 在目标库中创建表结构(除非指定了 no_create_tables)
|
||||||
migrated_rows, error_count = migrate_table_data(
|
if self.no_create_tables:
|
||||||
source_conn,
|
# 反射目标数据库中已存在的表结构
|
||||||
target_conn,
|
target_metadata = MetaData()
|
||||||
source_table,
|
target_metadata.reflect(bind=self.target_engine, only=[source_table.name])
|
||||||
target_table,
|
target_table = target_metadata.tables.get(source_table.name)
|
||||||
batch_size=self.batch_size,
|
if target_table is None:
|
||||||
)
|
logger.error("目标数据库中不存在表: %s,请先创建表结构或移除 --no-create-tables 参数", source_table.name)
|
||||||
|
self.stats["errors"].append(f"目标数据库中不存在表: {source_table.name}")
|
||||||
|
continue
|
||||||
|
logger.info("使用目标数据库中已存在的表结构: %s", source_table.name)
|
||||||
|
else:
|
||||||
|
target_table = copy_table_structure(source_table, MetaData(), self.target_engine)
|
||||||
|
|
||||||
|
# 对 messages 表限制迁移行数(只迁移最新 1 万条)
|
||||||
|
row_limit = None
|
||||||
|
if source_table.name == "messages":
|
||||||
|
row_limit = 10000
|
||||||
|
logger.info("messages 表将只迁移最新 %d 条记录", row_limit)
|
||||||
|
|
||||||
|
# 每个批次使用独立事务,传入 engine 而不是 connection
|
||||||
|
migrated_rows, error_count = migrate_table_data(
|
||||||
|
source_conn,
|
||||||
|
self.target_engine,
|
||||||
|
source_table,
|
||||||
|
target_table,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
target_dialect=target_dialect,
|
||||||
|
row_limit=row_limit,
|
||||||
|
)
|
||||||
|
|
||||||
self.stats["tables_migrated"] += 1
|
self.stats["tables_migrated"] += 1
|
||||||
self.stats["rows_migrated"] += migrated_rows
|
self.stats["rows_migrated"] += migrated_rows
|
||||||
@@ -691,6 +952,11 @@ class DatabaseMigrator:
|
|||||||
self.stats["errors"].append(f"表 {source_table.name} 迁移失败: {e}")
|
self.stats["errors"].append(f"表 {source_table.name} 迁移失败: {e}")
|
||||||
|
|
||||||
self.stats["end_time"] = time.time()
|
self.stats["end_time"] = time.time()
|
||||||
|
|
||||||
|
# 迁移完成后,自动修复 PostgreSQL 特有问题
|
||||||
|
if self.target_type == "postgresql" and self.target_engine:
|
||||||
|
fix_postgresql_boolean_columns(self.target_engine)
|
||||||
|
fix_postgresql_sequences(self.target_engine)
|
||||||
|
|
||||||
def print_summary(self):
|
def print_summary(self):
|
||||||
"""打印迁移总结"""
|
"""打印迁移总结"""
|
||||||
@@ -804,6 +1070,29 @@ def parse_args():
|
|||||||
target_group.add_argument("--target-schema", type=str, default="public", help="PostgreSQL schema")
|
target_group.add_argument("--target-schema", type=str, default="public", help="PostgreSQL schema")
|
||||||
target_group.add_argument("--target-charset", type=str, default="utf8mb4", help="MySQL 字符集")
|
target_group.add_argument("--target-charset", type=str, default="utf8mb4", help="MySQL 字符集")
|
||||||
|
|
||||||
|
# 跳过表参数
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-tables",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="跳过迁移的表名,多个表名用逗号分隔(如: messages,logs)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 只迁移指定表参数
|
||||||
|
parser.add_argument(
|
||||||
|
"--only-tables",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="只迁移指定的表名,多个表名用逗号分隔(如: user_relationships,maizone_schedule_status)。设置后将忽略 --skip-tables",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 不创建表结构,假设目标表已存在
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-create-tables",
|
||||||
|
action="store_true",
|
||||||
|
help="不创建表结构,假设目标数据库中的表已存在。用于增量迁移指定表的数据",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@@ -1012,6 +1301,112 @@ def interactive_setup() -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def fix_postgresql_sequences(engine: Engine):
|
||||||
|
"""修复 PostgreSQL 序列值
|
||||||
|
|
||||||
|
迁移数据后,PostgreSQL 的序列(用于自增主键)可能没有更新到正确的值,
|
||||||
|
导致插入新记录时出现主键冲突。此函数会自动检测并重置所有序列。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
engine: PostgreSQL 数据库引擎
|
||||||
|
"""
|
||||||
|
if engine.dialect.name != "postgresql":
|
||||||
|
logger.info("非 PostgreSQL 数据库,跳过序列修复")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("正在修复 PostgreSQL 序列...")
|
||||||
|
|
||||||
|
with engine.connect() as conn:
|
||||||
|
# 获取所有带有序列的表
|
||||||
|
result = conn.execute(text('''
|
||||||
|
SELECT
|
||||||
|
t.table_name,
|
||||||
|
c.column_name,
|
||||||
|
pg_get_serial_sequence(t.table_name, c.column_name) as sequence_name
|
||||||
|
FROM information_schema.tables t
|
||||||
|
JOIN information_schema.columns c
|
||||||
|
ON t.table_name = c.table_name AND t.table_schema = c.table_schema
|
||||||
|
WHERE t.table_schema = 'public'
|
||||||
|
AND t.table_type = 'BASE TABLE'
|
||||||
|
AND c.column_default LIKE 'nextval%'
|
||||||
|
ORDER BY t.table_name
|
||||||
|
'''))
|
||||||
|
|
||||||
|
sequences = result.fetchall()
|
||||||
|
logger.info("发现 %d 个带序列的表", len(sequences))
|
||||||
|
|
||||||
|
fixed_count = 0
|
||||||
|
for table_name, column_name, seq_name in sequences:
|
||||||
|
if seq_name:
|
||||||
|
try:
|
||||||
|
# 获取当前表中该列的最大值
|
||||||
|
max_result = conn.execute(text(f'SELECT COALESCE(MAX({column_name}), 0) FROM {table_name}'))
|
||||||
|
max_val = max_result.scalar()
|
||||||
|
|
||||||
|
# 设置序列的下一个值
|
||||||
|
next_val = max_val + 1
|
||||||
|
conn.execute(text(f"SELECT setval('{seq_name}', {next_val}, false)"))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
logger.info(" ✅ %s.%s: 最大值=%d, 序列设为=%d", table_name, column_name, max_val, next_val)
|
||||||
|
fixed_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(" ❌ %s.%s: 修复失败 - %s", table_name, column_name, e)
|
||||||
|
|
||||||
|
logger.info("序列修复完成!共修复 %d 个序列", fixed_count)
|
||||||
|
|
||||||
|
|
||||||
|
def fix_postgresql_boolean_columns(engine: Engine):
|
||||||
|
"""修复 PostgreSQL 布尔列类型
|
||||||
|
|
||||||
|
从 SQLite 迁移后,布尔列可能是 INTEGER 类型。此函数将其转换为 BOOLEAN。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
engine: PostgreSQL 数据库引擎
|
||||||
|
"""
|
||||||
|
if engine.dialect.name != "postgresql":
|
||||||
|
logger.info("非 PostgreSQL 数据库,跳过布尔列修复")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 已知需要转换为 BOOLEAN 的列
|
||||||
|
BOOLEAN_COLUMNS = {
|
||||||
|
'messages': ['is_mentioned', 'is_emoji', 'is_picid', 'is_command',
|
||||||
|
'is_notify', 'is_public_notice', 'should_reply', 'should_act'],
|
||||||
|
'action_records': ['action_done', 'action_build_into_prompt'],
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("正在检查并修复 PostgreSQL 布尔列...")
|
||||||
|
|
||||||
|
with engine.connect() as conn:
|
||||||
|
fixed_count = 0
|
||||||
|
for table_name, columns in BOOLEAN_COLUMNS.items():
|
||||||
|
for col_name in columns:
|
||||||
|
try:
|
||||||
|
# 检查当前类型
|
||||||
|
result = conn.execute(text(f'''
|
||||||
|
SELECT data_type FROM information_schema.columns
|
||||||
|
WHERE table_name = '{table_name}' AND column_name = '{col_name}'
|
||||||
|
'''))
|
||||||
|
row = result.fetchone()
|
||||||
|
if row and row[0] != 'boolean':
|
||||||
|
# 需要修复
|
||||||
|
conn.execute(text(f'''
|
||||||
|
ALTER TABLE {table_name}
|
||||||
|
ALTER COLUMN {col_name} TYPE BOOLEAN
|
||||||
|
USING CASE WHEN {col_name} = 0 THEN FALSE ELSE TRUE END
|
||||||
|
'''))
|
||||||
|
conn.commit()
|
||||||
|
logger.info(" ✅ %s.%s: %s -> BOOLEAN", table_name, col_name, row[0])
|
||||||
|
fixed_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(" ⚠️ %s.%s: 检查/修复失败 - %s", table_name, col_name, e)
|
||||||
|
|
||||||
|
if fixed_count > 0:
|
||||||
|
logger.info("布尔列修复完成!共修复 %d 列", fixed_count)
|
||||||
|
else:
|
||||||
|
logger.info("所有布尔列类型正确,无需修复")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""主函数"""
|
"""主函数"""
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
@@ -1055,12 +1450,27 @@ def main():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 解析跳过的表
|
||||||
|
skip_tables = set()
|
||||||
|
if args.skip_tables:
|
||||||
|
skip_tables = {t.strip() for t in args.skip_tables.split(",") if t.strip()}
|
||||||
|
logger.info("将跳过以下表: %s", ", ".join(skip_tables))
|
||||||
|
|
||||||
|
# 解析只迁移的表
|
||||||
|
only_tables = set()
|
||||||
|
if args.only_tables:
|
||||||
|
only_tables = {t.strip() for t in args.only_tables.split(",") if t.strip()}
|
||||||
|
logger.info("将只迁移以下表: %s", ", ".join(only_tables))
|
||||||
|
|
||||||
migrator = DatabaseMigrator(
|
migrator = DatabaseMigrator(
|
||||||
source_type=args.source,
|
source_type=args.source,
|
||||||
target_type=args.target,
|
target_type=args.target,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
source_config=source_config,
|
source_config=source_config,
|
||||||
target_config=target_config,
|
target_config=target_config,
|
||||||
|
skip_tables=skip_tables,
|
||||||
|
only_tables=only_tables,
|
||||||
|
no_create_tables=args.no_create_tables,
|
||||||
)
|
)
|
||||||
|
|
||||||
stats = migrator.run()
|
stats = migrator.run()
|
||||||
|
|||||||
@@ -1,77 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""重置 PostgreSQL 序列值
|
|
||||||
|
|
||||||
迁移数据后,PostgreSQL 的序列(用于自增主键)可能没有更新到正确的值,
|
|
||||||
导致插入新记录时出现主键冲突。此脚本会自动检测并重置所有序列。
|
|
||||||
|
|
||||||
使用方法:
|
|
||||||
python scripts/reset_pg_sequences.py --host localhost --port 5432 --database maibot --user postgres --password your_password
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import psycopg
|
|
||||||
|
|
||||||
|
|
||||||
def reset_sequences(host: str, port: int, database: str, user: str, password: str):
|
|
||||||
"""重置所有序列值"""
|
|
||||||
conn_str = f"host={host} port={port} dbname={database} user={user} password={password}"
|
|
||||||
|
|
||||||
print(f"连接到 PostgreSQL: {host}:{port}/{database}")
|
|
||||||
conn = psycopg.connect(conn_str)
|
|
||||||
conn.autocommit = True
|
|
||||||
|
|
||||||
# 查询所有序列及其关联的表和列
|
|
||||||
query = """
|
|
||||||
SELECT
|
|
||||||
t.relname AS table_name,
|
|
||||||
a.attname AS column_name,
|
|
||||||
s.relname AS sequence_name
|
|
||||||
FROM pg_class s
|
|
||||||
JOIN pg_depend d ON d.objid = s.oid
|
|
||||||
JOIN pg_class t ON d.refobjid = t.oid
|
|
||||||
JOIN pg_attribute a ON (d.refobjid, d.refobjsubid) = (a.attrelid, a.attnum)
|
|
||||||
WHERE s.relkind = 'S'
|
|
||||||
"""
|
|
||||||
|
|
||||||
cursor = conn.execute(query)
|
|
||||||
sequences = cursor.fetchall()
|
|
||||||
|
|
||||||
print(f"发现 {len(sequences)} 个序列")
|
|
||||||
|
|
||||||
reset_count = 0
|
|
||||||
for table_name, col_name, seq_name in sequences:
|
|
||||||
try:
|
|
||||||
# 获取当前最大 ID
|
|
||||||
max_result = conn.execute(f'SELECT MAX("{col_name}") FROM "{table_name}"')
|
|
||||||
max_id = max_result.fetchone()[0]
|
|
||||||
|
|
||||||
if max_id is not None:
|
|
||||||
# 重置序列
|
|
||||||
conn.execute(f"SELECT setval('{seq_name}', {max_id}, true)")
|
|
||||||
print(f" ✓ {seq_name} -> {max_id}")
|
|
||||||
reset_count += 1
|
|
||||||
else:
|
|
||||||
print(f" - {seq_name}: 表为空,跳过")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ✗ {table_name}.{col_name}: {e}")
|
|
||||||
|
|
||||||
conn.close()
|
|
||||||
print(f"\n✅ 重置完成!共重置 {reset_count} 个序列")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="重置 PostgreSQL 序列值")
|
|
||||||
parser.add_argument("--host", default="localhost", help="PostgreSQL 主机")
|
|
||||||
parser.add_argument("--port", type=int, default=5432, help="PostgreSQL 端口")
|
|
||||||
parser.add_argument("--database", default="maibot", help="数据库名")
|
|
||||||
parser.add_argument("--user", default="postgres", help="用户名")
|
|
||||||
parser.add_argument("--password", required=True, help="密码")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
reset_sequences(args.host, args.port, args.database, args.user, args.password)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
Reference in New Issue
Block a user