diff --git a/src/common/database/connection_pool_manager.py b/src/common/database/connection_pool_manager.py index e295e2284..e1df6fd2e 100644 --- a/src/common/database/connection_pool_manager.py +++ b/src/common/database/connection_pool_manager.py @@ -8,6 +8,7 @@ import time from contextlib import asynccontextmanager from typing import Any +from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from src.common.logger import get_logger @@ -53,10 +54,16 @@ class ConnectionInfo: async def close(self): """关闭连接""" try: - await self.session.close() + # 使用 shield 保护 close 操作,确保即使任务被取消也能完成关闭 + # 通过 `cast` 明确告知类型检查器 `shield` 的返回类型,避免类型错误 + from typing import cast + await cast(asyncio.Future, asyncio.shield(self.session.close())) logger.debug("连接已关闭") except asyncio.CancelledError: - logger.warning("关闭连接时任务被取消") + # 这是一个预期的行为,例如在流式聊天中断时 + logger.debug("关闭连接时任务被取消") + # 重新抛出异常以确保任务状态正确 + raise except Exception as e: logger.warning(f"关闭连接时出错: {e}") @@ -172,7 +179,7 @@ class ConnectionPoolManager: # 验证连接是否仍然有效 try: # 执行一个简单的查询来验证连接 - await connection_info.session.execute("SELECT 1") + await connection_info.session.execute(text("SELECT 1")) return connection_info except Exception as e: logger.debug(f"连接验证失败,将移除: {e}")