引入Redis
This commit is contained in:
554
src/common/database/optimization/redis_cache.py
Normal file
554
src/common/database/optimization/redis_cache.py
Normal file
@@ -0,0 +1,554 @@
|
||||
"""Redis 缓存后端实现
|
||||
|
||||
基于 redis-py 的异步 Redis 缓存实现,支持:
|
||||
- 异步连接池
|
||||
- 自动序列化/反序列化
|
||||
- TTL 过期管理
|
||||
- 模式删除
|
||||
- 批量操作
|
||||
- 统计信息
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import pickle
|
||||
from typing import Any
|
||||
|
||||
from src.common.database.optimization.cache_backend import CacheBackend, CacheStats
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("redis_cache")
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
|
||||
class RedisCache(CacheBackend):
|
||||
"""Redis 缓存后端
|
||||
|
||||
特性:
|
||||
- 分布式缓存:支持多实例共享
|
||||
- 自动序列化:支持 JSON 和 Pickle
|
||||
- TTL 管理:Redis 原生过期机制
|
||||
- 模式删除:支持通配符删除
|
||||
- 连接池:高效连接复用
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
password: str | None = None,
|
||||
db: int = 0,
|
||||
key_prefix: str = "mofox:",
|
||||
default_ttl: int = 600,
|
||||
pool_size: int = 10,
|
||||
socket_timeout: float = 5.0,
|
||||
ssl: bool = False,
|
||||
serializer: str = "json", # "json" 或 "pickle"
|
||||
):
|
||||
"""初始化 Redis 缓存
|
||||
|
||||
Args:
|
||||
host: Redis 服务器地址
|
||||
port: Redis 服务器端口
|
||||
password: Redis 密码(可选)
|
||||
db: Redis 数据库编号
|
||||
key_prefix: 缓存键前缀
|
||||
default_ttl: 默认过期时间(秒)
|
||||
pool_size: 连接池大小
|
||||
socket_timeout: socket 超时时间(秒)
|
||||
ssl: 是否启用 SSL
|
||||
serializer: 序列化方式(json 或 pickle)
|
||||
"""
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.password = password if password else None
|
||||
self.db = db
|
||||
self.key_prefix = key_prefix
|
||||
self.default_ttl = default_ttl
|
||||
self.pool_size = pool_size
|
||||
self.socket_timeout = socket_timeout
|
||||
self.ssl = ssl
|
||||
self.serializer = serializer
|
||||
|
||||
# 连接池和客户端(延迟初始化)
|
||||
self._pool: Any = None
|
||||
self._client: Any = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._is_closing = False
|
||||
|
||||
# 统计信息
|
||||
self._stats = CacheStats()
|
||||
self._stats_lock = asyncio.Lock()
|
||||
|
||||
logger.info(
|
||||
f"Redis 缓存初始化: {host}:{port}/{db}, "
|
||||
f"前缀={key_prefix}, TTL={default_ttl}s, "
|
||||
f"序列化={serializer}"
|
||||
)
|
||||
|
||||
async def _ensure_connection(self) -> Any:
|
||||
"""确保 Redis 连接已建立"""
|
||||
if self._client is not None:
|
||||
return self._client
|
||||
|
||||
async with self._lock:
|
||||
if self._client is not None:
|
||||
return self._client
|
||||
|
||||
try:
|
||||
# 创建连接池 (使用 aioredis 模块确保类型安全)
|
||||
self._pool = aioredis.ConnectionPool(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
password=self.password,
|
||||
db=self.db,
|
||||
max_connections=self.pool_size,
|
||||
socket_timeout=self.socket_timeout,
|
||||
socket_connect_timeout=self.socket_timeout,
|
||||
decode_responses=False, # 我们自己处理序列化
|
||||
ssl=self.ssl,
|
||||
)
|
||||
|
||||
# 创建客户端
|
||||
self._client = aioredis.Redis(connection_pool=self._pool)
|
||||
|
||||
# 测试连接
|
||||
await self._client.ping()
|
||||
logger.info(f"Redis 连接成功: {self.host}:{self.port}/{self.db}")
|
||||
|
||||
return self._client
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis 连接失败: {e}")
|
||||
self._client = None
|
||||
self._pool = None
|
||||
raise
|
||||
|
||||
def _make_key(self, key: str) -> str:
|
||||
"""生成带前缀的完整键名"""
|
||||
return f"{self.key_prefix}{key}"
|
||||
|
||||
def _serialize(self, value: Any) -> bytes:
|
||||
"""序列化值"""
|
||||
if self.serializer == "json":
|
||||
try:
|
||||
return json.dumps(value, ensure_ascii=False, default=str).encode("utf-8")
|
||||
except (TypeError, ValueError):
|
||||
# JSON 序列化失败,回退到 pickle
|
||||
return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
else:
|
||||
return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def _deserialize(self, data: bytes) -> Any:
|
||||
"""反序列化值"""
|
||||
if self.serializer == "json":
|
||||
try:
|
||||
return json.loads(data.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
# JSON 反序列化失败,尝试 pickle
|
||||
try:
|
||||
return pickle.loads(data)
|
||||
except Exception:
|
||||
return None
|
||||
else:
|
||||
try:
|
||||
return pickle.loads(data)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get(self, key: str) -> Any | None:
|
||||
"""从缓存获取数据"""
|
||||
try:
|
||||
client = await self._ensure_connection()
|
||||
full_key = self._make_key(key)
|
||||
|
||||
data = await client.get(full_key)
|
||||
|
||||
async with self._stats_lock:
|
||||
if data is not None:
|
||||
self._stats.hits += 1
|
||||
return self._deserialize(data)
|
||||
else:
|
||||
self._stats.misses += 1
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis GET 失败 [{key}]: {e}")
|
||||
async with self._stats_lock:
|
||||
self._stats.misses += 1
|
||||
return None
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl: float | None = None,
|
||||
) -> None:
|
||||
"""设置缓存值"""
|
||||
try:
|
||||
client = await self._ensure_connection()
|
||||
full_key = self._make_key(key)
|
||||
data = self._serialize(value)
|
||||
|
||||
# 使用 TTL
|
||||
expire_time = int(ttl) if ttl is not None else self.default_ttl
|
||||
|
||||
await client.setex(full_key, expire_time, data)
|
||||
|
||||
logger.debug(f"Redis SET: {key} (TTL={expire_time}s)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis SET 失败 [{key}]: {e}")
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""删除缓存条目"""
|
||||
try:
|
||||
client = await self._ensure_connection()
|
||||
full_key = self._make_key(key)
|
||||
|
||||
result = await client.delete(full_key)
|
||||
|
||||
if result > 0:
|
||||
async with self._stats_lock:
|
||||
self._stats.evictions += 1
|
||||
logger.debug(f"Redis DEL: {key}")
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis DEL 失败 [{key}]: {e}")
|
||||
return False
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""检查键是否存在"""
|
||||
try:
|
||||
client = await self._ensure_connection()
|
||||
full_key = self._make_key(key)
|
||||
return bool(await client.exists(full_key))
|
||||
except Exception as e:
|
||||
logger.error(f"Redis EXISTS 失败 [{key}]: {e}")
|
||||
return False
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""清空所有带前缀的缓存"""
|
||||
try:
|
||||
client = await self._ensure_connection()
|
||||
pattern = self._make_key("*")
|
||||
|
||||
# 使用 SCAN 避免阻塞
|
||||
cursor = 0
|
||||
deleted_count = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = await client.scan(cursor, match=pattern, count=100)
|
||||
if keys:
|
||||
await client.delete(*keys)
|
||||
deleted_count += len(keys)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
async with self._stats_lock:
|
||||
self._stats = CacheStats()
|
||||
|
||||
logger.info(f"Redis 缓存已清空: 删除 {deleted_count} 个键")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis CLEAR 失败: {e}")
|
||||
|
||||
async def delete_pattern(self, pattern: str) -> int:
|
||||
"""删除匹配模式的所有键
|
||||
|
||||
Args:
|
||||
pattern: 键模式(支持 * 通配符)
|
||||
|
||||
Returns:
|
||||
删除的键数量
|
||||
"""
|
||||
try:
|
||||
client = await self._ensure_connection()
|
||||
full_pattern = self._make_key(pattern)
|
||||
|
||||
# 使用 SCAN 避免阻塞
|
||||
cursor = 0
|
||||
deleted_count = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = await client.scan(cursor, match=full_pattern, count=100)
|
||||
if keys:
|
||||
await client.delete(*keys)
|
||||
deleted_count += len(keys)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
async with self._stats_lock:
|
||||
self._stats.evictions += deleted_count
|
||||
|
||||
logger.debug(f"Redis 模式删除: {pattern} -> {deleted_count} 个键")
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis 模式删除失败 [{pattern}]: {e}")
|
||||
return 0
|
||||
|
||||
async def mget(self, keys: list[str]) -> dict[str, Any]:
|
||||
"""批量获取多个键的值"""
|
||||
if not keys:
|
||||
return {}
|
||||
|
||||
try:
|
||||
client = await self._ensure_connection()
|
||||
full_keys = [self._make_key(k) for k in keys]
|
||||
|
||||
values = await client.mget(full_keys)
|
||||
|
||||
result = {}
|
||||
hits = 0
|
||||
misses = 0
|
||||
|
||||
for key, value in zip(keys, values):
|
||||
if value is not None:
|
||||
result[key] = self._deserialize(value)
|
||||
hits += 1
|
||||
else:
|
||||
misses += 1
|
||||
|
||||
async with self._stats_lock:
|
||||
self._stats.hits += hits
|
||||
self._stats.misses += misses
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis MGET 失败: {e}")
|
||||
return {}
|
||||
|
||||
async def mset(
|
||||
self,
|
||||
mapping: dict[str, Any],
|
||||
ttl: float | None = None,
|
||||
) -> None:
|
||||
"""批量设置多个键值对"""
|
||||
if not mapping:
|
||||
return
|
||||
|
||||
try:
|
||||
client = await self._ensure_connection()
|
||||
expire_time = int(ttl) if ttl is not None else self.default_ttl
|
||||
|
||||
# 使用 pipeline 提高效率
|
||||
async with client.pipeline(transaction=False) as pipe:
|
||||
for key, value in mapping.items():
|
||||
full_key = self._make_key(key)
|
||||
data = self._serialize(value)
|
||||
pipe.setex(full_key, expire_time, data)
|
||||
|
||||
await pipe.execute()
|
||||
|
||||
logger.debug(f"Redis MSET: {len(mapping)} 个键")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis MSET 失败: {e}")
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
try:
|
||||
client = await self._ensure_connection()
|
||||
|
||||
# 获取 Redis 服务器信息
|
||||
info = await client.info("memory")
|
||||
# keyspace_info 可用于扩展统计, 暂时不获取避免开销
|
||||
# keyspace_info = await client.info("keyspace")
|
||||
|
||||
# 统计带前缀的键数量
|
||||
pattern = self._make_key("*")
|
||||
key_count = 0
|
||||
cursor = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = await client.scan(cursor, match=pattern, count=1000)
|
||||
key_count += len(keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
async with self._stats_lock:
|
||||
return {
|
||||
"backend": "redis",
|
||||
"hits": self._stats.hits,
|
||||
"misses": self._stats.misses,
|
||||
"hit_rate": self._stats.hit_rate,
|
||||
"evictions": self._stats.evictions,
|
||||
"key_count": key_count,
|
||||
"redis_memory_used_mb": info.get("used_memory", 0) / (1024 * 1024),
|
||||
"redis_memory_peak_mb": info.get("used_memory_peak", 0) / (1024 * 1024),
|
||||
"redis_connected_clients": info.get("connected_clients", 0),
|
||||
"key_prefix": self.key_prefix,
|
||||
"default_ttl": self.default_ttl,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取 Redis 统计信息失败: {e}")
|
||||
async with self._stats_lock:
|
||||
return {
|
||||
"backend": "redis",
|
||||
"hits": self._stats.hits,
|
||||
"misses": self._stats.misses,
|
||||
"hit_rate": self._stats.hit_rate,
|
||||
"evictions": self._stats.evictions,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 Redis 连接"""
|
||||
self._is_closing = True
|
||||
|
||||
if self._client is not None:
|
||||
try:
|
||||
await self._client.aclose()
|
||||
logger.info("Redis 连接已关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"关闭 Redis 连接失败: {e}")
|
||||
finally:
|
||||
self._client = None
|
||||
self._pool = None
|
||||
|
||||
@property
|
||||
def backend_type(self) -> str:
|
||||
"""返回缓存后端类型标识"""
|
||||
return "redis"
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
"""Redis 是分布式缓存"""
|
||||
return True
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""健康检查"""
|
||||
try:
|
||||
client = await self._ensure_connection()
|
||||
await client.ping()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def ttl(self, key: str) -> int:
|
||||
"""获取键的剩余 TTL
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
剩余秒数,-1 表示无过期时间,-2 表示键不存在
|
||||
"""
|
||||
try:
|
||||
client = await self._ensure_connection()
|
||||
full_key = self._make_key(key)
|
||||
return await client.ttl(full_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Redis TTL 失败 [{key}]: {e}")
|
||||
return -2
|
||||
|
||||
async def expire(self, key: str, ttl: int) -> bool:
|
||||
"""更新键的 TTL
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
ttl: 新的过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
try:
|
||||
client = await self._ensure_connection()
|
||||
full_key = self._make_key(key)
|
||||
return bool(await client.expire(full_key, ttl))
|
||||
except Exception as e:
|
||||
logger.error(f"Redis EXPIRE 失败 [{key}]: {e}")
|
||||
return False
|
||||
|
||||
async def incr(self, key: str, amount: int = 1) -> int:
|
||||
"""原子递增
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
amount: 递增量
|
||||
|
||||
Returns:
|
||||
递增后的值
|
||||
"""
|
||||
try:
|
||||
client = await self._ensure_connection()
|
||||
full_key = self._make_key(key)
|
||||
return await client.incrby(full_key, amount)
|
||||
except Exception as e:
|
||||
logger.error(f"Redis INCR 失败 [{key}]: {e}")
|
||||
return 0
|
||||
|
||||
async def decr(self, key: str, amount: int = 1) -> int:
|
||||
"""原子递减
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
amount: 递减量
|
||||
|
||||
Returns:
|
||||
递减后的值
|
||||
"""
|
||||
try:
|
||||
client = await self._ensure_connection()
|
||||
full_key = self._make_key(key)
|
||||
return await client.decrby(full_key, amount)
|
||||
except Exception as e:
|
||||
logger.error(f"Redis DECR 失败 [{key}]: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# 全局 Redis 缓存实例
|
||||
_global_redis_cache: RedisCache | None = None
|
||||
_redis_cache_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_redis_cache() -> RedisCache:
|
||||
"""获取全局 Redis 缓存实例(单例)"""
|
||||
global _global_redis_cache
|
||||
|
||||
if _global_redis_cache is None:
|
||||
async with _redis_cache_lock:
|
||||
if _global_redis_cache is None:
|
||||
# 从配置加载参数
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
assert global_config is not None
|
||||
db_config = global_config.database
|
||||
|
||||
_global_redis_cache = RedisCache(
|
||||
host=db_config.redis_host,
|
||||
port=db_config.redis_port,
|
||||
password=db_config.redis_password or None,
|
||||
db=db_config.redis_db,
|
||||
key_prefix=db_config.redis_key_prefix,
|
||||
default_ttl=db_config.redis_default_ttl,
|
||||
pool_size=db_config.redis_connection_pool_size,
|
||||
socket_timeout=db_config.redis_socket_timeout,
|
||||
ssl=db_config.redis_ssl,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法从配置加载 Redis 参数,使用默认值: {e}")
|
||||
_global_redis_cache = RedisCache()
|
||||
|
||||
return _global_redis_cache
|
||||
|
||||
|
||||
async def close_redis_cache() -> None:
|
||||
"""关闭全局 Redis 缓存"""
|
||||
global _global_redis_cache
|
||||
|
||||
if _global_redis_cache is not None:
|
||||
await _global_redis_cache.close()
|
||||
_global_redis_cache = None
|
||||
logger.info("全局 Redis 缓存已关闭")
|
||||
|
||||
Reference in New Issue
Block a user