Files
Mofox-Core/scripts/extract_models.py
明天好像没什么 ff6dc542e1 rufffffff
2025-11-19 23:31:37 +08:00

67 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""提取models.py中的模型定义"""
import re
# 读取原始文件
with open("src/common/database/sqlalchemy_models.py", encoding="utf-8") as f:
content = f.read()
# 找到get_string_field函数的开始和结束
get_string_field_start = content.find("# MySQL兼容的字段类型辅助函数")
get_string_field_end = content.find("\n\nclass ChatStreams(Base):")
get_string_field = content[get_string_field_start:get_string_field_end]
# 找到第一个class定义开始
first_class_pos = content.find("class ChatStreams(Base):")
# 找到所有class定义直到遇到非class的def
# 简单策略:找到所有以"class "开头且继承Base的类
classes_pattern = r"class \w+\(Base\):.*?(?=\nclass \w+\(Base\):|$)"
matches = list(re.finditer(classes_pattern, content[first_class_pos:], re.DOTALL))
if matches:
# 取最后一个匹配的结束位置
models_content = content[first_class_pos:first_class_pos + matches[-1].end()]
else:
# 备用方案从第一个class到文件的85%位置
models_end = int(len(content) * 0.85)
models_content = content[first_class_pos:models_end]
# 创建新文件内容
header = '''"""SQLAlchemy数据库模型定义
本文件只包含纯模型定义使用SQLAlchemy 2.0的Mapped类型注解风格。
引擎和会话管理已移至core/engine.py和core/session.py。
所有模型使用统一的类型注解风格:
field_name: Mapped[PyType] = mapped_column(Type, ...)
这样IDE/Pylance能正确推断实例属性类型。
"""
import datetime
import time
from sqlalchemy import Boolean, DateTime, Float, Index, Integer, String, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Mapped, mapped_column
# 创建基类
Base = declarative_base()
'''
new_content = header + get_string_field + "\n\n" + models_content
# 写入新文件
with open("src/common/database/core/models.py", "w", encoding="utf-8") as f:
f.write(new_content)
print("✅ Models file rewritten successfully")
print(f"File size: {len(new_content)} characters")
pattern = r"^class \w+\(Base\):"
model_count = len(re.findall(pattern, models_content, re.MULTILINE))
print(f"Number of model classes: {model_count}")