数据库异步支持

仅仅支持还有107处待迁移
This commit is contained in:
雅诺狐
2025-09-19 20:20:20 +08:00
parent 5a0a63464a
commit 4dbc651d74
6 changed files with 169 additions and 136 deletions

View File

@@ -1,4 +1,6 @@
sqlalchemy sqlalchemy
aiosqlite
aiomysql
APScheduler APScheduler
aiohttp aiohttp
aiohttp-cors aiohttp-cors

0
rust_image/Cargo.toml Normal file
View File

View File

@@ -7,71 +7,81 @@ from src.common.logger import get_logger
logger = get_logger("db_migration") logger = get_logger("db_migration")
def check_and_migrate_database(): async def check_and_migrate_database():
""" """
检查数据库结构并自动迁移(添加缺失的表和列)。 异步检查数据库结构并自动迁移(添加缺失的表和列)。
""" """
logger.info("正在检查数据库结构并执行自动迁移...") logger.info("正在检查数据库结构并执行自动迁移...")
engine = get_engine() engine = await get_engine()
inspector = inspect(engine)
# 使用异步引擎获取inspector
async with engine.connect() as connection:
# 在同步上下文中运行inspector操作
inspector = await connection.run_sync(lambda sync_conn: inspect(sync_conn))
# 1. 获取数据库中所有已存在的表名
db_table_names = await connection.run_sync(lambda sync_conn: set(inspect(sync_conn).get_table_names()))
# 1. 获取数据库中所有已存在的表名 # 2. 遍历所有在代码中定义的模型
db_table_names = set(inspector.get_table_names()) for table_name, table in Base.metadata.tables.items():
logger.debug(f"正在检查表: {table_name}")
# 2. 遍历所有在代码中定义的模型 # 3. 如果表不存在,则创建它
for table_name, table in Base.metadata.tables.items(): if table_name not in db_table_names:
logger.debug(f"正在检查表: {table_name}") logger.info(f"'{table_name}' 不存在,正在创建...")
try:
await connection.run_sync(lambda sync_conn: table.create(sync_conn))
logger.info(f"'{table_name}' 创建成功。")
except Exception as e:
logger.error(f"创建表 '{table_name}' 失败: {e}")
continue
# 3. 如果表存在,则创建它 # 4. 如果表存在,则检查并添加缺失的列
if table_name not in db_table_names: db_columns = await connection.run_sync(
logger.info(f"'{table_name}' 不存在,正在创建...") lambda sync_conn: {col["name"] for col in inspect(sync_conn).get_columns(table_name)}
try: )
table.create(engine) model_columns = {col.name for col in table.c}
logger.info(f"'{table_name}' 创建成功。")
except Exception as e:
logger.error(f"创建表 '{table_name}' 失败: {e}")
continue
# 4. 如果表已存在,则检查并添加缺失的列 missing_columns = model_columns - db_columns
db_columns = {col["name"] for col in inspector.get_columns(table_name)} if not missing_columns:
model_columns = {col.name for col in table.c} logger.debug(f"'{table_name}' 结构一致,无需修改。")
continue
missing_columns = model_columns - db_columns logger.info(f"在表 '{table_name}' 中发现缺失的列: {', '.join(missing_columns)}")
if not missing_columns:
logger.debug(f"'{table_name}' 结构一致,无需修改。") # 开始事务来添加缺失的列
continue async with connection.begin() as trans:
try:
for column_name in missing_columns:
column = table.c[column_name]
logger.info(f"在表 '{table_name}' 中发现缺失的列: {', '.join(missing_columns)}") # 构造并执行 ALTER TABLE 语句
with engine.connect() as connection: try:
trans = connection.begin() # 在同步上下文中编译列类型
try: column_type = await connection.run_sync(
for column_name in missing_columns: lambda sync_conn: column.type.compile(sync_conn.dialect)
column = table.c[column_name] )
sql = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}"
# 构造并执行 ALTER TABLE 语句 # 添加默认值和非空约束的处理
try: if column.default is not None:
column_type = column.type.compile(engine.dialect) default_value = column.default.arg
sql = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}" if isinstance(default_value, str):
sql += f" DEFAULT '{default_value}'"
else:
sql += f" DEFAULT {default_value}"
# 添加默认值和非空约束的处理 if not column.nullable:
if column.default is not None: sql += " NOT NULL"
default_value = column.default.arg
if isinstance(default_value, str):
sql += f" DEFAULT '{default_value}'"
else:
sql += f" DEFAULT {default_value}"
if not column.nullable: await connection.execute(text(sql))
sql += " NOT NULL" logger.info(f"成功向表 '{table_name}' 添加列 '{column_name}'")
except Exception as e:
logger.error(f"向表 '{table_name}' 添加列 '{column_name}' 失败: {e}")
connection.execute(text(sql)) except Exception as e:
logger.info(f"成功向'{table_name}' 添加列 '{column_name}'") logger.error(f"'{table_name}' 添加列时发生错误,事务已回滚: {e}")
except Exception as e: await trans.rollback()
logger.error(f"向表 '{table_name}' 添加列 '{column_name}' 失败: {e}") raise
trans.commit()
except Exception as e:
logger.error(f"在表 '{table_name}' 添加列时发生错误,事务已回滚: {e}")
trans.rollback()
logger.info("数据库结构检查与自动迁移完成。") logger.info("数据库结构检查与自动迁移完成。")

View File

@@ -6,9 +6,11 @@
import traceback import traceback
import time import time
import asyncio
from typing import Dict, List, Any, Union, Type, Optional from typing import Dict, List, Any, Union, Type, Optional
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy import desc, asc, func, and_ from sqlalchemy import desc, asc, func, and_, select
from sqlalchemy.ext.asyncio import AsyncSession
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import ( from src.common.database.sqlalchemy_models import (
Base, Base,
@@ -56,7 +58,7 @@ MODEL_MAPPING = {
} }
def build_filters(session, model_class: Type[Base], filters: Dict[str, Any]): async def build_filters(model_class, filters: Dict[str, Any]):
"""构建查询过滤条件""" """构建查询过滤条件"""
conditions = [] conditions = []
@@ -94,7 +96,7 @@ def build_filters(session, model_class: Type[Base], filters: Dict[str, Any]):
async def db_query( async def db_query(
model_class: Type[Base], model_class,
data: Optional[Dict[str, Any]] = None, data: Optional[Dict[str, Any]] = None,
query_type: Optional[str] = "get", query_type: Optional[str] = "get",
filters: Optional[Dict[str, Any]] = None, filters: Optional[Dict[str, Any]] = None,
@@ -102,7 +104,7 @@ async def db_query(
order_by: Optional[List[str]] = None, order_by: Optional[List[str]] = None,
single_result: Optional[bool] = False, single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""执行数据库查询操作 """执行异步数据库查询操作
Args: Args:
model_class: SQLAlchemy模型类 model_class: SQLAlchemy模型类
@@ -120,15 +122,15 @@ async def db_query(
if query_type not in ["get", "create", "update", "delete", "count"]: if query_type not in ["get", "create", "update", "delete", "count"]:
raise ValueError("query_type must be 'get', 'create', 'update', 'delete' or 'count'") raise ValueError("query_type must be 'get', 'create', 'update', 'delete' or 'count'")
with get_db_session() as session: async with get_db_session() as session:
if query_type == "get": if query_type == "get":
query = session.query(model_class) query = select(model_class)
# 应用过滤条件 # 应用过滤条件
if filters: if filters:
conditions = build_filters(session, model_class, filters) conditions = await build_filters(model_class, filters)
if conditions: if conditions:
query = query.filter(and_(*conditions)) query = query.where(and_(*conditions))
# 应用排序 # 应用排序
if order_by: if order_by:
@@ -146,14 +148,15 @@ async def db_query(
query = query.limit(limit) query = query.limit(limit)
# 执行查询 # 执行查询
results = query.all() result = await session.execute(query)
results = result.scalars().all()
# 转换为字典格式 # 转换为字典格式
result_dicts = [] result_dicts = []
for result in results: for result_obj in results:
result_dict = {} result_dict = {}
for column in result.__table__.columns: for column in result_obj.__table__.columns:
result_dict[column.name] = getattr(result, column.name) result_dict[column.name] = getattr(result_obj, column.name)
result_dicts.append(result_dict) result_dicts.append(result_dict)
if single_result: if single_result:
@@ -167,7 +170,7 @@ async def db_query(
# 创建新记录 # 创建新记录
new_record = model_class(**data) new_record = model_class(**data)
session.add(new_record) session.add(new_record)
session.flush() # 获取自动生成的ID await session.flush() # 获取自动生成的ID
# 转换为字典格式返回 # 转换为字典格式返回
result_dict = {} result_dict = {}
@@ -179,43 +182,60 @@ async def db_query(
if not data: if not data:
raise ValueError("更新记录需要提供data参数") raise ValueError("更新记录需要提供data参数")
query = session.query(model_class) query = select(model_class)
# 应用过滤条件 # 应用过滤条件
if filters: if filters:
conditions = build_filters(session, model_class, filters) conditions = await build_filters(model_class, filters)
if conditions: if conditions:
query = query.filter(and_(*conditions)) query = query.where(and_(*conditions))
# 执行更新 # 首先获取要更新的记录
affected_rows = query.update(data) result = await session.execute(query)
records_to_update = result.scalars().all()
# 更新每个记录
affected_rows = 0
for record in records_to_update:
for field, value in data.items():
if hasattr(record, field):
setattr(record, field, value)
affected_rows += 1
return affected_rows return affected_rows
elif query_type == "delete": elif query_type == "delete":
query = session.query(model_class) query = select(model_class)
# 应用过滤条件 # 应用过滤条件
if filters: if filters:
conditions = build_filters(session, model_class, filters) conditions = await build_filters(model_class, filters)
if conditions: if conditions:
query = query.filter(and_(*conditions)) query = query.where(and_(*conditions))
# 执行删除 # 首先获取要删除的记录
affected_rows = query.delete() result = await session.execute(query)
records_to_delete = result.scalars().all()
# 删除记录
affected_rows = 0
for record in records_to_delete:
session.delete(record)
affected_rows += 1
return affected_rows return affected_rows
elif query_type == "count": elif query_type == "count":
query = session.query(func.count(model_class.id)) query = select(func.count(model_class.id))
# 应用过滤条件 # 应用过滤条件
if filters: if filters:
base_query = session.query(model_class) conditions = await build_filters(model_class, filters)
conditions = build_filters(session, model_class, filters)
if conditions: if conditions:
base_query = base_query.filter(and_(*conditions)) query = query.where(and_(*conditions))
query = session.query(func.count()).select_from(base_query.subquery())
return query.scalar() result = await session.execute(query)
return result.scalar()
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"[SQLAlchemy] 数据库操作出错: {e}") logger.error(f"[SQLAlchemy] 数据库操作出错: {e}")
@@ -238,9 +258,9 @@ async def db_query(
async def db_save( async def db_save(
model_class: Type[Base], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None model_class, data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
"""保存数据到数据库(创建或更新) """异步保存数据到数据库(创建或更新)
Args: Args:
model_class: SQLAlchemy模型类 model_class: SQLAlchemy模型类
@@ -252,13 +272,13 @@ async def db_save(
保存后的记录数据或None 保存后的记录数据或None
""" """
try: try:
with get_db_session() as session: async with get_db_session() as session:
# 如果提供了key_field和key_value尝试更新现有记录 # 如果提供了key_field和key_value尝试更新现有记录
if key_field and key_value is not None: if key_field and key_value is not None:
if hasattr(model_class, key_field): if hasattr(model_class, key_field):
existing_record = ( query = select(model_class).where(getattr(model_class, key_field) == key_value)
session.query(model_class).filter(getattr(model_class, key_field) == key_value).first() result = await session.execute(query)
) existing_record = result.scalars().first()
if existing_record: if existing_record:
# 更新现有记录 # 更新现有记录
@@ -266,7 +286,7 @@ async def db_save(
if hasattr(existing_record, field): if hasattr(existing_record, field):
setattr(existing_record, field, value) setattr(existing_record, field, value)
session.flush() await session.flush()
# 转换为字典格式返回 # 转换为字典格式返回
result_dict = {} result_dict = {}
@@ -277,8 +297,7 @@ async def db_save(
# 创建新记录 # 创建新记录
new_record = model_class(**data) new_record = model_class(**data)
session.add(new_record) session.add(new_record)
session.commit() await session.flush()
session.flush()
# 转换为字典格式返回 # 转换为字典格式返回
result_dict = {} result_dict = {}
@@ -297,13 +316,13 @@ async def db_save(
async def db_get( async def db_get(
model_class: Type[Base], model_class,
filters: Optional[Dict[str, Any]] = None, filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
order_by: Optional[str] = None, order_by: Optional[str] = None,
single_result: Optional[bool] = False, single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""从数据库获取记录 """异步从数据库获取记录
Args: Args:
model_class: SQLAlchemy模型类 model_class: SQLAlchemy模型类
@@ -335,7 +354,7 @@ async def store_action_info(
action_data: Optional[dict] = None, action_data: Optional[dict] = None,
action_name: str = "", action_name: str = "",
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
"""存储动作信息到数据库 """异步存储动作信息到数据库
Args: Args:
chat_stream: 聊天流对象 chat_stream: 聊天流对象

View File

@@ -1,7 +1,7 @@
"""SQLAlchemy数据库初始化模块 """SQLAlchemy数据库初始化模块
替换Peewee的数据库初始化逻辑 替换Peewee的数据库初始化逻辑
提供统一的数据库初始化接口 提供统一的异步数据库初始化接口
""" """
from typing import Optional from typing import Optional
@@ -12,25 +12,25 @@ from src.common.database.sqlalchemy_models import Base, get_engine, initialize_d
logger = get_logger("sqlalchemy_init") logger = get_logger("sqlalchemy_init")
def initialize_sqlalchemy_database() -> bool: async def initialize_sqlalchemy_database() -> bool:
""" """
初始化SQLAlchemy数据库 初始化SQLAlchemy异步数据库
创建所有表结构 创建所有表结构
Returns: Returns:
bool: 初始化是否成功 bool: 初始化是否成功
""" """
try: try:
logger.info("开始初始化SQLAlchemy数据库...") logger.info("开始初始化SQLAlchemy异步数据库...")
# 初始化数据库引擎和会话 # 初始化数据库引擎和会话
engine, session_local = initialize_database() engine, session_local = await initialize_database()
if engine is None: if engine is None:
logger.error("数据库引擎初始化失败") logger.error("数据库引擎初始化失败")
return False return False
logger.info("SQLAlchemy数据库初始化成功") logger.info("SQLAlchemy异步数据库初始化成功")
return True return True
except SQLAlchemyError as e: except SQLAlchemyError as e:
@@ -41,9 +41,9 @@ def initialize_sqlalchemy_database() -> bool:
return False return False
def create_all_tables() -> bool: async def create_all_tables() -> bool:
""" """
创建所有数据库表 异步创建所有数据库表
Returns: Returns:
bool: 创建是否成功 bool: 创建是否成功
@@ -51,13 +51,14 @@ def create_all_tables() -> bool:
try: try:
logger.info("开始创建数据库表...") logger.info("开始创建数据库表...")
engine = get_engine() engine = await get_engine()
if engine is None: if engine is None:
logger.error("无法获取数据库引擎") logger.error("无法获取数据库引擎")
return False return False
# 创建所有表 # 异步创建所有表
Base.metadata.create_all(bind=engine) async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("数据库表创建成功") logger.info("数据库表创建成功")
return True return True
@@ -70,15 +71,15 @@ def create_all_tables() -> bool:
return False return False
def get_database_info() -> Optional[dict]: async def get_database_info() -> Optional[dict]:
""" """
获取数据库信息 异步获取数据库信息
Returns: Returns:
dict: 数据库信息字典,包含引擎信息等 dict: 数据库信息字典,包含引擎信息等
""" """
try: try:
engine = get_engine() engine = await get_engine()
if engine is None: if engine is None:
return None return None
@@ -100,9 +101,9 @@ def get_database_info() -> Optional[dict]:
_database_initialized = False _database_initialized = False
def initialize_database_compat() -> bool: async def initialize_database_compat() -> bool:
""" """
兼容性数据库初始化函数 兼容性异步数据库初始化函数
用于替换原有的Peewee初始化代码 用于替换原有的Peewee初始化代码
Returns: Returns:
@@ -113,9 +114,9 @@ def initialize_database_compat() -> bool:
if _database_initialized: if _database_initialized:
return True return True
success = initialize_sqlalchemy_database() success = await initialize_sqlalchemy_database()
if success: if success:
success = create_all_tables() success = await create_all_tables()
if success: if success:
_database_initialized = True _database_initialized = True

View File

@@ -3,16 +3,18 @@
替换Peewee ORM使用SQLAlchemy提供更好的连接池管理和错误恢复能力 替换Peewee ORM使用SQLAlchemy提供更好的连接池管理和错误恢复能力
""" """
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, create_engine, DateTime from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, DateTime
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session, Mapped, mapped_column from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.pool import QueuePool from sqlalchemy.pool import QueuePool
import os import os
import datetime import datetime
import time import time
from typing import Iterator, Optional, Any, Dict from typing import Iterator, Optional, Any, Dict, AsyncGenerator
from src.common.logger import get_logger from src.common.logger import get_logger
from contextlib import contextmanager from contextlib import asynccontextmanager
import asyncio
logger = get_logger("sqlalchemy_models") logger = get_logger("sqlalchemy_models")
@@ -575,14 +577,14 @@ def get_database_url():
# 使用Unix socket连接 # 使用Unix socket连接
encoded_socket = quote_plus(config.mysql_unix_socket) encoded_socket = quote_plus(config.mysql_unix_socket)
return ( return (
f"mysql+pymysql://{encoded_user}:{encoded_password}" f"mysql+aiomysql://{encoded_user}:{encoded_password}"
f"@/{config.mysql_database}" f"@/{config.mysql_database}"
f"?unix_socket={encoded_socket}&charset={config.mysql_charset}" f"?unix_socket={encoded_socket}&charset={config.mysql_charset}"
) )
else: else:
# 使用标准TCP连接 # 使用标准TCP连接
return ( return (
f"mysql+pymysql://{encoded_user}:{encoded_password}" f"mysql+aiomysql://{encoded_user}:{encoded_password}"
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
f"?charset={config.mysql_charset}" f"?charset={config.mysql_charset}"
) )
@@ -597,11 +599,11 @@ def get_database_url():
# 确保数据库目录存在 # 确保数据库目录存在
os.makedirs(os.path.dirname(db_path), exist_ok=True) os.makedirs(os.path.dirname(db_path), exist_ok=True)
return f"sqlite:///{db_path}" return f"sqlite+aiosqlite:///{db_path}"
def initialize_database(): async def initialize_database():
"""初始化数据库引擎和会话""" """初始化异步数据库引擎和会话"""
global _engine, _SessionLocal global _engine, _SessionLocal
if _engine is not None: if _engine is not None:
@@ -654,41 +656,40 @@ def initialize_database():
} }
) )
_engine = create_engine(database_url, **engine_kwargs) _engine = create_async_engine(database_url, **engine_kwargs)
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine) _SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
# 调用新的迁移函数,它会处理表的创建和列的添加 # 调用新的迁移函数,它会处理表的创建和列的添加
from src.common.database.db_migration import check_and_migrate_database from src.common.database.db_migration import check_and_migrate_database
check_and_migrate_database() await check_and_migrate_database()
logger.info(f"SQLAlchemy数据库初始化成功: {config.database_type}") logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}")
return _engine, _SessionLocal return _engine, _SessionLocal
@contextmanager @asynccontextmanager
def get_db_session() -> Iterator[Session]: async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""数据库会话上下文管理器 - 推荐使用这个而不是get_session()""" """异步数据库会话上下文管理器"""
session: Optional[Session] = None session: Optional[AsyncSession] = None
try: try:
engine, SessionLocal = initialize_database() engine, SessionLocal = await initialize_database()
if not SessionLocal: if not SessionLocal:
raise RuntimeError("Database session not initialized") raise RuntimeError("Database session not initialized")
session = SessionLocal() session = SessionLocal()
yield session yield session
# session.commit()
except Exception: except Exception:
if session: if session:
session.rollback() await session.rollback()
raise raise
finally: finally:
if session: if session:
session.close() await session.close()
def get_engine(): async def get_engine():
"""获取数据库引擎""" """获取异步数据库引擎"""
engine, _ = initialize_database() engine, _ = await initialize_database()
return engine return engine