refactor(memory): 移除废弃的记忆系统备份文件,优化消息管理器架构
移除了deprecated_backup目录下的所有废弃记忆系统文件,包括增强记忆适配器、钩子、集成层、重排序器、元数据索引、多阶段检索和向量存储等模块。同时优化了消息管理器,集成了批量数据库写入器、流缓存管理器和自适应流管理器,提升了系统性能和可维护性。
This commit is contained in:
489
src/chat/message_manager/adaptive_stream_manager.py
Normal file
489
src/chat/message_manager/adaptive_stream_manager.py
Normal file
@@ -0,0 +1,489 @@
|
||||
"""
|
||||
自适应流管理器 - 动态并发限制和异步流池管理
|
||||
根据系统负载和流优先级动态调整并发限制
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import psutil
|
||||
import time
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
logger = get_logger("adaptive_stream_manager")
|
||||
|
||||
|
||||
class StreamPriority(Enum):
|
||||
"""流优先级"""
|
||||
LOW = 1
|
||||
NORMAL = 2
|
||||
HIGH = 3
|
||||
CRITICAL = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMetrics:
|
||||
"""系统指标"""
|
||||
cpu_usage: float = 0.0
|
||||
memory_usage: float = 0.0
|
||||
active_coroutines: int = 0
|
||||
event_loop_lag: float = 0.0
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamMetrics:
|
||||
"""流指标"""
|
||||
stream_id: str
|
||||
priority: StreamPriority
|
||||
message_rate: float = 0.0 # 消息速率(消息/分钟)
|
||||
response_time: float = 0.0 # 平均响应时间
|
||||
last_activity: float = field(default_factory=time.time)
|
||||
consecutive_failures: int = 0
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class AdaptiveStreamManager:
|
||||
"""自适应流管理器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_concurrent_limit: int = 50,
|
||||
max_concurrent_limit: int = 200,
|
||||
min_concurrent_limit: int = 10,
|
||||
metrics_window: float = 60.0, # 指标窗口时间
|
||||
adjustment_interval: float = 30.0, # 调整间隔
|
||||
cpu_threshold_high: float = 0.8, # CPU高负载阈值
|
||||
cpu_threshold_low: float = 0.3, # CPU低负载阈值
|
||||
memory_threshold_high: float = 0.85, # 内存高负载阈值
|
||||
):
|
||||
self.base_concurrent_limit = base_concurrent_limit
|
||||
self.max_concurrent_limit = max_concurrent_limit
|
||||
self.min_concurrent_limit = min_concurrent_limit
|
||||
self.metrics_window = metrics_window
|
||||
self.adjustment_interval = adjustment_interval
|
||||
self.cpu_threshold_high = cpu_threshold_high
|
||||
self.cpu_threshold_low = cpu_threshold_low
|
||||
self.memory_threshold_high = memory_threshold_high
|
||||
|
||||
# 当前状态
|
||||
self.current_limit = base_concurrent_limit
|
||||
self.active_streams: Set[str] = set()
|
||||
self.pending_streams: Set[str] = set()
|
||||
self.stream_metrics: Dict[str, StreamMetrics] = {}
|
||||
|
||||
# 异步信号量
|
||||
self.semaphore = asyncio.Semaphore(base_concurrent_limit)
|
||||
self.priority_semaphore = asyncio.Semaphore(5) # 高优先级专用信号量
|
||||
|
||||
# 系统监控
|
||||
self.system_metrics: List[SystemMetrics] = []
|
||||
self.last_adjustment_time = 0.0
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
"total_requests": 0,
|
||||
"accepted_requests": 0,
|
||||
"rejected_requests": 0,
|
||||
"priority_accepts": 0,
|
||||
"limit_adjustments": 0,
|
||||
"avg_concurrent_streams": 0,
|
||||
"peak_concurrent_streams": 0,
|
||||
}
|
||||
|
||||
# 监控任务
|
||||
self.monitor_task: Optional[asyncio.Task] = None
|
||||
self.adjustment_task: Optional[asyncio.Task] = None
|
||||
self.is_running = False
|
||||
|
||||
logger.info(f"自适应流管理器初始化完成 (base_limit={base_concurrent_limit}, max_limit={max_concurrent_limit})")
|
||||
|
||||
async def start(self):
|
||||
"""启动自适应管理器"""
|
||||
if self.is_running:
|
||||
logger.warning("自适应流管理器已经在运行")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.monitor_task = asyncio.create_task(self._system_monitor_loop(), name="system_monitor")
|
||||
self.adjustment_task = asyncio.create_task(self._adjustment_loop(), name="limit_adjustment")
|
||||
logger.info("自适应流管理器已启动")
|
||||
|
||||
async def stop(self):
|
||||
"""停止自适应管理器"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
self.is_running = False
|
||||
|
||||
# 停止监控任务
|
||||
if self.monitor_task and not self.monitor_task.done():
|
||||
self.monitor_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self.monitor_task, timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("系统监控任务停止超时")
|
||||
except Exception as e:
|
||||
logger.error(f"停止系统监控任务时出错: {e}")
|
||||
|
||||
if self.adjustment_task and not self.adjustment_task.done():
|
||||
self.adjustment_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self.adjustment_task, timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("限制调整任务停止超时")
|
||||
except Exception as e:
|
||||
logger.error(f"停止限制调整任务时出错: {e}")
|
||||
|
||||
logger.info("自适应流管理器已停止")
|
||||
|
||||
async def acquire_stream_slot(
|
||||
self,
|
||||
stream_id: str,
|
||||
priority: StreamPriority = StreamPriority.NORMAL,
|
||||
force: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
获取流处理槽位
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
priority: 优先级
|
||||
force: 是否强制获取(突破限制)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功获取槽位
|
||||
"""
|
||||
# 检查管理器是否已启动
|
||||
if not self.is_running:
|
||||
logger.warning(f"自适应流管理器未运行,直接允许流 {stream_id}")
|
||||
return True
|
||||
|
||||
self.stats["total_requests"] += 1
|
||||
current_time = time.time()
|
||||
|
||||
# 更新流指标
|
||||
if stream_id not in self.stream_metrics:
|
||||
self.stream_metrics[stream_id] = StreamMetrics(
|
||||
stream_id=stream_id,
|
||||
priority=priority
|
||||
)
|
||||
self.stream_metrics[stream_id].last_activity = current_time
|
||||
|
||||
# 检查是否已经活跃
|
||||
if stream_id in self.active_streams:
|
||||
logger.debug(f"流 {stream_id} 已经在活跃列表中")
|
||||
return True
|
||||
|
||||
# 优先级处理
|
||||
if priority in [StreamPriority.HIGH, StreamPriority.CRITICAL]:
|
||||
return await self._acquire_priority_slot(stream_id, priority, force)
|
||||
|
||||
# 检查是否需要强制分发(消息积压)
|
||||
if not force and self._should_force_dispatch(stream_id):
|
||||
force = True
|
||||
logger.info(f"流 {stream_id} 消息积压严重,强制分发")
|
||||
|
||||
# 尝试获取常规信号量
|
||||
try:
|
||||
# 使用wait_for实现非阻塞获取
|
||||
acquired = await asyncio.wait_for(self.semaphore.acquire(), timeout=0.001)
|
||||
if acquired:
|
||||
self.active_streams.add(stream_id)
|
||||
self.stats["accepted_requests"] += 1
|
||||
logger.debug(f"流 {stream_id} 获取常规槽位成功 (当前活跃: {len(self.active_streams)})")
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug(f"常规信号量已满: {stream_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取常规槽位时出错: {e}")
|
||||
|
||||
# 如果强制分发,尝试突破限制
|
||||
if force:
|
||||
return await self._force_acquire_slot(stream_id)
|
||||
|
||||
# 无法获取槽位
|
||||
self.stats["rejected_requests"] += 1
|
||||
logger.debug(f"流 {stream_id} 获取槽位失败,当前限制: {self.current_limit}, 活跃流: {len(self.active_streams)}")
|
||||
return False
|
||||
|
||||
async def _acquire_priority_slot(self, stream_id: str, priority: StreamPriority, force: bool) -> bool:
|
||||
"""获取优先级槽位"""
|
||||
try:
|
||||
# 优先级信号量有少量槽位
|
||||
acquired = await asyncio.wait_for(self.priority_semaphore.acquire(), timeout=0.001)
|
||||
if acquired:
|
||||
self.active_streams.add(stream_id)
|
||||
self.stats["priority_accepts"] += 1
|
||||
self.stats["accepted_requests"] += 1
|
||||
logger.debug(f"流 {stream_id} 获取优先级槽位成功 (优先级: {priority.name})")
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug(f"优先级信号量已满: {stream_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取优先级槽位时出错: {e}")
|
||||
|
||||
# 如果优先级槽位也满了,检查是否强制
|
||||
if force or priority == StreamPriority.CRITICAL:
|
||||
return await self._force_acquire_slot(stream_id)
|
||||
|
||||
return False
|
||||
|
||||
async def _force_acquire_slot(self, stream_id: str) -> bool:
|
||||
"""强制获取槽位(突破限制)"""
|
||||
# 检查是否超过最大限制
|
||||
if len(self.active_streams) >= self.max_concurrent_limit:
|
||||
logger.warning(f"达到最大并发限制 {self.max_concurrent_limit},无法为流 {stream_id} 强制分发")
|
||||
return False
|
||||
|
||||
# 强制添加到活跃列表
|
||||
self.active_streams.add(stream_id)
|
||||
self.stats["accepted_requests"] += 1
|
||||
logger.warning(f"流 {stream_id} 突破并发限制强制分发 (当前活跃: {len(self.active_streams)})")
|
||||
return True
|
||||
|
||||
def release_stream_slot(self, stream_id: str):
|
||||
"""释放流处理槽位"""
|
||||
if stream_id in self.active_streams:
|
||||
self.active_streams.remove(stream_id)
|
||||
|
||||
# 释放相应的信号量
|
||||
metrics = self.stream_metrics.get(stream_id)
|
||||
if metrics and metrics.priority in [StreamPriority.HIGH, StreamPriority.CRITICAL]:
|
||||
self.priority_semaphore.release()
|
||||
else:
|
||||
self.semaphore.release()
|
||||
|
||||
logger.debug(f"流 {stream_id} 释放槽位 (当前活跃: {len(self.active_streams)})")
|
||||
|
||||
def _should_force_dispatch(self, stream_id: str) -> bool:
|
||||
"""判断是否应该强制分发"""
|
||||
# 这里可以实现基于消息积压的判断逻辑
|
||||
# 简化版本:基于流的历史活跃度和优先级
|
||||
metrics = self.stream_metrics.get(stream_id)
|
||||
if not metrics:
|
||||
return False
|
||||
|
||||
# 如果是高优先级流,更容易强制分发
|
||||
if metrics.priority == StreamPriority.HIGH:
|
||||
return True
|
||||
|
||||
# 如果最近有活跃且响应时间较长,可能需要强制分发
|
||||
current_time = time.time()
|
||||
if (current_time - metrics.last_activity < 300 and # 5分钟内有活动
|
||||
metrics.response_time > 5.0): # 响应时间超过5秒
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _system_monitor_loop(self):
|
||||
"""系统监控循环"""
|
||||
logger.info("系统监控循环启动")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
await asyncio.sleep(5.0) # 每5秒监控一次
|
||||
await self._collect_system_metrics()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("系统监控循环被取消")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"系统监控出错: {e}")
|
||||
|
||||
logger.info("系统监控循环结束")
|
||||
|
||||
async def _collect_system_metrics(self):
|
||||
"""收集系统指标"""
|
||||
try:
|
||||
# CPU使用率
|
||||
cpu_usage = psutil.cpu_percent(interval=None) / 100.0
|
||||
|
||||
# 内存使用率
|
||||
memory = psutil.virtual_memory()
|
||||
memory_usage = memory.percent / 100.0
|
||||
|
||||
# 活跃协程数量
|
||||
try:
|
||||
active_coroutines = len(asyncio.all_tasks())
|
||||
except:
|
||||
active_coroutines = 0
|
||||
|
||||
# 事件循环延迟
|
||||
event_loop_lag = 0.0
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
start_time = time.time()
|
||||
await asyncio.sleep(0)
|
||||
event_loop_lag = time.time() - start_time
|
||||
except:
|
||||
pass
|
||||
|
||||
metrics = SystemMetrics(
|
||||
cpu_usage=cpu_usage,
|
||||
memory_usage=memory_usage,
|
||||
active_coroutines=active_coroutines,
|
||||
event_loop_lag=event_loop_lag,
|
||||
timestamp=time.time()
|
||||
)
|
||||
|
||||
self.system_metrics.append(metrics)
|
||||
|
||||
# 保持指标窗口大小
|
||||
cutoff_time = time.time() - self.metrics_window
|
||||
self.system_metrics = [
|
||||
m for m in self.system_metrics
|
||||
if m.timestamp > cutoff_time
|
||||
]
|
||||
|
||||
# 更新统计信息
|
||||
self.stats["avg_concurrent_streams"] = (
|
||||
self.stats["avg_concurrent_streams"] * 0.9 + len(self.active_streams) * 0.1
|
||||
)
|
||||
self.stats["peak_concurrent_streams"] = max(
|
||||
self.stats["peak_concurrent_streams"],
|
||||
len(self.active_streams)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"收集系统指标失败: {e}")
|
||||
|
||||
async def _adjustment_loop(self):
|
||||
"""限制调整循环"""
|
||||
logger.info("限制调整循环启动")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
await asyncio.sleep(self.adjustment_interval)
|
||||
await self._adjust_concurrent_limit()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("限制调整循环被取消")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"限制调整出错: {e}")
|
||||
|
||||
logger.info("限制调整循环结束")
|
||||
|
||||
async def _adjust_concurrent_limit(self):
|
||||
"""调整并发限制"""
|
||||
if not self.system_metrics:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
if current_time - self.last_adjustment_time < self.adjustment_interval:
|
||||
return
|
||||
|
||||
# 计算平均系统指标
|
||||
recent_metrics = self.system_metrics[-10:] if len(self.system_metrics) >= 10 else self.system_metrics
|
||||
if not recent_metrics:
|
||||
return
|
||||
|
||||
avg_cpu = sum(m.cpu_usage for m in recent_metrics) / len(recent_metrics)
|
||||
avg_memory = sum(m.memory_usage for m in recent_metrics) / len(recent_metrics)
|
||||
avg_coroutines = sum(m.active_coroutines for m in recent_metrics) / len(recent_metrics)
|
||||
|
||||
# 调整策略
|
||||
old_limit = self.current_limit
|
||||
adjustment_factor = 1.0
|
||||
|
||||
# CPU负载调整
|
||||
if avg_cpu > self.cpu_threshold_high:
|
||||
adjustment_factor *= 0.8 # 减少20%
|
||||
elif avg_cpu < self.cpu_threshold_low:
|
||||
adjustment_factor *= 1.2 # 增加20%
|
||||
|
||||
# 内存负载调整
|
||||
if avg_memory > self.memory_threshold_high:
|
||||
adjustment_factor *= 0.7 # 减少30%
|
||||
|
||||
# 协程数量调整
|
||||
if avg_coroutines > 1000:
|
||||
adjustment_factor *= 0.9 # 减少10%
|
||||
|
||||
# 应用调整
|
||||
new_limit = int(self.current_limit * adjustment_factor)
|
||||
new_limit = max(self.min_concurrent_limit, min(self.max_concurrent_limit, new_limit))
|
||||
|
||||
# 检查是否需要调整信号量
|
||||
if new_limit != self.current_limit:
|
||||
await self._adjust_semaphore(self.current_limit, new_limit)
|
||||
self.current_limit = new_limit
|
||||
self.stats["limit_adjustments"] += 1
|
||||
self.last_adjustment_time = current_time
|
||||
|
||||
logger.info(
|
||||
f"并发限制调整: {old_limit} -> {new_limit} "
|
||||
f"(CPU: {avg_cpu:.2f}, 内存: {avg_memory:.2f}, 协程: {avg_coroutines:.0f})"
|
||||
)
|
||||
|
||||
async def _adjust_semaphore(self, old_limit: int, new_limit: int):
|
||||
"""调整信号量大小"""
|
||||
if new_limit > old_limit:
|
||||
# 增加信号量槽位
|
||||
for _ in range(new_limit - old_limit):
|
||||
self.semaphore.release()
|
||||
elif new_limit < old_limit:
|
||||
# 减少信号量槽位(通过等待槽位被释放)
|
||||
reduction = old_limit - new_limit
|
||||
for _ in range(reduction):
|
||||
try:
|
||||
await asyncio.wait_for(self.semaphore.acquire(), timeout=0.001)
|
||||
except:
|
||||
# 如果无法立即获取,说明当前使用量接近限制
|
||||
break
|
||||
|
||||
def update_stream_metrics(self, stream_id: str, **kwargs):
|
||||
"""更新流指标"""
|
||||
if stream_id not in self.stream_metrics:
|
||||
return
|
||||
|
||||
metrics = self.stream_metrics[stream_id]
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(metrics, key):
|
||||
setattr(metrics, key, value)
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""获取统计信息"""
|
||||
stats = self.stats.copy()
|
||||
stats.update({
|
||||
"current_limit": self.current_limit,
|
||||
"active_streams": len(self.active_streams),
|
||||
"pending_streams": len(self.pending_streams),
|
||||
"is_running": self.is_running,
|
||||
"system_cpu": self.system_metrics[-1].cpu_usage if self.system_metrics else 0,
|
||||
"system_memory": self.system_metrics[-1].memory_usage if self.system_metrics else 0,
|
||||
})
|
||||
|
||||
# 计算接受率
|
||||
if stats["total_requests"] > 0:
|
||||
stats["acceptance_rate"] = stats["accepted_requests"] / stats["total_requests"]
|
||||
else:
|
||||
stats["acceptance_rate"] = 0
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# 全局自适应管理器实例
|
||||
_adaptive_manager: Optional[AdaptiveStreamManager] = None
|
||||
|
||||
|
||||
def get_adaptive_stream_manager() -> AdaptiveStreamManager:
|
||||
"""获取自适应流管理器实例"""
|
||||
global _adaptive_manager
|
||||
if _adaptive_manager is None:
|
||||
_adaptive_manager = AdaptiveStreamManager()
|
||||
return _adaptive_manager
|
||||
|
||||
|
||||
async def init_adaptive_stream_manager():
|
||||
"""初始化自适应流管理器"""
|
||||
manager = get_adaptive_stream_manager()
|
||||
await manager.start()
|
||||
|
||||
|
||||
async def shutdown_adaptive_stream_manager():
|
||||
"""关闭自适应流管理器"""
|
||||
manager = get_adaptive_stream_manager()
|
||||
await manager.stop()
|
||||
348
src/chat/message_manager/batch_database_writer.py
Normal file
348
src/chat/message_manager/batch_database_writer.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""
|
||||
异步批量数据库写入器
|
||||
优化频繁的数据库写入操作,减少I/O阻塞
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import ChatStreams
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("batch_database_writer")
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamUpdatePayload:
|
||||
"""流更新数据结构"""
|
||||
stream_id: str
|
||||
update_data: Dict[str, Any]
|
||||
priority: int = 0 # 优先级,数字越大优先级越高
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class BatchDatabaseWriter:
|
||||
"""异步批量数据库写入器"""
|
||||
|
||||
def __init__(self, batch_size: int = 50, flush_interval: float = 5.0, max_queue_size: int = 1000):
|
||||
"""
|
||||
初始化批量写入器
|
||||
|
||||
Args:
|
||||
batch_size: 批量写入的大小
|
||||
flush_interval: 刷新间隔(秒)
|
||||
max_queue_size: 最大队列大小
|
||||
"""
|
||||
self.batch_size = batch_size
|
||||
self.flush_interval = flush_interval
|
||||
self.max_queue_size = max_queue_size
|
||||
|
||||
# 异步队列
|
||||
self.write_queue: asyncio.Queue[StreamUpdatePayload] = asyncio.Queue(maxsize=max_queue_size)
|
||||
|
||||
# 运行状态
|
||||
self.is_running = False
|
||||
self.writer_task: Optional[asyncio.Task] = None
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
"total_writes": 0,
|
||||
"batch_writes": 0,
|
||||
"failed_writes": 0,
|
||||
"queue_size": 0,
|
||||
"avg_batch_size": 0,
|
||||
"last_flush_time": 0,
|
||||
}
|
||||
|
||||
# 按优先级分类的批次
|
||||
self.priority_batches: Dict[int, List[StreamUpdatePayload]] = defaultdict(list)
|
||||
|
||||
logger.info(f"批量数据库写入器初始化完成 (batch_size={batch_size}, interval={flush_interval}s)")
|
||||
|
||||
async def start(self):
|
||||
"""启动批量写入器"""
|
||||
if self.is_running:
|
||||
logger.warning("批量写入器已经在运行")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.writer_task = asyncio.create_task(self._batch_writer_loop(), name="batch_database_writer")
|
||||
logger.info("批量数据库写入器已启动")
|
||||
|
||||
async def stop(self):
|
||||
"""停止批量写入器"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
self.is_running = False
|
||||
|
||||
# 等待当前批次写入完成
|
||||
if self.writer_task and not self.writer_task.done():
|
||||
try:
|
||||
# 先处理剩余的数据
|
||||
await self._flush_all_batches()
|
||||
# 取消任务
|
||||
self.writer_task.cancel()
|
||||
await asyncio.wait_for(self.writer_task, timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("批量写入器停止超时")
|
||||
except Exception as e:
|
||||
logger.error(f"停止批量写入器时出错: {e}")
|
||||
|
||||
logger.info("批量数据库写入器已停止")
|
||||
|
||||
async def schedule_stream_update(
|
||||
self,
|
||||
stream_id: str,
|
||||
update_data: Dict[str, Any],
|
||||
priority: int = 0
|
||||
) -> bool:
|
||||
"""
|
||||
调度流更新
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
update_data: 更新数据
|
||||
priority: 优先级
|
||||
|
||||
Returns:
|
||||
bool: 是否成功加入队列
|
||||
"""
|
||||
try:
|
||||
if not self.is_running:
|
||||
logger.warning("批量写入器未运行,直接写入数据库")
|
||||
await self._direct_write(stream_id, update_data)
|
||||
return True
|
||||
|
||||
# 创建更新载荷
|
||||
payload = StreamUpdatePayload(
|
||||
stream_id=stream_id,
|
||||
update_data=update_data,
|
||||
priority=priority
|
||||
)
|
||||
|
||||
# 非阻塞方式加入队列
|
||||
try:
|
||||
self.write_queue.put_nowait(payload)
|
||||
self.stats["total_writes"] += 1
|
||||
self.stats["queue_size"] = self.write_queue.qsize()
|
||||
return True
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(f"写入队列已满,丢弃低优先级更新: stream_id={stream_id}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"调度流更新失败: {e}")
|
||||
return False
|
||||
|
||||
async def _batch_writer_loop(self):
|
||||
"""批量写入主循环"""
|
||||
logger.info("批量写入循环启动")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
# 等待批次填满或超时
|
||||
batch = await self._collect_batch()
|
||||
|
||||
if batch:
|
||||
await self._write_batch(batch)
|
||||
|
||||
# 更新统计信息
|
||||
self.stats["queue_size"] = self.write_queue.qsize()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("批量写入循环被取消")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"批量写入循环出错: {e}")
|
||||
# 短暂等待后继续
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
# 循环结束前处理剩余数据
|
||||
await self._flush_all_batches()
|
||||
logger.info("批量写入循环结束")
|
||||
|
||||
async def _collect_batch(self) -> List[StreamUpdatePayload]:
|
||||
"""收集一个批次的数据"""
|
||||
batch = []
|
||||
deadline = time.time() + self.flush_interval
|
||||
|
||||
while len(batch) < self.batch_size and time.time() < deadline:
|
||||
try:
|
||||
# 计算剩余等待时间
|
||||
remaining_time = max(0, deadline - time.time())
|
||||
if remaining_time == 0:
|
||||
break
|
||||
|
||||
payload = await asyncio.wait_for(
|
||||
self.write_queue.get(),
|
||||
timeout=remaining_time
|
||||
)
|
||||
batch.append(payload)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
return batch
|
||||
|
||||
async def _write_batch(self, batch: List[StreamUpdatePayload]):
|
||||
"""批量写入数据库"""
|
||||
if not batch:
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 按优先级排序
|
||||
batch.sort(key=lambda x: (-x.priority, x.timestamp))
|
||||
|
||||
# 合并同一流ID的更新(保留最新的)
|
||||
merged_updates = {}
|
||||
for payload in batch:
|
||||
if payload.stream_id not in merged_updates or payload.timestamp > merged_updates[payload.stream_id].timestamp:
|
||||
merged_updates[payload.stream_id] = payload
|
||||
|
||||
# 批量写入
|
||||
await self._batch_write_to_database(list(merged_updates.values()))
|
||||
|
||||
# 更新统计
|
||||
self.stats["batch_writes"] += 1
|
||||
self.stats["avg_batch_size"] = (
|
||||
self.stats["avg_batch_size"] * 0.9 + len(batch) * 0.1
|
||||
) # 滑动平均
|
||||
self.stats["last_flush_time"] = start_time
|
||||
|
||||
logger.debug(f"批量写入完成: {len(batch)} 个更新,耗时 {time.time() - start_time:.3f}s")
|
||||
|
||||
except Exception as e:
|
||||
self.stats["failed_writes"] += 1
|
||||
logger.error(f"批量写入失败: {e}")
|
||||
# 降级到单个写入
|
||||
for payload in batch:
|
||||
try:
|
||||
await self._direct_write(payload.stream_id, payload.update_data)
|
||||
except Exception as single_e:
|
||||
logger.error(f"单个写入也失败: {single_e}")
|
||||
|
||||
async def _batch_write_to_database(self, payloads: List[StreamUpdatePayload]):
|
||||
"""批量写入数据库"""
|
||||
async with get_db_session() as session:
|
||||
for payload in payloads:
|
||||
stream_id = payload.stream_id
|
||||
update_data = payload.update_data
|
||||
|
||||
# 根据数据库类型选择不同的插入/更新策略
|
||||
if global_config.database.database_type == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
stmt = sqlite_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["stream_id"],
|
||||
set_=update_data
|
||||
)
|
||||
elif global_config.database.database_type == "mysql":
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
stmt = mysql_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_duplicate_key_update(
|
||||
**{key: value for key, value in update_data.items() if key != "stream_id"}
|
||||
)
|
||||
else:
|
||||
# 默认使用SQLite语法
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
stmt = sqlite_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["stream_id"],
|
||||
set_=update_data
|
||||
)
|
||||
|
||||
await session.execute(stmt)
|
||||
|
||||
await session.commit()
|
||||
|
||||
async def _direct_write(self, stream_id: str, update_data: Dict[str, Any]):
|
||||
"""直接写入数据库(降级方案)"""
|
||||
async with get_db_session() as session:
|
||||
if global_config.database.database_type == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
stmt = sqlite_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["stream_id"],
|
||||
set_=update_data
|
||||
)
|
||||
elif global_config.database.database_type == "mysql":
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
stmt = mysql_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_duplicate_key_update(
|
||||
**{key: value for key, value in update_data.items() if key != "stream_id"}
|
||||
)
|
||||
else:
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
stmt = sqlite_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["stream_id"],
|
||||
set_=update_data
|
||||
)
|
||||
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def _flush_all_batches(self):
|
||||
"""刷新所有剩余批次"""
|
||||
# 收集所有剩余数据
|
||||
remaining_batch = []
|
||||
while not self.write_queue.empty():
|
||||
try:
|
||||
payload = self.write_queue.get_nowait()
|
||||
remaining_batch.append(payload)
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
if remaining_batch:
|
||||
await self._write_batch(remaining_batch)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
stats = self.stats.copy()
|
||||
stats["is_running"] = self.is_running
|
||||
stats["current_queue_size"] = self.write_queue.qsize() if self.is_running else 0
|
||||
return stats
|
||||
|
||||
|
||||
# 全局批量写入器实例
|
||||
_batch_writer: Optional[BatchDatabaseWriter] = None
|
||||
|
||||
|
||||
def get_batch_writer() -> BatchDatabaseWriter:
|
||||
"""获取批量写入器实例"""
|
||||
global _batch_writer
|
||||
if _batch_writer is None:
|
||||
_batch_writer = BatchDatabaseWriter()
|
||||
return _batch_writer
|
||||
|
||||
|
||||
async def init_batch_writer():
|
||||
"""初始化批量写入器"""
|
||||
writer = get_batch_writer()
|
||||
await writer.start()
|
||||
|
||||
|
||||
async def shutdown_batch_writer():
|
||||
"""关闭批量写入器"""
|
||||
writer = get_batch_writer()
|
||||
await writer.stop()
|
||||
@@ -23,6 +23,8 @@ class StreamLoopManager:
|
||||
def __init__(self, max_concurrent_streams: int | None = None):
|
||||
# 流循环任务管理
|
||||
self.stream_loops: dict[str, asyncio.Task] = {}
|
||||
# 跟踪流使用的管理器类型
|
||||
self.stream_management_type: dict[str, str] = {} # stream_id -> "adaptive" or "fallback"
|
||||
|
||||
# 统计信息
|
||||
self.stats: dict[str, Any] = {
|
||||
@@ -99,7 +101,7 @@ class StreamLoopManager:
|
||||
logger.info("流循环管理器已停止")
|
||||
|
||||
async def start_stream_loop(self, stream_id: str, force: bool = False) -> bool:
|
||||
"""启动指定流的循环任务
|
||||
"""启动指定流的循环任务 - 优化版本使用自适应管理器
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
@@ -113,6 +115,71 @@ class StreamLoopManager:
|
||||
logger.debug(f"流 {stream_id} 循环已在运行")
|
||||
return True
|
||||
|
||||
# 使用自适应流管理器获取槽位
|
||||
use_adaptive = False
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager, StreamPriority
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
|
||||
if adaptive_manager.is_running:
|
||||
# 确定流优先级
|
||||
priority = self._determine_stream_priority(stream_id)
|
||||
|
||||
# 获取处理槽位
|
||||
slot_acquired = await adaptive_manager.acquire_stream_slot(
|
||||
stream_id=stream_id,
|
||||
priority=priority,
|
||||
force=force
|
||||
)
|
||||
|
||||
if slot_acquired:
|
||||
use_adaptive = True
|
||||
logger.debug(f"成功获取流处理槽位: {stream_id} (优先级: {priority.name})")
|
||||
else:
|
||||
logger.debug(f"自适应管理器拒绝槽位请求: {stream_id},尝试回退方案")
|
||||
else:
|
||||
logger.debug(f"自适应管理器未运行,使用原始方法")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"自适应管理器获取槽位失败,使用原始方法: {e}")
|
||||
|
||||
# 如果自适应管理器失败或未运行,使用回退方案
|
||||
if not use_adaptive:
|
||||
if not await self._fallback_acquire_slot(stream_id, force):
|
||||
logger.debug(f"回退方案也失败: {stream_id}")
|
||||
return False
|
||||
|
||||
# 创建流循环任务
|
||||
try:
|
||||
loop_task = asyncio.create_task(
|
||||
self._stream_loop_worker(stream_id),
|
||||
name=f"stream_loop_{stream_id}"
|
||||
)
|
||||
self.stream_loops[stream_id] = loop_task
|
||||
# 记录管理器类型
|
||||
self.stream_management_type[stream_id] = "adaptive" if use_adaptive else "fallback"
|
||||
|
||||
# 更新统计信息
|
||||
self.stats["active_streams"] += 1
|
||||
self.stats["total_loops"] += 1
|
||||
|
||||
logger.info(f"启动流循环任务: {stream_id} (管理器: {'adaptive' if use_adaptive else 'fallback'})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动流循环任务失败 {stream_id}: {e}")
|
||||
# 释放槽位
|
||||
if use_adaptive:
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
adaptive_manager.release_stream_slot(stream_id)
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
async def _fallback_acquire_slot(self, stream_id: str, force: bool) -> bool:
|
||||
"""回退方案:获取槽位(原始方法)"""
|
||||
# 判断是否需要强制分发
|
||||
should_force = force or self._should_force_dispatch_for_stream(stream_id)
|
||||
|
||||
@@ -149,6 +216,28 @@ class StreamLoopManager:
|
||||
del self.stream_loops[stream_id]
|
||||
current_streams -= 1 # 更新当前流数量
|
||||
|
||||
return True
|
||||
|
||||
def _determine_stream_priority(self, stream_id: str) -> "StreamPriority":
|
||||
"""确定流优先级"""
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
|
||||
|
||||
# 这里可以基于流的历史数据、用户身份等确定优先级
|
||||
# 简化版本:基于流ID的哈希值分配优先级
|
||||
hash_value = hash(stream_id) % 10
|
||||
|
||||
if hash_value >= 8: # 20% 高优先级
|
||||
return StreamPriority.HIGH
|
||||
elif hash_value >= 5: # 30% 中等优先级
|
||||
return StreamPriority.NORMAL
|
||||
else: # 50% 低优先级
|
||||
return StreamPriority.LOW
|
||||
|
||||
except Exception:
|
||||
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
|
||||
return StreamPriority.NORMAL
|
||||
|
||||
# 创建流循环任务
|
||||
try:
|
||||
task = asyncio.create_task(
|
||||
@@ -201,13 +290,13 @@ class StreamLoopManager:
|
||||
logger.info(f"停止流循环: {stream_id} (剩余: {len(self.stream_loops)})")
|
||||
return True
|
||||
|
||||
async def _stream_loop(self, stream_id: str) -> None:
|
||||
"""单个流的无限循环
|
||||
async def _stream_loop_worker(self, stream_id: str) -> None:
|
||||
"""单个流的工作循环 - 优化版本
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
"""
|
||||
logger.info(f"流循环开始: {stream_id}")
|
||||
logger.info(f"流循环工作器启动: {stream_id}")
|
||||
|
||||
try:
|
||||
while self.is_running:
|
||||
@@ -223,6 +312,18 @@ class StreamLoopManager:
|
||||
unread_count = self._get_unread_count(context)
|
||||
force_dispatch = self._needs_force_dispatch_for_context(context, unread_count)
|
||||
|
||||
# 3. 更新自适应管理器指标
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
adaptive_manager.update_stream_metrics(
|
||||
stream_id,
|
||||
message_rate=unread_count / 5.0 if unread_count > 0 else 0.0, # 简化计算
|
||||
last_activity=time.time()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"更新流指标失败: {e}")
|
||||
|
||||
has_messages = force_dispatch or await self._has_messages_to_process(context)
|
||||
|
||||
if has_messages:
|
||||
@@ -278,6 +379,24 @@ class StreamLoopManager:
|
||||
del self.stream_loops[stream_id]
|
||||
logger.debug(f"清理流循环标记: {stream_id}")
|
||||
|
||||
# 根据管理器类型释放相应的槽位
|
||||
management_type = self.stream_management_type.get(stream_id, "fallback")
|
||||
if management_type == "adaptive":
|
||||
# 释放自适应管理器的槽位
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
adaptive_manager.release_stream_slot(stream_id)
|
||||
logger.debug(f"释放自适应流处理槽位: {stream_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"释放自适应流处理槽位失败: {e}")
|
||||
else:
|
||||
logger.debug(f"流 {stream_id} 使用回退方案,无需释放自适应槽位")
|
||||
|
||||
# 清理管理器类型记录
|
||||
if stream_id in self.stream_management_type:
|
||||
del self.stream_management_type[stream_id]
|
||||
|
||||
logger.info(f"流循环结束: {stream_id}")
|
||||
|
||||
async def _get_stream_context(self, stream_id: str) -> Any | None:
|
||||
|
||||
@@ -56,6 +56,30 @@ class MessageManager:
|
||||
|
||||
self.is_running = True
|
||||
|
||||
# 启动批量数据库写入器
|
||||
try:
|
||||
from src.chat.message_manager.batch_database_writer import init_batch_writer
|
||||
await init_batch_writer()
|
||||
logger.info("📦 批量数据库写入器已启动")
|
||||
except Exception as e:
|
||||
logger.error(f"启动批量数据库写入器失败: {e}")
|
||||
|
||||
# 启动流缓存管理器
|
||||
try:
|
||||
from src.chat.message_manager.stream_cache_manager import init_stream_cache_manager
|
||||
await init_stream_cache_manager()
|
||||
logger.info("🗄️ 流缓存管理器已启动")
|
||||
except Exception as e:
|
||||
logger.error(f"启动流缓存管理器失败: {e}")
|
||||
|
||||
# 启动自适应流管理器
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import init_adaptive_stream_manager
|
||||
await init_adaptive_stream_manager()
|
||||
logger.info("🎯 自适应流管理器已启动")
|
||||
except Exception as e:
|
||||
logger.error(f"启动自适应流管理器失败: {e}")
|
||||
|
||||
# 启动睡眠和唤醒管理器
|
||||
await self.wakeup_manager.start()
|
||||
|
||||
@@ -72,6 +96,30 @@ class MessageManager:
|
||||
|
||||
self.is_running = False
|
||||
|
||||
# 停止批量数据库写入器
|
||||
try:
|
||||
from src.chat.message_manager.batch_database_writer import shutdown_batch_writer
|
||||
await shutdown_batch_writer()
|
||||
logger.info("📦 批量数据库写入器已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"停止批量数据库写入器失败: {e}")
|
||||
|
||||
# 停止流缓存管理器
|
||||
try:
|
||||
from src.chat.message_manager.stream_cache_manager import shutdown_stream_cache_manager
|
||||
await shutdown_stream_cache_manager()
|
||||
logger.info("🗄️ 流缓存管理器已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"停止流缓存管理器失败: {e}")
|
||||
|
||||
# 停止自适应流管理器
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import shutdown_adaptive_stream_manager
|
||||
await shutdown_adaptive_stream_manager()
|
||||
logger.info("🎯 自适应流管理器已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"停止自适应流管理器失败: {e}")
|
||||
|
||||
# 停止睡眠和唤醒管理器
|
||||
await self.wakeup_manager.stop()
|
||||
|
||||
|
||||
381
src/chat/message_manager/stream_cache_manager.py
Normal file
381
src/chat/message_manager/stream_cache_manager.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""
|
||||
流缓存管理器 - 使用优化版聊天流和智能缓存策略
|
||||
提供分层缓存和自动清理功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Set
|
||||
from dataclasses import dataclass
|
||||
from collections import OrderedDict
|
||||
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream
|
||||
|
||||
logger = get_logger("stream_cache_manager")
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamCacheStats:
|
||||
"""缓存统计信息"""
|
||||
hot_cache_size: int = 0
|
||||
warm_storage_size: int = 0
|
||||
cold_storage_size: int = 0
|
||||
total_memory_usage: int = 0 # 估算的内存使用(字节)
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
evictions: int = 0
|
||||
last_cleanup_time: float = 0
|
||||
|
||||
|
||||
class TieredStreamCache:
|
||||
"""分层流缓存管理器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_hot_size: int = 100,
|
||||
max_warm_size: int = 500,
|
||||
max_cold_size: int = 2000,
|
||||
cleanup_interval: float = 300.0, # 5分钟清理一次
|
||||
hot_timeout: float = 1800.0, # 30分钟未访问降级到warm
|
||||
warm_timeout: float = 7200.0, # 2小时未访问降级到cold
|
||||
cold_timeout: float = 86400.0, # 24小时未访问删除
|
||||
):
|
||||
self.max_hot_size = max_hot_size
|
||||
self.max_warm_size = max_warm_size
|
||||
self.max_cold_size = max_cold_size
|
||||
self.cleanup_interval = cleanup_interval
|
||||
self.hot_timeout = hot_timeout
|
||||
self.warm_timeout = warm_timeout
|
||||
self.cold_timeout = cold_timeout
|
||||
|
||||
# 三层缓存存储
|
||||
self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据(LRU)
|
||||
self.warm_storage: Dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
|
||||
self.cold_storage: Dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
|
||||
|
||||
# 统计信息
|
||||
self.stats = StreamCacheStats()
|
||||
|
||||
# 清理任务
|
||||
self.cleanup_task: Optional[asyncio.Task] = None
|
||||
self.is_running = False
|
||||
|
||||
logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})")
|
||||
|
||||
async def start(self):
|
||||
"""启动缓存管理器"""
|
||||
if self.is_running:
|
||||
logger.warning("缓存管理器已经在运行")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_loop(), name="stream_cache_cleanup")
|
||||
logger.info("分层流缓存管理器已启动")
|
||||
|
||||
async def stop(self):
|
||||
"""停止缓存管理器"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
self.is_running = False
|
||||
|
||||
if self.cleanup_task and not self.cleanup_task.done():
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self.cleanup_task, timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("缓存清理任务停止超时")
|
||||
except Exception as e:
|
||||
logger.error(f"停止缓存清理任务时出错: {e}")
|
||||
|
||||
logger.info("分层流缓存管理器已停止")
|
||||
|
||||
async def get_or_create_stream(
|
||||
self,
|
||||
stream_id: str,
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
group_info: Optional[GroupInfo] = None,
|
||||
data: Optional[Dict] = None,
|
||||
) -> OptimizedChatStream:
|
||||
"""获取或创建流 - 优化版本"""
|
||||
current_time = time.time()
|
||||
|
||||
# 1. 检查热缓存
|
||||
if stream_id in self.hot_cache:
|
||||
stream = self.hot_cache[stream_id]
|
||||
# 移动到末尾(LRU更新)
|
||||
self.hot_cache.move_to_end(stream_id)
|
||||
self.stats.cache_hits += 1
|
||||
logger.debug(f"热缓存命中: {stream_id}")
|
||||
return stream.create_snapshot()
|
||||
|
||||
# 2. 检查温存储
|
||||
if stream_id in self.warm_storage:
|
||||
stream, last_access = self.warm_storage[stream_id]
|
||||
self.warm_storage[stream_id] = (stream, current_time)
|
||||
self.stats.cache_hits += 1
|
||||
logger.debug(f"温缓存命中: {stream_id}")
|
||||
# 提升到热缓存
|
||||
await self._promote_to_hot(stream_id, stream)
|
||||
return stream.create_snapshot()
|
||||
|
||||
# 3. 检查冷存储
|
||||
if stream_id in self.cold_storage:
|
||||
stream, last_access = self.cold_storage[stream_id]
|
||||
self.cold_storage[stream_id] = (stream, current_time)
|
||||
self.stats.cache_hits += 1
|
||||
logger.debug(f"冷缓存命中: {stream_id}")
|
||||
# 提升到温缓存
|
||||
await self._promote_to_warm(stream_id, stream)
|
||||
return stream.create_snapshot()
|
||||
|
||||
# 4. 缓存未命中,创建新流
|
||||
self.stats.cache_misses += 1
|
||||
stream = create_optimized_chat_stream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
data=data
|
||||
)
|
||||
logger.debug(f"缓存未命中,创建新流: {stream_id}")
|
||||
|
||||
# 添加到热缓存
|
||||
await self._add_to_hot(stream_id, stream)
|
||||
|
||||
return stream
|
||||
|
||||
async def _add_to_hot(self, stream_id: str, stream: OptimizedChatStream):
|
||||
"""添加到热缓存"""
|
||||
# 检查是否需要驱逐
|
||||
if len(self.hot_cache) >= self.max_hot_size:
|
||||
await self._evict_from_hot()
|
||||
|
||||
self.hot_cache[stream_id] = stream
|
||||
self.stats.hot_cache_size = len(self.hot_cache)
|
||||
|
||||
async def _promote_to_hot(self, stream_id: str, stream: OptimizedChatStream):
|
||||
"""提升到热缓存"""
|
||||
# 从温存储中移除
|
||||
if stream_id in self.warm_storage:
|
||||
del self.warm_storage[stream_id]
|
||||
self.stats.warm_storage_size = len(self.warm_storage)
|
||||
|
||||
# 添加到热缓存
|
||||
await self._add_to_hot(stream_id, stream)
|
||||
logger.debug(f"流 {stream_id} 提升到热缓存")
|
||||
|
||||
async def _promote_to_warm(self, stream_id: str, stream: OptimizedChatStream):
|
||||
"""提升到温缓存"""
|
||||
# 从冷存储中移除
|
||||
if stream_id in self.cold_storage:
|
||||
del self.cold_storage[stream_id]
|
||||
self.stats.cold_storage_size = len(self.cold_storage)
|
||||
|
||||
# 添加到温存储
|
||||
if len(self.warm_storage) >= self.max_warm_size:
|
||||
await self._evict_from_warm()
|
||||
|
||||
current_time = time.time()
|
||||
self.warm_storage[stream_id] = (stream, current_time)
|
||||
self.stats.warm_storage_size = len(self.warm_storage)
|
||||
logger.debug(f"流 {stream_id} 提升到温缓存")
|
||||
|
||||
async def _evict_from_hot(self):
|
||||
"""从热缓存驱逐最久未使用的流"""
|
||||
if not self.hot_cache:
|
||||
return
|
||||
|
||||
# LRU驱逐
|
||||
stream_id, stream = self.hot_cache.popitem(last=False)
|
||||
self.stats.evictions += 1
|
||||
logger.debug(f"从热缓存驱逐: {stream_id}")
|
||||
|
||||
# 移动到温存储
|
||||
if len(self.warm_storage) < self.max_warm_size:
|
||||
current_time = time.time()
|
||||
self.warm_storage[stream_id] = (stream, current_time)
|
||||
self.stats.warm_storage_size = len(self.warm_storage)
|
||||
else:
|
||||
# 温存储也满了,直接删除
|
||||
logger.debug(f"温存储已满,删除流: {stream_id}")
|
||||
|
||||
self.stats.hot_cache_size = len(self.hot_cache)
|
||||
|
||||
async def _evict_from_warm(self):
|
||||
"""从温存储驱逐最久未使用的流"""
|
||||
if not self.warm_storage:
|
||||
return
|
||||
|
||||
# 找到最久未访问的流
|
||||
oldest_stream_id = min(self.warm_storage.keys(), key=lambda k: self.warm_storage[k][1])
|
||||
stream, last_access = self.warm_storage.pop(oldest_stream_id)
|
||||
self.stats.evictions += 1
|
||||
logger.debug(f"从温存储驱逐: {oldest_stream_id}")
|
||||
|
||||
# 移动到冷存储
|
||||
if len(self.cold_storage) < self.max_cold_size:
|
||||
current_time = time.time()
|
||||
self.cold_storage[oldest_stream_id] = (stream, current_time)
|
||||
self.stats.cold_storage_size = len(self.cold_storage)
|
||||
else:
|
||||
# 冷存储也满了,直接删除
|
||||
logger.debug(f"冷存储已满,删除流: {oldest_stream_id}")
|
||||
|
||||
self.stats.warm_storage_size = len(self.warm_storage)
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""清理循环"""
|
||||
logger.info("流缓存清理循环启动")
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
await asyncio.sleep(self.cleanup_interval)
|
||||
await self._perform_cleanup()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("流缓存清理循环被取消")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"流缓存清理出错: {e}")
|
||||
|
||||
logger.info("流缓存清理循环结束")
|
||||
|
||||
async def _perform_cleanup(self):
|
||||
"""执行清理操作"""
|
||||
current_time = time.time()
|
||||
cleanup_stats = {
|
||||
"hot_to_warm": 0,
|
||||
"warm_to_cold": 0,
|
||||
"cold_removed": 0,
|
||||
}
|
||||
|
||||
# 1. 检查热缓存超时
|
||||
hot_to_demote = []
|
||||
for stream_id, stream in self.hot_cache.items():
|
||||
# 获取最后访问时间(简化:使用创建时间作为近似)
|
||||
last_access = getattr(stream, 'last_active_time', stream.create_time)
|
||||
if current_time - last_access > self.hot_timeout:
|
||||
hot_to_demote.append(stream_id)
|
||||
|
||||
for stream_id in hot_to_demote:
|
||||
stream = self.hot_cache.pop(stream_id)
|
||||
current_time_local = time.time()
|
||||
self.warm_storage[stream_id] = (stream, current_time_local)
|
||||
cleanup_stats["hot_to_warm"] += 1
|
||||
|
||||
# 2. 检查温存储超时
|
||||
warm_to_demote = []
|
||||
for stream_id, (stream, last_access) in self.warm_storage.items():
|
||||
if current_time - last_access > self.warm_timeout:
|
||||
warm_to_demote.append(stream_id)
|
||||
|
||||
for stream_id in warm_to_demote:
|
||||
stream, last_access = self.warm_storage.pop(stream_id)
|
||||
self.cold_storage[stream_id] = (stream, last_access)
|
||||
cleanup_stats["warm_to_cold"] += 1
|
||||
|
||||
# 3. 检查冷存储超时
|
||||
cold_to_remove = []
|
||||
for stream_id, (stream, last_access) in self.cold_storage.items():
|
||||
if current_time - last_access > self.cold_timeout:
|
||||
cold_to_remove.append(stream_id)
|
||||
|
||||
for stream_id in cold_to_remove:
|
||||
self.cold_storage.pop(stream_id)
|
||||
cleanup_stats["cold_removed"] += 1
|
||||
|
||||
# 更新统计信息
|
||||
self.stats.hot_cache_size = len(self.hot_cache)
|
||||
self.stats.warm_storage_size = len(self.warm_storage)
|
||||
self.stats.cold_storage_size = len(self.cold_storage)
|
||||
self.stats.last_cleanup_time = current_time
|
||||
|
||||
# 估算内存使用(粗略估计)
|
||||
self.stats.total_memory_usage = (
|
||||
len(self.hot_cache) * 1024 + # 每个热流约1KB
|
||||
len(self.warm_storage) * 512 + # 每个温流约512B
|
||||
len(self.cold_storage) * 256 # 每个冷流约256B
|
||||
)
|
||||
|
||||
if sum(cleanup_stats.values()) > 0:
|
||||
logger.info(
|
||||
f"缓存清理完成: {cleanup_stats['hot_to_warm']}热→温, "
|
||||
f"{cleanup_stats['warm_to_cold']}温→冷, "
|
||||
f"{cleanup_stats['cold_removed']}冷删除"
|
||||
)
|
||||
|
||||
def get_stats(self) -> StreamCacheStats:
|
||||
"""获取缓存统计信息"""
|
||||
# 计算命中率
|
||||
total_requests = self.stats.cache_hits + self.stats.cache_misses
|
||||
hit_rate = self.stats.cache_hits / total_requests if total_requests > 0 else 0
|
||||
|
||||
stats_copy = StreamCacheStats(
|
||||
hot_cache_size=self.stats.hot_cache_size,
|
||||
warm_storage_size=self.stats.warm_storage_size,
|
||||
cold_storage_size=self.stats.cold_storage_size,
|
||||
total_memory_usage=self.stats.total_memory_usage,
|
||||
cache_hits=self.stats.cache_hits,
|
||||
cache_misses=self.stats.cache_misses,
|
||||
evictions=self.stats.evictions,
|
||||
last_cleanup_time=self.stats.last_cleanup_time,
|
||||
)
|
||||
|
||||
# 添加命中率信息
|
||||
stats_copy.hit_rate = hit_rate
|
||||
|
||||
return stats_copy
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空所有缓存"""
|
||||
self.hot_cache.clear()
|
||||
self.warm_storage.clear()
|
||||
self.cold_storage.clear()
|
||||
|
||||
self.stats.hot_cache_size = 0
|
||||
self.stats.warm_storage_size = 0
|
||||
self.stats.cold_storage_size = 0
|
||||
self.stats.total_memory_usage = 0
|
||||
|
||||
logger.info("所有缓存已清空")
|
||||
|
||||
async def get_stream_snapshot(self, stream_id: str) -> Optional[OptimizedChatStream]:
|
||||
"""获取流的快照(不修改缓存状态)"""
|
||||
if stream_id in self.hot_cache:
|
||||
return self.hot_cache[stream_id].create_snapshot()
|
||||
elif stream_id in self.warm_storage:
|
||||
return self.warm_storage[stream_id][0].create_snapshot()
|
||||
elif stream_id in self.cold_storage:
|
||||
return self.cold_storage[stream_id][0].create_snapshot()
|
||||
return None
|
||||
|
||||
def get_cached_stream_ids(self) -> Set[str]:
|
||||
"""获取所有缓存的流ID"""
|
||||
return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys())
|
||||
|
||||
|
||||
# 全局缓存管理器实例
|
||||
_cache_manager: Optional[TieredStreamCache] = None
|
||||
|
||||
|
||||
def get_stream_cache_manager() -> TieredStreamCache:
|
||||
"""获取流缓存管理器实例"""
|
||||
global _cache_manager
|
||||
if _cache_manager is None:
|
||||
_cache_manager = TieredStreamCache()
|
||||
return _cache_manager
|
||||
|
||||
|
||||
async def init_stream_cache_manager():
|
||||
"""初始化流缓存管理器"""
|
||||
manager = get_stream_cache_manager()
|
||||
await manager.start()
|
||||
|
||||
|
||||
async def shutdown_stream_cache_manager():
|
||||
"""关闭流缓存管理器"""
|
||||
manager = get_stream_cache_manager()
|
||||
await manager.stop()
|
||||
Reference in New Issue
Block a user