🤖 自动格式化代码 [skip ci]

This commit is contained in:
github-actions[bot]
2025-05-28 12:51:51 +00:00
parent 460c7fb75a
commit 847bd23a62

View File

@@ -1,6 +1,7 @@
import os
import json
import sys # 新增系统模块导入
# import time
import pickle
from pathlib import Path
@@ -22,7 +23,7 @@ from rich.progress import (
TaskProgressColumn,
TimeRemainingColumn,
TimeElapsedColumn,
SpinnerColumn
SpinnerColumn,
)
from rich.table import Table
from rich.panel import Panel
@@ -30,17 +31,29 @@ from rich.panel import Panel
from src.common.database.database import db
from src.common.database.database_model import (
ChatStreams, LLMUsage, Emoji, Messages, Images, ImageDescriptions,
PersonInfo, Knowledges, ThinkingLog, GraphNodes, GraphEdges
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
PersonInfo,
Knowledges,
ThinkingLog,
GraphNodes,
GraphEdges,
)
from src.common.logger_manager import get_logger
logger = get_logger("mongodb_to_sqlite")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@dataclass
class MigrationConfig:
"""迁移配置类"""
mongo_collection: str
target_model: Type[Model]
field_mapping: Dict[str, str]
@@ -56,6 +69,7 @@ class MigrationConfig:
@dataclass
class MigrationCheckpoint:
"""迁移断点数据"""
collection_name: str
processed_count: int
last_processed_id: Any
@@ -66,6 +80,7 @@ class MigrationCheckpoint:
@dataclass
class MigrationStats:
"""迁移统计信息"""
total_documents: int = 0
processed_count: int = 0
success_count: int = 0
@@ -80,12 +95,9 @@ class MigrationStats:
def add_error(self, doc_id: Any, error: str, doc_data: Optional[Dict] = None):
"""添加错误记录"""
self.errors.append({
'doc_id': str(doc_id),
'error': error,
'timestamp': datetime.now().isoformat(),
'doc_data': doc_data
})
self.errors.append(
{"doc_id": str(doc_id), "error": error, "timestamp": datetime.now().isoformat(), "doc_data": doc_data}
)
self.error_count += 1
def add_validation_error(self, doc_id: Any, field: str, error: str):
@@ -120,11 +132,11 @@ class MongoToSQLiteMigrator:
if mongo_uri := os.getenv("MONGODB_URI"):
return mongo_uri
user = os.getenv('MONGODB_USER')
password = os.getenv('MONGODB_PASS')
host = os.getenv('MONGODB_HOST', 'localhost')
port = os.getenv('MONGODB_PORT', '27017')
auth_source = os.getenv('MONGODB_AUTH_SOURCE', 'admin')
user = os.getenv("MONGODB_USER")
password = os.getenv("MONGODB_PASS")
host = os.getenv("MONGODB_HOST", "localhost")
port = os.getenv("MONGODB_PORT", "27017")
auth_source = os.getenv("MONGODB_AUTH_SOURCE", "admin")
if user and password:
return f"mongodb://{user}:{password}@{host}:{port}/{self.database_name}?authSource={auth_source}"
@@ -144,11 +156,11 @@ class MongoToSQLiteMigrator:
"description": "description",
"emotion": "emotion",
"usage_count": "usage_count",
"last_used_time": "last_used_time"
"last_used_time": "last_used_time",
# record_time字段将在转换时自动设置为当前时间
},
enable_validation=False, # 禁用数据验证
unique_fields=["full_path", "emoji_hash"]
unique_fields=["full_path", "emoji_hash"],
),
# 聊天流迁移配置
MigrationConfig(
@@ -157,7 +169,7 @@ class MongoToSQLiteMigrator:
field_mapping={
"stream_id": "stream_id",
"create_time": "create_time",
"group_info.platform": "group_platform",# 由于Mongodb处理私聊时会让group_info值为null而新的数据库不允许为null所以私聊聊天流是没法迁移的等更新吧。
"group_info.platform": "group_platform", # 由于Mongodb处理私聊时会让group_info值为null而新的数据库不允许为null所以私聊聊天流是没法迁移的等更新吧。
"group_info.group_id": "group_id", # 同上
"group_info.group_name": "group_name", # 同上
"last_active_time": "last_active_time",
@@ -165,10 +177,10 @@ class MongoToSQLiteMigrator:
"user_info.platform": "user_platform",
"user_info.user_id": "user_id",
"user_info.user_nickname": "user_nickname",
"user_info.user_cardname": "user_cardname"
"user_info.user_cardname": "user_cardname",
},
enable_validation=False, # 禁用数据验证
unique_fields=["stream_id"]
unique_fields=["stream_id"],
),
# LLM使用记录迁移配置
MigrationConfig(
@@ -184,10 +196,10 @@ class MongoToSQLiteMigrator:
"total_tokens": "total_tokens",
"cost": "cost",
"status": "status",
"timestamp": "timestamp"
"timestamp": "timestamp",
},
enable_validation=False, # 禁用数据验证
unique_fields=["user_id", "timestamp"] # 组合唯一性
unique_fields=["user_id", "timestamp"], # 组合唯一性
),
# 消息迁移配置
MigrationConfig(
@@ -214,12 +226,11 @@ class MongoToSQLiteMigrator:
"user_info.user_cardname": "user_cardname",
"processed_plain_text": "processed_plain_text",
"detailed_plain_text": "detailed_plain_text",
"memorized_times": "memorized_times"
"memorized_times": "memorized_times",
},
enable_validation=False, # 禁用数据验证
unique_fields=["message_id"]
unique_fields=["message_id"],
),
# 图片迁移配置
MigrationConfig(
mongo_collection="images",
@@ -229,11 +240,10 @@ class MongoToSQLiteMigrator:
"description": "description",
"path": "path",
"timestamp": "timestamp",
"type": "type"
"type": "type",
},
unique_fields=["path"]
unique_fields=["path"],
),
# 图片描述迁移配置
MigrationConfig(
mongo_collection="image_descriptions",
@@ -242,10 +252,10 @@ class MongoToSQLiteMigrator:
"type": "type",
"hash": "image_description_hash",
"description": "description",
"timestamp": "timestamp"
}, unique_fields=["image_description_hash", "type"]
"timestamp": "timestamp",
},
unique_fields=["image_description_hash", "type"],
),
# 个人信息迁移配置
MigrationConfig(
mongo_collection="person_info",
@@ -260,22 +270,17 @@ class MongoToSQLiteMigrator:
"relationship_value": "relationship_value",
"konw_time": "know_time",
"msg_interval": "msg_interval",
"msg_interval_list": "msg_interval_list"
"msg_interval_list": "msg_interval_list",
},
unique_fields=["person_id"]
unique_fields=["person_id"],
),
# 知识库迁移配置
MigrationConfig(
mongo_collection="knowledges",
target_model=Knowledges,
field_mapping={
"content": "content",
"embedding": "embedding"
},
unique_fields=["content"] # 假设内容唯一
field_mapping={"content": "content", "embedding": "embedding"},
unique_fields=["content"], # 假设内容唯一
),
# 思考日志迁移配置
MigrationConfig(
mongo_collection="thinking_log",
@@ -293,9 +298,8 @@ class MongoToSQLiteMigrator:
"heartflow_data": "heartflow_data_json",
"reasoning_data": "reasoning_data_json",
},
unique_fields=["chat_id", "trigger_text"]
unique_fields=["chat_id", "trigger_text"],
),
# 图节点迁移配置
MigrationConfig(
mongo_collection="graph_data.nodes",
@@ -305,11 +309,10 @@ class MongoToSQLiteMigrator:
"memory_items": "memory_items",
"hash": "hash",
"created_time": "created_time",
"last_modified": "last_modified"
"last_modified": "last_modified",
},
unique_fields=["concept"]
unique_fields=["concept"],
),
# 图边迁移配置
MigrationConfig(
mongo_collection="graph_data.edges",
@@ -320,11 +323,12 @@ class MongoToSQLiteMigrator:
"strength": "strength",
"hash": "hash",
"created_time": "created_time",
"last_modified": "last_modified"
"last_modified": "last_modified",
},
unique_fields=["source", "target"] # 组合唯一性
)
unique_fields=["source", "target"], # 组合唯一性
),
]
def _initialize_validation_rules(self) -> Dict[str, Any]:
"""数据验证已禁用 - 返回空字典"""
return {}
@@ -333,14 +337,11 @@ class MongoToSQLiteMigrator:
"""连接到MongoDB"""
try:
self.mongo_client = MongoClient(
self.mongo_uri,
serverSelectionTimeoutMS=5000,
connectTimeoutMS=10000,
maxPoolSize=10
self.mongo_uri, serverSelectionTimeoutMS=5000, connectTimeoutMS=10000, maxPoolSize=10
)
# 测试连接
self.mongo_client.admin.command('ping')
self.mongo_client.admin.command("ping")
self.mongo_db = self.mongo_client[self.database_name]
logger.info(f"成功连接到MongoDB: {self.database_name}")
@@ -398,7 +399,7 @@ class MongoToSQLiteMigrator:
if isinstance(value, str):
# 处理字符串数字
clean_value = value.strip()
if clean_value.replace('.', '').replace('-', '').isdigit():
if clean_value.replace(".", "").replace("-", "").isdigit():
return int(float(clean_value))
return 0
return int(value) if value is not None else 0
@@ -408,7 +409,7 @@ class MongoToSQLiteMigrator:
elif field_type == "BooleanField":
if isinstance(value, str):
return value.lower() in ('true', '1', 'yes', 'on')
return value.lower() in ("true", "1", "yes", "on")
return bool(value)
elif field_type == "DateTimeField":
@@ -417,7 +418,7 @@ class MongoToSQLiteMigrator:
elif isinstance(value, str):
try:
# 尝试解析ISO格式日期
return datetime.fromisoformat(value.replace('Z', '+00:00'))
return datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
try:
# 尝试解析时间戳字符串
@@ -436,7 +437,7 @@ class MongoToSQLiteMigrator:
"""获取字段的默认值"""
field_type = field.__class__.__name__
if hasattr(field, 'default') and field.default is not None:
if hasattr(field, "default") and field.default is not None:
return field.default
if field.null:
@@ -455,6 +456,7 @@ class MongoToSQLiteMigrator:
return datetime.now()
return None
def _validate_data(self, collection_name: str, data: Dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
"""数据验证已禁用 - 始终返回True"""
return True
@@ -465,12 +467,12 @@ class MongoToSQLiteMigrator:
collection_name=collection_name,
processed_count=processed_count,
last_processed_id=last_id,
timestamp=datetime.now()
timestamp=datetime.now(),
)
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
try:
with open(checkpoint_file, 'wb') as f:
with open(checkpoint_file, "wb") as f:
pickle.dump(checkpoint, f)
except Exception as e:
logger.warning(f"保存断点失败: {e}")
@@ -482,7 +484,7 @@ class MongoToSQLiteMigrator:
return None
try:
with open(checkpoint_file, 'rb') as f:
with open(checkpoint_file, "rb") as f:
return pickle.load(f)
except Exception as e:
logger.warning(f"加载断点失败: {e}")
@@ -499,7 +501,7 @@ class MongoToSQLiteMigrator:
# 分批插入避免SQL语句过长
batch_size = 100
for i in range(0, len(data_list), batch_size):
batch = data_list[i:i + batch_size]
batch = data_list[i : i + batch_size]
model.insert_many(batch).execute()
success_count += len(batch)
except Exception as e:
@@ -514,8 +516,9 @@ class MongoToSQLiteMigrator:
return success_count
def _check_duplicate_by_unique_fields(self, model: Type[Model], data: Dict[str, Any],
unique_fields: List[str]) -> bool:
def _check_duplicate_by_unique_fields(
self, model: Type[Model], data: Dict[str, Any], unique_fields: List[str]
) -> bool:
"""根据唯一字段检查重复"""
if not unique_fields:
return False
@@ -554,6 +557,7 @@ class MongoToSQLiteMigrator:
except Exception as e:
logger.error(f"创建模型实例失败: {e}")
return None
def migrate_collection(self, config: MigrationConfig) -> MigrationStats:
"""迁移单个集合 - 使用优化的批量插入和进度条"""
stats = MigrationStats()
@@ -594,13 +598,9 @@ class MongoToSQLiteMigrator:
TimeElapsedColumn(),
TimeRemainingColumn(),
console=self.console,
refresh_per_second=10
refresh_per_second=10,
) as progress:
task = progress.add_task(
f"迁移 {config.mongo_collection}",
total=stats.total_documents
)
task = progress.add_task(f"迁移 {config.mongo_collection}", total=stats.total_documents)
# 批量处理数据
batch_data = []
batch_count = 0
@@ -608,7 +608,7 @@ class MongoToSQLiteMigrator:
for mongo_doc in mongo_collection.find(query).batch_size(config.batch_size):
try:
doc_id = mongo_doc.get('_id', 'unknown')
doc_id = mongo_doc.get("_id", "unknown")
last_processed_id = doc_id
# 构建目标数据
@@ -657,7 +657,7 @@ class MongoToSQLiteMigrator:
progress.update(task, advance=config.batch_size)
except Exception as e:
doc_id = mongo_doc.get('_id', 'unknown')
doc_id = mongo_doc.get("_id", "unknown")
stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc)
logger.error(f"处理文档失败 (ID: {doc_id}): {e}")
@@ -691,6 +691,7 @@ class MongoToSQLiteMigrator:
stats.add_error("collection_error", str(e))
return stats
def migrate_all(self) -> Dict[str, MigrationStats]:
"""执行所有迁移任务"""
logger.info("开始执行数据库迁移...")
@@ -704,20 +705,24 @@ class MongoToSQLiteMigrator:
try:
# 创建总体进度表格
total_collections = len(self.migration_configs)
self.console.print(Panel(
self.console.print(
Panel(
f"[bold blue]MongoDB 到 SQLite 数据迁移[/bold blue]\n"
f"[yellow]总集合数: {total_collections}[/yellow]",
title="迁移开始",
expand=False
))
expand=False,
)
)
for idx, config in enumerate(self.migration_configs, 1):
self.console.print(f"\n[bold green]正在处理集合 {idx}/{total_collections}: {config.mongo_collection}[/bold green]")
self.console.print(
f"\n[bold green]正在处理集合 {idx}/{total_collections}: {config.mongo_collection}[/bold green]"
)
stats = self.migrate_collection(config)
all_stats[config.mongo_collection] = stats
# 显示单个集合的快速统计
if stats.processed_count > 0:
success_rate = (stats.success_count / stats.processed_count * 100)
success_rate = stats.success_count / stats.processed_count * 100
if success_rate >= 95:
status_emoji = ""
status_color = "bright_green"
@@ -805,7 +810,7 @@ class MongoToSQLiteMigrator:
f"[red]{stats.validation_errors}[/red]" if stats.validation_errors > 0 else "0",
str(stats.batch_insert_count),
f"{success_rate_style}{success_rate:.1f}%[/{success_rate_style[1:]}",
f"{duration:.2f}"
f"{duration:.2f}",
)
# 添加总计行
@@ -825,11 +830,13 @@ class MongoToSQLiteMigrator:
f"[bold]{total_success}[/bold]",
f"[bold red]{total_errors}[/bold red]" if total_errors > 0 else "[bold]0[/bold]",
f"[bold yellow]{total_skipped}[/bold yellow]" if total_skipped > 0 else "[bold]0[/bold]",
f"[bold bright_yellow]{total_duplicates}[/bold bright_yellow]" if total_duplicates > 0 else "[bold]0[/bold]",
f"[bold bright_yellow]{total_duplicates}[/bold bright_yellow]"
if total_duplicates > 0
else "[bold]0[/bold]",
f"[bold red]{total_validation_errors}[/bold red]" if total_validation_errors > 0 else "[bold]0[/bold]",
f"[bold]{total_batch_inserts}[/bold]",
f"[bold]{total_rate_style}{total_success_rate:.1f}%[/{total_rate_style[1:]}[/bold]",
f"[bold]{total_duration_seconds:.2f}[/bold]"
f"[bold]{total_duration_seconds:.2f}[/bold]",
)
self.console.print(table)
@@ -854,9 +861,7 @@ class MongoToSQLiteMigrator:
if status_items:
status_panel = Panel(
"\n".join(status_items),
title="[bold yellow]迁移状态总结[/bold yellow]",
border_style="yellow"
"\n".join(status_items), title="[bold yellow]迁移状态总结[/bold yellow]", border_style="yellow"
)
self.console.print(status_panel)
@@ -868,11 +873,7 @@ class MongoToSQLiteMigrator:
f"[cyan]批量插入优化:[/cyan] 执行了 {total_batch_inserts} 次批量操作"
)
performance_panel = Panel(
performance_info,
title="[bold green]性能统计[/bold green]",
border_style="green"
)
performance_panel = Panel(performance_info, title="[bold green]性能统计[/bold green]", border_style="green")
self.console.print(performance_panel)
def add_migration_config(self, config: MigrationConfig):
@@ -900,27 +901,23 @@ class MongoToSQLiteMigrator:
def export_error_report(self, all_stats: Dict[str, MigrationStats], filepath: str):
"""导出错误报告"""
error_report = {
'timestamp': datetime.now().isoformat(),
'summary': {
"timestamp": datetime.now().isoformat(),
"summary": {
collection: {
'total': stats.total_documents,
'processed': stats.processed_count,
'success': stats.success_count,
'errors': stats.error_count,
'skipped': stats.skipped_count,
'duplicates': stats.duplicate_count
"total": stats.total_documents,
"processed": stats.processed_count,
"success": stats.success_count,
"errors": stats.error_count,
"skipped": stats.skipped_count,
"duplicates": stats.duplicate_count,
}
for collection, stats in all_stats.items()
},
'errors': {
collection: stats.errors
for collection, stats in all_stats.items()
if stats.errors
}
"errors": {collection: stats.errors for collection, stats in all_stats.items() if stats.errors},
}
try:
with open(filepath, 'w', encoding='utf-8') as f:
with open(filepath, "w", encoding="utf-8") as f:
json.dump(error_report, f, ensure_ascii=False, indent=2)
logger.info(f"错误报告已导出到: {filepath}")
except Exception as e: