fix: 更新数据库会话管理,确保事务在正常退出时自动提交,并在异常时安全回滚

This commit is contained in:
Windpicker-owo
2025-11-28 13:24:41 +08:00
parent ac017986fd
commit 2bd7e93af7
4 changed files with 54 additions and 21 deletions

View File

@@ -135,6 +135,7 @@ class MessageHandler:
handler=self._handle_adapter_response_route, handler=self._handle_adapter_response_route,
name="adapter_response_handler", name="adapter_response_handler",
message_type="adapter_response", message_type="adapter_response",
priority=100
) )
# 注册 notice 消息处理器(处理通知消息,如戳一戳、禁言等) # 注册 notice 消息处理器(处理通知消息,如戳一戳、禁言等)
@@ -153,6 +154,7 @@ class MessageHandler:
handler=self._handle_notice_message, handler=self._handle_notice_message,
name="notice_message_handler", name="notice_message_handler",
message_type="notice", message_type="notice",
priority=90
) )
# 注册默认消息处理器(处理所有其他消息) # 注册默认消息处理器(处理所有其他消息)
@@ -160,6 +162,7 @@ class MessageHandler:
predicate=lambda _: True, # 匹配所有消息 predicate=lambda _: True, # 匹配所有消息
handler=self._handle_normal_message, handler=self._handle_normal_message,
name="default_message_handler", name="default_message_handler",
priority=50
) )
logger.info("MessageHandler 已向 MessageRuntime 注册处理器和钩子") logger.info("MessageHandler 已向 MessageRuntime 注册处理器和钩子")

View File

@@ -126,6 +126,12 @@ async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
用于特殊场景,如需要完全独立的连接时。 用于特殊场景,如需要完全独立的连接时。
一般情况下应使用 get_db_session()。 一般情况下应使用 get_db_session()。
事务管理说明:
- 正常退出时自动提交事务
- 发生异常时自动回滚事务
- 如果用户代码已手动调用 commit/rollback再次调用是安全的
- 适用于所有数据库类型SQLite, MySQL, PostgreSQL
Yields: Yields:
AsyncSession: SQLAlchemy异步会话对象 AsyncSession: SQLAlchemy异步会话对象
""" """
@@ -139,8 +145,16 @@ async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
await _apply_session_settings(session, global_config.database.database_type) await _apply_session_settings(session, global_config.database.database_type)
yield session yield session
# 正常退出时提交事务
# 这对所有数据库都很重要,因为 SQLAlchemy 默认不是 autocommit 模式
# 检查事务是否活动,避免在已回滚的事务上提交
if session.is_active:
await session.commit()
except Exception: except Exception:
await session.rollback() # 检查是否需要回滚(事务是否活动)
if session.is_active:
await session.rollback()
raise raise
finally: finally:
await session.close() await session.close()

View File

@@ -17,7 +17,7 @@ from typing import Any, TypeVar
from sqlalchemy import delete, insert, select, update from sqlalchemy import delete, insert, select, update
from src.common.database.core.session import get_db_session from src.common.database.core.session import get_db_session_direct
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.memory_utils import estimate_size_smart from src.common.memory_utils import estimate_size_smart
@@ -330,7 +330,7 @@ class AdaptiveBatchScheduler:
operations: list[BatchOperation], operations: list[BatchOperation],
) -> None: ) -> None:
"""批量执行查询操作""" """批量执行查询操作"""
async with get_db_session() as session: async with get_db_session_direct() as session:
for op in operations: for op in operations:
try: try:
# 构建查询 # 构建查询
@@ -371,7 +371,7 @@ class AdaptiveBatchScheduler:
operations: list[BatchOperation], operations: list[BatchOperation],
) -> None: ) -> None:
"""批量执行插入操作""" """批量执行插入操作"""
async with get_db_session() as session: async with get_db_session_direct() as session:
try: try:
# 收集数据,并过滤掉 id=None 的情况(让数据库自动生成) # 收集数据,并过滤掉 id=None 的情况(让数据库自动生成)
all_data = [] all_data = []
@@ -387,7 +387,7 @@ class AdaptiveBatchScheduler:
# 批量插入 # 批量插入
stmt = insert(operations[0].model_class).values(all_data) stmt = insert(operations[0].model_class).values(all_data)
await session.execute(stmt) await session.execute(stmt)
await session.commit() # 注意commit 由 get_db_session_direct 上下文管理器自动处理
# 设置结果 # 设置结果
for op in operations: for op in operations:
@@ -402,20 +402,21 @@ class AdaptiveBatchScheduler:
except Exception as e: except Exception as e:
logger.error(f"批量插入失败: {e}") logger.error(f"批量插入失败: {e}")
await session.rollback() # 注意rollback 由 get_db_session_direct 上下文管理器自动处理
for op in operations: for op in operations:
if op.future and not op.future.done(): if op.future and not op.future.done():
op.future.set_exception(e) op.future.set_exception(e)
raise # 重新抛出异常以触发 rollback
async def _execute_update_batch( async def _execute_update_batch(
self, self,
operations: list[BatchOperation], operations: list[BatchOperation],
) -> None: ) -> None:
"""批量执行更新操作""" """批量执行更新操作"""
async with get_db_session() as session: async with get_db_session_direct() as session:
results = [] results = []
try: try:
# 🔧 修复:收集所有操作后一次性commit而不是循环中多次commit # 🔧 收集所有操作后一次性commit而不是循环中多次commit
for op in operations: for op in operations:
# 构建更新语句 # 构建更新语句
stmt = update(op.model_class) stmt = update(op.model_class)
@@ -430,8 +431,7 @@ class AdaptiveBatchScheduler:
result = await session.execute(stmt) result = await session.execute(stmt)
results.append((op, result.rowcount)) results.append((op, result.rowcount))
# 所有操作成功后一次性commit # 注意commit 由 get_db_session_direct 上下文管理器自动处理
await session.commit()
# 设置所有操作的结果 # 设置所有操作的结果
for op, rowcount in results: for op, rowcount in results:
@@ -446,21 +446,22 @@ class AdaptiveBatchScheduler:
except Exception as e: except Exception as e:
logger.error(f"批量更新失败: {e}") logger.error(f"批量更新失败: {e}")
await session.rollback() # 注意rollback 由 get_db_session_direct 上下文管理器自动处理
# 所有操作都失败 # 所有操作都失败
for op in operations: for op in operations:
if op.future and not op.future.done(): if op.future and not op.future.done():
op.future.set_exception(e) op.future.set_exception(e)
raise # 重新抛出异常以触发 rollback
async def _execute_delete_batch( async def _execute_delete_batch(
self, self,
operations: list[BatchOperation], operations: list[BatchOperation],
) -> None: ) -> None:
"""批量执行删除操作""" """批量执行删除操作"""
async with get_db_session() as session: async with get_db_session_direct() as session:
results = [] results = []
try: try:
# 🔧 修复:收集所有操作后一次性commit而不是循环中多次commit # 🔧 收集所有操作后一次性commit而不是循环中多次commit
for op in operations: for op in operations:
# 构建删除语句 # 构建删除语句
stmt = delete(op.model_class) stmt = delete(op.model_class)
@@ -472,8 +473,7 @@ class AdaptiveBatchScheduler:
result = await session.execute(stmt) result = await session.execute(stmt)
results.append((op, result.rowcount)) results.append((op, result.rowcount))
# 所有操作成功后一次性commit # 注意commit 由 get_db_session_direct 上下文管理器自动处理
await session.commit()
# 设置所有操作的结果 # 设置所有操作的结果
for op, rowcount in results: for op, rowcount in results:
@@ -488,11 +488,12 @@ class AdaptiveBatchScheduler:
except Exception as e: except Exception as e:
logger.error(f"批量删除失败: {e}") logger.error(f"批量删除失败: {e}")
await session.rollback() # 注意rollback 由 get_db_session_direct 上下文管理器自动处理
# 所有操作都失败 # 所有操作都失败
for op in operations: for op in operations:
if op.future and not op.future.done(): if op.future and not op.future.done():
op.future.set_exception(e) op.future.set_exception(e)
raise # 重新抛出异常以触发 rollback
async def _adjust_parameters(self) -> None: async def _adjust_parameters(self) -> None:
"""根据性能自适应调整参数""" """根据性能自适应调整参数"""

View File

@@ -123,6 +123,12 @@ class ConnectionPoolManager:
""" """
获取数据库会话的透明包装器 获取数据库会话的透明包装器
如果有可用连接则复用,否则创建新连接 如果有可用连接则复用,否则创建新连接
事务管理说明:
- 正常退出时自动提交事务
- 发生异常时自动回滚事务
- 如果用户代码已手动调用 commit/rollback再次调用是安全的空操作
- 支持所有数据库类型SQLite、MySQL、PostgreSQL
""" """
connection_info = None connection_info = None
@@ -151,21 +157,30 @@ class ConnectionPoolManager:
yield connection_info.session yield connection_info.session
# 🔧 修复:正常退出时提交事务 # 🔧 正常退出时提交事务
# 这对SQLite至关重要因为SQLite没有autocommit # 这对所有数据库SQLite、MySQL、PostgreSQL都很重要
# 因为 SQLAlchemy 默认使用事务模式,不会自动提交
# 注意:如果用户代码已调用 commit(),这里的 commit() 是安全的空操作
if connection_info and connection_info.session: if connection_info and connection_info.session:
try: try:
await connection_info.session.commit() # 检查事务是否处于活动状态,避免在已回滚的事务上提交
if connection_info.session.is_active:
await connection_info.session.commit()
except Exception as commit_error: except Exception as commit_error:
logger.warning(f"提交事务时出错: {commit_error}") logger.warning(f"提交事务时出错: {commit_error}")
await connection_info.session.rollback() try:
await connection_info.session.rollback()
except Exception:
pass # 忽略回滚错误,因为事务可能已经结束
raise raise
except Exception: except Exception:
# 发生错误时回滚连接 # 发生错误时回滚连接
if connection_info and connection_info.session: if connection_info and connection_info.session:
try: try:
await connection_info.session.rollback() # 检查是否需要回滚(事务是否活动)
if connection_info.session.is_active:
await connection_info.session.rollback()
except Exception as rollback_error: except Exception as rollback_error:
logger.warning(f"回滚连接时出错: {rollback_error}") logger.warning(f"回滚连接时出错: {rollback_error}")
raise raise