refactor: 清理项目结构并修复类型注解问题
修复 SQLAlchemy 模型的类型注解,使用 Mapped 类型避免类型检查器错误 - 修正异步数据库操作中缺少 await 的问题 - 优化反注入统计系统的数值字段处理逻辑 - 添加缺失的导入语句修复模块依赖问题
This commit is contained in:
220
scripts/convert_sqlalchemy_models.py
Normal file
220
scripts/convert_sqlalchemy_models.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""批量将经典 SQLAlchemy 模型字段写法
|
||||
|
||||
field = Column(Integer, nullable=False, default=0)
|
||||
|
||||
转换为 2.0 推荐的带类型注解写法:
|
||||
|
||||
field: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
脚本特点:
|
||||
1. 仅处理指定文件(默认: src/common/database/sqlalchemy_models.py)。
|
||||
2. 自动识别多行 Column(...) 定义 (括号未闭合会继续合并)。
|
||||
3. 已经是 Mapped 写法的行会跳过。
|
||||
4. 根据类型名 (Integer / Float / Boolean / Text / String / DateTime / get_string_field) 推断 Python 类型。
|
||||
5. nullable=True 时自动添加 "| None"。
|
||||
6. 保留 Column(...) 内的原始参数顺序与内容。
|
||||
7. 生成 .bak 备份文件,确保可回滚。
|
||||
8. 支持 --dry-run 查看差异,不写回文件。
|
||||
|
||||
局限/注意:
|
||||
- 简单基于正则/括号计数,不解析完整 AST;非常规写法(例如变量中构造 Column 再赋值)不会处理。
|
||||
- 复杂工厂/自定义类型未在映射表中的,统一映射为 Any。
|
||||
- 不自动添加 from __future__ import annotations;如需 Python 3.10 以下更先进类型表达式,请自行处理。
|
||||
|
||||
使用方式: (在项目根目录执行)
|
||||
|
||||
python scripts/convert_sqlalchemy_models.py \
|
||||
--file src/common/database/sqlalchemy_models.py --dry-run
|
||||
|
||||
确认无误后去掉 --dry-run 真实写入。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
TYPE_MAP = {
|
||||
"Integer": "int",
|
||||
"Float": "float",
|
||||
"Boolean": "bool",
|
||||
"Text": "str",
|
||||
"String": "str",
|
||||
"DateTime": "datetime.datetime",
|
||||
# 自定义帮助函数 get_string_field(...) 也返回字符串类型
|
||||
"get_string_field": "str",
|
||||
}
|
||||
|
||||
|
||||
COLUMN_ASSIGN_RE = re.compile(r"^(?P<indent>\s+)(?P<name>[A-Za-z_][A-Za-z0-9_]*)\s*=\s*Column\(")
|
||||
ALREADY_MAPPED_RE = re.compile(r"^[ \t]*[A-Za-z_][A-Za-z0-9_]*\s*:\s*Mapped\[")
|
||||
|
||||
|
||||
def detect_column_block(lines: list[str], start_index: int) -> tuple[int, int] | None:
|
||||
"""检测从 start_index 开始的 Column(...) 语句跨越的行范围 (包含结束行)。
|
||||
|
||||
使用括号计数法处理多行。
|
||||
返回 (start, end) 行号 (包含 end)。"""
|
||||
line = lines[start_index]
|
||||
if "Column(" not in line:
|
||||
return None
|
||||
open_parens = line.count("(") - line.count(")")
|
||||
i = start_index
|
||||
while open_parens > 0 and i + 1 < len(lines):
|
||||
i += 1
|
||||
l2 = lines[i]
|
||||
open_parens += l2.count("(") - l2.count(")")
|
||||
return (start_index, i)
|
||||
|
||||
|
||||
def extract_column_body(block_lines: list[str]) -> str:
|
||||
"""提取 Column(...) 内部参数文本 (去掉首尾 Column( 和 最后一个 ) )。"""
|
||||
joined = "\n".join(block_lines)
|
||||
# 找到第一次出现 Column(
|
||||
start_pos = joined.find("Column(")
|
||||
if start_pos == -1:
|
||||
return ""
|
||||
inner = joined[start_pos + len("Column(") :]
|
||||
# 去掉最后一个 ) —— 简单方式: 找到最后一个 ) 并截断
|
||||
last_paren = inner.rfind(")")
|
||||
if last_paren != -1:
|
||||
inner = inner[:last_paren]
|
||||
return inner.strip()
|
||||
|
||||
|
||||
def guess_python_type(column_body: str) -> str:
|
||||
# 简单取第一个类型标识符 (去掉前导装饰/空格)
|
||||
# 可能形式: Integer, Text, get_string_field(50), DateTime, Boolean
|
||||
# 利用正则抓取第一个标识符
|
||||
m = re.search(r"([A-Za-z_][A-Za-z0-9_]*)", column_body)
|
||||
if not m:
|
||||
return "Any"
|
||||
type_token = m.group(1)
|
||||
py_type = TYPE_MAP.get(type_token, "Any")
|
||||
# nullable 检测
|
||||
if "nullable=True" in column_body or "nullable = True" in column_body:
|
||||
# 避免重复 Optional
|
||||
if py_type != "Any" and not py_type.endswith(" | None"):
|
||||
py_type = f"{py_type} | None"
|
||||
elif py_type == "Any":
|
||||
py_type = "Any | None"
|
||||
return py_type
|
||||
|
||||
|
||||
def convert_block(block_lines: list[str]) -> list[str]:
|
||||
first_line = block_lines[0]
|
||||
m_name = re.match(r"^(?P<indent>\s+)(?P<name>[A-Za-z_][A-Za-z0-9_]*)\s*=", first_line)
|
||||
if not m_name:
|
||||
return block_lines
|
||||
indent = m_name.group("indent")
|
||||
name = m_name.group("name")
|
||||
body = extract_column_body(block_lines)
|
||||
py_type = guess_python_type(body)
|
||||
# 构造新的多行 mapped_column 写法
|
||||
# 保留内部参数的换行缩进: 重新缩进为 indent + 4 空格 (延续原风格: 在 indent 基础上再加 4 空格)
|
||||
inner_lines = body.split("\n")
|
||||
if len(inner_lines) == 1:
|
||||
new_line = f"{indent}{name}: Mapped[{py_type}] = mapped_column({inner_lines[0].strip()})\n"
|
||||
return [new_line]
|
||||
else:
|
||||
# 多行情况
|
||||
ind2 = indent + " "
|
||||
rebuilt = [f"{indent}{name}: Mapped[{py_type}] = mapped_column(",]
|
||||
for il in inner_lines:
|
||||
if il.strip():
|
||||
rebuilt.append(f"{ind2}{il.rstrip()}")
|
||||
rebuilt.append(f"{indent})\n")
|
||||
return [l + ("\n" if not l.endswith("\n") else "") for l in rebuilt]
|
||||
|
||||
|
||||
def ensure_imports(content: str) -> str:
|
||||
if "Mapped," in content or "Mapped[" in content:
|
||||
# 已经可能存在导入
|
||||
if "from sqlalchemy.orm import Mapped, mapped_column" not in content:
|
||||
# 简单插到第一个 import sqlalchemy 之后
|
||||
lines = content.splitlines()
|
||||
for i, line in enumerate(lines):
|
||||
if "sqlalchemy" in line and line.startswith("from sqlalchemy"):
|
||||
lines.insert(i + 1, "from sqlalchemy.orm import Mapped, mapped_column")
|
||||
return "\n".join(lines)
|
||||
return content
|
||||
|
||||
|
||||
def process_file(path: Path) -> tuple[str, str]:
|
||||
original = path.read_text(encoding="utf-8")
|
||||
lines = original.splitlines(keepends=True)
|
||||
i = 0
|
||||
out: list[str] = []
|
||||
changed = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
# 跳过已是 Mapped 风格
|
||||
if ALREADY_MAPPED_RE.match(line):
|
||||
out.append(line)
|
||||
i += 1
|
||||
continue
|
||||
if "= Column(" in line and re.match(r"^\s+[A-Za-z_][A-Za-z0-9_]*\s*=", line):
|
||||
start, end = detect_column_block(lines, i) or (i, i)
|
||||
block = lines[start : end + 1]
|
||||
converted = convert_block(block)
|
||||
out.extend(converted)
|
||||
i = end + 1
|
||||
# 如果转换结果与原始不同,计数
|
||||
if "".join(converted) != "".join(block):
|
||||
changed += 1
|
||||
else:
|
||||
out.append(line)
|
||||
i += 1
|
||||
new_content = "".join(out)
|
||||
new_content = ensure_imports(new_content)
|
||||
# 在文件末尾或头部预留统计信息打印(不写入文件,只返回)
|
||||
return original, new_content if changed else original
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="批量转换 SQLAlchemy 模型字段为 2.0 Mapped 写法")
|
||||
parser.add_argument("--file", default="src/common/database/sqlalchemy_models.py", help="目标模型文件")
|
||||
parser.add_argument("--dry-run", action="store_true", help="仅显示差异,不写回")
|
||||
parser.add_argument("--write", action="store_true", help="执行写回 (与 --dry-run 互斥)")
|
||||
args = parser.parse_args()
|
||||
|
||||
target = Path(args.file)
|
||||
if not target.exists():
|
||||
raise SystemExit(f"文件不存在: {target}")
|
||||
|
||||
original, new_content = process_file(target)
|
||||
|
||||
if original == new_content:
|
||||
print("[INFO] 没有需要转换的内容或转换后无差异。")
|
||||
return
|
||||
|
||||
# 简单差异输出 (行对比)
|
||||
if args.dry_run or not args.write:
|
||||
print("[DRY-RUN] 以下为转换后预览 (仅显示不同段落):")
|
||||
import difflib
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
original.splitlines(), new_content.splitlines(), fromfile="original", tofile="converted", lineterm=""
|
||||
)
|
||||
count = 0
|
||||
for d in diff:
|
||||
print(d)
|
||||
count += 1
|
||||
if count == 0:
|
||||
print("[INFO] 差异为空 (可能未匹配到 Column 定义)。")
|
||||
if not args.write:
|
||||
print("\n未写回。若确认无误,添加 --write 执行替换。")
|
||||
return
|
||||
|
||||
backup = target.with_suffix(target.suffix + ".bak")
|
||||
shutil.copyfile(target, backup)
|
||||
target.write_text(new_content, encoding="utf-8")
|
||||
print(f"[DONE] 已写回: {target},备份文件: {backup.name}")
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
main()
|
||||
Reference in New Issue
Block a user