re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
parent ecb02cae31
commit 7923eafef3
263 changed files with 3103 additions and 3123 deletions

View File

@@ -1,46 +1,48 @@
import os
import orjson
import sys # 新增系统模块导入
# import time
import pickle
import sys # 新增系统模块导入
from pathlib import Path
import orjson
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from typing import Dict, Any, List, Optional, Type
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from peewee import Field, IntegrityError, Model
from pymongo import MongoClient
from pymongo.errors import ConnectionFailure
from peewee import Model, Field, IntegrityError
# Rich 进度条和显示组件
from rich.console import Console
from rich.panel import Panel
from rich.progress import (
Progress,
TextColumn,
BarColumn,
TaskProgressColumn,
TimeRemainingColumn,
TimeElapsedColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.table import Table
from rich.panel import Panel
# from rich.text import Text
# from rich.text import Text
from src.common.database.database import db
from src.common.database.sqlalchemy_models import (
ChatStreams,
Emoji,
Messages,
Images,
ImageDescriptions,
PersonInfo,
Knowledges,
ThinkingLog,
GraphNodes,
GraphEdges,
GraphNodes,
ImageDescriptions,
Images,
Knowledges,
Messages,
PersonInfo,
ThinkingLog,
)
from src.common.logger import get_logger
@@ -54,12 +56,12 @@ class MigrationConfig:
"""迁移配置类"""
mongo_collection: str
target_model: Type[Model]
field_mapping: Dict[str, str]
target_model: type[Model]
field_mapping: dict[str, str]
batch_size: int = 500
enable_validation: bool = True
skip_duplicates: bool = True
unique_fields: List[str] = field(default_factory=list) # 用于重复检查的字段
unique_fields: list[str] = field(default_factory=list) # 用于重复检查的字段
# 数据验证相关类已移除 - 用户要求不要数据验证
@@ -73,7 +75,7 @@ class MigrationCheckpoint:
processed_count: int
last_processed_id: Any
timestamp: datetime
batch_errors: List[Dict[str, Any]] = field(default_factory=list)
batch_errors: list[dict[str, Any]] = field(default_factory=list)
@dataclass
@@ -88,11 +90,11 @@ class MigrationStats:
duplicate_count: int = 0
validation_errors: int = 0
batch_insert_count: int = 0
errors: List[Dict[str, Any]] = field(default_factory=list)
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
errors: list[dict[str, Any]] = field(default_factory=list)
start_time: datetime | None = None
end_time: datetime | None = None
def add_error(self, doc_id: Any, error: str, doc_data: Optional[Dict] = None):
def add_error(self, doc_id: Any, error: str, doc_data: dict | None = None):
"""添加错误记录"""
self.errors.append(
{"doc_id": str(doc_id), "error": error, "timestamp": datetime.now().isoformat(), "doc_data": doc_data}
@@ -108,10 +110,10 @@ class MigrationStats:
class MongoToSQLiteMigrator:
"""MongoDB到SQLite数据迁移器 - 使用Peewee ORM"""
def __init__(self, mongo_uri: Optional[str] = None, database_name: Optional[str] = None):
def __init__(self, mongo_uri: str | None = None, database_name: str | None = None):
self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot")
self.mongo_uri = mongo_uri or self._build_mongo_uri()
self.mongo_client: Optional[MongoClient] = None
self.mongo_client: MongoClient | None = None
self.mongo_db = None
# 迁移配置
@@ -142,7 +144,7 @@ class MongoToSQLiteMigrator:
else:
return f"mongodb://{host}:{port}/{self.database_name}"
def _initialize_migration_configs(self) -> List[MigrationConfig]:
def _initialize_migration_configs(self) -> list[MigrationConfig]:
"""初始化迁移配置"""
return [ # 表情包迁移配置
MigrationConfig(
@@ -306,7 +308,7 @@ class MongoToSQLiteMigrator:
),
]
def _initialize_validation_rules(self) -> Dict[str, Any]:
def _initialize_validation_rules(self) -> dict[str, Any]:
"""数据验证已禁用 - 返回空字典"""
return {}
@@ -337,7 +339,7 @@ class MongoToSQLiteMigrator:
self.mongo_client.close()
logger.info("MongoDB连接已关闭")
def _get_nested_value(self, document: Dict[str, Any], field_path: str) -> Any:
def _get_nested_value(self, document: dict[str, Any], field_path: str) -> Any:
"""获取嵌套字段的值"""
if "." not in field_path:
return document.get(field_path)
@@ -434,7 +436,7 @@ class MongoToSQLiteMigrator:
return None
def _validate_data(self, collection_name: str, data: Dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
def _validate_data(self, collection_name: str, data: dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
"""数据验证已禁用 - 始终返回True"""
return True
@@ -454,7 +456,7 @@ class MongoToSQLiteMigrator:
except Exception as e:
logger.warning(f"保存断点失败: {e}")
def _load_checkpoint(self, collection_name: str) -> Optional[MigrationCheckpoint]:
def _load_checkpoint(self, collection_name: str) -> MigrationCheckpoint | None:
"""加载迁移断点"""
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
if not checkpoint_file.exists():
@@ -467,7 +469,7 @@ class MongoToSQLiteMigrator:
logger.warning(f"加载断点失败: {e}")
return None
def _batch_insert(self, model: Type[Model], data_list: List[Dict[str, Any]]) -> int:
def _batch_insert(self, model: type[Model], data_list: list[dict[str, Any]]) -> int:
"""批量插入数据"""
if not data_list:
return 0
@@ -494,7 +496,7 @@ class MongoToSQLiteMigrator:
return success_count
def _check_duplicate_by_unique_fields(
self, model: Type[Model], data: Dict[str, Any], unique_fields: List[str]
self, model: type[Model], data: dict[str, Any], unique_fields: list[str]
) -> bool:
"""根据唯一字段检查重复"""
if not unique_fields:
@@ -512,7 +514,7 @@ class MongoToSQLiteMigrator:
logger.debug(f"重复检查失败: {e}")
return False
def _create_model_instance(self, model: Type[Model], data: Dict[str, Any]) -> Optional[Model]:
def _create_model_instance(self, model: type[Model], data: dict[str, Any]) -> Model | None:
"""使用ORM创建模型实例"""
try:
# 过滤掉不存在的字段
@@ -669,7 +671,7 @@ class MongoToSQLiteMigrator:
return stats
def migrate_all(self) -> Dict[str, MigrationStats]:
def migrate_all(self) -> dict[str, MigrationStats]:
"""执行所有迁移任务"""
logger.info("开始执行数据库迁移...")
@@ -730,7 +732,7 @@ class MongoToSQLiteMigrator:
self._print_migration_summary(all_stats)
return all_stats
def _print_migration_summary(self, all_stats: Dict[str, MigrationStats]):
def _print_migration_summary(self, all_stats: dict[str, MigrationStats]):
"""使用Rich打印美观的迁移汇总信息"""
# 计算总体统计
total_processed = sum(stats.processed_count for stats in all_stats.values())
@@ -857,7 +859,7 @@ class MongoToSQLiteMigrator:
"""添加新的迁移配置"""
self.migration_configs.append(config)
def migrate_single_collection(self, collection_name: str) -> Optional[MigrationStats]:
def migrate_single_collection(self, collection_name: str) -> MigrationStats | None:
"""迁移单个指定的集合"""
config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None)
if not config:
@@ -875,7 +877,7 @@ class MongoToSQLiteMigrator:
finally:
self.disconnect_mongodb()
def export_error_report(self, all_stats: Dict[str, MigrationStats], filepath: str):
def export_error_report(self, all_stats: dict[str, MigrationStats], filepath: str):
"""导出错误报告"""
error_report = {
"timestamp": datetime.now().isoformat(),