Files
Mofox-Core/src/common/database/connection_pool_manager.py
Windpicker-owo 43483b934e feat: 更新机器人配置并添加数据库迁移脚本
- 将bot_config_template.toml中的版本升级至7.9.0
- 增强数据库配置选项以支持PostgreSQL
- 引入一个新脚本,用于在SQLite、MySQL和PostgreSQL之间迁移数据
- 实现一个方言适配器,用于处理特定于数据库的行为和配置
2025-11-27 18:45:01 +08:00

282 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
透明连接复用管理器
在不改变原有API的情况下实现数据库连接的智能复用
"""
import asyncio
import time
from contextlib import asynccontextmanager
from typing import Any
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from src.common.logger import get_logger
logger = get_logger("connection_pool_manager")
class ConnectionInfo:
"""连接信息包装器"""
def __init__(self, session: AsyncSession, created_at: float):
self.session = session
self.created_at = created_at
self.last_used = created_at
self.in_use = False
self.ref_count = 0
def mark_used(self):
"""标记连接被使用"""
self.last_used = time.time()
self.in_use = True
self.ref_count += 1
def mark_released(self):
"""标记连接被释放"""
self.in_use = False
self.ref_count = max(0, self.ref_count - 1)
def is_expired(self, max_lifetime: float = 300.0, max_idle: float = 60.0) -> bool:
"""检查连接是否过期"""
current_time = time.time()
# 检查总生命周期
if current_time - self.created_at > max_lifetime:
return True
# 检查空闲时间
if not self.in_use and current_time - self.last_used > max_idle:
return True
return False
async def close(self):
"""关闭连接"""
try:
# 使用 shield 保护 close 操作,确保即使任务被取消也能完成关闭
# 通过 `cast` 明确告知类型检查器 `shield` 的返回类型,避免类型错误
from typing import cast
await cast(asyncio.Future, asyncio.shield(self.session.close()))
logger.debug("连接已关闭")
except asyncio.CancelledError:
# 这是一个预期的行为,例如在流式聊天中断时
logger.debug("关闭连接时任务被取消")
# 重新抛出异常以确保任务状态正确
raise
except Exception as e:
logger.warning(f"关闭连接时出错: {e}")
class ConnectionPoolManager:
"""透明的连接池管理器"""
def __init__(self, max_pool_size: int = 10, max_lifetime: float = 300.0, max_idle: float = 60.0):
self.max_pool_size = max_pool_size
self.max_lifetime = max_lifetime
self.max_idle = max_idle
# 连接池
self._connections: set[ConnectionInfo] = set()
self._lock = asyncio.Lock()
# 统计信息
self._stats = {
"total_created": 0,
"total_reused": 0,
"total_expired": 0,
"active_connections": 0,
"pool_hits": 0,
"pool_misses": 0,
}
# 后台清理任务
self._cleanup_task: asyncio.Task | None = None
self._should_cleanup = False
logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})")
async def start(self):
"""启动连接池管理器"""
if self._cleanup_task is None:
self._should_cleanup = True
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info("连接池管理器已启动")
async def stop(self):
"""停止连接池管理器"""
self._should_cleanup = False
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
# 关闭所有连接
await self._close_all_connections()
logger.info("连接池管理器已停止")
@asynccontextmanager
async def get_session(self, session_factory: async_sessionmaker[AsyncSession]):
"""
获取数据库会话的透明包装器
如果有可用连接则复用,否则创建新连接
"""
connection_info = None
try:
# 尝试获取现有连接
connection_info = await self._get_reusable_connection(session_factory)
if connection_info:
# 复用现有连接
connection_info.mark_used()
self._stats["total_reused"] += 1
self._stats["pool_hits"] += 1
logger.debug(f"复用现有连接 (活跃连接数: {len(self._connections)})")
else:
# 创建新连接
session = session_factory()
connection_info = ConnectionInfo(session, time.time())
async with self._lock:
self._connections.add(connection_info)
connection_info.mark_used()
self._stats["total_created"] += 1
self._stats["pool_misses"] += 1
logger.debug(f"创建新连接 (活跃连接数: {len(self._connections)})")
yield connection_info.session
except Exception:
# 发生错误时回滚连接
if connection_info and connection_info.session:
try:
await connection_info.session.rollback()
except Exception as rollback_error:
logger.warning(f"回滚连接时出错: {rollback_error}")
raise
finally:
# 释放连接回池中
if connection_info:
connection_info.mark_released()
async def _get_reusable_connection(
self, session_factory: async_sessionmaker[AsyncSession]
) -> ConnectionInfo | None:
"""获取可复用的连接"""
# 导入方言适配器获取 ping 查询
from src.common.database.core.dialect_adapter import DialectAdapter
ping_query = DialectAdapter.get_ping_query()
async with self._lock:
# 清理过期连接
await self._cleanup_expired_connections_locked()
# 查找可复用的连接
for connection_info in list(self._connections):
if not connection_info.in_use and not connection_info.is_expired(self.max_lifetime, self.max_idle):
# 验证连接是否仍然有效
try:
# 执行 ping 查询来验证连接
await connection_info.session.execute(text(ping_query))
return connection_info
except Exception as e:
logger.debug(f"连接验证失败,将移除: {e}")
await connection_info.close()
self._connections.remove(connection_info)
self._stats["total_expired"] += 1
# 检查是否可以创建新连接
if len(self._connections) >= self.max_pool_size:
logger.warning(f"连接池已满 ({len(self._connections)}/{self.max_pool_size}),等待复用")
return None
return None
async def _cleanup_expired_connections_locked(self):
"""清理过期连接(需要在锁内调用)"""
time.time()
expired_connections = [
connection_info for connection_info in list(self._connections)
if connection_info.is_expired(self.max_lifetime, self.max_idle) and not connection_info.in_use
]
for connection_info in expired_connections:
await connection_info.close()
self._connections.remove(connection_info)
self._stats["total_expired"] += 1
if expired_connections:
logger.debug(f"清理了 {len(expired_connections)} 个过期连接")
async def _cleanup_loop(self):
"""后台清理循环"""
while self._should_cleanup:
try:
await asyncio.sleep(30.0) # 每30秒清理一次
async with self._lock:
await self._cleanup_expired_connections_locked()
# 更新统计信息
self._stats["active_connections"] = len(self._connections)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"连接池清理循环出错: {e}")
await asyncio.sleep(10.0)
async def _close_all_connections(self):
"""关闭所有连接"""
async with self._lock:
for connection_info in list(self._connections):
await connection_info.close()
self._connections.clear()
logger.info("所有连接已关闭")
def get_stats(self) -> dict[str, Any]:
"""获取连接池统计信息"""
return {
**self._stats,
"active_connections": len(self._connections),
"max_pool_size": self.max_pool_size,
"pool_efficiency": (
self._stats["pool_hits"] / max(1, self._stats["pool_hits"] + self._stats["pool_misses"])
)
* 100,
}
# 全局连接池管理器实例
_connection_pool_manager: ConnectionPoolManager | None = None
def get_connection_pool_manager() -> ConnectionPoolManager:
"""获取全局连接池管理器实例"""
global _connection_pool_manager
if _connection_pool_manager is None:
_connection_pool_manager = ConnectionPoolManager()
return _connection_pool_manager
async def start_connection_pool():
"""启动连接池"""
manager = get_connection_pool_manager()
await manager.start()
async def stop_connection_pool():
"""停止连接池"""
global _connection_pool_manager
if _connection_pool_manager:
await _connection_pool_manager.stop()
_connection_pool_manager = None