fix(db): 修复数据库迁移中列和索引的创建逻辑

- 增强了添加列时对默认值的处理,以兼容不同数据库方言(例如 SQLite 的布尔值)。
- 切换到更标准的 `index.create()` 方法来创建索引,提高了稳定性。
- 调整了启动顺序,确保数据库在主系统之前完成初始化,以防止竞争条件。
This commit is contained in:
minecraft1024a
2025-09-24 13:46:44 +08:00
parent ae738ef8cb
commit 8ff4687670
3 changed files with 30 additions and 23 deletions

4
bot.py
View File

@@ -229,10 +229,10 @@ if __name__ == "__main__":
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
try: try:
# 执行初始化和任务调度
loop.run_until_complete(main_system.initialize())
# 异步初始化数据库表结构 # 异步初始化数据库表结构
loop.run_until_complete(maibot.initialize_database_async()) loop.run_until_complete(maibot.initialize_database_async())
# 执行初始化和任务调度
loop.run_until_complete(main_system.initialize())
initialize_lpmm_knowledge() initialize_lpmm_knowledge()
# Schedule tasks returns a future that runs forever. # Schedule tasks returns a future that runs forever.
# We can run console_input_loop concurrently. # We can run console_input_loop concurrently.

View File

@@ -254,7 +254,7 @@ class ChatManager:
model_instance = await _db_find_stream_async(stream_id) model_instance = await _db_find_stream_async(stream_id)
if model_instance: if model_instance:
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式 # 从 SQLAlchemy 模型转换回 ChatStream.from_dict 期望的格式
user_info_data = { user_info_data = {
"platform": model_instance.user_platform, "platform": model_instance.user_platform,
"user_id": model_instance.user_id, "user_id": model_instance.user_id,
@@ -382,7 +382,7 @@ class ChatManager:
await _db_save_stream_async(stream_data_dict) await _db_save_stream_async(stream_data_dict)
stream.saved = True stream.saved = True
except Exception as e: except Exception as e:
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True) logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (SQLAlchemy): {e}", exc_info=True)
async def _save_all_streams(self): async def _save_all_streams(self):
"""保存所有聊天流""" """保存所有聊天流"""
@@ -435,7 +435,7 @@ class ChatManager:
if stream.stream_id in self.last_messages: if stream.stream_id in self.last_messages:
stream.set_context(self.last_messages[stream.stream_id]) stream.set_context(self.last_messages[stream.stream_id])
except Exception as e: except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True) logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True)
chat_manager = None chat_manager = None

View File

@@ -70,24 +70,32 @@ async def check_and_migrate_database():
def add_columns_sync(conn): def add_columns_sync(conn):
dialect = conn.dialect dialect = conn.dialect
for column_name in missing_columns:
column = table.c[column_name]
# 使用DDLCompiler为特定方言编译列
compiler = dialect.ddl_compiler(dialect, None) compiler = dialect.ddl_compiler(dialect, None)
# 编译列的数据类型 for column_name in missing_columns:
column = table.c[column_name]
column_type = compiler.get_column_specification(column) column_type = compiler.get_column_specification(column)
# 构建原生SQL
sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type}" sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type}"
# 添加默认值(如果存在)
if column.default: if column.default:
default_value = compiler.render_literal_value(column.default.arg, column.type) # 手动处理不同方言的默认值
default_arg = column.default.arg
if dialect.name == "sqlite" and isinstance(default_arg, bool):
# SQLite 将布尔值存储为 0 或 1
default_value = "1" if default_arg else "0"
elif hasattr(compiler, 'render_literal_value'):
try:
# 尝试使用 render_literal_value
default_value = compiler.render_literal_value(default_arg, column.type)
except AttributeError:
# 如果失败,则回退到简单的字符串转换
default_value = f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg)
else:
# 对于没有 render_literal_value 的旧版或特定方言
default_value = f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg)
sql += f" DEFAULT {default_value}" sql += f" DEFAULT {default_value}"
# 添加非空约束(如果存在)
if not column.nullable: if not column.nullable:
sql += " NOT NULL" sql += " NOT NULL"
@@ -109,11 +117,10 @@ async def check_and_migrate_database():
logger.info(f"在表 '{table_name}' 中发现缺失的索引: {', '.join(missing_indexes)}") logger.info(f"在表 '{table_name}' 中发现缺失的索引: {', '.join(missing_indexes)}")
def add_indexes_sync(conn): def add_indexes_sync(conn):
with conn.begin():
for index_name in missing_indexes: for index_name in missing_indexes:
index_obj = next((idx for idx in table.indexes if idx.name == index_name), None) index_obj = next((idx for idx in table.indexes if idx.name == index_name), None)
if index_obj is not None: if index_obj is not None:
conn.execute(CreateIndex(index_obj)) index_obj.create(conn)
logger.info(f"成功为表 '{table_name}' 创建索引 '{index_name}'") logger.info(f"成功为表 '{table_name}' 创建索引 '{index_name}'")
await connection.run_sync(add_indexes_sync) await connection.run_sync(add_indexes_sync)