fix(db): 修复数据库迁移中列和索引的创建逻辑
- 增强了添加列时对默认值的处理,以兼容不同数据库方言(例如 SQLite 的布尔值)。 - 切换到更标准的 `index.create()` 方法来创建索引,提高了稳定性。 - 调整了启动顺序,确保数据库在主系统之前完成初始化,以防止竞争条件。
This commit is contained in:
4
bot.py
4
bot.py
@@ -229,10 +229,10 @@ if __name__ == "__main__":
|
|||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 执行初始化和任务调度
|
|
||||||
loop.run_until_complete(main_system.initialize())
|
|
||||||
# 异步初始化数据库表结构
|
# 异步初始化数据库表结构
|
||||||
loop.run_until_complete(maibot.initialize_database_async())
|
loop.run_until_complete(maibot.initialize_database_async())
|
||||||
|
# 执行初始化和任务调度
|
||||||
|
loop.run_until_complete(main_system.initialize())
|
||||||
initialize_lpmm_knowledge()
|
initialize_lpmm_knowledge()
|
||||||
# Schedule tasks returns a future that runs forever.
|
# Schedule tasks returns a future that runs forever.
|
||||||
# We can run console_input_loop concurrently.
|
# We can run console_input_loop concurrently.
|
||||||
|
|||||||
@@ -254,7 +254,7 @@ class ChatManager:
|
|||||||
model_instance = await _db_find_stream_async(stream_id)
|
model_instance = await _db_find_stream_async(stream_id)
|
||||||
|
|
||||||
if model_instance:
|
if model_instance:
|
||||||
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
|
# 从 SQLAlchemy 模型转换回 ChatStream.from_dict 期望的格式
|
||||||
user_info_data = {
|
user_info_data = {
|
||||||
"platform": model_instance.user_platform,
|
"platform": model_instance.user_platform,
|
||||||
"user_id": model_instance.user_id,
|
"user_id": model_instance.user_id,
|
||||||
@@ -382,7 +382,7 @@ class ChatManager:
|
|||||||
await _db_save_stream_async(stream_data_dict)
|
await _db_save_stream_async(stream_data_dict)
|
||||||
stream.saved = True
|
stream.saved = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
|
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (SQLAlchemy): {e}", exc_info=True)
|
||||||
|
|
||||||
async def _save_all_streams(self):
|
async def _save_all_streams(self):
|
||||||
"""保存所有聊天流"""
|
"""保存所有聊天流"""
|
||||||
@@ -435,7 +435,7 @@ class ChatManager:
|
|||||||
if stream.stream_id in self.last_messages:
|
if stream.stream_id in self.last_messages:
|
||||||
stream.set_context(self.last_messages[stream.stream_id])
|
stream.set_context(self.last_messages[stream.stream_id])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
|
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
chat_manager = None
|
chat_manager = None
|
||||||
|
|||||||
@@ -70,24 +70,32 @@ async def check_and_migrate_database():
|
|||||||
|
|
||||||
def add_columns_sync(conn):
|
def add_columns_sync(conn):
|
||||||
dialect = conn.dialect
|
dialect = conn.dialect
|
||||||
for column_name in missing_columns:
|
|
||||||
column = table.c[column_name]
|
|
||||||
|
|
||||||
# 使用DDLCompiler为特定方言编译列
|
|
||||||
compiler = dialect.ddl_compiler(dialect, None)
|
compiler = dialect.ddl_compiler(dialect, None)
|
||||||
|
|
||||||
# 编译列的数据类型
|
for column_name in missing_columns:
|
||||||
|
column = table.c[column_name]
|
||||||
column_type = compiler.get_column_specification(column)
|
column_type = compiler.get_column_specification(column)
|
||||||
|
|
||||||
# 构建原生SQL
|
|
||||||
sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type}"
|
sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type}"
|
||||||
|
|
||||||
# 添加默认值(如果存在)
|
|
||||||
if column.default:
|
if column.default:
|
||||||
default_value = compiler.render_literal_value(column.default.arg, column.type)
|
# 手动处理不同方言的默认值
|
||||||
|
default_arg = column.default.arg
|
||||||
|
if dialect.name == "sqlite" and isinstance(default_arg, bool):
|
||||||
|
# SQLite 将布尔值存储为 0 或 1
|
||||||
|
default_value = "1" if default_arg else "0"
|
||||||
|
elif hasattr(compiler, 'render_literal_value'):
|
||||||
|
try:
|
||||||
|
# 尝试使用 render_literal_value
|
||||||
|
default_value = compiler.render_literal_value(default_arg, column.type)
|
||||||
|
except AttributeError:
|
||||||
|
# 如果失败,则回退到简单的字符串转换
|
||||||
|
default_value = f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg)
|
||||||
|
else:
|
||||||
|
# 对于没有 render_literal_value 的旧版或特定方言
|
||||||
|
default_value = f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg)
|
||||||
|
|
||||||
sql += f" DEFAULT {default_value}"
|
sql += f" DEFAULT {default_value}"
|
||||||
|
|
||||||
# 添加非空约束(如果存在)
|
|
||||||
if not column.nullable:
|
if not column.nullable:
|
||||||
sql += " NOT NULL"
|
sql += " NOT NULL"
|
||||||
|
|
||||||
@@ -109,11 +117,10 @@ async def check_and_migrate_database():
|
|||||||
logger.info(f"在表 '{table_name}' 中发现缺失的索引: {', '.join(missing_indexes)}")
|
logger.info(f"在表 '{table_name}' 中发现缺失的索引: {', '.join(missing_indexes)}")
|
||||||
|
|
||||||
def add_indexes_sync(conn):
|
def add_indexes_sync(conn):
|
||||||
with conn.begin():
|
|
||||||
for index_name in missing_indexes:
|
for index_name in missing_indexes:
|
||||||
index_obj = next((idx for idx in table.indexes if idx.name == index_name), None)
|
index_obj = next((idx for idx in table.indexes if idx.name == index_name), None)
|
||||||
if index_obj is not None:
|
if index_obj is not None:
|
||||||
conn.execute(CreateIndex(index_obj))
|
index_obj.create(conn)
|
||||||
logger.info(f"成功为表 '{table_name}' 创建索引 '{index_name}'。")
|
logger.info(f"成功为表 '{table_name}' 创建索引 '{index_name}'。")
|
||||||
|
|
||||||
await connection.run_sync(add_indexes_sync)
|
await connection.run_sync(add_indexes_sync)
|
||||||
|
|||||||
Reference in New Issue
Block a user