From 4dbc651d74f2ae43fecd4186a40661f78277b70f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Fri, 19 Sep 2025 20:20:20 +0800 Subject: [PATCH] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E5=BC=82=E6=AD=A5?= =?UTF-8?q?=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 仅仅支持还有107处待迁移 --- requirements.txt | 2 + rust_image/Cargo.toml | 0 src/common/database/db_migration.py | 116 ++++++++++-------- .../database/sqlalchemy_database_api.py | 101 ++++++++------- src/common/database/sqlalchemy_init.py | 37 +++--- src/common/database/sqlalchemy_models.py | 49 ++++---- 6 files changed, 169 insertions(+), 136 deletions(-) create mode 100644 rust_image/Cargo.toml diff --git a/requirements.txt b/requirements.txt index edc7b9cb8..757c1e09a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ sqlalchemy +aiosqlite +aiomysql APScheduler aiohttp aiohttp-cors diff --git a/rust_image/Cargo.toml b/rust_image/Cargo.toml new file mode 100644 index 000000000..e69de29bb diff --git a/src/common/database/db_migration.py b/src/common/database/db_migration.py index cae3cbd29..9d2be9e5b 100644 --- a/src/common/database/db_migration.py +++ b/src/common/database/db_migration.py @@ -7,71 +7,81 @@ from src.common.logger import get_logger logger = get_logger("db_migration") -def check_and_migrate_database(): +async def check_and_migrate_database(): """ - 检查数据库结构并自动迁移(添加缺失的表和列)。 + 异步检查数据库结构并自动迁移(添加缺失的表和列)。 """ logger.info("正在检查数据库结构并执行自动迁移...") - engine = get_engine() - inspector = inspect(engine) + engine = await get_engine() + + # 使用异步引擎获取inspector + async with engine.connect() as connection: + # 在同步上下文中运行inspector操作 + inspector = await connection.run_sync(lambda sync_conn: inspect(sync_conn)) + + # 1. 获取数据库中所有已存在的表名 + db_table_names = await connection.run_sync(lambda sync_conn: set(inspect(sync_conn).get_table_names())) - # 1. 获取数据库中所有已存在的表名 - db_table_names = set(inspector.get_table_names()) + # 2. 遍历所有在代码中定义的模型 + for table_name, table in Base.metadata.tables.items(): + logger.debug(f"正在检查表: {table_name}") - # 2. 遍历所有在代码中定义的模型 - for table_name, table in Base.metadata.tables.items(): - logger.debug(f"正在检查表: {table_name}") + # 3. 如果表不存在,则创建它 + if table_name not in db_table_names: + logger.info(f"表 '{table_name}' 不存在,正在创建...") + try: + await connection.run_sync(lambda sync_conn: table.create(sync_conn)) + logger.info(f"表 '{table_name}' 创建成功。") + except Exception as e: + logger.error(f"创建表 '{table_name}' 失败: {e}") + continue - # 3. 如果表不存在,则创建它 - if table_name not in db_table_names: - logger.info(f"表 '{table_name}' 不存在,正在创建...") - try: - table.create(engine) - logger.info(f"表 '{table_name}' 创建成功。") - except Exception as e: - logger.error(f"创建表 '{table_name}' 失败: {e}") - continue + # 4. 如果表已存在,则检查并添加缺失的列 + db_columns = await connection.run_sync( + lambda sync_conn: {col["name"] for col in inspect(sync_conn).get_columns(table_name)} + ) + model_columns = {col.name for col in table.c} - # 4. 如果表已存在,则检查并添加缺失的列 - db_columns = {col["name"] for col in inspector.get_columns(table_name)} - model_columns = {col.name for col in table.c} + missing_columns = model_columns - db_columns + if not missing_columns: + logger.debug(f"表 '{table_name}' 结构一致,无需修改。") + continue - missing_columns = model_columns - db_columns - if not missing_columns: - logger.debug(f"表 '{table_name}' 结构一致,无需修改。") - continue + logger.info(f"在表 '{table_name}' 中发现缺失的列: {', '.join(missing_columns)}") + + # 开始事务来添加缺失的列 + async with connection.begin() as trans: + try: + for column_name in missing_columns: + column = table.c[column_name] - logger.info(f"在表 '{table_name}' 中发现缺失的列: {', '.join(missing_columns)}") - with engine.connect() as connection: - trans = connection.begin() - try: - for column_name in missing_columns: - column = table.c[column_name] + # 构造并执行 ALTER TABLE 语句 + try: + # 在同步上下文中编译列类型 + column_type = await connection.run_sync( + lambda sync_conn: column.type.compile(sync_conn.dialect) + ) + sql = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}" - # 构造并执行 ALTER TABLE 语句 - try: - column_type = column.type.compile(engine.dialect) - sql = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}" + # 添加默认值和非空约束的处理 + if column.default is not None: + default_value = column.default.arg + if isinstance(default_value, str): + sql += f" DEFAULT '{default_value}'" + else: + sql += f" DEFAULT {default_value}" - # 添加默认值和非空约束的处理 - if column.default is not None: - default_value = column.default.arg - if isinstance(default_value, str): - sql += f" DEFAULT '{default_value}'" - else: - sql += f" DEFAULT {default_value}" + if not column.nullable: + sql += " NOT NULL" - if not column.nullable: - sql += " NOT NULL" + await connection.execute(text(sql)) + logger.info(f"成功向表 '{table_name}' 添加列 '{column_name}'。") + except Exception as e: + logger.error(f"向表 '{table_name}' 添加列 '{column_name}' 失败: {e}") - connection.execute(text(sql)) - logger.info(f"成功向表 '{table_name}' 添加列 '{column_name}'。") - except Exception as e: - logger.error(f"向表 '{table_name}' 添加列 '{column_name}' 失败: {e}") - - trans.commit() - except Exception as e: - logger.error(f"在表 '{table_name}' 添加列时发生错误,事务已回滚: {e}") - trans.rollback() + except Exception as e: + logger.error(f"在表 '{table_name}' 添加列时发生错误,事务已回滚: {e}") + await trans.rollback() + raise logger.info("数据库结构检查与自动迁移完成。") diff --git a/src/common/database/sqlalchemy_database_api.py b/src/common/database/sqlalchemy_database_api.py index 1643f5838..4f0258c2b 100644 --- a/src/common/database/sqlalchemy_database_api.py +++ b/src/common/database/sqlalchemy_database_api.py @@ -6,9 +6,11 @@ import traceback import time +import asyncio from typing import Dict, List, Any, Union, Type, Optional from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy import desc, asc, func, and_ +from sqlalchemy import desc, asc, func, and_, select +from sqlalchemy.ext.asyncio import AsyncSession from src.common.logger import get_logger from src.common.database.sqlalchemy_models import ( Base, @@ -56,7 +58,7 @@ MODEL_MAPPING = { } -def build_filters(session, model_class: Type[Base], filters: Dict[str, Any]): +async def build_filters(model_class, filters: Dict[str, Any]): """构建查询过滤条件""" conditions = [] @@ -94,7 +96,7 @@ def build_filters(session, model_class: Type[Base], filters: Dict[str, Any]): async def db_query( - model_class: Type[Base], + model_class, data: Optional[Dict[str, Any]] = None, query_type: Optional[str] = "get", filters: Optional[Dict[str, Any]] = None, @@ -102,7 +104,7 @@ async def db_query( order_by: Optional[List[str]] = None, single_result: Optional[bool] = False, ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: - """执行数据库查询操作 + """执行异步数据库查询操作 Args: model_class: SQLAlchemy模型类 @@ -120,15 +122,15 @@ async def db_query( if query_type not in ["get", "create", "update", "delete", "count"]: raise ValueError("query_type must be 'get', 'create', 'update', 'delete' or 'count'") - with get_db_session() as session: + async with get_db_session() as session: if query_type == "get": - query = session.query(model_class) + query = select(model_class) # 应用过滤条件 if filters: - conditions = build_filters(session, model_class, filters) + conditions = await build_filters(model_class, filters) if conditions: - query = query.filter(and_(*conditions)) + query = query.where(and_(*conditions)) # 应用排序 if order_by: @@ -146,14 +148,15 @@ async def db_query( query = query.limit(limit) # 执行查询 - results = query.all() + result = await session.execute(query) + results = result.scalars().all() # 转换为字典格式 result_dicts = [] - for result in results: + for result_obj in results: result_dict = {} - for column in result.__table__.columns: - result_dict[column.name] = getattr(result, column.name) + for column in result_obj.__table__.columns: + result_dict[column.name] = getattr(result_obj, column.name) result_dicts.append(result_dict) if single_result: @@ -167,7 +170,7 @@ async def db_query( # 创建新记录 new_record = model_class(**data) session.add(new_record) - session.flush() # 获取自动生成的ID + await session.flush() # 获取自动生成的ID # 转换为字典格式返回 result_dict = {} @@ -179,43 +182,60 @@ async def db_query( if not data: raise ValueError("更新记录需要提供data参数") - query = session.query(model_class) + query = select(model_class) # 应用过滤条件 if filters: - conditions = build_filters(session, model_class, filters) + conditions = await build_filters(model_class, filters) if conditions: - query = query.filter(and_(*conditions)) + query = query.where(and_(*conditions)) - # 执行更新 - affected_rows = query.update(data) + # 首先获取要更新的记录 + result = await session.execute(query) + records_to_update = result.scalars().all() + + # 更新每个记录 + affected_rows = 0 + for record in records_to_update: + for field, value in data.items(): + if hasattr(record, field): + setattr(record, field, value) + affected_rows += 1 + return affected_rows elif query_type == "delete": - query = session.query(model_class) + query = select(model_class) # 应用过滤条件 if filters: - conditions = build_filters(session, model_class, filters) + conditions = await build_filters(model_class, filters) if conditions: - query = query.filter(and_(*conditions)) + query = query.where(and_(*conditions)) - # 执行删除 - affected_rows = query.delete() + # 首先获取要删除的记录 + result = await session.execute(query) + records_to_delete = result.scalars().all() + + # 删除记录 + affected_rows = 0 + for record in records_to_delete: + session.delete(record) + affected_rows += 1 + return affected_rows elif query_type == "count": - query = session.query(func.count(model_class.id)) + query = select(func.count(model_class.id)) # 应用过滤条件 if filters: - base_query = session.query(model_class) - conditions = build_filters(session, model_class, filters) + conditions = await build_filters(model_class, filters) if conditions: - base_query = base_query.filter(and_(*conditions)) - query = session.query(func.count()).select_from(base_query.subquery()) + query = query.where(and_(*conditions)) - return query.scalar() + result = await session.execute(query) + return result.scalar() except SQLAlchemyError as e: logger.error(f"[SQLAlchemy] 数据库操作出错: {e}") @@ -238,9 +258,9 @@ async def db_query( async def db_save( - model_class: Type[Base], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None + model_class, data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None ) -> Optional[Dict[str, Any]]: - """保存数据到数据库(创建或更新) + """异步保存数据到数据库(创建或更新) Args: model_class: SQLAlchemy模型类 @@ -252,13 +272,13 @@ async def db_save( 保存后的记录数据或None """ try: - with get_db_session() as session: + async with get_db_session() as session: # 如果提供了key_field和key_value,尝试更新现有记录 if key_field and key_value is not None: if hasattr(model_class, key_field): - existing_record = ( - session.query(model_class).filter(getattr(model_class, key_field) == key_value).first() - ) + query = select(model_class).where(getattr(model_class, key_field) == key_value) + result = await session.execute(query) + existing_record = result.scalars().first() if existing_record: # 更新现有记录 @@ -266,7 +286,7 @@ async def db_save( if hasattr(existing_record, field): setattr(existing_record, field, value) - session.flush() + await session.flush() # 转换为字典格式返回 result_dict = {} @@ -277,8 +297,7 @@ async def db_save( # 创建新记录 new_record = model_class(**data) session.add(new_record) - session.commit() - session.flush() + await session.flush() # 转换为字典格式返回 result_dict = {} @@ -297,13 +316,13 @@ async def db_save( async def db_get( - model_class: Type[Base], + model_class, filters: Optional[Dict[str, Any]] = None, limit: Optional[int] = None, order_by: Optional[str] = None, single_result: Optional[bool] = False, ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: - """从数据库获取记录 + """异步从数据库获取记录 Args: model_class: SQLAlchemy模型类 @@ -335,7 +354,7 @@ async def store_action_info( action_data: Optional[dict] = None, action_name: str = "", ) -> Optional[Dict[str, Any]]: - """存储动作信息到数据库 + """异步存储动作信息到数据库 Args: chat_stream: 聊天流对象 diff --git a/src/common/database/sqlalchemy_init.py b/src/common/database/sqlalchemy_init.py index fa7a864eb..7d3f97136 100644 --- a/src/common/database/sqlalchemy_init.py +++ b/src/common/database/sqlalchemy_init.py @@ -1,7 +1,7 @@ """SQLAlchemy数据库初始化模块 替换Peewee的数据库初始化逻辑 -提供统一的数据库初始化接口 +提供统一的异步数据库初始化接口 """ from typing import Optional @@ -12,25 +12,25 @@ from src.common.database.sqlalchemy_models import Base, get_engine, initialize_d logger = get_logger("sqlalchemy_init") -def initialize_sqlalchemy_database() -> bool: +async def initialize_sqlalchemy_database() -> bool: """ - 初始化SQLAlchemy数据库 + 初始化SQLAlchemy异步数据库 创建所有表结构 Returns: bool: 初始化是否成功 """ try: - logger.info("开始初始化SQLAlchemy数据库...") + logger.info("开始初始化SQLAlchemy异步数据库...") # 初始化数据库引擎和会话 - engine, session_local = initialize_database() + engine, session_local = await initialize_database() if engine is None: logger.error("数据库引擎初始化失败") return False - logger.info("SQLAlchemy数据库初始化成功") + logger.info("SQLAlchemy异步数据库初始化成功") return True except SQLAlchemyError as e: @@ -41,9 +41,9 @@ def initialize_sqlalchemy_database() -> bool: return False -def create_all_tables() -> bool: +async def create_all_tables() -> bool: """ - 创建所有数据库表 + 异步创建所有数据库表 Returns: bool: 创建是否成功 @@ -51,13 +51,14 @@ def create_all_tables() -> bool: try: logger.info("开始创建数据库表...") - engine = get_engine() + engine = await get_engine() if engine is None: logger.error("无法获取数据库引擎") return False - # 创建所有表 - Base.metadata.create_all(bind=engine) + # 异步创建所有表 + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) logger.info("数据库表创建成功") return True @@ -70,15 +71,15 @@ def create_all_tables() -> bool: return False -def get_database_info() -> Optional[dict]: +async def get_database_info() -> Optional[dict]: """ - 获取数据库信息 + 异步获取数据库信息 Returns: dict: 数据库信息字典,包含引擎信息等 """ try: - engine = get_engine() + engine = await get_engine() if engine is None: return None @@ -100,9 +101,9 @@ def get_database_info() -> Optional[dict]: _database_initialized = False -def initialize_database_compat() -> bool: +async def initialize_database_compat() -> bool: """ - 兼容性数据库初始化函数 + 兼容性异步数据库初始化函数 用于替换原有的Peewee初始化代码 Returns: @@ -113,9 +114,9 @@ def initialize_database_compat() -> bool: if _database_initialized: return True - success = initialize_sqlalchemy_database() + success = await initialize_sqlalchemy_database() if success: - success = create_all_tables() + success = await create_all_tables() if success: _database_initialized = True diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 464b38e9f..f9b0ef68d 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -3,16 +3,18 @@ 替换Peewee ORM,使用SQLAlchemy提供更好的连接池管理和错误恢复能力 """ -from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, create_engine, DateTime +from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, DateTime from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, Session, Mapped, mapped_column +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker +from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.pool import QueuePool import os import datetime import time -from typing import Iterator, Optional, Any, Dict +from typing import Iterator, Optional, Any, Dict, AsyncGenerator from src.common.logger import get_logger -from contextlib import contextmanager +from contextlib import asynccontextmanager +import asyncio logger = get_logger("sqlalchemy_models") @@ -575,14 +577,14 @@ def get_database_url(): # 使用Unix socket连接 encoded_socket = quote_plus(config.mysql_unix_socket) return ( - f"mysql+pymysql://{encoded_user}:{encoded_password}" + f"mysql+aiomysql://{encoded_user}:{encoded_password}" f"@/{config.mysql_database}" f"?unix_socket={encoded_socket}&charset={config.mysql_charset}" ) else: # 使用标准TCP连接 return ( - f"mysql+pymysql://{encoded_user}:{encoded_password}" + f"mysql+aiomysql://{encoded_user}:{encoded_password}" f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" f"?charset={config.mysql_charset}" ) @@ -597,11 +599,11 @@ def get_database_url(): # 确保数据库目录存在 os.makedirs(os.path.dirname(db_path), exist_ok=True) - return f"sqlite:///{db_path}" + return f"sqlite+aiosqlite:///{db_path}" -def initialize_database(): - """初始化数据库引擎和会话""" +async def initialize_database(): + """初始化异步数据库引擎和会话""" global _engine, _SessionLocal if _engine is not None: @@ -654,41 +656,40 @@ def initialize_database(): } ) - _engine = create_engine(database_url, **engine_kwargs) - _SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine) + _engine = create_async_engine(database_url, **engine_kwargs) + _SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False) # 调用新的迁移函数,它会处理表的创建和列的添加 from src.common.database.db_migration import check_and_migrate_database - check_and_migrate_database() + await check_and_migrate_database() - logger.info(f"SQLAlchemy数据库初始化成功: {config.database_type}") + logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}") return _engine, _SessionLocal -@contextmanager -def get_db_session() -> Iterator[Session]: - """数据库会话上下文管理器 - 推荐使用这个而不是get_session()""" - session: Optional[Session] = None +@asynccontextmanager +async def get_db_session() -> AsyncGenerator[AsyncSession, None]: + """异步数据库会话上下文管理器""" + session: Optional[AsyncSession] = None try: - engine, SessionLocal = initialize_database() + engine, SessionLocal = await initialize_database() if not SessionLocal: raise RuntimeError("Database session not initialized") session = SessionLocal() yield session - # session.commit() except Exception: if session: - session.rollback() + await session.rollback() raise finally: if session: - session.close() + await session.close() -def get_engine(): - """获取数据库引擎""" - engine, _ = initialize_database() +async def get_engine(): + """获取异步数据库引擎""" + engine, _ = await initialize_database() return engine