diff --git a/bot.py b/bot.py index f382df1e1..472ee5f08 100644 --- a/bot.py +++ b/bot.py @@ -229,10 +229,10 @@ if __name__ == "__main__": asyncio.set_event_loop(loop) try: - # 执行初始化和任务调度 - loop.run_until_complete(main_system.initialize()) # 异步初始化数据库表结构 loop.run_until_complete(maibot.initialize_database_async()) + # 执行初始化和任务调度 + loop.run_until_complete(main_system.initialize()) initialize_lpmm_knowledge() # Schedule tasks returns a future that runs forever. # We can run console_input_loop concurrently. diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index c42654aa3..de2fb62e9 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -254,7 +254,7 @@ class ChatManager: model_instance = await _db_find_stream_async(stream_id) if model_instance: - # 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式 + # 从 SQLAlchemy 模型转换回 ChatStream.from_dict 期望的格式 user_info_data = { "platform": model_instance.user_platform, "user_id": model_instance.user_id, @@ -382,7 +382,7 @@ class ChatManager: await _db_save_stream_async(stream_data_dict) stream.saved = True except Exception as e: - logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True) + logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (SQLAlchemy): {e}", exc_info=True) async def _save_all_streams(self): """保存所有聊天流""" @@ -435,7 +435,7 @@ class ChatManager: if stream.stream_id in self.last_messages: stream.set_context(self.last_messages[stream.stream_id]) except Exception as e: - logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True) + logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True) chat_manager = None diff --git a/src/common/database/db_migration.py b/src/common/database/db_migration.py index aedff3676..8f7b1ecd3 100644 --- a/src/common/database/db_migration.py +++ b/src/common/database/db_migration.py @@ -70,24 +70,32 @@ async def check_and_migrate_database(): def add_columns_sync(conn): dialect = conn.dialect + compiler = dialect.ddl_compiler(dialect, None) + for column_name in missing_columns: column = table.c[column_name] - - # 使用DDLCompiler为特定方言编译列 - compiler = dialect.ddl_compiler(dialect, None) - - # 编译列的数据类型 column_type = compiler.get_column_specification(column) - - # 构建原生SQL sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type}" - - # 添加默认值(如果存在) + if column.default: - default_value = compiler.render_literal_value(column.default.arg, column.type) + # 手动处理不同方言的默认值 + default_arg = column.default.arg + if dialect.name == "sqlite" and isinstance(default_arg, bool): + # SQLite 将布尔值存储为 0 或 1 + default_value = "1" if default_arg else "0" + elif hasattr(compiler, 'render_literal_value'): + try: + # 尝试使用 render_literal_value + default_value = compiler.render_literal_value(default_arg, column.type) + except AttributeError: + # 如果失败,则回退到简单的字符串转换 + default_value = f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg) + else: + # 对于没有 render_literal_value 的旧版或特定方言 + default_value = f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg) + sql += f" DEFAULT {default_value}" - - # 添加非空约束(如果存在) + if not column.nullable: sql += " NOT NULL" @@ -109,12 +117,11 @@ async def check_and_migrate_database(): logger.info(f"在表 '{table_name}' 中发现缺失的索引: {', '.join(missing_indexes)}") def add_indexes_sync(conn): - with conn.begin(): - for index_name in missing_indexes: - index_obj = next((idx for idx in table.indexes if idx.name == index_name), None) - if index_obj is not None: - conn.execute(CreateIndex(index_obj)) - logger.info(f"成功为表 '{table_name}' 创建索引 '{index_name}'。") + for index_name in missing_indexes: + index_obj = next((idx for idx in table.indexes if idx.name == index_name), None) + if index_obj is not None: + index_obj.create(conn) + logger.info(f"成功为表 '{table_name}' 创建索引 '{index_name}'。") await connection.run_sync(add_indexes_sync) else: