923 lines
37 KiB
Python
923 lines
37 KiB
Python
import os
|
||
|
||
# import time
|
||
import pickle
|
||
import sys # 新增系统模块导入
|
||
from pathlib import Path
|
||
|
||
import orjson
|
||
|
||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
from typing import Any
|
||
|
||
from peewee import Field, IntegrityError, Model
|
||
from pymongo import MongoClient
|
||
from pymongo.errors import ConnectionFailure
|
||
|
||
# Rich 进度条和显示组件
|
||
from rich.console import Console
|
||
from rich.panel import Panel
|
||
from rich.progress import (
|
||
BarColumn,
|
||
Progress,
|
||
SpinnerColumn,
|
||
TaskProgressColumn,
|
||
TextColumn,
|
||
TimeElapsedColumn,
|
||
TimeRemainingColumn,
|
||
)
|
||
from rich.table import Table
|
||
|
||
# from rich.text import Text
|
||
from src.common.database.database import db
|
||
from src.common.database.sqlalchemy_models import (
|
||
ChatStreams,
|
||
Emoji,
|
||
GraphEdges,
|
||
GraphNodes,
|
||
ImageDescriptions,
|
||
Images,
|
||
Knowledges,
|
||
Messages,
|
||
PersonInfo,
|
||
ThinkingLog,
|
||
)
|
||
from src.common.logger import get_logger
|
||
|
||
logger = get_logger("mongodb_to_sqlite")
|
||
|
||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||
|
||
|
||
@dataclass
|
||
class MigrationConfig:
|
||
"""迁移配置类"""
|
||
|
||
mongo_collection: str
|
||
target_model: type[Model]
|
||
field_mapping: dict[str, str]
|
||
batch_size: int = 500
|
||
enable_validation: bool = True
|
||
skip_duplicates: bool = True
|
||
unique_fields: list[str] = field(default_factory=list) # 用于重复检查的字段
|
||
|
||
|
||
# 数据验证相关类已移除 - 用户要求不要数据验证
|
||
|
||
|
||
@dataclass
|
||
class MigrationCheckpoint:
|
||
"""迁移断点数据"""
|
||
|
||
collection_name: str
|
||
processed_count: int
|
||
last_processed_id: Any
|
||
timestamp: datetime
|
||
batch_errors: list[dict[str, Any]] = field(default_factory=list)
|
||
|
||
|
||
@dataclass
|
||
class MigrationStats:
|
||
"""迁移统计信息"""
|
||
|
||
total_documents: int = 0
|
||
processed_count: int = 0
|
||
success_count: int = 0
|
||
error_count: int = 0
|
||
skipped_count: int = 0
|
||
duplicate_count: int = 0
|
||
validation_errors: int = 0
|
||
batch_insert_count: int = 0
|
||
errors: list[dict[str, Any]] = field(default_factory=list)
|
||
start_time: datetime | None = None
|
||
end_time: datetime | None = None
|
||
|
||
def add_error(self, doc_id: Any, error: str, doc_data: dict | None = None):
|
||
"""添加错误记录"""
|
||
self.errors.append(
|
||
{"doc_id": str(doc_id), "error": error, "timestamp": datetime.now().isoformat(), "doc_data": doc_data}
|
||
)
|
||
self.error_count += 1
|
||
|
||
def add_validation_error(self, doc_id: Any, field: str, error: str):
|
||
"""添加验证错误"""
|
||
self.add_error(doc_id, f"验证失败 - {field}: {error}")
|
||
self.validation_errors += 1
|
||
|
||
|
||
class MongoToSQLiteMigrator:
|
||
"""MongoDB到SQLite数据迁移器 - 使用Peewee ORM"""
|
||
|
||
def __init__(self, mongo_uri: str | None = None, database_name: str | None = None):
|
||
self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot")
|
||
self.mongo_uri = mongo_uri or self._build_mongo_uri()
|
||
self.mongo_client: MongoClient | None = None
|
||
self.mongo_db = None
|
||
|
||
# 迁移配置
|
||
self.migration_configs = self._initialize_migration_configs()
|
||
|
||
# 进度条控制台
|
||
self.console = Console()
|
||
# 检查点目录
|
||
self.checkpoint_dir = Path(os.path.join(ROOT_PATH, "data", "checkpoints"))
|
||
self.checkpoint_dir.mkdir(exist_ok=True)
|
||
|
||
# 验证规则已禁用
|
||
self.validation_rules = self._initialize_validation_rules()
|
||
|
||
def _build_mongo_uri(self) -> str:
|
||
"""构建MongoDB连接URI"""
|
||
if mongo_uri := os.getenv("MONGODB_URI"):
|
||
return mongo_uri
|
||
|
||
user = os.getenv("MONGODB_USER")
|
||
password = os.getenv("MONGODB_PASS")
|
||
host = os.getenv("MONGODB_HOST", "localhost")
|
||
port = os.getenv("MONGODB_PORT", "27017")
|
||
auth_source = os.getenv("MONGODB_AUTH_SOURCE", "admin")
|
||
|
||
if user and password:
|
||
return f"mongodb://{user}:{password}@{host}:{port}/{self.database_name}?authSource={auth_source}"
|
||
else:
|
||
return f"mongodb://{host}:{port}/{self.database_name}"
|
||
|
||
def _initialize_migration_configs(self) -> list[MigrationConfig]:
|
||
"""初始化迁移配置"""
|
||
return [ # 表情包迁移配置
|
||
MigrationConfig(
|
||
mongo_collection="emoji",
|
||
target_model=Emoji,
|
||
field_mapping={
|
||
"full_path": "full_path",
|
||
"format": "format",
|
||
"hash": "emoji_hash",
|
||
"description": "description",
|
||
"emotion": "emotion",
|
||
"usage_count": "usage_count",
|
||
"last_used_time": "last_used_time",
|
||
# record_time字段将在转换时自动设置为当前时间
|
||
},
|
||
enable_validation=False, # 禁用数据验证
|
||
unique_fields=["full_path", "emoji_hash"],
|
||
),
|
||
# 聊天流迁移配置
|
||
MigrationConfig(
|
||
mongo_collection="chat_streams",
|
||
target_model=ChatStreams,
|
||
field_mapping={
|
||
"stream_id": "stream_id",
|
||
"create_time": "create_time",
|
||
"group_info.platform": "group_platform", # 由于Mongodb处理私聊时会让group_info值为null,而新的数据库不允许为null,所以私聊聊天流是没法迁移的,等更新吧。
|
||
"group_info.group_id": "group_id", # 同上
|
||
"group_info.group_name": "group_name", # 同上
|
||
"last_active_time": "last_active_time",
|
||
"platform": "platform",
|
||
"user_info.platform": "user_platform",
|
||
"user_info.user_id": "user_id",
|
||
"user_info.user_nickname": "user_nickname",
|
||
"user_info.user_cardname": "user_cardname",
|
||
},
|
||
enable_validation=False, # 禁用数据验证
|
||
unique_fields=["stream_id"],
|
||
),
|
||
# 消息迁移配置
|
||
MigrationConfig(
|
||
mongo_collection="messages",
|
||
target_model=Messages,
|
||
field_mapping={
|
||
"message_id": "message_id",
|
||
"time": "time",
|
||
"chat_id": "chat_id",
|
||
"chat_info.stream_id": "chat_info_stream_id",
|
||
"chat_info.platform": "chat_info_platform",
|
||
"chat_info.user_info.platform": "chat_info_user_platform",
|
||
"chat_info.user_info.user_id": "chat_info_user_id",
|
||
"chat_info.user_info.user_nickname": "chat_info_user_nickname",
|
||
"chat_info.user_info.user_cardname": "chat_info_user_cardname",
|
||
"chat_info.group_info.platform": "chat_info_group_platform",
|
||
"chat_info.group_info.group_id": "chat_info_group_id",
|
||
"chat_info.group_info.group_name": "chat_info_group_name",
|
||
"chat_info.create_time": "chat_info_create_time",
|
||
"chat_info.last_active_time": "chat_info_last_active_time",
|
||
"user_info.platform": "user_platform",
|
||
"user_info.user_id": "user_id",
|
||
"user_info.user_nickname": "user_nickname",
|
||
"user_info.user_cardname": "user_cardname",
|
||
"processed_plain_text": "processed_plain_text",
|
||
"memorized_times": "memorized_times",
|
||
},
|
||
enable_validation=False, # 禁用数据验证
|
||
unique_fields=["message_id"],
|
||
),
|
||
# 图片迁移配置
|
||
MigrationConfig(
|
||
mongo_collection="images",
|
||
target_model=Images,
|
||
field_mapping={
|
||
"hash": "emoji_hash",
|
||
"description": "description",
|
||
"path": "path",
|
||
"timestamp": "timestamp",
|
||
"type": "type",
|
||
},
|
||
unique_fields=["path"],
|
||
),
|
||
# 图片描述迁移配置
|
||
MigrationConfig(
|
||
mongo_collection="image_descriptions",
|
||
target_model=ImageDescriptions,
|
||
field_mapping={
|
||
"type": "type",
|
||
"hash": "image_description_hash",
|
||
"description": "description",
|
||
"timestamp": "timestamp",
|
||
},
|
||
unique_fields=["image_description_hash", "type"],
|
||
),
|
||
# 个人信息迁移配置
|
||
MigrationConfig(
|
||
mongo_collection="person_info",
|
||
target_model=PersonInfo,
|
||
field_mapping={
|
||
"person_id": "person_id",
|
||
"person_name": "person_name",
|
||
"name_reason": "name_reason",
|
||
"platform": "platform",
|
||
"user_id": "user_id",
|
||
"nickname": "nickname",
|
||
"relationship_value": "relationship_value",
|
||
"konw_time": "know_time",
|
||
},
|
||
unique_fields=["person_id"],
|
||
),
|
||
# 知识库迁移配置
|
||
MigrationConfig(
|
||
mongo_collection="knowledges",
|
||
target_model=Knowledges,
|
||
field_mapping={"content": "content", "embedding": "embedding"},
|
||
unique_fields=["content"], # 假设内容唯一
|
||
),
|
||
# 思考日志迁移配置
|
||
MigrationConfig(
|
||
mongo_collection="thinking_log",
|
||
target_model=ThinkingLog,
|
||
field_mapping={
|
||
"chat_id": "chat_id",
|
||
"trigger_text": "trigger_text",
|
||
"response_text": "response_text",
|
||
"trigger_info": "trigger_info_json",
|
||
"response_info": "response_info_json",
|
||
"timing_results": "timing_results_json",
|
||
"chat_history": "chat_history_json",
|
||
"chat_history_in_thinking": "chat_history_in_thinking_json",
|
||
"chat_history_after_response": "chat_history_after_response_json",
|
||
"heartflow_data": "heartflow_data_json",
|
||
"reasoning_data": "reasoning_data_json",
|
||
},
|
||
unique_fields=["chat_id", "trigger_text"],
|
||
),
|
||
# 图节点迁移配置
|
||
MigrationConfig(
|
||
mongo_collection="graph_data.nodes",
|
||
target_model=GraphNodes,
|
||
field_mapping={
|
||
"concept": "concept",
|
||
"memory_items": "memory_items",
|
||
"hash": "hash",
|
||
"created_time": "created_time",
|
||
"last_modified": "last_modified",
|
||
},
|
||
unique_fields=["concept"],
|
||
),
|
||
# 图边迁移配置
|
||
MigrationConfig(
|
||
mongo_collection="graph_data.edges",
|
||
target_model=GraphEdges,
|
||
field_mapping={
|
||
"source": "source",
|
||
"target": "target",
|
||
"strength": "strength",
|
||
"hash": "hash",
|
||
"created_time": "created_time",
|
||
"last_modified": "last_modified",
|
||
},
|
||
unique_fields=["source", "target"], # 组合唯一性
|
||
),
|
||
]
|
||
|
||
def _initialize_validation_rules(self) -> dict[str, Any]:
|
||
"""数据验证已禁用 - 返回空字典"""
|
||
return {}
|
||
|
||
def connect_mongodb(self) -> bool:
|
||
"""连接到MongoDB"""
|
||
try:
|
||
self.mongo_client = MongoClient(
|
||
self.mongo_uri, serverSelectionTimeoutMS=5000, connectTimeoutMS=10000, maxPoolSize=10
|
||
)
|
||
|
||
# 测试连接
|
||
self.mongo_client.admin.command("ping")
|
||
self.mongo_db = self.mongo_client[self.database_name]
|
||
|
||
logger.info(f"成功连接到MongoDB: {self.database_name}")
|
||
return True
|
||
|
||
except ConnectionFailure as e:
|
||
logger.error(f"MongoDB连接失败: {e}")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"MongoDB连接异常: {e}")
|
||
return False
|
||
|
||
def disconnect_mongodb(self):
|
||
"""断开MongoDB连接"""
|
||
if self.mongo_client:
|
||
self.mongo_client.close()
|
||
logger.info("MongoDB连接已关闭")
|
||
|
||
def _get_nested_value(self, document: dict[str, Any], field_path: str) -> Any:
|
||
"""获取嵌套字段的值"""
|
||
if "." not in field_path:
|
||
return document.get(field_path)
|
||
|
||
parts = field_path.split(".")
|
||
value = document
|
||
|
||
for part in parts:
|
||
if isinstance(value, dict):
|
||
value = value.get(part)
|
||
else:
|
||
return None
|
||
|
||
if value is None:
|
||
break
|
||
|
||
return value
|
||
|
||
def _convert_field_value(self, value: Any, target_field: Field) -> Any:
|
||
"""根据目标字段类型转换值"""
|
||
if value is None:
|
||
return None
|
||
|
||
field_type = target_field.__class__.__name__
|
||
|
||
try:
|
||
if target_field.name == "record_time" and field_type == "DateTimeField":
|
||
return datetime.now()
|
||
|
||
if field_type in ["CharField", "TextField"]:
|
||
if isinstance(value, (list, dict)):
|
||
return orjson.dumps(value, ensure_ascii=False)
|
||
return str(value) if value is not None else ""
|
||
|
||
elif field_type == "IntegerField":
|
||
if isinstance(value, str):
|
||
# 处理字符串数字
|
||
clean_value = value.strip()
|
||
if clean_value.replace(".", "").replace("-", "").isdigit():
|
||
return int(float(clean_value))
|
||
return 0
|
||
return int(value) if value is not None else 0
|
||
|
||
elif field_type in ["FloatField", "DoubleField"]:
|
||
return float(value) if value is not None else 0.0
|
||
|
||
elif field_type == "BooleanField":
|
||
if isinstance(value, str):
|
||
return value.lower() in ("true", "1", "yes", "on")
|
||
return bool(value)
|
||
|
||
elif field_type == "DateTimeField":
|
||
if isinstance(value, (int, float)):
|
||
return datetime.fromtimestamp(value)
|
||
elif isinstance(value, str):
|
||
try:
|
||
# 尝试解析ISO格式日期
|
||
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||
except ValueError:
|
||
try:
|
||
# 尝试解析时间戳字符串
|
||
return datetime.fromtimestamp(float(value))
|
||
except ValueError:
|
||
return datetime.now()
|
||
return datetime.now()
|
||
|
||
return value
|
||
|
||
except (ValueError, TypeError) as e:
|
||
logger.warning(f"字段值转换失败 ({field_type}): {value} -> {e}")
|
||
return self._get_default_value_for_field(target_field)
|
||
|
||
def _get_default_value_for_field(self, field: Field) -> Any:
|
||
"""获取字段的默认值"""
|
||
field_type = field.__class__.__name__
|
||
|
||
if hasattr(field, "default") and field.default is not None:
|
||
return field.default
|
||
|
||
if field.null:
|
||
return None
|
||
|
||
# 根据字段类型返回默认值
|
||
if field_type in ["CharField", "TextField"]:
|
||
return ""
|
||
elif field_type == "IntegerField":
|
||
return 0
|
||
elif field_type in ["FloatField", "DoubleField"]:
|
||
return 0.0
|
||
elif field_type == "BooleanField":
|
||
return False
|
||
elif field_type == "DateTimeField":
|
||
return datetime.now()
|
||
|
||
return None
|
||
|
||
def _validate_data(self, collection_name: str, data: dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
|
||
"""数据验证已禁用 - 始终返回True"""
|
||
return True
|
||
|
||
def _save_checkpoint(self, collection_name: str, processed_count: int, last_id: Any):
|
||
"""保存迁移断点"""
|
||
checkpoint = MigrationCheckpoint(
|
||
collection_name=collection_name,
|
||
processed_count=processed_count,
|
||
last_processed_id=last_id,
|
||
timestamp=datetime.now(),
|
||
)
|
||
|
||
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
|
||
try:
|
||
with open(checkpoint_file, "wb") as f:
|
||
pickle.dump(checkpoint, f)
|
||
except Exception as e:
|
||
logger.warning(f"保存断点失败: {e}")
|
||
|
||
def _load_checkpoint(self, collection_name: str) -> MigrationCheckpoint | None:
|
||
"""加载迁移断点"""
|
||
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
|
||
if not checkpoint_file.exists():
|
||
return None
|
||
|
||
try:
|
||
with open(checkpoint_file, "rb") as f:
|
||
return pickle.load(f)
|
||
except Exception as e:
|
||
logger.warning(f"加载断点失败: {e}")
|
||
return None
|
||
|
||
def _batch_insert(self, model: type[Model], data_list: list[dict[str, Any]]) -> int:
|
||
"""批量插入数据"""
|
||
if not data_list:
|
||
return 0
|
||
|
||
success_count = 0
|
||
try:
|
||
with db.atomic():
|
||
# 分批插入,避免SQL语句过长
|
||
batch_size = 100
|
||
for i in range(0, len(data_list), batch_size):
|
||
batch = data_list[i : i + batch_size]
|
||
model.insert_many(batch).execute()
|
||
success_count += len(batch)
|
||
except Exception as e:
|
||
logger.error(f"批量插入失败: {e}")
|
||
# 如果批量插入失败,尝试逐个插入
|
||
for data in data_list:
|
||
try:
|
||
model.create(**data)
|
||
success_count += 1
|
||
except Exception:
|
||
pass # 忽略单个插入失败
|
||
|
||
return success_count
|
||
|
||
def _check_duplicate_by_unique_fields(
|
||
self, model: type[Model], data: dict[str, Any], unique_fields: list[str]
|
||
) -> bool:
|
||
"""根据唯一字段检查重复"""
|
||
if not unique_fields:
|
||
return False
|
||
|
||
try:
|
||
query = model.select()
|
||
for field_name in unique_fields:
|
||
if field_name in data and data[field_name] is not None:
|
||
field_obj = getattr(model, field_name)
|
||
query = query.where(field_obj == data[field_name])
|
||
|
||
return query.exists()
|
||
except Exception as e:
|
||
logger.debug(f"重复检查失败: {e}")
|
||
return False
|
||
|
||
def _create_model_instance(self, model: type[Model], data: dict[str, Any]) -> Model | None:
|
||
"""使用ORM创建模型实例"""
|
||
try:
|
||
# 过滤掉不存在的字段
|
||
valid_data = {}
|
||
for field_name, value in data.items():
|
||
if hasattr(model, field_name):
|
||
valid_data[field_name] = value
|
||
else:
|
||
logger.debug(f"跳过未知字段: {field_name}")
|
||
|
||
# 创建实例
|
||
instance = model.create(**valid_data)
|
||
return instance
|
||
|
||
except IntegrityError as e:
|
||
# 处理唯一约束冲突等完整性错误
|
||
logger.debug(f"完整性约束冲突: {e}")
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"创建模型实例失败: {e}")
|
||
return None
|
||
|
||
def migrate_collection(self, config: MigrationConfig) -> MigrationStats:
|
||
"""迁移单个集合 - 使用优化的批量插入和进度条"""
|
||
stats = MigrationStats()
|
||
stats.start_time = datetime.now()
|
||
|
||
# 检查是否有断点
|
||
checkpoint = self._load_checkpoint(config.mongo_collection)
|
||
start_from_id = checkpoint.last_processed_id if checkpoint else None
|
||
if checkpoint:
|
||
stats.processed_count = checkpoint.processed_count
|
||
logger.info(f"从断点恢复: 已处理 {checkpoint.processed_count} 条记录")
|
||
|
||
logger.info(f"开始迁移: {config.mongo_collection} -> {config.target_model._meta.table_name}")
|
||
|
||
try:
|
||
# 获取MongoDB集合
|
||
mongo_collection = self.mongo_db[config.mongo_collection]
|
||
|
||
# 构建查询条件(用于断点恢复)
|
||
query = {}
|
||
if start_from_id:
|
||
query = {"_id": {"$gt": start_from_id}}
|
||
|
||
stats.total_documents = mongo_collection.count_documents(query)
|
||
|
||
if stats.total_documents == 0:
|
||
logger.warning(f"集合 {config.mongo_collection} 为空,跳过迁移")
|
||
return stats
|
||
|
||
logger.info(f"待迁移文档数量: {stats.total_documents}")
|
||
|
||
# 创建Rich进度条
|
||
with Progress(
|
||
SpinnerColumn(),
|
||
TextColumn("[progress.description]{task.description}"),
|
||
BarColumn(),
|
||
TaskProgressColumn(),
|
||
TimeElapsedColumn(),
|
||
TimeRemainingColumn(),
|
||
console=self.console,
|
||
refresh_per_second=10,
|
||
) as progress:
|
||
task = progress.add_task(f"迁移 {config.mongo_collection}", total=stats.total_documents)
|
||
# 批量处理数据
|
||
batch_data = []
|
||
batch_count = 0
|
||
last_processed_id = None
|
||
|
||
for mongo_doc in mongo_collection.find(query).batch_size(config.batch_size):
|
||
try:
|
||
doc_id = mongo_doc.get("_id", "unknown")
|
||
last_processed_id = doc_id
|
||
|
||
# 构建目标数据
|
||
target_data = {}
|
||
for mongo_field, sqlite_field in config.field_mapping.items():
|
||
value = self._get_nested_value(mongo_doc, mongo_field)
|
||
|
||
# 获取目标字段对象并转换类型
|
||
if hasattr(config.target_model, sqlite_field):
|
||
field_obj = getattr(config.target_model, sqlite_field)
|
||
converted_value = self._convert_field_value(value, field_obj)
|
||
target_data[sqlite_field] = converted_value
|
||
|
||
# 数据验证已禁用
|
||
# if config.enable_validation:
|
||
# if not self._validate_data(config.mongo_collection, target_data, doc_id, stats):
|
||
# stats.skipped_count += 1
|
||
# continue
|
||
|
||
# 重复检查
|
||
if config.skip_duplicates and self._check_duplicate_by_unique_fields(
|
||
config.target_model, target_data, config.unique_fields
|
||
):
|
||
stats.duplicate_count += 1
|
||
stats.skipped_count += 1
|
||
logger.debug(f"跳过重复记录: {doc_id}")
|
||
continue
|
||
|
||
# 添加到批量数据
|
||
batch_data.append(target_data)
|
||
stats.processed_count += 1
|
||
|
||
# 执行批量插入
|
||
if len(batch_data) >= config.batch_size:
|
||
success_count = self._batch_insert(config.target_model, batch_data)
|
||
stats.success_count += success_count
|
||
stats.batch_insert_count += 1
|
||
|
||
# 保存断点
|
||
self._save_checkpoint(config.mongo_collection, stats.processed_count, last_processed_id)
|
||
|
||
batch_data.clear()
|
||
batch_count += 1
|
||
|
||
# 更新进度条
|
||
progress.update(task, advance=config.batch_size)
|
||
|
||
except Exception as e:
|
||
doc_id = mongo_doc.get("_id", "unknown")
|
||
stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc)
|
||
logger.error(f"处理文档失败 (ID: {doc_id}): {e}")
|
||
|
||
# 处理剩余的批量数据
|
||
if batch_data:
|
||
success_count = self._batch_insert(config.target_model, batch_data)
|
||
stats.success_count += success_count
|
||
stats.batch_insert_count += 1
|
||
progress.update(task, advance=len(batch_data))
|
||
|
||
# 完成进度条
|
||
progress.update(task, completed=stats.total_documents)
|
||
|
||
stats.end_time = datetime.now()
|
||
duration = stats.end_time - stats.start_time
|
||
|
||
logger.info(
|
||
f"迁移完成: {config.mongo_collection} -> {config.target_model._meta.table_name}\n"
|
||
f"总计: {stats.total_documents}, 成功: {stats.success_count}, "
|
||
f"错误: {stats.error_count}, 跳过: {stats.skipped_count}, 重复: {stats.duplicate_count}\n"
|
||
f"耗时: {duration.total_seconds():.2f}秒, 批量插入次数: {stats.batch_insert_count}"
|
||
)
|
||
|
||
# 清理断点文件
|
||
checkpoint_file = self.checkpoint_dir / f"{config.mongo_collection}_checkpoint.pkl"
|
||
if checkpoint_file.exists():
|
||
checkpoint_file.unlink()
|
||
|
||
except Exception as e:
|
||
logger.error(f"迁移集合 {config.mongo_collection} 时发生异常: {e}")
|
||
stats.add_error("collection_error", str(e))
|
||
|
||
return stats
|
||
|
||
def migrate_all(self) -> dict[str, MigrationStats]:
|
||
"""执行所有迁移任务"""
|
||
logger.info("开始执行数据库迁移...")
|
||
|
||
if not self.connect_mongodb():
|
||
logger.error("无法连接到MongoDB,迁移终止")
|
||
return {}
|
||
|
||
all_stats = {}
|
||
|
||
try:
|
||
# 创建总体进度表格
|
||
total_collections = len(self.migration_configs)
|
||
self.console.print(
|
||
Panel(
|
||
f"[bold blue]MongoDB 到 SQLite 数据迁移[/bold blue]\n"
|
||
f"[yellow]总集合数: {total_collections}[/yellow]",
|
||
title="迁移开始",
|
||
expand=False,
|
||
)
|
||
)
|
||
for idx, config in enumerate(self.migration_configs, 1):
|
||
self.console.print(
|
||
f"\n[bold green]正在处理集合 {idx}/{total_collections}: {config.mongo_collection}[/bold green]"
|
||
)
|
||
stats = self.migrate_collection(config)
|
||
all_stats[config.mongo_collection] = stats
|
||
|
||
# 显示单个集合的快速统计
|
||
if stats.processed_count > 0:
|
||
success_rate = stats.success_count / stats.processed_count * 100
|
||
if success_rate >= 95:
|
||
status_emoji = "✅"
|
||
status_color = "bright_green"
|
||
elif success_rate >= 80:
|
||
status_emoji = "⚠️"
|
||
status_color = "yellow"
|
||
else:
|
||
status_emoji = "❌"
|
||
status_color = "red"
|
||
|
||
self.console.print(
|
||
f" {status_emoji} [{status_color}]完成: {stats.success_count}/{stats.processed_count} "
|
||
f"({success_rate:.1f}%) 错误: {stats.error_count}[/{status_color}]"
|
||
)
|
||
|
||
# 错误率检查
|
||
if stats.processed_count > 0:
|
||
error_rate = stats.error_count / stats.processed_count
|
||
if error_rate > 0.1: # 错误率超过10%
|
||
self.console.print(
|
||
f" [red]⚠️ 警告: 错误率较高 {error_rate:.1%} "
|
||
f"({stats.error_count}/{stats.processed_count})[/red]"
|
||
)
|
||
|
||
finally:
|
||
self.disconnect_mongodb()
|
||
|
||
self._print_migration_summary(all_stats)
|
||
return all_stats
|
||
|
||
def _print_migration_summary(self, all_stats: dict[str, MigrationStats]):
|
||
"""使用Rich打印美观的迁移汇总信息"""
|
||
# 计算总体统计
|
||
total_processed = sum(stats.processed_count for stats in all_stats.values())
|
||
total_success = sum(stats.success_count for stats in all_stats.values())
|
||
total_errors = sum(stats.error_count for stats in all_stats.values())
|
||
total_skipped = sum(stats.skipped_count for stats in all_stats.values())
|
||
total_duplicates = sum(stats.duplicate_count for stats in all_stats.values())
|
||
total_validation_errors = sum(stats.validation_errors for stats in all_stats.values())
|
||
total_batch_inserts = sum(stats.batch_insert_count for stats in all_stats.values())
|
||
|
||
# 计算总耗时
|
||
total_duration_seconds = 0
|
||
for stats in all_stats.values():
|
||
if stats.start_time and stats.end_time:
|
||
duration = stats.end_time - stats.start_time
|
||
total_duration_seconds += duration.total_seconds()
|
||
|
||
# 创建详细统计表格
|
||
table = Table(title="[bold blue]数据迁移汇总报告[/bold blue]", show_header=True, header_style="bold magenta")
|
||
table.add_column("集合名称", style="cyan", width=20)
|
||
table.add_column("文档总数", justify="right", style="blue")
|
||
table.add_column("处理数量", justify="right", style="green")
|
||
table.add_column("成功数量", justify="right", style="green")
|
||
table.add_column("错误数量", justify="right", style="red")
|
||
table.add_column("跳过数量", justify="right", style="yellow")
|
||
table.add_column("重复数量", justify="right", style="bright_yellow")
|
||
table.add_column("验证错误", justify="right", style="red")
|
||
table.add_column("批次数", justify="right", style="purple")
|
||
table.add_column("成功率", justify="right", style="bright_green")
|
||
table.add_column("耗时(秒)", justify="right", style="blue")
|
||
|
||
for collection_name, stats in all_stats.items():
|
||
success_rate = (stats.success_count / stats.processed_count * 100) if stats.processed_count > 0 else 0
|
||
duration = 0
|
||
if stats.start_time and stats.end_time:
|
||
duration = (stats.end_time - stats.start_time).total_seconds()
|
||
|
||
# 根据成功率设置颜色
|
||
if success_rate >= 95:
|
||
success_rate_style = "[bright_green]"
|
||
elif success_rate >= 80:
|
||
success_rate_style = "[yellow]"
|
||
else:
|
||
success_rate_style = "[red]"
|
||
|
||
table.add_row(
|
||
collection_name,
|
||
str(stats.total_documents),
|
||
str(stats.processed_count),
|
||
str(stats.success_count),
|
||
f"[red]{stats.error_count}[/red]" if stats.error_count > 0 else "0",
|
||
f"[yellow]{stats.skipped_count}[/yellow]" if stats.skipped_count > 0 else "0",
|
||
f"[bright_yellow]{stats.duplicate_count}[/bright_yellow]" if stats.duplicate_count > 0 else "0",
|
||
f"[red]{stats.validation_errors}[/red]" if stats.validation_errors > 0 else "0",
|
||
str(stats.batch_insert_count),
|
||
f"{success_rate_style}{success_rate:.1f}%[/{success_rate_style[1:]}",
|
||
f"{duration:.2f}",
|
||
)
|
||
|
||
# 添加总计行
|
||
total_success_rate = (total_success / total_processed * 100) if total_processed > 0 else 0
|
||
if total_success_rate >= 95:
|
||
total_rate_style = "[bright_green]"
|
||
elif total_success_rate >= 80:
|
||
total_rate_style = "[yellow]"
|
||
else:
|
||
total_rate_style = "[red]"
|
||
|
||
table.add_section()
|
||
table.add_row(
|
||
"[bold]总计[/bold]",
|
||
f"[bold]{sum(stats.total_documents for stats in all_stats.values())}[/bold]",
|
||
f"[bold]{total_processed}[/bold]",
|
||
f"[bold]{total_success}[/bold]",
|
||
f"[bold red]{total_errors}[/bold red]" if total_errors > 0 else "[bold]0[/bold]",
|
||
f"[bold yellow]{total_skipped}[/bold yellow]" if total_skipped > 0 else "[bold]0[/bold]",
|
||
f"[bold bright_yellow]{total_duplicates}[/bold bright_yellow]"
|
||
if total_duplicates > 0
|
||
else "[bold]0[/bold]",
|
||
f"[bold red]{total_validation_errors}[/bold red]" if total_validation_errors > 0 else "[bold]0[/bold]",
|
||
f"[bold]{total_batch_inserts}[/bold]",
|
||
f"[bold]{total_rate_style}{total_success_rate:.1f}%[/{total_rate_style[1:]}[/bold]",
|
||
f"[bold]{total_duration_seconds:.2f}[/bold]",
|
||
)
|
||
|
||
self.console.print(table)
|
||
|
||
# 创建状态面板
|
||
status_items = []
|
||
if total_errors > 0:
|
||
status_items.append(f"[red]⚠️ 发现 {total_errors} 个错误,请检查日志详情[/red]")
|
||
|
||
if total_validation_errors > 0:
|
||
status_items.append(f"[red]🔍 数据验证失败: {total_validation_errors} 条记录[/red]")
|
||
|
||
if total_duplicates > 0:
|
||
status_items.append(f"[yellow]📋 跳过重复记录: {total_duplicates} 条[/yellow]")
|
||
|
||
if total_success_rate >= 95:
|
||
status_items.append(f"[bright_green]✅ 迁移成功率优秀: {total_success_rate:.1f}%[/bright_green]")
|
||
elif total_success_rate >= 80:
|
||
status_items.append(f"[yellow]⚡ 迁移成功率良好: {total_success_rate:.1f}%[/yellow]")
|
||
else:
|
||
status_items.append(f"[red]❌ 迁移成功率较低: {total_success_rate:.1f}%,需要检查[/red]")
|
||
|
||
if status_items:
|
||
status_panel = Panel(
|
||
"\n".join(status_items), title="[bold yellow]迁移状态总结[/bold yellow]", border_style="yellow"
|
||
)
|
||
self.console.print(status_panel)
|
||
|
||
# 性能统计面板
|
||
avg_speed = total_processed / total_duration_seconds if total_duration_seconds > 0 else 0
|
||
performance_info = (
|
||
f"[cyan]总处理时间:[/cyan] {total_duration_seconds:.2f} 秒\n"
|
||
f"[cyan]平均处理速度:[/cyan] {avg_speed:.1f} 条记录/秒\n"
|
||
f"[cyan]批量插入优化:[/cyan] 执行了 {total_batch_inserts} 次批量操作"
|
||
)
|
||
|
||
performance_panel = Panel(performance_info, title="[bold green]性能统计[/bold green]", border_style="green")
|
||
self.console.print(performance_panel)
|
||
|
||
def add_migration_config(self, config: MigrationConfig):
|
||
"""添加新的迁移配置"""
|
||
self.migration_configs.append(config)
|
||
|
||
def migrate_single_collection(self, collection_name: str) -> MigrationStats | None:
|
||
"""迁移单个指定的集合"""
|
||
config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None)
|
||
if not config:
|
||
logger.error(f"未找到集合 {collection_name} 的迁移配置")
|
||
return None
|
||
|
||
if not self.connect_mongodb():
|
||
logger.error("无法连接到MongoDB")
|
||
return None
|
||
|
||
try:
|
||
stats = self.migrate_collection(config)
|
||
self._print_migration_summary({collection_name: stats})
|
||
return stats
|
||
finally:
|
||
self.disconnect_mongodb()
|
||
|
||
def export_error_report(self, all_stats: dict[str, MigrationStats], filepath: str):
|
||
"""导出错误报告"""
|
||
error_report = {
|
||
"timestamp": datetime.now().isoformat(),
|
||
"summary": {
|
||
collection: {
|
||
"total": stats.total_documents,
|
||
"processed": stats.processed_count,
|
||
"success": stats.success_count,
|
||
"errors": stats.error_count,
|
||
"skipped": stats.skipped_count,
|
||
"duplicates": stats.duplicate_count,
|
||
}
|
||
for collection, stats in all_stats.items()
|
||
},
|
||
"errors": {collection: stats.errors for collection, stats in all_stats.items() if stats.errors},
|
||
}
|
||
|
||
try:
|
||
with open(filepath, "w", encoding="utf-8") as f:
|
||
orjson.dumps(error_report, f, ensure_ascii=False, indent=2)
|
||
logger.info(f"错误报告已导出到: {filepath}")
|
||
except Exception as e:
|
||
logger.error(f"导出错误报告失败: {e}")
|
||
|
||
|
||
def main():
|
||
"""主程序入口"""
|
||
migrator = MongoToSQLiteMigrator()
|
||
|
||
# 执行迁移
|
||
migration_results = migrator.migrate_all()
|
||
|
||
# 导出错误报告(如果有错误)
|
||
if any(stats.error_count > 0 for stats in migration_results.values()):
|
||
error_report_path = f"migration_errors_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||
migrator.export_error_report(migration_results, error_report_path)
|
||
|
||
logger.info("数据迁移完成!")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|