refactor(database): 重构数据库初始化和字段检查逻辑

- 添加对 MySQL 数据库的支持
- 优化字段检查和添加逻辑,处理 NOT NULL 字段和默认值
- 改进错误处理和日志记录
- 调整表和字段操作的 SQL 语句以适应不同数据库类型
This commit is contained in:
cuckoo711
2025-08-07 11:22:01 +08:00
parent b6f5831785
commit 939f17890a

View File

@@ -405,75 +405,137 @@ def initialize_database():
ThinkingLog,
GraphNodes,
GraphEdges,
ActionRecords, # 添加 ActionRecords 到初始化列表
ActionRecords,
]
del_extra = False # 是否删除多余字段
try:
with db: # 管理 table_exists 检查的连接
for model in models:
table_name = model._meta.table_name
if not db.table_exists(model):
logger.warning(f"'{table_name}' 未找到,正在创建...")
db.create_tables([model])
logger.info(f"'{table_name}' 创建成功")
continue
# 保持 del_extra 为 False以避免在生产环境中意外删除数据。
# 如果需要删除多余字段,请谨慎设置为 True。
del_extra = False
# 检查字段
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
existing_columns = {row[1] for row in cursor.fetchall()}
model_fields = set(model._meta.fields.keys())
if missing_fields := model_fields - existing_columns:
logger.warning(f"'{table_name}' 缺失字段: {missing_fields}")
for field_name, field_obj in model._meta.fields.items():
if field_name not in existing_columns:
logger.info(f"'{table_name}' 缺失字段 '{field_name}',正在添加...")
field_type = field_obj.__class__.__name__
sql_type = {
# 辅助函数:根据字段对象和数据库类型获取对应的 SQL 类型字符串
def get_sql_type(field_obj, db_type):
field_type_name = field_obj.__class__.__name__
if db_type == "sqlite":
return {
"TextField": "TEXT",
"IntegerField": "INTEGER",
"FloatField": "FLOAT",
"DoubleField": "DOUBLE",
"BooleanField": "INTEGER",
"DateTimeField": "DATETIME",
}.get(field_type, "TEXT")
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}"
alter_sql += " NULL" if field_obj.null else " NOT NULL"
if hasattr(field_obj, "default") and field_obj.default is not None:
# 正确处理不同类型的默认值跳过lambda函数
}.get(field_type_name, "TEXT")
elif db_type == "mysql":
# CharField 的 max_length 将在主循环中单独处理
return {
"TextField": "LONGTEXT", # MySQL TEXT 类型长度有限LONGTEXT 更安全
"IntegerField": "INT",
"FloatField": "FLOAT",
"DoubleField": "DOUBLE",
"BooleanField": "TINYINT(1)", # MySQL 布尔值存储为 TINYINT(1)
"DateTimeField": "DATETIME",
}.get(field_type_name, "TEXT")
logger.error(f"不支持的数据库类型: {db_type}")
return "TEXT" # 默认回退类型
# 辅助函数:将 Peewee 字段的默认值转换为 SQL 语句中的 DEFAULT 子句
def get_sql_default_value(field_obj):
if field_obj.default is None:
return "" # 没有定义默认值
# 可调用默认值(如 datetime.datetime.now无法直接转换为 SQL DDL 的 DEFAULT 子句
# 因此,对于这类情况,我们不生成 DEFAULT 子句,并依赖 Peewee 在应用层处理
# 如果字段为 NOT NULL 且无法提供字面默认值,则需要在 ADD COLUMN 时临时设为 NULLABLE
if callable(field_obj.default):
return ""
default_value = field_obj.default
if callable(default_value):
# 跳过lambda函数或其他可调用对象这些无法在SQL中表示
pass
elif isinstance(default_value, str):
alter_sql += f" DEFAULT '{default_value}'"
if isinstance(default_value, str):
# 字符串默认值需要用单引号括起来,并对内部的单引号进行转义
escaped_value = str(default_value).replace("'", "''")
return f" DEFAULT '{escaped_value}'"
elif isinstance(default_value, bool):
alter_sql += f" DEFAULT {int(default_value)}"
return f" DEFAULT {int(default_value)}" # 布尔值转换为 0 或 1
elif isinstance(default_value, (int, float)):
return f" DEFAULT {default_value}"
return "" # 其他无法直接转换为 SQL 字面值的类型
try:
with db:
for model in models:
table_name = model._meta.table_name
if not db.table_exists(model):
logger.warning(f"'{table_name}' 未找到,正在创建...")
db.create_tables([model])
logger.info(f"'{table_name}' 创建成功")
# 表刚创建,无需检查字段
continue
# 获取现有列
db_type = global_config.data_base.db_type
if db_type == "sqlite":
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
existing_columns = {row[1] for row in cursor.fetchall()}
elif db_type == "mysql":
cursor = db.execute_sql(f"SHOW COLUMNS FROM {table_name}")
existing_columns = {row[0] for row in cursor.fetchall()}
else:
alter_sql += f" DEFAULT {default_value}"
logger.error(f"不支持的数据库类型 '{db_type}',跳过表 '{table_name}' 的字段检查。")
continue
model_fields = set(model._meta.fields.keys())
# 识别并添加缺失字段
missing_fields = model_fields - existing_columns
if missing_fields:
logger.warning(f"'{table_name}' 缺失字段: {missing_fields}")
for field_name, field_obj in model._meta.fields.items():
if field_name not in existing_columns:
logger.info(f"'{table_name}' 缺失字段 '{field_name}',正在尝试添加...")
sql_type = get_sql_type(field_obj, db_type)
# 特殊处理 MySQL 的 CharField需要 max_length
if isinstance(field_obj, CharField) and db_type == "mysql":
sql_type = f"VARCHAR({field_obj.max_length})"
null_clause = " NULL" if field_obj.null else " NOT NULL"
default_clause = get_sql_default_value(field_obj)
# 如果字段定义为 NOT NULL 且无法在 SQL DDL 中提供字面默认值 (如可调用默认值)
# 为了避免在有数据的表中添加列时失败,暂时将其添加为 NULLABLE。
# 这是一种务实的兼容性处理,后续可能需要手动回填数据并修改为 NOT NULL。
if not field_obj.null and not default_clause:
logger.warning(
f"'{table_name}' 的字段 '{field_name}' 为 NOT NULL 但无法生成SQL默认值"
f"将暂时添加为 NULLABLE 以避免现有数据行错误。"
)
null_clause = " NULL" # 强制设为 NULLABLE
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}{null_clause}{default_clause}"
try:
db.execute_sql(alter_sql)
logger.info(f"字段 '{field_name}' 添加成功")
except Exception as e:
logger.error(f"添加字段 '{field_name}' 失败: {e}")
logger.error(f"添加字段 '{field_name}' 失败: {e}. SQL 语句: {alter_sql}")
# 检查并删除多余字段(新增逻辑
if not del_extra:
continue
# 检查并删除多余字段(根据 del_extra 旗标决定
if del_extra:
extra_fields = existing_columns - model_fields
if extra_fields:
logger.warning(f"'{table_name}' 存在多余字段: {extra_fields}")
logger.warning(f"'{table_name}' 存在模型中未定义的字段: {extra_fields}")
for field_name in extra_fields:
try:
logger.warning(f"'{table_name}' 存在多余字段 '{field_name}',正在尝试删除...")
logger.warning(f"'{table_name}' 正在尝试删除多余字段 '{field_name}'...")
db.execute_sql(f"ALTER TABLE {table_name} DROP COLUMN {field_name}")
logger.info(f"字段 '{field_name}' 删除成功")
except Exception as e:
logger.error(f"删除字段 '{field_name}' 失败: {e}")
except Exception as e:
logger.exception(f"检查表或字段是否存在时出错: {e}")
# 如果检查失败(例如数据库不可用),则退出
logger.exception(f"数据库初始化过程中发生异常: {e}")
# 如果初始化失败(例如数据库不可用),则退出
return
logger.info("数据库初始化完成")