re-style: 格式化代码
This commit is contained in:
@@ -1,46 +1,48 @@
|
||||
import os
|
||||
import orjson
|
||||
import sys # 新增系统模块导入
|
||||
|
||||
# import time
|
||||
import pickle
|
||||
import sys # 新增系统模块导入
|
||||
from pathlib import Path
|
||||
|
||||
import orjson
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from typing import Dict, Any, List, Optional, Type
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from peewee import Field, IntegrityError, Model
|
||||
from pymongo import MongoClient
|
||||
from pymongo.errors import ConnectionFailure
|
||||
from peewee import Model, Field, IntegrityError
|
||||
|
||||
# Rich 进度条和显示组件
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
TextColumn,
|
||||
BarColumn,
|
||||
TaskProgressColumn,
|
||||
TimeRemainingColumn,
|
||||
TimeElapsedColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
)
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
# from rich.text import Text
|
||||
|
||||
# from rich.text import Text
|
||||
from src.common.database.database import db
|
||||
from src.common.database.sqlalchemy_models import (
|
||||
ChatStreams,
|
||||
Emoji,
|
||||
Messages,
|
||||
Images,
|
||||
ImageDescriptions,
|
||||
PersonInfo,
|
||||
Knowledges,
|
||||
ThinkingLog,
|
||||
GraphNodes,
|
||||
GraphEdges,
|
||||
GraphNodes,
|
||||
ImageDescriptions,
|
||||
Images,
|
||||
Knowledges,
|
||||
Messages,
|
||||
PersonInfo,
|
||||
ThinkingLog,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -54,12 +56,12 @@ class MigrationConfig:
|
||||
"""迁移配置类"""
|
||||
|
||||
mongo_collection: str
|
||||
target_model: Type[Model]
|
||||
field_mapping: Dict[str, str]
|
||||
target_model: type[Model]
|
||||
field_mapping: dict[str, str]
|
||||
batch_size: int = 500
|
||||
enable_validation: bool = True
|
||||
skip_duplicates: bool = True
|
||||
unique_fields: List[str] = field(default_factory=list) # 用于重复检查的字段
|
||||
unique_fields: list[str] = field(default_factory=list) # 用于重复检查的字段
|
||||
|
||||
|
||||
# 数据验证相关类已移除 - 用户要求不要数据验证
|
||||
@@ -73,7 +75,7 @@ class MigrationCheckpoint:
|
||||
processed_count: int
|
||||
last_processed_id: Any
|
||||
timestamp: datetime
|
||||
batch_errors: List[Dict[str, Any]] = field(default_factory=list)
|
||||
batch_errors: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -88,11 +90,11 @@ class MigrationStats:
|
||||
duplicate_count: int = 0
|
||||
validation_errors: int = 0
|
||||
batch_insert_count: int = 0
|
||||
errors: List[Dict[str, Any]] = field(default_factory=list)
|
||||
start_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None
|
||||
errors: list[dict[str, Any]] = field(default_factory=list)
|
||||
start_time: datetime | None = None
|
||||
end_time: datetime | None = None
|
||||
|
||||
def add_error(self, doc_id: Any, error: str, doc_data: Optional[Dict] = None):
|
||||
def add_error(self, doc_id: Any, error: str, doc_data: dict | None = None):
|
||||
"""添加错误记录"""
|
||||
self.errors.append(
|
||||
{"doc_id": str(doc_id), "error": error, "timestamp": datetime.now().isoformat(), "doc_data": doc_data}
|
||||
@@ -108,10 +110,10 @@ class MigrationStats:
|
||||
class MongoToSQLiteMigrator:
|
||||
"""MongoDB到SQLite数据迁移器 - 使用Peewee ORM"""
|
||||
|
||||
def __init__(self, mongo_uri: Optional[str] = None, database_name: Optional[str] = None):
|
||||
def __init__(self, mongo_uri: str | None = None, database_name: str | None = None):
|
||||
self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot")
|
||||
self.mongo_uri = mongo_uri or self._build_mongo_uri()
|
||||
self.mongo_client: Optional[MongoClient] = None
|
||||
self.mongo_client: MongoClient | None = None
|
||||
self.mongo_db = None
|
||||
|
||||
# 迁移配置
|
||||
@@ -142,7 +144,7 @@ class MongoToSQLiteMigrator:
|
||||
else:
|
||||
return f"mongodb://{host}:{port}/{self.database_name}"
|
||||
|
||||
def _initialize_migration_configs(self) -> List[MigrationConfig]:
|
||||
def _initialize_migration_configs(self) -> list[MigrationConfig]:
|
||||
"""初始化迁移配置"""
|
||||
return [ # 表情包迁移配置
|
||||
MigrationConfig(
|
||||
@@ -306,7 +308,7 @@ class MongoToSQLiteMigrator:
|
||||
),
|
||||
]
|
||||
|
||||
def _initialize_validation_rules(self) -> Dict[str, Any]:
|
||||
def _initialize_validation_rules(self) -> dict[str, Any]:
|
||||
"""数据验证已禁用 - 返回空字典"""
|
||||
return {}
|
||||
|
||||
@@ -337,7 +339,7 @@ class MongoToSQLiteMigrator:
|
||||
self.mongo_client.close()
|
||||
logger.info("MongoDB连接已关闭")
|
||||
|
||||
def _get_nested_value(self, document: Dict[str, Any], field_path: str) -> Any:
|
||||
def _get_nested_value(self, document: dict[str, Any], field_path: str) -> Any:
|
||||
"""获取嵌套字段的值"""
|
||||
if "." not in field_path:
|
||||
return document.get(field_path)
|
||||
@@ -434,7 +436,7 @@ class MongoToSQLiteMigrator:
|
||||
|
||||
return None
|
||||
|
||||
def _validate_data(self, collection_name: str, data: Dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
|
||||
def _validate_data(self, collection_name: str, data: dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
|
||||
"""数据验证已禁用 - 始终返回True"""
|
||||
return True
|
||||
|
||||
@@ -454,7 +456,7 @@ class MongoToSQLiteMigrator:
|
||||
except Exception as e:
|
||||
logger.warning(f"保存断点失败: {e}")
|
||||
|
||||
def _load_checkpoint(self, collection_name: str) -> Optional[MigrationCheckpoint]:
|
||||
def _load_checkpoint(self, collection_name: str) -> MigrationCheckpoint | None:
|
||||
"""加载迁移断点"""
|
||||
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
|
||||
if not checkpoint_file.exists():
|
||||
@@ -467,7 +469,7 @@ class MongoToSQLiteMigrator:
|
||||
logger.warning(f"加载断点失败: {e}")
|
||||
return None
|
||||
|
||||
def _batch_insert(self, model: Type[Model], data_list: List[Dict[str, Any]]) -> int:
|
||||
def _batch_insert(self, model: type[Model], data_list: list[dict[str, Any]]) -> int:
|
||||
"""批量插入数据"""
|
||||
if not data_list:
|
||||
return 0
|
||||
@@ -494,7 +496,7 @@ class MongoToSQLiteMigrator:
|
||||
return success_count
|
||||
|
||||
def _check_duplicate_by_unique_fields(
|
||||
self, model: Type[Model], data: Dict[str, Any], unique_fields: List[str]
|
||||
self, model: type[Model], data: dict[str, Any], unique_fields: list[str]
|
||||
) -> bool:
|
||||
"""根据唯一字段检查重复"""
|
||||
if not unique_fields:
|
||||
@@ -512,7 +514,7 @@ class MongoToSQLiteMigrator:
|
||||
logger.debug(f"重复检查失败: {e}")
|
||||
return False
|
||||
|
||||
def _create_model_instance(self, model: Type[Model], data: Dict[str, Any]) -> Optional[Model]:
|
||||
def _create_model_instance(self, model: type[Model], data: dict[str, Any]) -> Model | None:
|
||||
"""使用ORM创建模型实例"""
|
||||
try:
|
||||
# 过滤掉不存在的字段
|
||||
@@ -669,7 +671,7 @@ class MongoToSQLiteMigrator:
|
||||
|
||||
return stats
|
||||
|
||||
def migrate_all(self) -> Dict[str, MigrationStats]:
|
||||
def migrate_all(self) -> dict[str, MigrationStats]:
|
||||
"""执行所有迁移任务"""
|
||||
logger.info("开始执行数据库迁移...")
|
||||
|
||||
@@ -730,7 +732,7 @@ class MongoToSQLiteMigrator:
|
||||
self._print_migration_summary(all_stats)
|
||||
return all_stats
|
||||
|
||||
def _print_migration_summary(self, all_stats: Dict[str, MigrationStats]):
|
||||
def _print_migration_summary(self, all_stats: dict[str, MigrationStats]):
|
||||
"""使用Rich打印美观的迁移汇总信息"""
|
||||
# 计算总体统计
|
||||
total_processed = sum(stats.processed_count for stats in all_stats.values())
|
||||
@@ -857,7 +859,7 @@ class MongoToSQLiteMigrator:
|
||||
"""添加新的迁移配置"""
|
||||
self.migration_configs.append(config)
|
||||
|
||||
def migrate_single_collection(self, collection_name: str) -> Optional[MigrationStats]:
|
||||
def migrate_single_collection(self, collection_name: str) -> MigrationStats | None:
|
||||
"""迁移单个指定的集合"""
|
||||
config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None)
|
||||
if not config:
|
||||
@@ -875,7 +877,7 @@ class MongoToSQLiteMigrator:
|
||||
finally:
|
||||
self.disconnect_mongodb()
|
||||
|
||||
def export_error_report(self, all_stats: Dict[str, MigrationStats], filepath: str):
|
||||
def export_error_report(self, all_stats: dict[str, MigrationStats], filepath: str):
|
||||
"""导出错误报告"""
|
||||
error_report = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
|
||||
Reference in New Issue
Block a user