- 将bot_config_template.toml中的版本升级至7.9.0 - 增强数据库配置选项以支持PostgreSQL - 引入一个新脚本,用于在SQLite、MySQL和PostgreSQL之间迁移数据 - 实现一个方言适配器,用于处理特定于数据库的行为和配置
282 lines
9.6 KiB
Python
282 lines
9.6 KiB
Python
"""
|
||
透明连接复用管理器
|
||
在不改变原有API的情况下,实现数据库连接的智能复用
|
||
"""
|
||
|
||
import asyncio
|
||
import time
|
||
from contextlib import asynccontextmanager
|
||
from typing import Any
|
||
|
||
from sqlalchemy import text
|
||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||
|
||
from src.common.logger import get_logger
|
||
|
||
logger = get_logger("connection_pool_manager")
|
||
|
||
|
||
class ConnectionInfo:
|
||
"""连接信息包装器"""
|
||
|
||
def __init__(self, session: AsyncSession, created_at: float):
|
||
self.session = session
|
||
self.created_at = created_at
|
||
self.last_used = created_at
|
||
self.in_use = False
|
||
self.ref_count = 0
|
||
|
||
def mark_used(self):
|
||
"""标记连接被使用"""
|
||
self.last_used = time.time()
|
||
self.in_use = True
|
||
self.ref_count += 1
|
||
|
||
def mark_released(self):
|
||
"""标记连接被释放"""
|
||
self.in_use = False
|
||
self.ref_count = max(0, self.ref_count - 1)
|
||
|
||
def is_expired(self, max_lifetime: float = 300.0, max_idle: float = 60.0) -> bool:
|
||
"""检查连接是否过期"""
|
||
current_time = time.time()
|
||
|
||
# 检查总生命周期
|
||
if current_time - self.created_at > max_lifetime:
|
||
return True
|
||
|
||
# 检查空闲时间
|
||
if not self.in_use and current_time - self.last_used > max_idle:
|
||
return True
|
||
|
||
return False
|
||
|
||
async def close(self):
|
||
"""关闭连接"""
|
||
try:
|
||
# 使用 shield 保护 close 操作,确保即使任务被取消也能完成关闭
|
||
# 通过 `cast` 明确告知类型检查器 `shield` 的返回类型,避免类型错误
|
||
from typing import cast
|
||
await cast(asyncio.Future, asyncio.shield(self.session.close()))
|
||
logger.debug("连接已关闭")
|
||
except asyncio.CancelledError:
|
||
# 这是一个预期的行为,例如在流式聊天中断时
|
||
logger.debug("关闭连接时任务被取消")
|
||
# 重新抛出异常以确保任务状态正确
|
||
raise
|
||
except Exception as e:
|
||
logger.warning(f"关闭连接时出错: {e}")
|
||
|
||
|
||
class ConnectionPoolManager:
|
||
"""透明的连接池管理器"""
|
||
|
||
def __init__(self, max_pool_size: int = 10, max_lifetime: float = 300.0, max_idle: float = 60.0):
|
||
self.max_pool_size = max_pool_size
|
||
self.max_lifetime = max_lifetime
|
||
self.max_idle = max_idle
|
||
|
||
# 连接池
|
||
self._connections: set[ConnectionInfo] = set()
|
||
self._lock = asyncio.Lock()
|
||
|
||
# 统计信息
|
||
self._stats = {
|
||
"total_created": 0,
|
||
"total_reused": 0,
|
||
"total_expired": 0,
|
||
"active_connections": 0,
|
||
"pool_hits": 0,
|
||
"pool_misses": 0,
|
||
}
|
||
|
||
# 后台清理任务
|
||
self._cleanup_task: asyncio.Task | None = None
|
||
self._should_cleanup = False
|
||
|
||
logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})")
|
||
|
||
async def start(self):
|
||
"""启动连接池管理器"""
|
||
if self._cleanup_task is None:
|
||
self._should_cleanup = True
|
||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||
logger.info("连接池管理器已启动")
|
||
|
||
async def stop(self):
|
||
"""停止连接池管理器"""
|
||
self._should_cleanup = False
|
||
|
||
if self._cleanup_task:
|
||
self._cleanup_task.cancel()
|
||
try:
|
||
await self._cleanup_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
self._cleanup_task = None
|
||
|
||
# 关闭所有连接
|
||
await self._close_all_connections()
|
||
logger.info("连接池管理器已停止")
|
||
|
||
@asynccontextmanager
|
||
async def get_session(self, session_factory: async_sessionmaker[AsyncSession]):
|
||
"""
|
||
获取数据库会话的透明包装器
|
||
如果有可用连接则复用,否则创建新连接
|
||
"""
|
||
connection_info = None
|
||
|
||
try:
|
||
# 尝试获取现有连接
|
||
connection_info = await self._get_reusable_connection(session_factory)
|
||
|
||
if connection_info:
|
||
# 复用现有连接
|
||
connection_info.mark_used()
|
||
self._stats["total_reused"] += 1
|
||
self._stats["pool_hits"] += 1
|
||
logger.debug(f"复用现有连接 (活跃连接数: {len(self._connections)})")
|
||
else:
|
||
# 创建新连接
|
||
session = session_factory()
|
||
connection_info = ConnectionInfo(session, time.time())
|
||
|
||
async with self._lock:
|
||
self._connections.add(connection_info)
|
||
|
||
connection_info.mark_used()
|
||
self._stats["total_created"] += 1
|
||
self._stats["pool_misses"] += 1
|
||
logger.debug(f"创建新连接 (活跃连接数: {len(self._connections)})")
|
||
|
||
yield connection_info.session
|
||
|
||
except Exception:
|
||
# 发生错误时回滚连接
|
||
if connection_info and connection_info.session:
|
||
try:
|
||
await connection_info.session.rollback()
|
||
except Exception as rollback_error:
|
||
logger.warning(f"回滚连接时出错: {rollback_error}")
|
||
raise
|
||
finally:
|
||
# 释放连接回池中
|
||
if connection_info:
|
||
connection_info.mark_released()
|
||
|
||
async def _get_reusable_connection(
|
||
self, session_factory: async_sessionmaker[AsyncSession]
|
||
) -> ConnectionInfo | None:
|
||
"""获取可复用的连接"""
|
||
# 导入方言适配器获取 ping 查询
|
||
from src.common.database.core.dialect_adapter import DialectAdapter
|
||
|
||
ping_query = DialectAdapter.get_ping_query()
|
||
|
||
async with self._lock:
|
||
# 清理过期连接
|
||
await self._cleanup_expired_connections_locked()
|
||
|
||
# 查找可复用的连接
|
||
for connection_info in list(self._connections):
|
||
if not connection_info.in_use and not connection_info.is_expired(self.max_lifetime, self.max_idle):
|
||
# 验证连接是否仍然有效
|
||
try:
|
||
# 执行 ping 查询来验证连接
|
||
await connection_info.session.execute(text(ping_query))
|
||
return connection_info
|
||
except Exception as e:
|
||
logger.debug(f"连接验证失败,将移除: {e}")
|
||
await connection_info.close()
|
||
self._connections.remove(connection_info)
|
||
self._stats["total_expired"] += 1
|
||
|
||
# 检查是否可以创建新连接
|
||
if len(self._connections) >= self.max_pool_size:
|
||
logger.warning(f"连接池已满 ({len(self._connections)}/{self.max_pool_size}),等待复用")
|
||
return None
|
||
|
||
return None
|
||
|
||
async def _cleanup_expired_connections_locked(self):
|
||
"""清理过期连接(需要在锁内调用)"""
|
||
time.time()
|
||
expired_connections = [
|
||
connection_info for connection_info in list(self._connections)
|
||
if connection_info.is_expired(self.max_lifetime, self.max_idle) and not connection_info.in_use
|
||
]
|
||
|
||
for connection_info in expired_connections:
|
||
await connection_info.close()
|
||
self._connections.remove(connection_info)
|
||
self._stats["total_expired"] += 1
|
||
|
||
if expired_connections:
|
||
logger.debug(f"清理了 {len(expired_connections)} 个过期连接")
|
||
|
||
async def _cleanup_loop(self):
|
||
"""后台清理循环"""
|
||
while self._should_cleanup:
|
||
try:
|
||
await asyncio.sleep(30.0) # 每30秒清理一次
|
||
|
||
async with self._lock:
|
||
await self._cleanup_expired_connections_locked()
|
||
|
||
# 更新统计信息
|
||
self._stats["active_connections"] = len(self._connections)
|
||
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"连接池清理循环出错: {e}")
|
||
await asyncio.sleep(10.0)
|
||
|
||
async def _close_all_connections(self):
|
||
"""关闭所有连接"""
|
||
async with self._lock:
|
||
for connection_info in list(self._connections):
|
||
await connection_info.close()
|
||
|
||
self._connections.clear()
|
||
logger.info("所有连接已关闭")
|
||
|
||
def get_stats(self) -> dict[str, Any]:
|
||
"""获取连接池统计信息"""
|
||
return {
|
||
**self._stats,
|
||
"active_connections": len(self._connections),
|
||
"max_pool_size": self.max_pool_size,
|
||
"pool_efficiency": (
|
||
self._stats["pool_hits"] / max(1, self._stats["pool_hits"] + self._stats["pool_misses"])
|
||
)
|
||
* 100,
|
||
}
|
||
|
||
|
||
# 全局连接池管理器实例
|
||
_connection_pool_manager: ConnectionPoolManager | None = None
|
||
|
||
|
||
def get_connection_pool_manager() -> ConnectionPoolManager:
|
||
"""获取全局连接池管理器实例"""
|
||
global _connection_pool_manager
|
||
if _connection_pool_manager is None:
|
||
_connection_pool_manager = ConnectionPoolManager()
|
||
return _connection_pool_manager
|
||
|
||
|
||
async def start_connection_pool():
|
||
"""启动连接池"""
|
||
manager = get_connection_pool_manager()
|
||
await manager.start()
|
||
|
||
|
||
async def stop_connection_pool():
|
||
"""停止连接池"""
|
||
global _connection_pool_manager
|
||
if _connection_pool_manager:
|
||
await _connection_pool_manager.stop()
|
||
_connection_pool_manager = None
|