refactor(database): 重构数据库初始化和字段检查逻辑
- 添加对 MySQL 数据库的支持 - 优化字段检查和添加逻辑,处理 NOT NULL 字段和默认值 - 改进错误处理和日志记录 - 调整表和字段操作的 SQL 语句以适应不同数据库类型
This commit is contained in:
@@ -405,75 +405,137 @@ def initialize_database():
|
|||||||
ThinkingLog,
|
ThinkingLog,
|
||||||
GraphNodes,
|
GraphNodes,
|
||||||
GraphEdges,
|
GraphEdges,
|
||||||
ActionRecords, # 添加 ActionRecords 到初始化列表
|
ActionRecords,
|
||||||
]
|
]
|
||||||
del_extra = False # 是否删除多余字段
|
# 保持 del_extra 为 False,以避免在生产环境中意外删除数据。
|
||||||
try:
|
# 如果需要删除多余字段,请谨慎设置为 True。
|
||||||
with db: # 管理 table_exists 检查的连接
|
del_extra = False
|
||||||
for model in models:
|
|
||||||
table_name = model._meta.table_name
|
|
||||||
if not db.table_exists(model):
|
|
||||||
logger.warning(f"表 '{table_name}' 未找到,正在创建...")
|
|
||||||
db.create_tables([model])
|
|
||||||
logger.info(f"表 '{table_name}' 创建成功")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查字段
|
# 辅助函数:根据字段对象和数据库类型获取对应的 SQL 类型字符串
|
||||||
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
|
def get_sql_type(field_obj, db_type):
|
||||||
existing_columns = {row[1] for row in cursor.fetchall()}
|
field_type_name = field_obj.__class__.__name__
|
||||||
model_fields = set(model._meta.fields.keys())
|
if db_type == "sqlite":
|
||||||
|
return {
|
||||||
if missing_fields := model_fields - existing_columns:
|
|
||||||
logger.warning(f"表 '{table_name}' 缺失字段: {missing_fields}")
|
|
||||||
|
|
||||||
for field_name, field_obj in model._meta.fields.items():
|
|
||||||
if field_name not in existing_columns:
|
|
||||||
logger.info(f"表 '{table_name}' 缺失字段 '{field_name}',正在添加...")
|
|
||||||
field_type = field_obj.__class__.__name__
|
|
||||||
sql_type = {
|
|
||||||
"TextField": "TEXT",
|
"TextField": "TEXT",
|
||||||
"IntegerField": "INTEGER",
|
"IntegerField": "INTEGER",
|
||||||
"FloatField": "FLOAT",
|
"FloatField": "FLOAT",
|
||||||
"DoubleField": "DOUBLE",
|
"DoubleField": "DOUBLE",
|
||||||
"BooleanField": "INTEGER",
|
"BooleanField": "INTEGER",
|
||||||
"DateTimeField": "DATETIME",
|
"DateTimeField": "DATETIME",
|
||||||
}.get(field_type, "TEXT")
|
}.get(field_type_name, "TEXT")
|
||||||
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}"
|
elif db_type == "mysql":
|
||||||
alter_sql += " NULL" if field_obj.null else " NOT NULL"
|
# CharField 的 max_length 将在主循环中单独处理
|
||||||
if hasattr(field_obj, "default") and field_obj.default is not None:
|
return {
|
||||||
# 正确处理不同类型的默认值,跳过lambda函数
|
"TextField": "LONGTEXT", # MySQL TEXT 类型长度有限,LONGTEXT 更安全
|
||||||
|
"IntegerField": "INT",
|
||||||
|
"FloatField": "FLOAT",
|
||||||
|
"DoubleField": "DOUBLE",
|
||||||
|
"BooleanField": "TINYINT(1)", # MySQL 布尔值存储为 TINYINT(1)
|
||||||
|
"DateTimeField": "DATETIME",
|
||||||
|
}.get(field_type_name, "TEXT")
|
||||||
|
logger.error(f"不支持的数据库类型: {db_type}")
|
||||||
|
return "TEXT" # 默认回退类型
|
||||||
|
|
||||||
|
# 辅助函数:将 Peewee 字段的默认值转换为 SQL 语句中的 DEFAULT 子句
|
||||||
|
def get_sql_default_value(field_obj):
|
||||||
|
if field_obj.default is None:
|
||||||
|
return "" # 没有定义默认值
|
||||||
|
|
||||||
|
# 可调用默认值(如 datetime.datetime.now)无法直接转换为 SQL DDL 的 DEFAULT 子句
|
||||||
|
# 因此,对于这类情况,我们不生成 DEFAULT 子句,并依赖 Peewee 在应用层处理
|
||||||
|
# 如果字段为 NOT NULL 且无法提供字面默认值,则需要在 ADD COLUMN 时临时设为 NULLABLE
|
||||||
|
if callable(field_obj.default):
|
||||||
|
return ""
|
||||||
|
|
||||||
default_value = field_obj.default
|
default_value = field_obj.default
|
||||||
if callable(default_value):
|
if isinstance(default_value, str):
|
||||||
# 跳过lambda函数或其他可调用对象,这些无法在SQL中表示
|
# 字符串默认值需要用单引号括起来,并对内部的单引号进行转义
|
||||||
pass
|
escaped_value = str(default_value).replace("'", "''")
|
||||||
elif isinstance(default_value, str):
|
return f" DEFAULT '{escaped_value}'"
|
||||||
alter_sql += f" DEFAULT '{default_value}'"
|
|
||||||
elif isinstance(default_value, bool):
|
elif isinstance(default_value, bool):
|
||||||
alter_sql += f" DEFAULT {int(default_value)}"
|
return f" DEFAULT {int(default_value)}" # 布尔值转换为 0 或 1
|
||||||
|
elif isinstance(default_value, (int, float)):
|
||||||
|
return f" DEFAULT {default_value}"
|
||||||
|
|
||||||
|
return "" # 其他无法直接转换为 SQL 字面值的类型
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db:
|
||||||
|
for model in models:
|
||||||
|
table_name = model._meta.table_name
|
||||||
|
if not db.table_exists(model):
|
||||||
|
logger.warning(f"表 '{table_name}' 未找到,正在创建...")
|
||||||
|
db.create_tables([model])
|
||||||
|
logger.info(f"表 '{table_name}' 创建成功")
|
||||||
|
# 表刚创建,无需检查字段
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取现有列
|
||||||
|
db_type = global_config.data_base.db_type
|
||||||
|
if db_type == "sqlite":
|
||||||
|
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
|
||||||
|
existing_columns = {row[1] for row in cursor.fetchall()}
|
||||||
|
elif db_type == "mysql":
|
||||||
|
cursor = db.execute_sql(f"SHOW COLUMNS FROM {table_name}")
|
||||||
|
existing_columns = {row[0] for row in cursor.fetchall()}
|
||||||
else:
|
else:
|
||||||
alter_sql += f" DEFAULT {default_value}"
|
logger.error(f"不支持的数据库类型 '{db_type}',跳过表 '{table_name}' 的字段检查。")
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_fields = set(model._meta.fields.keys())
|
||||||
|
|
||||||
|
# 识别并添加缺失字段
|
||||||
|
missing_fields = model_fields - existing_columns
|
||||||
|
if missing_fields:
|
||||||
|
logger.warning(f"表 '{table_name}' 缺失字段: {missing_fields}")
|
||||||
|
|
||||||
|
for field_name, field_obj in model._meta.fields.items():
|
||||||
|
if field_name not in existing_columns:
|
||||||
|
logger.info(f"表 '{table_name}' 缺失字段 '{field_name}',正在尝试添加...")
|
||||||
|
|
||||||
|
sql_type = get_sql_type(field_obj, db_type)
|
||||||
|
|
||||||
|
# 特殊处理 MySQL 的 CharField,需要 max_length
|
||||||
|
if isinstance(field_obj, CharField) and db_type == "mysql":
|
||||||
|
sql_type = f"VARCHAR({field_obj.max_length})"
|
||||||
|
|
||||||
|
null_clause = " NULL" if field_obj.null else " NOT NULL"
|
||||||
|
default_clause = get_sql_default_value(field_obj)
|
||||||
|
|
||||||
|
# 如果字段定义为 NOT NULL 且无法在 SQL DDL 中提供字面默认值 (如可调用默认值),
|
||||||
|
# 为了避免在有数据的表中添加列时失败,暂时将其添加为 NULLABLE。
|
||||||
|
# 这是一种务实的兼容性处理,后续可能需要手动回填数据并修改为 NOT NULL。
|
||||||
|
if not field_obj.null and not default_clause:
|
||||||
|
logger.warning(
|
||||||
|
f"表 '{table_name}' 的字段 '{field_name}' 为 NOT NULL 但无法生成SQL默认值,"
|
||||||
|
f"将暂时添加为 NULLABLE 以避免现有数据行错误。"
|
||||||
|
)
|
||||||
|
null_clause = " NULL" # 强制设为 NULLABLE
|
||||||
|
|
||||||
|
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}{null_clause}{default_clause}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db.execute_sql(alter_sql)
|
db.execute_sql(alter_sql)
|
||||||
logger.info(f"字段 '{field_name}' 添加成功")
|
logger.info(f"字段 '{field_name}' 添加成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"添加字段 '{field_name}' 失败: {e}")
|
logger.error(f"添加字段 '{field_name}' 失败: {e}. SQL 语句: {alter_sql}")
|
||||||
|
|
||||||
# 检查并删除多余字段(新增逻辑)
|
# 检查并删除多余字段(根据 del_extra 旗标决定)
|
||||||
if not del_extra:
|
if del_extra:
|
||||||
continue
|
|
||||||
extra_fields = existing_columns - model_fields
|
extra_fields = existing_columns - model_fields
|
||||||
if extra_fields:
|
if extra_fields:
|
||||||
logger.warning(f"表 '{table_name}' 存在多余字段: {extra_fields}")
|
logger.warning(f"表 '{table_name}' 存在模型中未定义的字段: {extra_fields}")
|
||||||
for field_name in extra_fields:
|
for field_name in extra_fields:
|
||||||
try:
|
try:
|
||||||
logger.warning(f"表 '{table_name}' 存在多余字段 '{field_name}',正在尝试删除...")
|
logger.warning(f"表 '{table_name}' 正在尝试删除多余字段 '{field_name}'...")
|
||||||
db.execute_sql(f"ALTER TABLE {table_name} DROP COLUMN {field_name}")
|
db.execute_sql(f"ALTER TABLE {table_name} DROP COLUMN {field_name}")
|
||||||
logger.info(f"字段 '{field_name}' 删除成功")
|
logger.info(f"字段 '{field_name}' 删除成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"删除字段 '{field_name}' 失败: {e}")
|
logger.error(f"删除字段 '{field_name}' 失败: {e}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"检查表或字段是否存在时出错: {e}")
|
logger.exception(f"数据库初始化过程中发生异常: {e}")
|
||||||
# 如果检查失败(例如数据库不可用),则退出
|
# 如果初始化失败(例如数据库不可用),则退出
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("数据库初始化完成")
|
logger.info("数据库初始化完成")
|
||||||
|
|||||||
Reference in New Issue
Block a user