refactor(database): 阶段二 - 完成核心层重构
- models.py: 迁移25个模型类,使用统一的Mapped类型注解 * 包含: ChatStreams, Messages, PersonInfo, LLMUsage等 * 新增: PermissionNodes, UserPermissions, UserRelationships * 654行纯模型定义代码,无初始化逻辑 - migration.py: 重构数据库迁移逻辑 * check_and_migrate_database: 自动检查和迁移表结构 * create_all_tables: 快速创建所有表 * drop_all_tables: 测试用删除所有表 * 使用新架构的engine和models - __init__.py: 完善导出清单 * 导出所有25个模型类 * 导出迁移函数 * 导出Base和工具函数 - 辅助脚本: * extract_models.py: 自动提取模型定义 * cleanup_models.py: 清理非模型代码 核心层现已完整,下一步进入优化层实现
This commit is contained in:
49
scripts/cleanup_models.py
Normal file
49
scripts/cleanup_models.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python3
|
||||
"""清理 core/models.py,只保留模型定义"""
|
||||
|
||||
import os
|
||||
|
||||
# 文件路径
|
||||
models_file = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"src",
|
||||
"common",
|
||||
"database",
|
||||
"core",
|
||||
"models.py"
|
||||
)
|
||||
|
||||
print(f"正在清理文件: {models_file}")
|
||||
|
||||
# 读取文件
|
||||
with open(models_file, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# 找到最后一个模型类的结束位置(MonthlyPlan的 __table_args__ 结束)
|
||||
# 我们要保留到第593行(包含)
|
||||
keep_lines = []
|
||||
found_end = False
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
keep_lines.append(line)
|
||||
|
||||
# 检查是否到达 MonthlyPlan 的 __table_args__ 结束
|
||||
if i > 580 and line.strip() == ")":
|
||||
# 再检查前一行是否有 Index 相关内容
|
||||
if "idx_monthlyplan" in "".join(lines[max(0, i-5):i]):
|
||||
print(f"找到模型定义结束位置: 第 {i} 行")
|
||||
found_end = True
|
||||
break
|
||||
|
||||
if not found_end:
|
||||
print("❌ 未找到模型定义结束标记")
|
||||
exit(1)
|
||||
|
||||
# 写回文件
|
||||
with open(models_file, "w", encoding="utf-8") as f:
|
||||
f.writelines(keep_lines)
|
||||
|
||||
print(f"✅ 文件清理完成")
|
||||
print(f"保留行数: {len(keep_lines)}")
|
||||
print(f"原始行数: {len(lines)}")
|
||||
print(f"删除行数: {len(lines) - len(keep_lines)}")
|
||||
66
scripts/extract_models.py
Normal file
66
scripts/extract_models.py
Normal file
@@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env python3
|
||||
"""提取models.py中的模型定义"""
|
||||
|
||||
import re
|
||||
|
||||
# 读取原始文件
|
||||
with open('src/common/database/sqlalchemy_models.py', 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# 找到get_string_field函数的开始和结束
|
||||
get_string_field_start = content.find('# MySQL兼容的字段类型辅助函数')
|
||||
get_string_field_end = content.find('\n\nclass ChatStreams(Base):')
|
||||
get_string_field = content[get_string_field_start:get_string_field_end]
|
||||
|
||||
# 找到第一个class定义开始
|
||||
first_class_pos = content.find('class ChatStreams(Base):')
|
||||
|
||||
# 找到所有class定义,直到遇到非class的def
|
||||
# 简单策略:找到所有以"class "开头且继承Base的类
|
||||
classes_pattern = r'class \w+\(Base\):.*?(?=\nclass \w+\(Base\):|$)'
|
||||
matches = list(re.finditer(classes_pattern, content[first_class_pos:], re.DOTALL))
|
||||
|
||||
if matches:
|
||||
# 取最后一个匹配的结束位置
|
||||
models_content = content[first_class_pos:first_class_pos + matches[-1].end()]
|
||||
else:
|
||||
# 备用方案:从第一个class到文件的85%位置
|
||||
models_end = int(len(content) * 0.85)
|
||||
models_content = content[first_class_pos:models_end]
|
||||
|
||||
# 创建新文件内容
|
||||
header = '''"""SQLAlchemy数据库模型定义
|
||||
|
||||
本文件只包含纯模型定义,使用SQLAlchemy 2.0的Mapped类型注解风格。
|
||||
引擎和会话管理已移至core/engine.py和core/session.py。
|
||||
|
||||
所有模型使用统一的类型注解风格:
|
||||
field_name: Mapped[PyType] = mapped_column(Type, ...)
|
||||
|
||||
这样IDE/Pylance能正确推断实例属性类型。
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import time
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Float, Index, Integer, String, Text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
# 创建基类
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
'''
|
||||
|
||||
new_content = header + get_string_field + '\n\n' + models_content
|
||||
|
||||
# 写入新文件
|
||||
with open('src/common/database/core/models.py', 'w', encoding='utf-8') as f:
|
||||
f.write(new_content)
|
||||
|
||||
print('✅ Models file rewritten successfully')
|
||||
print(f'File size: {len(new_content)} characters')
|
||||
pattern = r"^class \w+\(Base\):"
|
||||
model_count = len(re.findall(pattern, models_content, re.MULTILINE))
|
||||
print(f'Number of model classes: {model_count}')
|
||||
Reference in New Issue
Block a user