diff --git a/scripts/mongodb_to_sqlite.py b/scripts/mongodb_to_sqlite.py index 5ff346af7..609906fa4 100644 --- a/scripts/mongodb_to_sqlite.py +++ b/scripts/mongodb_to_sqlite.py @@ -1,6 +1,9 @@ import os import json import sys # 新增系统模块导入 +# import time +import pickle +from pathlib import Path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from typing import Dict, Any, List, Optional, Type @@ -10,16 +13,31 @@ from pymongo import MongoClient from pymongo.errors import ConnectionFailure from peewee import Model, Field, IntegrityError +# Rich 进度条和显示组件 +from rich.console import Console +from rich.progress import ( + Progress, + TextColumn, + BarColumn, + TaskProgressColumn, + TimeRemainingColumn, + TimeElapsedColumn, + SpinnerColumn +) +from rich.table import Table +from rich.panel import Panel +# from rich.text import Text + from src.common.database.database import db from src.common.database.database_model import ( ChatStreams, LLMUsage, Emoji, Messages, Images, ImageDescriptions, - OnlineTime, PersonInfo, Knowledges, ThinkingLog, GraphNodes, GraphEdges + PersonInfo, Knowledges, ThinkingLog, GraphNodes, GraphEdges ) from src.common.logger_manager import get_logger logger = get_logger("mongodb_to_sqlite") - +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @dataclass class MigrationConfig: """迁移配置类""" @@ -32,6 +50,19 @@ class MigrationConfig: unique_fields: List[str] = field(default_factory=list) # 用于重复检查的字段 +# 数据验证相关类已移除 - 用户要求不要数据验证 + + +@dataclass +class MigrationCheckpoint: + """迁移断点数据""" + collection_name: str + processed_count: int + last_processed_id: Any + timestamp: datetime + batch_errors: List[Dict[str, Any]] = field(default_factory=list) + + @dataclass class MigrationStats: """迁移统计信息""" @@ -41,7 +72,11 @@ class MigrationStats: error_count: int = 0 skipped_count: int = 0 duplicate_count: int = 0 + validation_errors: int = 0 + batch_insert_count: int = 0 errors: List[Dict[str, Any]] = field(default_factory=list) + start_time: Optional[datetime] = None + end_time: Optional[datetime] = None def add_error(self, doc_id: Any, error: str, doc_data: Optional[Dict] = None): """添加错误记录""" @@ -52,6 +87,11 @@ class MigrationStats: 'doc_data': doc_data }) self.error_count += 1 + + def add_validation_error(self, doc_id: Any, field: str, error: str): + """添加验证错误""" + self.add_error(doc_id, f"验证失败 - {field}: {error}") + self.validation_errors += 1 class MongoToSQLiteMigrator: @@ -65,6 +105,15 @@ class MongoToSQLiteMigrator: # 迁移配置 self.migration_configs = self._initialize_migration_configs() + + # 进度条控制台 + self.console = Console() + # 检查点目录 + self.checkpoint_dir = Path(os.path.join(ROOT_PATH, "data", "checkpoints")) + self.checkpoint_dir.mkdir(exist_ok=True) + + # 验证规则已禁用 + self.validation_rules = self._initialize_validation_rules() def _build_mongo_uri(self) -> str: """构建MongoDB连接URI""" @@ -84,8 +133,7 @@ class MongoToSQLiteMigrator: def _initialize_migration_configs(self) -> List[MigrationConfig]: """初始化迁移配置""" - return [ - # 表情包迁移配置 + return [ # 表情包迁移配置 MigrationConfig( mongo_collection="emoji", target_model=Emoji, @@ -96,13 +144,13 @@ class MongoToSQLiteMigrator: "description": "description", "emotion": "emotion", "usage_count": "usage_count", - "last_used_time": "last_used_time", - "last_used_time": "record_time" # 这个纯粹是为了应付整体映射格式,实际上直接用当前时间戳填了record_time + "last_used_time": "last_used_time" + # record_time字段将在转换时自动设置为当前时间 }, + enable_validation=False, # 禁用数据验证 unique_fields=["full_path", "emoji_hash"] ), - - # 聊天流迁移配置 + # 聊天流迁移配置 MigrationConfig( mongo_collection="chat_streams", target_model=ChatStreams, @@ -119,10 +167,10 @@ class MongoToSQLiteMigrator: "user_info.user_nickname": "user_nickname", "user_info.user_cardname": "user_cardname" }, + enable_validation=False, # 禁用数据验证 unique_fields=["stream_id"] ), - - # LLM使用记录迁移配置 + # LLM使用记录迁移配置 MigrationConfig( mongo_collection="llm_usage", target_model=LLMUsage, @@ -138,10 +186,10 @@ class MongoToSQLiteMigrator: "status": "status", "timestamp": "timestamp" }, + enable_validation=False, # 禁用数据验证 unique_fields=["user_id", "timestamp"] # 组合唯一性 ), - - # 消息迁移配置 + # 消息迁移配置 MigrationConfig( mongo_collection="messages", target_model=Messages, @@ -168,6 +216,7 @@ class MongoToSQLiteMigrator: "detailed_plain_text": "detailed_plain_text", "memorized_times": "memorized_times" }, + enable_validation=False, # 禁用数据验证 unique_fields=["message_id"] ), @@ -194,21 +243,7 @@ class MongoToSQLiteMigrator: "hash": "image_description_hash", "description": "description", "timestamp": "timestamp" - }, - unique_fields=["image_description_hash", "type"] - ), - - # 在线时长迁移配置 - MigrationConfig( - mongo_collection="online_time", - target_model=OnlineTime, - field_mapping={ - "timestamp": "timestamp", - "duration": "duration", - "start_timestamp": "start_timestamp", - "end_timestamp": "end_timestamp" - }, - unique_fields=["start_timestamp", "end_timestamp"] + }, unique_fields=["image_description_hash", "type"] ), # 个人信息迁移配置 @@ -290,6 +325,9 @@ class MongoToSQLiteMigrator: unique_fields=["source", "target"] # 组合唯一性 ) ] + def _initialize_validation_rules(self) -> Dict[str, Any]: + """数据验证已禁用 - 返回空字典""" + return {} def connect_mongodb(self) -> bool: """连接到MongoDB""" @@ -412,11 +450,69 @@ class MongoToSQLiteMigrator: elif field_type in ["FloatField", "DoubleField"]: return 0.0 elif field_type == "BooleanField": - return False + return False elif field_type == "DateTimeField": return datetime.now() return None + def _validate_data(self, collection_name: str, data: Dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool: + """数据验证已禁用 - 始终返回True""" + return True + + def _save_checkpoint(self, collection_name: str, processed_count: int, last_id: Any): + """保存迁移断点""" + checkpoint = MigrationCheckpoint( + collection_name=collection_name, + processed_count=processed_count, + last_processed_id=last_id, + timestamp=datetime.now() + ) + + checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl" + try: + with open(checkpoint_file, 'wb') as f: + pickle.dump(checkpoint, f) + except Exception as e: + logger.warning(f"保存断点失败: {e}") + + def _load_checkpoint(self, collection_name: str) -> Optional[MigrationCheckpoint]: + """加载迁移断点""" + checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl" + if not checkpoint_file.exists(): + return None + + try: + with open(checkpoint_file, 'rb') as f: + return pickle.load(f) + except Exception as e: + logger.warning(f"加载断点失败: {e}") + return None + + def _batch_insert(self, model: Type[Model], data_list: List[Dict[str, Any]]) -> int: + """批量插入数据""" + if not data_list: + return 0 + + success_count = 0 + try: + with db.atomic(): + # 分批插入,避免SQL语句过长 + batch_size = 100 + for i in range(0, len(data_list), batch_size): + batch = data_list[i:i + batch_size] + model.insert_many(batch).execute() + success_count += len(batch) + except Exception as e: + logger.error(f"批量插入失败: {e}") + # 如果批量插入失败,尝试逐个插入 + for data in data_list: + try: + model.create(**data) + success_count += 1 + except Exception: + pass # 忽略单个插入失败 + + return success_count def _check_duplicate_by_unique_fields(self, model: Type[Model], data: Dict[str, Any], unique_fields: List[str]) -> bool: @@ -458,17 +554,30 @@ class MongoToSQLiteMigrator: except Exception as e: logger.error(f"创建模型实例失败: {e}") return None - def migrate_collection(self, config: MigrationConfig) -> MigrationStats: - """迁移单个集合 - 使用ORM方式""" + """迁移单个集合 - 使用优化的批量插入和进度条""" stats = MigrationStats() + stats.start_time = datetime.now() + + # 检查是否有断点 + checkpoint = self._load_checkpoint(config.mongo_collection) + start_from_id = checkpoint.last_processed_id if checkpoint else None + if checkpoint: + stats.processed_count = checkpoint.processed_count + logger.info(f"从断点恢复: 已处理 {checkpoint.processed_count} 条记录") logger.info(f"开始迁移: {config.mongo_collection} -> {config.target_model._meta.table_name}") try: # 获取MongoDB集合 mongo_collection = self.mongo_db[config.mongo_collection] - stats.total_documents = mongo_collection.count_documents({}) + + # 构建查询条件(用于断点恢复) + query = {} + if start_from_id: + query = {"_id": {"$gt": start_from_id}} + + stats.total_documents = mongo_collection.count_documents(query) if stats.total_documents == 0: logger.warning(f"集合 {config.mongo_collection} 为空,跳过迁移") @@ -476,69 +585,112 @@ class MongoToSQLiteMigrator: logger.info(f"待迁移文档数量: {stats.total_documents}") - # 逐个处理文档 - batch_count = 0 - for mongo_doc in mongo_collection.find().batch_size(config.batch_size): - try: - stats.processed_count += 1 - doc_id = mongo_doc.get('_id', 'unknown') - - # 构建目标数据 - target_data = {} - for mongo_field, sqlite_field in config.field_mapping.items(): - value = self._get_nested_value(mongo_doc, mongo_field) - - # 获取目标字段对象并转换类型 - if hasattr(config.target_model, sqlite_field): - field_obj = getattr(config.target_model, sqlite_field) - converted_value = self._convert_field_value(value, field_obj) - target_data[sqlite_field] = converted_value - - # 重复检查 - if config.skip_duplicates and self._check_duplicate_by_unique_fields( - config.target_model, target_data, config.unique_fields - ): - stats.duplicate_count += 1 - stats.skipped_count += 1 - logger.debug(f"跳过重复记录: {doc_id}") - continue - - # 使用ORM创建实例 - with db.atomic(): # 每个实例的事务保护 - instance = self._create_model_instance(config.target_model, target_data) - - if instance: - stats.success_count += 1 - else: - stats.add_error(doc_id, "ORM创建实例失败", target_data) + # 创建Rich进度条 + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeElapsedColumn(), + TimeRemainingColumn(), + console=self.console, + refresh_per_second=10 + ) as progress: - except Exception as e: - doc_id = mongo_doc.get('_id', 'unknown') - stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc) - logger.error(f"处理文档失败 (ID: {doc_id}): {e}") + task = progress.add_task( + f"迁移 {config.mongo_collection}", + total=stats.total_documents + ) + # 批量处理数据 + batch_data = [] + batch_count = 0 + last_processed_id = None - # 进度报告 - batch_count += 1 - if batch_count % config.batch_size == 0: - progress = (stats.processed_count / stats.total_documents) * 100 - logger.info( - f"迁移进度: {stats.processed_count}/{stats.total_documents} " - f"({progress:.1f}%) - 成功: {stats.success_count}, " - f"错误: {stats.error_count}, 跳过: {stats.skipped_count}" - ) + for mongo_doc in mongo_collection.find(query).batch_size(config.batch_size): + try: + doc_id = mongo_doc.get('_id', 'unknown') + last_processed_id = doc_id + + # 构建目标数据 + target_data = {} + for mongo_field, sqlite_field in config.field_mapping.items(): + value = self._get_nested_value(mongo_doc, mongo_field) + + # 获取目标字段对象并转换类型 + if hasattr(config.target_model, sqlite_field): + field_obj = getattr(config.target_model, sqlite_field) + converted_value = self._convert_field_value(value, field_obj) + target_data[sqlite_field] = converted_value + + # 数据验证已禁用 + # if config.enable_validation: + # if not self._validate_data(config.mongo_collection, target_data, doc_id, stats): + # stats.skipped_count += 1 + # continue + + # 重复检查 + if config.skip_duplicates and self._check_duplicate_by_unique_fields( + config.target_model, target_data, config.unique_fields + ): + stats.duplicate_count += 1 + stats.skipped_count += 1 + logger.debug(f"跳过重复记录: {doc_id}") + continue + + # 添加到批量数据 + batch_data.append(target_data) + stats.processed_count += 1 + + # 执行批量插入 + if len(batch_data) >= config.batch_size: + success_count = self._batch_insert(config.target_model, batch_data) + stats.success_count += success_count + stats.batch_insert_count += 1 + + # 保存断点 + self._save_checkpoint(config.mongo_collection, stats.processed_count, last_processed_id) + + batch_data.clear() + batch_count += 1 + + # 更新进度条 + progress.update(task, advance=config.batch_size) + + except Exception as e: + doc_id = mongo_doc.get('_id', 'unknown') + stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc) + logger.error(f"处理文档失败 (ID: {doc_id}): {e}") + + # 处理剩余的批量数据 + if batch_data: + success_count = self._batch_insert(config.target_model, batch_data) + stats.success_count += success_count + stats.batch_insert_count += 1 + progress.update(task, advance=len(batch_data)) + + # 完成进度条 + progress.update(task, completed=stats.total_documents) + + stats.end_time = datetime.now() + duration = stats.end_time - stats.start_time logger.info( f"迁移完成: {config.mongo_collection} -> {config.target_model._meta.table_name}\n" f"总计: {stats.total_documents}, 成功: {stats.success_count}, " - f"错误: {stats.error_count}, 跳过: {stats.skipped_count}, 重复: {stats.duplicate_count}" + f"错误: {stats.error_count}, 跳过: {stats.skipped_count}, 重复: {stats.duplicate_count}\n" + f"耗时: {duration.total_seconds():.2f}秒, 批量插入次数: {stats.batch_insert_count}" ) + # 清理断点文件 + checkpoint_file = self.checkpoint_dir / f"{config.mongo_collection}_checkpoint.pkl" + if checkpoint_file.exists(): + checkpoint_file.unlink() + except Exception as e: logger.error(f"迁移集合 {config.mongo_collection} 时发生异常: {e}") stats.add_error("collection_error", str(e)) return stats - def migrate_all(self) -> Dict[str, MigrationStats]: """执行所有迁移任务""" logger.info("开始执行数据库迁移...") @@ -550,18 +702,44 @@ class MongoToSQLiteMigrator: all_stats = {} try: - for config in self.migration_configs: - logger.info(f"\n开始处理集合: {config.mongo_collection}") + # 创建总体进度表格 + total_collections = len(self.migration_configs) + self.console.print(Panel( + f"[bold blue]MongoDB 到 SQLite 数据迁移[/bold blue]\n" + f"[yellow]总集合数: {total_collections}[/yellow]", + title="迁移开始", + expand=False + )) + for idx, config in enumerate(self.migration_configs, 1): + self.console.print(f"\n[bold green]正在处理集合 {idx}/{total_collections}: {config.mongo_collection}[/bold green]") stats = self.migrate_collection(config) all_stats[config.mongo_collection] = stats + # 显示单个集合的快速统计 + if stats.processed_count > 0: + success_rate = (stats.success_count / stats.processed_count * 100) + if success_rate >= 95: + status_emoji = "✅" + status_color = "bright_green" + elif success_rate >= 80: + status_emoji = "⚠️" + status_color = "yellow" + else: + status_emoji = "❌" + status_color = "red" + + self.console.print( + f" {status_emoji} [{status_color}]完成: {stats.success_count}/{stats.processed_count} " + f"({success_rate:.1f}%) 错误: {stats.error_count}[/{status_color}]" + ) + # 错误率检查 if stats.processed_count > 0: error_rate = stats.error_count / stats.processed_count if error_rate > 0.1: # 错误率超过10% - logger.warning( - f"集合 {config.mongo_collection} 错误率较高: {error_rate:.1%} " - f"({stats.error_count}/{stats.processed_count})" + self.console.print( + f" [red]⚠️ 警告: 错误率较高 {error_rate:.1%} " + f"({stats.error_count}/{stats.processed_count})[/red]" ) finally: @@ -571,52 +749,131 @@ class MongoToSQLiteMigrator: return all_stats def _print_migration_summary(self, all_stats: Dict[str, MigrationStats]): - """打印迁移汇总信息""" - logger.info("\n" + "="*60) - logger.info("数据迁移汇总报告") - logger.info("="*60) - + """使用Rich打印美观的迁移汇总信息""" + # 计算总体统计 total_processed = sum(stats.processed_count for stats in all_stats.values()) total_success = sum(stats.success_count for stats in all_stats.values()) total_errors = sum(stats.error_count for stats in all_stats.values()) total_skipped = sum(stats.skipped_count for stats in all_stats.values()) total_duplicates = sum(stats.duplicate_count for stats in all_stats.values()) + total_validation_errors = sum(stats.validation_errors for stats in all_stats.values()) + total_batch_inserts = sum(stats.batch_insert_count for stats in all_stats.values()) - # 表头 - logger.info(f"{'集合名称':<20} | {'处理':<6} | {'成功':<6} | {'错误':<6} | {'跳过':<6} | {'重复':<6} | {'成功率':<8}") - logger.info("-" * 75) + # 计算总耗时 + total_duration_seconds = 0 + for stats in all_stats.values(): + if stats.start_time and stats.end_time: + duration = stats.end_time - stats.start_time + total_duration_seconds += duration.total_seconds() + + # 创建详细统计表格 + table = Table(title="[bold blue]数据迁移汇总报告[/bold blue]", show_header=True, header_style="bold magenta") + table.add_column("集合名称", style="cyan", width=20) + table.add_column("文档总数", justify="right", style="blue") + table.add_column("处理数量", justify="right", style="green") + table.add_column("成功数量", justify="right", style="green") + table.add_column("错误数量", justify="right", style="red") + table.add_column("跳过数量", justify="right", style="yellow") + table.add_column("重复数量", justify="right", style="bright_yellow") + table.add_column("验证错误", justify="right", style="red") + table.add_column("批次数", justify="right", style="purple") + table.add_column("成功率", justify="right", style="bright_green") + table.add_column("耗时(秒)", justify="right", style="blue") for collection_name, stats in all_stats.items(): success_rate = (stats.success_count / stats.processed_count * 100) if stats.processed_count > 0 else 0 - logger.info( - f"{collection_name:<20} | " - f"{stats.processed_count:<6} | " - f"{stats.success_count:<6} | " - f"{stats.error_count:<6} | " - f"{stats.skipped_count:<6} | " - f"{stats.duplicate_count:<6} | " - f"{success_rate:<7.1f}%" + duration = 0 + if stats.start_time and stats.end_time: + duration = (stats.end_time - stats.start_time).total_seconds() + + # 根据成功率设置颜色 + if success_rate >= 95: + success_rate_style = "[bright_green]" + elif success_rate >= 80: + success_rate_style = "[yellow]" + else: + success_rate_style = "[red]" + + table.add_row( + collection_name, + str(stats.total_documents), + str(stats.processed_count), + str(stats.success_count), + f"[red]{stats.error_count}[/red]" if stats.error_count > 0 else "0", + f"[yellow]{stats.skipped_count}[/yellow]" if stats.skipped_count > 0 else "0", + f"[bright_yellow]{stats.duplicate_count}[/bright_yellow]" if stats.duplicate_count > 0 else "0", + f"[red]{stats.validation_errors}[/red]" if stats.validation_errors > 0 else "0", + str(stats.batch_insert_count), + f"{success_rate_style}{success_rate:.1f}%[/{success_rate_style[1:]}", + f"{duration:.2f}" ) - logger.info("-" * 75) + # 添加总计行 total_success_rate = (total_success / total_processed * 100) if total_processed > 0 else 0 - logger.info( - f"{'总计':<20} | " - f"{total_processed:<6} | " - f"{total_success:<6} | " - f"{total_errors:<6} | " - f"{total_skipped:<6} | " - f"{total_duplicates:<6} | " - f"{total_success_rate:<7.1f}%" + if total_success_rate >= 95: + total_rate_style = "[bright_green]" + elif total_success_rate >= 80: + total_rate_style = "[yellow]" + else: + total_rate_style = "[red]" + + table.add_section() + table.add_row( + "[bold]总计[/bold]", + f"[bold]{sum(stats.total_documents for stats in all_stats.values())}[/bold]", + f"[bold]{total_processed}[/bold]", + f"[bold]{total_success}[/bold]", + f"[bold red]{total_errors}[/bold red]" if total_errors > 0 else "[bold]0[/bold]", + f"[bold yellow]{total_skipped}[/bold yellow]" if total_skipped > 0 else "[bold]0[/bold]", + f"[bold bright_yellow]{total_duplicates}[/bold bright_yellow]" if total_duplicates > 0 else "[bold]0[/bold]", + f"[bold red]{total_validation_errors}[/bold red]" if total_validation_errors > 0 else "[bold]0[/bold]", + f"[bold]{total_batch_inserts}[/bold]", + f"[bold]{total_rate_style}{total_success_rate:.1f}%[/{total_rate_style[1:]}[/bold]", + f"[bold]{total_duration_seconds:.2f}[/bold]" ) + self.console.print(table) + + # 创建状态面板 + status_items = [] if total_errors > 0: - logger.warning(f"\n⚠️ 存在 {total_errors} 个错误,请检查日志详情") + status_items.append(f"[red]⚠️ 发现 {total_errors} 个错误,请检查日志详情[/red]") + + if total_validation_errors > 0: + status_items.append(f"[red]🔍 数据验证失败: {total_validation_errors} 条记录[/red]") if total_duplicates > 0: - logger.info(f"ℹ️ 跳过了 {total_duplicates} 个重复记录") + status_items.append(f"[yellow]📋 跳过重复记录: {total_duplicates} 条[/yellow]") - logger.info("="*60) + if total_success_rate >= 95: + status_items.append(f"[bright_green]✅ 迁移成功率优秀: {total_success_rate:.1f}%[/bright_green]") + elif total_success_rate >= 80: + status_items.append(f"[yellow]⚡ 迁移成功率良好: {total_success_rate:.1f}%[/yellow]") + else: + status_items.append(f"[red]❌ 迁移成功率较低: {total_success_rate:.1f}%,需要检查[/red]") + + if status_items: + status_panel = Panel( + "\n".join(status_items), + title="[bold yellow]迁移状态总结[/bold yellow]", + border_style="yellow" + ) + self.console.print(status_panel) + + # 性能统计面板 + avg_speed = total_processed / total_duration_seconds if total_duration_seconds > 0 else 0 + performance_info = ( + f"[cyan]总处理时间:[/cyan] {total_duration_seconds:.2f} 秒\n" + f"[cyan]平均处理速度:[/cyan] {avg_speed:.1f} 条记录/秒\n" + f"[cyan]批量插入优化:[/cyan] 执行了 {total_batch_inserts} 次批量操作" + ) + + performance_panel = Panel( + performance_info, + title="[bold green]性能统计[/bold green]", + border_style="green" + ) + self.console.print(performance_panel) def add_migration_config(self, config: MigrationConfig): """添加新的迁移配置""" diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index ccf789649..bd2646371 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -44,9 +44,9 @@ class ChatStreams(BaseModel): # platform: "qq" # group_id: "941657197" # group_name: "测试" - group_platform = TextField() - group_id = TextField() - group_name = TextField() + group_platform = TextField(null=True) # 群聊信息可能不存在 + group_id = TextField(null=True) + group_name = TextField(null=True) # last_active_time: 1746623771.4825106 (时间戳,精确到小数点后7位) last_active_time = DoubleField()