From 939f17890a850b02740832aef1c68982ab51ca4c Mon Sep 17 00:00:00 2001 From: cuckoo711 <3038604221@qq.com> Date: Thu, 7 Aug 2025 11:22:01 +0800 Subject: [PATCH] =?UTF-8?q?refactor(database):=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E5=88=9D=E5=A7=8B=E5=8C=96=E5=92=8C?= =?UTF-8?q?=E5=AD=97=E6=AE=B5=E6=A3=80=E6=9F=A5=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加对 MySQL 数据库的支持 - 优化字段检查和添加逻辑,处理 NOT NULL 字段和默认值 - 改进错误处理和日志记录 - 调整表和字段操作的 SQL 语句以适应不同数据库类型 --- src/common/database/database_model.py | 156 ++++++++++++++++++-------- 1 file changed, 109 insertions(+), 47 deletions(-) diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 4d4675435..609d303b0 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -405,75 +405,137 @@ def initialize_database(): ThinkingLog, GraphNodes, GraphEdges, - ActionRecords, # 添加 ActionRecords 到初始化列表 + ActionRecords, ] - del_extra = False # 是否删除多余字段 + # 保持 del_extra 为 False,以避免在生产环境中意外删除数据。 + # 如果需要删除多余字段,请谨慎设置为 True。 + del_extra = False + + # 辅助函数:根据字段对象和数据库类型获取对应的 SQL 类型字符串 + def get_sql_type(field_obj, db_type): + field_type_name = field_obj.__class__.__name__ + if db_type == "sqlite": + return { + "TextField": "TEXT", + "IntegerField": "INTEGER", + "FloatField": "FLOAT", + "DoubleField": "DOUBLE", + "BooleanField": "INTEGER", + "DateTimeField": "DATETIME", + }.get(field_type_name, "TEXT") + elif db_type == "mysql": + # CharField 的 max_length 将在主循环中单独处理 + return { + "TextField": "LONGTEXT", # MySQL TEXT 类型长度有限,LONGTEXT 更安全 + "IntegerField": "INT", + "FloatField": "FLOAT", + "DoubleField": "DOUBLE", + "BooleanField": "TINYINT(1)", # MySQL 布尔值存储为 TINYINT(1) + "DateTimeField": "DATETIME", + }.get(field_type_name, "TEXT") + logger.error(f"不支持的数据库类型: {db_type}") + return "TEXT" # 默认回退类型 + + # 辅助函数:将 Peewee 字段的默认值转换为 SQL 语句中的 DEFAULT 子句 + def get_sql_default_value(field_obj): + if field_obj.default is None: + return "" # 没有定义默认值 + + # 可调用默认值(如 datetime.datetime.now)无法直接转换为 SQL DDL 的 DEFAULT 子句 + # 因此,对于这类情况,我们不生成 DEFAULT 子句,并依赖 Peewee 在应用层处理 + # 如果字段为 NOT NULL 且无法提供字面默认值,则需要在 ADD COLUMN 时临时设为 NULLABLE + if callable(field_obj.default): + return "" + + default_value = field_obj.default + if isinstance(default_value, str): + # 字符串默认值需要用单引号括起来,并对内部的单引号进行转义 + escaped_value = str(default_value).replace("'", "''") + return f" DEFAULT '{escaped_value}'" + elif isinstance(default_value, bool): + return f" DEFAULT {int(default_value)}" # 布尔值转换为 0 或 1 + elif isinstance(default_value, (int, float)): + return f" DEFAULT {default_value}" + + return "" # 其他无法直接转换为 SQL 字面值的类型 + try: - with db: # 管理 table_exists 检查的连接 + with db: for model in models: table_name = model._meta.table_name if not db.table_exists(model): logger.warning(f"表 '{table_name}' 未找到,正在创建...") db.create_tables([model]) logger.info(f"表 '{table_name}' 创建成功") + # 表刚创建,无需检查字段 + continue + + # 获取现有列 + db_type = global_config.data_base.db_type + if db_type == "sqlite": + cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')") + existing_columns = {row[1] for row in cursor.fetchall()} + elif db_type == "mysql": + cursor = db.execute_sql(f"SHOW COLUMNS FROM {table_name}") + existing_columns = {row[0] for row in cursor.fetchall()} + else: + logger.error(f"不支持的数据库类型 '{db_type}',跳过表 '{table_name}' 的字段检查。") continue - # 检查字段 - cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')") - existing_columns = {row[1] for row in cursor.fetchall()} model_fields = set(model._meta.fields.keys()) - if missing_fields := model_fields - existing_columns: + # 识别并添加缺失字段 + missing_fields = model_fields - existing_columns + if missing_fields: logger.warning(f"表 '{table_name}' 缺失字段: {missing_fields}") for field_name, field_obj in model._meta.fields.items(): if field_name not in existing_columns: - logger.info(f"表 '{table_name}' 缺失字段 '{field_name}',正在添加...") - field_type = field_obj.__class__.__name__ - sql_type = { - "TextField": "TEXT", - "IntegerField": "INTEGER", - "FloatField": "FLOAT", - "DoubleField": "DOUBLE", - "BooleanField": "INTEGER", - "DateTimeField": "DATETIME", - }.get(field_type, "TEXT") - alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}" - alter_sql += " NULL" if field_obj.null else " NOT NULL" - if hasattr(field_obj, "default") and field_obj.default is not None: - # 正确处理不同类型的默认值,跳过lambda函数 - default_value = field_obj.default - if callable(default_value): - # 跳过lambda函数或其他可调用对象,这些无法在SQL中表示 - pass - elif isinstance(default_value, str): - alter_sql += f" DEFAULT '{default_value}'" - elif isinstance(default_value, bool): - alter_sql += f" DEFAULT {int(default_value)}" - else: - alter_sql += f" DEFAULT {default_value}" + logger.info(f"表 '{table_name}' 缺失字段 '{field_name}',正在尝试添加...") + + sql_type = get_sql_type(field_obj, db_type) + + # 特殊处理 MySQL 的 CharField,需要 max_length + if isinstance(field_obj, CharField) and db_type == "mysql": + sql_type = f"VARCHAR({field_obj.max_length})" + + null_clause = " NULL" if field_obj.null else " NOT NULL" + default_clause = get_sql_default_value(field_obj) + + # 如果字段定义为 NOT NULL 且无法在 SQL DDL 中提供字面默认值 (如可调用默认值), + # 为了避免在有数据的表中添加列时失败,暂时将其添加为 NULLABLE。 + # 这是一种务实的兼容性处理,后续可能需要手动回填数据并修改为 NOT NULL。 + if not field_obj.null and not default_clause: + logger.warning( + f"表 '{table_name}' 的字段 '{field_name}' 为 NOT NULL 但无法生成SQL默认值," + f"将暂时添加为 NULLABLE 以避免现有数据行错误。" + ) + null_clause = " NULL" # 强制设为 NULLABLE + + alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}{null_clause}{default_clause}" + try: db.execute_sql(alter_sql) logger.info(f"字段 '{field_name}' 添加成功") except Exception as e: - logger.error(f"添加字段 '{field_name}' 失败: {e}") + logger.error(f"添加字段 '{field_name}' 失败: {e}. SQL 语句: {alter_sql}") + + # 检查并删除多余字段(根据 del_extra 旗标决定) + if del_extra: + extra_fields = existing_columns - model_fields + if extra_fields: + logger.warning(f"表 '{table_name}' 存在模型中未定义的字段: {extra_fields}") + for field_name in extra_fields: + try: + logger.warning(f"表 '{table_name}' 正在尝试删除多余字段 '{field_name}'...") + db.execute_sql(f"ALTER TABLE {table_name} DROP COLUMN {field_name}") + logger.info(f"字段 '{field_name}' 删除成功") + except Exception as e: + logger.error(f"删除字段 '{field_name}' 失败: {e}") - # 检查并删除多余字段(新增逻辑) - if not del_extra: - continue - extra_fields = existing_columns - model_fields - if extra_fields: - logger.warning(f"表 '{table_name}' 存在多余字段: {extra_fields}") - for field_name in extra_fields: - try: - logger.warning(f"表 '{table_name}' 存在多余字段 '{field_name}',正在尝试删除...") - db.execute_sql(f"ALTER TABLE {table_name} DROP COLUMN {field_name}") - logger.info(f"字段 '{field_name}' 删除成功") - except Exception as e: - logger.error(f"删除字段 '{field_name}' 失败: {e}") except Exception as e: - logger.exception(f"检查表或字段是否存在时出错: {e}") - # 如果检查失败(例如数据库不可用),则退出 + logger.exception(f"数据库初始化过程中发生异常: {e}") + # 如果初始化失败(例如数据库不可用),则退出 return logger.info("数据库初始化完成")