#!/usr/bin/env python3 """提取models.py中的模型定义""" import re # 读取原始文件 with open("src/common/database/sqlalchemy_models.py", encoding="utf-8") as f: content = f.read() # 找到get_string_field函数的开始和结束 get_string_field_start = content.find("# MySQL兼容的字段类型辅助函数") get_string_field_end = content.find("\n\nclass ChatStreams(Base):") get_string_field = content[get_string_field_start:get_string_field_end] # 找到第一个class定义开始 first_class_pos = content.find("class ChatStreams(Base):") # 找到所有class定义,直到遇到非class的def # 简单策略:找到所有以"class "开头且继承Base的类 classes_pattern = r"class \w+\(Base\):.*?(?=\nclass \w+\(Base\):|$)" matches = list(re.finditer(classes_pattern, content[first_class_pos:], re.DOTALL)) if matches: # 取最后一个匹配的结束位置 models_content = content[first_class_pos:first_class_pos + matches[-1].end()] else: # 备用方案:从第一个class到文件的85%位置 models_end = int(len(content) * 0.85) models_content = content[first_class_pos:models_end] # 创建新文件内容 header = '''"""SQLAlchemy数据库模型定义 本文件只包含纯模型定义,使用SQLAlchemy 2.0的Mapped类型注解风格。 引擎和会话管理已移至core/engine.py和core/session.py。 所有模型使用统一的类型注解风格: field_name: Mapped[PyType] = mapped_column(Type, ...) 这样IDE/Pylance能正确推断实例属性类型。 """ import datetime import time from sqlalchemy import Boolean, DateTime, Float, Index, Integer, String, Text from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Mapped, mapped_column # 创建基类 Base = declarative_base() ''' new_content = header + get_string_field + "\n\n" + models_content # 写入新文件 with open("src/common/database/core/models.py", "w", encoding="utf-8") as f: f.write(new_content) print("✅ Models file rewritten successfully") print(f"File size: {len(new_content)} characters") pattern = r"^class \w+\(Base\):" model_count = len(re.findall(pattern, models_content, re.MULTILINE)) print(f"Number of model classes: {model_count}")