Files
Mofox-Core/scripts/update_database_imports.py
明天好像没什么 ff6dc542e1 rufffffff
2025-11-19 23:31:37 +08:00

186 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""批量更新数据库导入语句的脚本
将旧的数据库导入路径更新为新的重构后的路径:
- sqlalchemy_models -> core, core.models
- sqlalchemy_database_api -> compatibility
- database.database -> core
"""
import re
from pathlib import Path
# 定义导入映射规则
IMPORT_MAPPINGS = {
# 模型导入
r"from src\.common\.database\.sqlalchemy_models import (.+)":
r"from src.common.database.core.models import \1",
# API导入 - 需要特殊处理
r"from src\.common\.database\.sqlalchemy_database_api import (.+)":
r"from src.common.database.compatibility import \1",
# get_db_session 从 sqlalchemy_database_api 导入
r"from src\.common\.database\.sqlalchemy_database_api import get_db_session":
r"from src.common.database.core import get_db_session",
# get_db_session 从 sqlalchemy_models 导入
r"from src\.common\.database\.sqlalchemy_models import (.*)get_db_session(.*)":
lambda m: f"from src.common.database.core import {m.group(1)}get_db_session{m.group(2)}"
if "get_db_session" in m.group(0) else m.group(0),
# get_engine 导入
r"from src\.common\.database\.sqlalchemy_models import (.*)get_engine(.*)":
lambda m: f"from src.common.database.core import {m.group(1)}get_engine{m.group(2)}",
# Base 导入
r"from src\.common\.database\.sqlalchemy_models import (.*)Base(.*)":
lambda m: f"from src.common.database.core.models import {m.group(1)}Base{m.group(2)}",
# initialize_database 导入
r"from src\.common\.database\.sqlalchemy_models import initialize_database":
r"from src.common.database.core import check_and_migrate_database as initialize_database",
# database.py 导入
r"from src\.common\.database\.database import stop_database":
r"from src.common.database.core import close_engine as stop_database",
r"from src\.common\.database\.database import initialize_sql_database":
r"from src.common.database.core import check_and_migrate_database as initialize_sql_database",
}
# 需要排除的文件
EXCLUDE_PATTERNS = [
"**/database_refactoring_plan.md", # 文档文件
"**/old/**", # 旧文件目录
"**/sqlalchemy_*.py", # 旧的数据库文件本身
"**/database.py", # 旧的database文件
"**/db_*.py", # 旧的db文件
]
def should_exclude(file_path: Path) -> bool:
"""检查文件是否应该被排除"""
for pattern in EXCLUDE_PATTERNS:
if file_path.match(pattern):
return True
return False
def update_imports_in_file(file_path: Path, dry_run: bool = True) -> tuple[int, list[str]]:
"""更新单个文件中的导入语句
Args:
file_path: 文件路径
dry_run: 是否只是预览而不实际修改
Returns:
(修改次数, 修改详情列表)
"""
try:
content = file_path.read_text(encoding="utf-8")
original_content = content
changes = []
# 应用每个映射规则
for pattern, replacement in IMPORT_MAPPINGS.items():
matches = list(re.finditer(pattern, content))
for match in matches:
old_line = match.group(0)
# 处理函数类型的替换
if callable(replacement):
new_line_result = replacement(match)
new_line = new_line_result if isinstance(new_line_result, str) else old_line
else:
new_line = re.sub(pattern, replacement, old_line)
if old_line != new_line and isinstance(new_line, str):
content = content.replace(old_line, new_line, 1)
changes.append(f" - {old_line}")
changes.append(f" + {new_line}")
# 如果有修改且不是dry_run写回文件
if content != original_content:
if not dry_run:
file_path.write_text(content, encoding="utf-8")
return len(changes) // 2, changes
return 0, []
except Exception as e:
print(f"❌ 处理文件 {file_path} 时出错: {e}")
return 0, []
def main():
"""主函数"""
print("🔍 搜索需要更新导入的文件...")
# 获取项目根目录
root_dir = Path(__file__).parent.parent
# 搜索所有Python文件
all_python_files = list(root_dir.rglob("*.py"))
# 过滤掉排除的文件
target_files = [f for f in all_python_files if not should_exclude(f)]
print(f"📊 找到 {len(target_files)} 个Python文件需要检查")
print("\n" + "="*80)
# 第一遍:预览模式
print("\n🔍 预览模式 - 检查需要更新的文件...\n")
files_to_update = []
for file_path in target_files:
count, changes = update_imports_in_file(file_path, dry_run=True)
if count > 0:
files_to_update.append((file_path, count, changes))
if not files_to_update:
print("✅ 没有文件需要更新!")
return
print(f"📝 发现 {len(files_to_update)} 个文件需要更新:\n")
total_changes = 0
for file_path, count, changes in files_to_update:
rel_path = file_path.relative_to(root_dir)
print(f"\n📄 {rel_path} ({count} 处修改)")
for change in changes[:10]: # 最多显示前5对修改
print(change)
if len(changes) > 10:
print(f" ... 还有 {len(changes) - 10}")
total_changes += count
print("\n" + "="*80)
print("\n📊 统计:")
print(f" - 需要更新的文件: {len(files_to_update)}")
print(f" - 总修改次数: {total_changes}")
# 询问是否继续
print("\n" + "="*80)
response = input("\n是否执行更新?(yes/no): ").strip().lower()
if response != "yes":
print("❌ 已取消更新")
return
# 第二遍:实际更新
print("\n✨ 开始更新文件...\n")
success_count = 0
for file_path, _, _ in files_to_update:
count, _ = update_imports_in_file(file_path, dry_run=False)
if count > 0:
rel_path = file_path.relative_to(root_dir)
print(f"{rel_path} ({count} 处修改)")
success_count += 1
print("\n" + "="*80)
print(f"\n🎉 完成!成功更新 {success_count} 个文件")
if __name__ == "__main__":
main()