refactor(database): 将同步数据库操作迁移为异步操作

将整个项目的数据库操作从同步模式迁移为异步模式,主要涉及以下修改:

- 将 `with get_db_session()` 改为 `async with get_db_session()`
- 将同步的 SQLAlchemy 查询方法改为异步执行
- 更新相关的方法签名,添加 async/await 关键字
- 修复由于异步化导致的并发问题和性能问题

这些修改提高了数据库操作的并发性能,避免了阻塞主线程,提升了系统的整体响应能力。涉及修改的模块包括表情包管理、反提示注入统计、用户封禁管理、记忆系统、消息存储等多个核心组件。

BREAKING CHANGE: 所有涉及数据库操作的方法现在都需要使用异步调用,同步调用将不再工作
This commit is contained in:
Windpicker-owo
2025-09-28 15:42:30 +08:00
parent 7508d542f2
commit 9836d317b8
35 changed files with 1337 additions and 1068 deletions

View File

@@ -4,7 +4,7 @@
"""
from .message_manager import MessageManager, message_manager
from .context_manager import StreamContextManager, context_manager
from .context_manager import SingleStreamContextManager
from .distribution_manager import (
DistributionManager,
DistributionPriority,
@@ -16,8 +16,7 @@ from .distribution_manager import (
__all__ = [
"MessageManager",
"message_manager",
"StreamContextManager",
"context_manager",
"SingleStreamContextManager",
"DistributionManager",
"DistributionPriority",
"DistributionTask",

View File

@@ -1,12 +1,12 @@
"""
重构后的聊天上下文管理器
提供统一、稳定的聊天上下文管理功能
每个 context_manager 实例只管理一个 stream 的上下文
"""
import asyncio
import time
from typing import Dict, List, Optional, Any, Union, Tuple
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any
from src.common.data_models.message_manager_data_model import StreamContext
from src.common.logger import get_logger
@@ -17,241 +17,112 @@ from .distribution_manager import distribution_manager
logger = get_logger("context_manager")
class StreamContextManager:
"""流上下文管理器 - 统一管理所有聊天流上下文"""
def __init__(self, max_context_size: Optional[int] = None, context_ttl: Optional[int] = None):
# 上下文存储
self.stream_contexts: Dict[str, Any] = {}
self.context_metadata: Dict[str, Dict[str, Any]] = {}
class SingleStreamContextManager:
"""单流上下文管理器 - 每个实例只管理一个 stream 的上下文"""
# 统计信息
self.stats: Dict[str, Union[int, float, str, Dict]] = {
"total_messages": 0,
"total_streams": 0,
"active_streams": 0,
"inactive_streams": 0,
"last_activity": time.time(),
"creation_time": time.time(),
}
def __init__(self, stream_id: str, context: StreamContext, max_context_size: Optional[int] = None):
self.stream_id = stream_id
self.context = context
# 配置参数
self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100)
self.context_ttl = context_ttl or getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时
self.cleanup_interval = getattr(global_config.chat, "context_cleanup_interval", 3600) # 1小时
self.auto_cleanup = getattr(global_config.chat, "auto_cleanup_contexts", True)
self.enable_validation = getattr(global_config.chat, "enable_context_validation", True)
self.context_ttl = getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时
# 清理任务
self.cleanup_task: Optional[Any] = None
self.is_running = False
# 元数据
self.created_time = time.time()
self.last_access_time = time.time()
self.access_count = 0
self.total_messages = 0
logger.info(f"上下文管理器初始化完成 (最大上下文: {self.max_context_size}, TTL: {self.context_ttl}s)")
logger.debug(f"单流上下文管理器初始化: {stream_id}")
def add_stream_context(self, stream_id: str, context: Any, metadata: Optional[Dict[str, Any]] = None) -> bool:
"""添加流上下文
def get_context(self) -> StreamContext:
"""获取流上下文"""
self._update_access_stats()
return self.context
Args:
stream_id: 流ID
context: 上下文对象
metadata: 上下文元数据
Returns:
bool: 是否成功添加
"""
if stream_id in self.stream_contexts:
logger.warning(f"流上下文已存在: {stream_id}")
return False
# 添加上下文
self.stream_contexts[stream_id] = context
# 初始化元数据
self.context_metadata[stream_id] = {
"created_time": time.time(),
"last_access_time": time.time(),
"access_count": 0,
"last_validation_time": 0.0,
"custom_metadata": metadata or {},
}
# 更新统计
self.stats["total_streams"] += 1
self.stats["active_streams"] += 1
self.stats["last_activity"] = time.time()
logger.debug(f"添加流上下文: {stream_id} (类型: {type(context).__name__})")
return True
def remove_stream_context(self, stream_id: str) -> bool:
"""移除流上下文
Args:
stream_id: 流ID
Returns:
bool: 是否成功移除
"""
if stream_id in self.stream_contexts:
context = self.stream_contexts[stream_id]
metadata = self.context_metadata.get(stream_id, {})
del self.stream_contexts[stream_id]
if stream_id in self.context_metadata:
del self.context_metadata[stream_id]
self.stats["active_streams"] = max(0, self.stats["active_streams"] - 1)
self.stats["inactive_streams"] += 1
self.stats["last_activity"] = time.time()
logger.debug(f"移除流上下文: {stream_id} (类型: {type(context).__name__})")
return True
return False
def get_stream_context(self, stream_id: str, update_access: bool = True) -> Optional[StreamContext]:
"""获取流上下文
Args:
stream_id: 流ID
update_access: 是否更新访问统计
Returns:
Optional[Any]: 上下文对象
"""
context = self.stream_contexts.get(stream_id)
if context and update_access:
# 更新访问统计
if stream_id in self.context_metadata:
metadata = self.context_metadata[stream_id]
metadata["last_access_time"] = time.time()
metadata["access_count"] = metadata.get("access_count", 0) + 1
return context
def get_context_metadata(self, stream_id: str) -> Optional[Dict[str, Any]]:
"""获取上下文元数据
Args:
stream_id: 流ID
Returns:
Optional[Dict[str, Any]]: 元数据
"""
return self.context_metadata.get(stream_id)
def update_context_metadata(self, stream_id: str, updates: Dict[str, Any]) -> bool:
"""更新上下文元数据
Args:
stream_id: 流ID
updates: 更新的元数据
Returns:
bool: 是否成功更新
"""
if stream_id not in self.context_metadata:
return False
self.context_metadata[stream_id].update(updates)
return True
def add_message_to_context(self, stream_id: str, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
def add_message(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
"""添加消息到上下文
Args:
stream_id: 流ID
message: 消息对象
skip_energy_update: 是否跳过能量更新
Returns:
bool: 是否成功添加
"""
context = self.get_stream_context(stream_id)
if not context:
logger.warning(f"流上下文不存在: {stream_id}")
return False
try:
# 添加消息到上下文
context.add_message(message)
self.context.add_message(message)
# 计算消息兴趣度
interest_value = self._calculate_message_interest(message)
message.interest_value = interest_value
# 更新统计
self.stats["total_messages"] += 1
self.stats["last_activity"] = time.time()
self.total_messages += 1
self.last_access_time = time.time()
# 更新能量和分发
if not skip_energy_update:
self._update_stream_energy(stream_id)
distribution_manager.add_stream_message(stream_id, 1)
self._update_stream_energy()
distribution_manager.add_stream_message(self.stream_id, 1)
logger.debug(f"添加消息到上下文: {stream_id} (兴趣度: {interest_value:.3f})")
logger.debug(f"添加消息到单流上下文: {self.stream_id} (兴趣度: {interest_value:.3f})")
return True
except Exception as e:
logger.error(f"添加消息到上下文失败 {stream_id}: {e}", exc_info=True)
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False
def update_message_in_context(self, stream_id: str, message_id: str, updates: Dict[str, Any]) -> bool:
def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool:
"""更新上下文中的消息
Args:
stream_id: 流ID
message_id: 消息ID
updates: 更新的属性
Returns:
bool: 是否成功更新
"""
context = self.get_stream_context(stream_id)
if not context:
logger.warning(f"流上下文不存在: {stream_id}")
return False
try:
# 更新消息信息
context.update_message_info(message_id, **updates)
self.context.update_message_info(message_id, **updates)
# 如果更新了兴趣度,重新计算能量
if "interest_value" in updates:
self._update_stream_energy(stream_id)
self._update_stream_energy()
logger.debug(f"更新上下文消息: {stream_id}/{message_id}")
logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}")
return True
except Exception as e:
logger.error(f"更新上下文消息失败 {stream_id}/{message_id}: {e}", exc_info=True)
logger.error(f"更新单流上下文消息失败 {self.stream_id}/{message_id}: {e}", exc_info=True)
return False
def get_context_messages(self, stream_id: str, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]:
def get_messages(self, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]:
"""获取上下文消息
Args:
stream_id: 流ID
limit: 消息数量限制
include_unread: 是否包含未读消息
Returns:
List[Any]: 消息列表
List[DatabaseMessages]: 消息列表
"""
context = self.get_stream_context(stream_id)
if not context:
return []
try:
messages = []
if include_unread:
messages.extend(context.get_unread_messages())
messages.extend(self.context.get_unread_messages())
if limit:
messages.extend(context.get_history_messages(limit=limit))
messages.extend(self.context.get_history_messages(limit=limit))
else:
messages.extend(context.get_history_messages())
messages.extend(self.context.get_history_messages())
# 按时间排序
messages.sort(key=lambda msg: getattr(msg, 'time', 0))
messages.sort(key=lambda msg: getattr(msg, "time", 0))
# 应用限制
if limit and len(messages) > limit:
@@ -260,103 +131,124 @@ class StreamContextManager:
return messages
except Exception as e:
logger.error(f"获取上下文消息失败 {stream_id}: {e}", exc_info=True)
return []
def get_unread_messages(self, stream_id: str) -> List[DatabaseMessages]:
"""获取未读消息
Args:
stream_id: 流ID
Returns:
List[Any]: 未读消息列表
"""
context = self.get_stream_context(stream_id)
if not context:
logger.error(f"获取单流上下文消息失败 {self.stream_id}: {e}", exc_info=True)
return []
def get_unread_messages(self) -> List[DatabaseMessages]:
"""获取未读消息"""
try:
return context.get_unread_messages()
return self.context.get_unread_messages()
except Exception as e:
logger.error(f"获取未读消息失败 {stream_id}: {e}", exc_info=True)
logger.error(f"获取单流未读消息失败 {self.stream_id}: {e}", exc_info=True)
return []
def mark_messages_as_read(self, stream_id: str, message_ids: List[str]) -> bool:
"""标记消息为已读
Args:
stream_id: 流ID
message_ids: 消息ID列表
Returns:
bool: 是否成功标记
"""
context = self.get_stream_context(stream_id)
if not context:
logger.warning(f"流上下文不存在: {stream_id}")
return False
def mark_messages_as_read(self, message_ids: List[str]) -> bool:
"""标记消息为已读"""
try:
if not hasattr(context, 'mark_message_as_read'):
logger.error(f"上下文对象缺少 mark_message_as_read 方法: {stream_id}")
if not hasattr(self.context, "mark_message_as_read"):
logger.error(f"上下文对象缺少 mark_message_as_read 方法: {self.stream_id}")
return False
marked_count = 0
for message_id in message_ids:
try:
context.mark_message_as_read(message_id)
self.context.mark_message_as_read(message_id)
marked_count += 1
except Exception as e:
logger.warning(f"标记消息已读失败 {message_id}: {e}")
logger.debug(f"标记消息为已读: {stream_id} ({marked_count}/{len(message_ids)}条)")
logger.debug(f"标记消息为已读: {self.stream_id} ({marked_count}/{len(message_ids)}条)")
return marked_count > 0
except Exception as e:
logger.error(f"标记消息已读失败 {stream_id}: {e}", exc_info=True)
return False
def clear_context(self, stream_id: str) -> bool:
"""清空上下文
Args:
stream_id: 流ID
Returns:
bool: 是否成功清空
"""
context = self.get_stream_context(stream_id)
if not context:
logger.warning(f"流上下文不存在: {stream_id}")
logger.error(f"标记消息已读失败 {self.stream_id}: {e}", exc_info=True)
return False
def clear_context(self) -> bool:
"""清空上下文"""
try:
# 清空消息
if hasattr(context, 'unread_messages'):
context.unread_messages.clear()
if hasattr(context, 'history_messages'):
context.history_messages.clear()
if hasattr(self.context, "unread_messages"):
self.context.unread_messages.clear()
if hasattr(self.context, "history_messages"):
self.context.history_messages.clear()
# 重置状态
reset_attrs = ['interruption_count', 'afc_threshold_adjustment', 'last_check_time']
reset_attrs = ["interruption_count", "afc_threshold_adjustment", "last_check_time"]
for attr in reset_attrs:
if hasattr(context, attr):
if attr in ['interruption_count', 'afc_threshold_adjustment']:
setattr(context, attr, 0)
if hasattr(self.context, attr):
if attr in ["interruption_count", "afc_threshold_adjustment"]:
setattr(self.context, attr, 0)
else:
setattr(context, attr, time.time())
setattr(self.context, attr, time.time())
# 重新计算能量
self._update_stream_energy(stream_id)
self._update_stream_energy()
logger.info(f"清空上下文: {stream_id}")
logger.info(f"清空单流上下文: {self.stream_id}")
return True
except Exception as e:
logger.error(f"清空上下文失败 {stream_id}: {e}", exc_info=True)
logger.error(f"清空单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False
def get_statistics(self) -> Dict[str, Any]:
"""获取流统计信息"""
try:
current_time = time.time()
uptime = current_time - self.created_time
unread_messages = getattr(self.context, "unread_messages", [])
history_messages = getattr(self.context, "history_messages", [])
return {
"stream_id": self.stream_id,
"context_type": type(self.context).__name__,
"total_messages": len(history_messages) + len(unread_messages),
"unread_messages": len(unread_messages),
"history_messages": len(history_messages),
"is_active": getattr(self.context, "is_active", True),
"last_check_time": getattr(self.context, "last_check_time", current_time),
"interruption_count": getattr(self.context, "interruption_count", 0),
"afc_threshold_adjustment": getattr(self.context, "afc_threshold_adjustment", 0.0),
"created_time": self.created_time,
"last_access_time": self.last_access_time,
"access_count": self.access_count,
"uptime_seconds": uptime,
"idle_seconds": current_time - self.last_access_time,
}
except Exception as e:
logger.error(f"获取单流统计失败 {self.stream_id}: {e}", exc_info=True)
return {}
def validate_integrity(self) -> bool:
"""验证上下文完整性"""
try:
# 检查基本属性
required_attrs = ["stream_id", "unread_messages", "history_messages"]
for attr in required_attrs:
if not hasattr(self.context, attr):
logger.warning(f"上下文缺少必要属性: {attr}")
return False
# 检查消息ID唯一性
all_messages = getattr(self.context, "unread_messages", []) + getattr(self.context, "history_messages", [])
message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")]
if len(message_ids) != len(set(message_ids)):
logger.warning(f"上下文中存在重复消息ID: {self.stream_id}")
return False
return True
except Exception as e:
logger.error(f"验证单流上下文完整性失败 {self.stream_id}: {e}")
return False
def _update_access_stats(self):
"""更新访问统计"""
self.last_access_time = time.time()
self.access_count += 1
def _calculate_message_interest(self, message: DatabaseMessages) -> float:
"""计算消息兴趣度"""
try:
@@ -373,8 +265,7 @@ class StreamContextManager:
interest_score = loop.run_until_complete(
chatter_interest_scoring_system._calculate_single_message_score(
message=message,
bot_nickname=global_config.bot.nickname
message=message, bot_nickname=global_config.bot.nickname
)
)
interest_value = interest_score.total_score
@@ -391,12 +282,12 @@ class StreamContextManager:
logger.error(f"计算消息兴趣度失败: {e}")
return 0.5
def _update_stream_energy(self, stream_id: str):
async def _update_stream_energy(self):
"""更新流能量"""
try:
# 获取所有消息
all_messages = self.get_context_messages(stream_id, self.max_context_size)
unread_messages = self.get_unread_messages(stream_id)
all_messages = self.get_messages(self.max_context_size)
unread_messages = self.get_unread_messages()
combined_messages = all_messages + unread_messages
# 获取用户ID
@@ -406,248 +297,12 @@ class StreamContextManager:
user_id = last_message.user_info.user_id
# 计算能量
energy = energy_manager.calculate_focus_energy(
stream_id=stream_id,
messages=combined_messages,
user_id=user_id
energy = await energy_manager.calculate_focus_energy(
stream_id=self.stream_id, messages=combined_messages, user_id=user_id
)
# 更新分发管理器
distribution_manager.update_stream_energy(stream_id, energy)
distribution_manager.update_stream_energy(self.stream_id, energy)
except Exception as e:
logger.error(f"更新流能量失败 {stream_id}: {e}")
def get_stream_statistics(self, stream_id: str) -> Optional[Dict[str, Any]]:
"""获取流统计信息
Args:
stream_id: 流ID
Returns:
Optional[Dict[str, Any]]: 统计信息
"""
context = self.get_stream_context(stream_id, update_access=False)
if not context:
return None
try:
metadata = self.context_metadata.get(stream_id, {})
current_time = time.time()
created_time = metadata.get("created_time", current_time)
last_access_time = metadata.get("last_access_time", current_time)
access_count = metadata.get("access_count", 0)
unread_messages = getattr(context, "unread_messages", [])
history_messages = getattr(context, "history_messages", [])
return {
"stream_id": stream_id,
"context_type": type(context).__name__,
"total_messages": len(history_messages) + len(unread_messages),
"unread_messages": len(unread_messages),
"history_messages": len(history_messages),
"is_active": getattr(context, "is_active", True),
"last_check_time": getattr(context, "last_check_time", current_time),
"interruption_count": getattr(context, "interruption_count", 0),
"afc_threshold_adjustment": getattr(context, "afc_threshold_adjustment", 0.0),
"created_time": created_time,
"last_access_time": last_access_time,
"access_count": access_count,
"uptime_seconds": current_time - created_time,
"idle_seconds": current_time - last_access_time,
}
except Exception as e:
logger.error(f"获取流统计失败 {stream_id}: {e}", exc_info=True)
return None
def get_manager_statistics(self) -> Dict[str, Any]:
"""获取管理器统计信息
Returns:
Dict[str, Any]: 管理器统计信息
"""
current_time = time.time()
uptime = current_time - self.stats.get("creation_time", current_time)
return {
**self.stats,
"uptime_hours": uptime / 3600,
"stream_count": len(self.stream_contexts),
"metadata_count": len(self.context_metadata),
"auto_cleanup_enabled": self.auto_cleanup,
"cleanup_interval": self.cleanup_interval,
}
def cleanup_inactive_contexts(self, max_inactive_hours: int = 24) -> int:
"""清理不活跃的上下文
Args:
max_inactive_hours: 最大不活跃小时数
Returns:
int: 清理的上下文数量
"""
current_time = time.time()
max_inactive_seconds = max_inactive_hours * 3600
inactive_streams = []
for stream_id, context in self.stream_contexts.items():
try:
# 获取最后活动时间
metadata = self.context_metadata.get(stream_id, {})
last_activity = metadata.get("last_access_time", metadata.get("created_time", 0))
context_last_activity = getattr(context, "last_check_time", 0)
actual_last_activity = max(last_activity, context_last_activity)
# 检查是否不活跃
unread_count = len(getattr(context, "unread_messages", []))
history_count = len(getattr(context, "history_messages", []))
total_messages = unread_count + history_count
if (current_time - actual_last_activity > max_inactive_seconds and
total_messages == 0):
inactive_streams.append(stream_id)
except Exception as e:
logger.warning(f"检查上下文活跃状态失败 {stream_id}: {e}")
continue
# 清理不活跃上下文
cleaned_count = 0
for stream_id in inactive_streams:
if self.remove_stream_context(stream_id):
cleaned_count += 1
if cleaned_count > 0:
logger.info(f"清理了 {cleaned_count} 个不活跃上下文")
return cleaned_count
def validate_context_integrity(self, stream_id: str) -> bool:
"""验证上下文完整性
Args:
stream_id: 流ID
Returns:
bool: 是否完整
"""
context = self.get_stream_context(stream_id)
if not context:
return False
try:
# 检查基本属性
required_attrs = ["stream_id", "unread_messages", "history_messages"]
for attr in required_attrs:
if not hasattr(context, attr):
logger.warning(f"上下文缺少必要属性: {attr}")
return False
# 检查消息ID唯一性
all_messages = getattr(context, "unread_messages", []) + getattr(context, "history_messages", [])
message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")]
if len(message_ids) != len(set(message_ids)):
logger.warning(f"上下文中存在重复消息ID: {stream_id}")
return False
return True
except Exception as e:
logger.error(f"验证上下文完整性失败 {stream_id}: {e}")
return False
async def start(self) -> None:
"""启动上下文管理器"""
if self.is_running:
logger.warning("上下文管理器已经在运行")
return
await self.start_auto_cleanup()
logger.info("上下文管理器已启动")
async def stop(self) -> None:
"""停止上下文管理器"""
if not self.is_running:
return
await self.stop_auto_cleanup()
logger.info("上下文管理器已停止")
async def start_auto_cleanup(self, interval: Optional[float] = None) -> None:
"""启动自动清理
Args:
interval: 清理间隔(秒)
"""
if not self.auto_cleanup:
logger.info("自动清理已禁用")
return
if self.is_running:
logger.warning("自动清理已在运行")
return
self.is_running = True
cleanup_interval = interval or self.cleanup_interval
logger.info(f"启动自动清理(间隔: {cleanup_interval}s")
import asyncio
self.cleanup_task = asyncio.create_task(self._cleanup_loop(cleanup_interval))
async def stop_auto_cleanup(self) -> None:
"""停止自动清理"""
self.is_running = False
if self.cleanup_task and not self.cleanup_task.done():
self.cleanup_task.cancel()
try:
await self.cleanup_task
except Exception:
pass
logger.info("自动清理已停止")
async def _cleanup_loop(self, interval: float) -> None:
"""清理循环
Args:
interval: 清理间隔
"""
while self.is_running:
try:
await asyncio.sleep(interval)
self.cleanup_inactive_contexts()
self._cleanup_expired_contexts()
logger.debug("自动清理完成")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"清理循环出错: {e}", exc_info=True)
await asyncio.sleep(interval)
def _cleanup_expired_contexts(self) -> None:
"""清理过期上下文"""
current_time = time.time()
expired_contexts = []
for stream_id, metadata in self.context_metadata.items():
created_time = metadata.get("created_time", current_time)
if current_time - created_time > self.context_ttl:
expired_contexts.append(stream_id)
for stream_id in expired_contexts:
self.remove_stream_context(stream_id)
if expired_contexts:
logger.info(f"清理了 {len(expired_contexts)} 个过期上下文")
def get_active_streams(self) -> List[str]:
"""获取活跃流列表
Returns:
List[str]: 活跃流ID列表
"""
return list(self.stream_contexts.keys())
# 全局上下文管理器实例
context_manager = StreamContextManager()
logger.error(f"更新流能量失败 {self.stream_id}: {e}")

View File

@@ -17,7 +17,7 @@ from src.chat.planner_actions.action_manager import ChatterActionManager
from .sleep_manager.sleep_manager import SleepManager
from .sleep_manager.wakeup_manager import WakeUpManager
from src.config.config import global_config
from .context_manager import context_manager
from src.plugin_system.apis.chat_api import get_chat_manager
if TYPE_CHECKING:
from src.common.data_models.message_manager_data_model import StreamContext
@@ -44,8 +44,7 @@ class MessageManager:
self.sleep_manager = SleepManager()
self.wakeup_manager = WakeUpManager(self.sleep_manager)
# 初始化上下文管理器
self.context_manager = context_manager
# 不再需要全局上下文管理器,直接通过 ChatManager 访问各个 ChatStream 的 context_manager
async def start(self):
"""启动消息管理器"""
@@ -56,7 +55,7 @@ class MessageManager:
self.is_running = True
self.manager_task = asyncio.create_task(self._manager_loop())
await self.wakeup_manager.start()
await self.context_manager.start()
# await self.context_manager.start() # 已删除,需要重构
logger.info("消息管理器已启动")
async def stop(self):
@@ -72,28 +71,31 @@ class MessageManager:
self.manager_task.cancel()
await self.wakeup_manager.stop()
await self.context_manager.stop()
# await self.context_manager.stop() # 已删除,需要重构
logger.info("消息管理器已停止")
def add_message(self, stream_id: str, message: DatabaseMessages):
"""添加消息到指定聊天流"""
# 检查流上下文是否存在,不存在则创建
context = self.context_manager.get_stream_context(stream_id)
if not context:
# 创建新的流上下文
from src.common.data_models.message_manager_data_model import StreamContext
context = StreamContext(stream_id=stream_id)
# 将创建的上下文添加到管理器
self.context_manager.add_stream_context(stream_id, context)
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
# 使用 context_manager 添加消息
success = self.context_manager.add_message_to_context(stream_id, message)
if not chat_stream:
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
return
if success:
logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}")
else:
logger.warning(f"添加消息到聊天流 {stream_id} 失败")
# 使用 ChatStream 的 context_manager 添加消息
success = chat_stream.context_manager.add_message(message)
if success:
logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}")
else:
logger.warning(f"添加消息到聊天流 {stream_id} 失败")
except Exception as e:
logger.error(f"添加消息到聊天流 {stream_id} 时发生错误: {e}")
def update_message(
self,
@@ -104,17 +106,60 @@ class MessageManager:
should_reply: bool = None,
):
"""更新消息信息"""
# 使用 context_manager 更新消息信息
context = self.context_manager.get_stream_context(stream_id)
if context:
context.update_message_info(message_id, interest_value, actions, should_reply)
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.update_message: 聊天流 {stream_id} 不存在")
return
# 构建更新字典
updates = {}
if interest_value is not None:
updates["interest_value"] = interest_value
if actions is not None:
updates["actions"] = actions
if should_reply is not None:
updates["should_reply"] = should_reply
# 使用 ChatStream 的 context_manager 更新消息
if updates:
success = chat_stream.context_manager.update_message(message_id, updates)
if success:
logger.debug(f"更新消息 {message_id} 成功")
else:
logger.warning(f"更新消息 {message_id} 失败")
except Exception as e:
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
def add_action(self, stream_id: str, message_id: str, action: str):
"""添加动作到消息"""
# 使用 context_manager 添加动作到消息
context = self.context_manager.get_stream_context(stream_id)
if context:
context.add_action_to_message(message_id, action)
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.add_action: 聊天流 {stream_id} 不存在")
return
# 使用 ChatStream 的 context_manager 添加动作
# 注意:这里需要根据实际的 API 调整
# 假设我们可以通过 update_message 来添加动作
success = chat_stream.context_manager.update_message(
message_id, {"actions": [action]}
)
if success:
logger.debug(f"为消息 {message_id} 添加动作 {action} 成功")
else:
logger.warning(f"为消息 {message_id} 添加动作 {action} 失败")
except Exception as e:
logger.error(f"为消息 {message_id} 添加动作时发生错误: {e}")
async def _manager_loop(self):
"""管理器主循环 - 独立聊天流分发周期版本"""
@@ -144,38 +189,53 @@ class MessageManager:
active_streams = 0
total_unread = 0
# 使用 context_manager 获取活跃的流
active_stream_ids = self.context_manager.get_active_streams()
# 通过 ChatManager 获取所有活跃的流
try:
chat_manager = get_chat_manager()
active_stream_ids = list(chat_manager.streams.keys())
for stream_id in active_stream_ids:
context = self.context_manager.get_stream_context(stream_id)
if not context:
continue
for stream_id in active_stream_ids:
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
continue
active_streams += 1
# 检查流是否活跃
context = chat_stream.stream_context
if not context.is_active:
continue
# 检查是否有未读消息
unread_messages = self.context_manager.get_unread_messages(stream_id)
if unread_messages:
total_unread += len(unread_messages)
active_streams += 1
# 如果没有处理任务,创建一个
if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done():
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
# 检查是否有未读消息
unread_messages = chat_stream.context_manager.get_unread_messages()
if unread_messages:
total_unread += len(unread_messages)
# 更新统计
self.stats.active_streams = active_streams
self.stats.total_unread_messages = total_unread
# 如果没有处理任务,创建一个
if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done():
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
# 更新统计
self.stats.active_streams = active_streams
self.stats.total_unread_messages = total_unread
except Exception as e:
logger.error(f"检查所有聊天流时发生错误: {e}")
async def _process_stream_messages(self, stream_id: str):
"""处理指定聊天流的消息"""
context = self.context_manager.get_stream_context(stream_id)
if not context:
return
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"处理消息失败: 聊天流 {stream_id} 不存在")
return
context = chat_stream.stream_context
# 获取未读消息
unread_messages = self.context_manager.get_unread_messages(stream_id)
unread_messages = chat_stream.context_manager.get_unread_messages()
if not unread_messages:
return
@@ -249,8 +309,15 @@ class MessageManager:
def deactivate_stream(self, stream_id: str):
"""停用聊天流"""
context = self.context_manager.get_stream_context(stream_id)
if context:
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在")
return
context = chat_stream.stream_context
context.is_active = False
# 取消处理任务
@@ -259,27 +326,50 @@ class MessageManager:
logger.info(f"停用聊天流: {stream_id}")
except Exception as e:
logger.error(f"停用聊天流 {stream_id} 时发生错误: {e}")
def activate_stream(self, stream_id: str):
"""激活聊天流"""
context = self.context_manager.get_stream_context(stream_id)
if context:
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在")
return
context = chat_stream.stream_context
context.is_active = True
logger.info(f"激活聊天流: {stream_id}")
except Exception as e:
logger.error(f"激活聊天流 {stream_id} 时发生错误: {e}")
def get_stream_stats(self, stream_id: str) -> Optional[StreamStats]:
"""获取聊天流统计"""
context = self.context_manager.get_stream_context(stream_id)
if not context:
return None
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
return None
return StreamStats(
stream_id=stream_id,
is_active=context.is_active,
unread_count=len(self.context_manager.get_unread_messages(stream_id)),
history_count=len(context.history_messages),
last_check_time=context.last_check_time,
has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()),
)
context = chat_stream.stream_context
unread_count = len(chat_stream.context_manager.get_unread_messages())
return StreamStats(
stream_id=stream_id,
is_active=context.is_active,
unread_count=unread_count,
history_count=len(context.history_messages),
last_check_time=context.last_check_time,
has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()),
)
except Exception as e:
logger.error(f"获取聊天流 {stream_id} 统计时发生错误: {e}")
return None
def get_manager_stats(self) -> Dict[str, Any]:
"""获取管理器统计"""
@@ -294,9 +384,36 @@ class MessageManager:
def cleanup_inactive_streams(self, max_inactive_hours: int = 24):
"""清理不活跃的聊天流"""
# 使用 context_manager 的自动清理功能
self.context_manager.cleanup_inactive_contexts(max_inactive_hours * 3600)
logger.info("已启动不活跃聊天流清理")
try:
# 通过 ChatManager 清理不活跃的流
chat_manager = get_chat_manager()
current_time = time.time()
max_inactive_seconds = max_inactive_hours * 3600
inactive_streams = []
for stream_id, chat_stream in chat_manager.streams.items():
# 检查最后活跃时间
if current_time - chat_stream.last_active_time > max_inactive_seconds:
inactive_streams.append(stream_id)
# 清理不活跃的流
for stream_id in inactive_streams:
try:
# 清理流的内容
chat_stream.context_manager.clear_context()
# 从 ChatManager 中移除
del chat_manager.streams[stream_id]
logger.info(f"清理不活跃聊天流: {stream_id}")
except Exception as e:
logger.error(f"清理聊天流 {stream_id} 失败: {e}")
if inactive_streams:
logger.info(f"已清理 {len(inactive_streams)} 个不活跃聊天流")
else:
logger.debug("没有需要清理的不活跃聊天流")
except Exception as e:
logger.error(f"清理不活跃聊天流时发生错误: {e}")
async def _check_and_handle_interruption(self, context: StreamContext, stream_id: str):
"""检查并处理消息打断"""
@@ -375,116 +492,124 @@ class MessageManager:
min_delay = float("inf")
# 找到最近需要检查的流
active_stream_ids = self.context_manager.get_active_streams()
for stream_id in active_stream_ids:
context = self.context_manager.get_stream_context(stream_id)
if not context or not context.is_active:
continue
try:
chat_manager = get_chat_manager()
for _stream_id, chat_stream in chat_manager.streams.items():
context = chat_stream.stream_context
if not context or not context.is_active:
continue
time_until_check = context.next_check_time - current_time
if time_until_check > 0:
min_delay = min(min_delay, time_until_check)
else:
min_delay = 0.1 # 立即检查
break
time_until_check = context.next_check_time - current_time
if time_until_check > 0:
min_delay = min(min_delay, time_until_check)
else:
min_delay = 0.1 # 立即检查
break
# 如果没有活跃流,使用默认间隔
if min_delay == float("inf"):
# 如果没有活跃流,使用默认间隔
if min_delay == float("inf"):
return self.check_interval
# 确保最小延迟
return max(0.1, min(min_delay, self.check_interval))
except Exception as e:
logger.error(f"计算下次检查延迟时发生错误: {e}")
return self.check_interval
# 确保最小延迟
return max(0.1, min(min_delay, self.check_interval))
async def _check_streams_with_individual_intervals(self):
"""检查所有达到检查时间的聊天流"""
current_time = time.time()
processed_streams = 0
# 使用 context_manager 获取活跃的流
active_stream_ids = self.context_manager.get_active_streams()
# 通过 ChatManager 获取活跃的流
try:
chat_manager = get_chat_manager()
for stream_id, chat_stream in chat_manager.streams.items():
context = chat_stream.stream_context
if not context or not context.is_active:
continue
for stream_id in active_stream_ids:
context = self.context_manager.get_stream_context(stream_id)
if not context or not context.is_active:
continue
# 检查是否达到检查时间
if current_time >= context.next_check_time:
# 更新检查时间
context.last_check_time = current_time
# 检查是否达到检查时间
if current_time >= context.next_check_time:
# 更新检查时间
context.last_check_time = current_time
# 计算下次检查时间和分发周期
if global_config.chat.dynamic_distribution_enabled:
context.distribution_interval = self._calculate_stream_distribution_interval(context)
else:
context.distribution_interval = self.check_interval
# 计算下次检查时间和分发周期
if global_config.chat.dynamic_distribution_enabled:
context.distribution_interval = self._calculate_stream_distribution_interval(context)
else:
context.distribution_interval = self.check_interval
# 设置下次检查时间
context.next_check_time = current_time + context.distribution_interval
# 设置下次检查时间
context.next_check_time = current_time + context.distribution_interval
# 检查未读消息
unread_messages = chat_stream.context_manager.get_unread_messages()
if unread_messages:
processed_streams += 1
self.stats.total_unread_messages = len(unread_messages)
# 检查未读消息
unread_messages = self.context_manager.get_unread_messages(stream_id)
if unread_messages:
processed_streams += 1
self.stats.total_unread_messages = len(unread_messages)
# 如果没有处理任务,创建一个
if not context.processing_task or context.processing_task.done():
focus_energy = chat_stream.focus_energy
# 如果没有处理任务,创建一个
if not context.processing_task or context.processing_task.done():
from src.plugin_system.apis.chat_api import get_chat_manager
# 根据优先级记录日志
if focus_energy >= 0.7:
logger.info(
f"高优先级流 {stream_id} 开始处理 | "
f"focus_energy: {focus_energy:.3f} | "
f"分发周期: {context.distribution_interval:.2f}s | "
f"未读消息: {len(unread_messages)}"
)
else:
logger.debug(
f"{stream_id} 开始处理 | "
f"focus_energy: {focus_energy:.3f} | "
f"分发周期: {context.distribution_interval:.2f}s"
)
chat_stream = get_chat_manager().get_stream(context.stream_id)
focus_energy = chat_stream.focus_energy if chat_stream else 0.5
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
# 根据优先级记录日志
if focus_energy >= 0.7:
logger.info(
f"高优先级流 {stream_id} 开始处理 | "
f"focus_energy: {focus_energy:.3f} | "
f"分发周期: {context.distribution_interval:.2f}s | "
f"未读消息: {len(unread_messages)}"
)
else:
logger.debug(
f"{stream_id} 开始处理 | "
f"focus_energy: {focus_energy:.3f} | "
f"分发周期: {context.distribution_interval:.2f}s"
)
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
except Exception as e:
logger.error(f"检查独立分发周期的聊天流时发生错误: {e}")
# 更新活跃流计数
active_count = len(self.context_manager.get_active_streams())
self.stats.active_streams = active_count
try:
chat_manager = get_chat_manager()
active_count = len([s for s in chat_manager.streams.values() if s.stream_context.is_active])
self.stats.active_streams = active_count
if processed_streams > 0:
logger.debug(f"本次循环处理了 {processed_streams} 个流 | 活跃流总数: {active_count}")
if processed_streams > 0:
logger.debug(f"本次循环处理了 {processed_streams} 个流 | 活跃流总数: {active_count}")
except Exception as e:
logger.error(f"更新活跃流计数时发生错误: {e}")
async def _check_all_streams_with_priority(self):
"""按优先级检查所有聊天流高focus_energy的流优先处理"""
if not self.context_manager.get_active_streams():
try:
chat_manager = get_chat_manager()
if not chat_manager.streams:
return
# 获取活跃的聊天流并按focus_energy排序
active_streams = []
for stream_id, chat_stream in chat_manager.streams.items():
context = chat_stream.stream_context
if not context or not context.is_active:
continue
# 获取focus_energy
focus_energy = chat_stream.focus_energy
# 计算流优先级分数
priority_score = self._calculate_stream_priority(context, focus_energy)
active_streams.append((priority_score, stream_id, context))
except Exception as e:
logger.error(f"获取活跃流列表时发生错误: {e}")
return
# 获取活跃的聊天流并按focus_energy排序
active_streams = []
active_stream_ids = self.context_manager.get_active_streams()
for stream_id in active_stream_ids:
context = self.context_manager.get_stream_context(stream_id)
if not context or not context.is_active:
continue
# 获取focus_energy如果不存在则使用默认值
from src.plugin_system.apis.chat_api import get_chat_manager
chat_stream = get_chat_manager().get_stream(context.stream_id)
focus_energy = 0.5
if hasattr(context, 'chat_stream') and context.chat_stream:
focus_energy = context.chat_stream.focus_energy
# 计算流优先级分数
priority_score = self._calculate_stream_priority(context, focus_energy)
active_streams.append((priority_score, stream_id, context))
# 按优先级降序排序
active_streams.sort(reverse=True, key=lambda x: x[0])
@@ -496,21 +621,29 @@ class MessageManager:
active_stream_count += 1
# 检查是否有未读消息
unread_messages = self.context_manager.get_unread_messages(stream_id)
if unread_messages:
total_unread += len(unread_messages)
try:
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
continue
# 如果没有处理任务,创建一个
if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done():
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
unread_messages = chat_stream.context_manager.get_unread_messages()
if unread_messages:
total_unread += len(unread_messages)
# 高优先级流的额外日志
if priority_score > 0.7:
logger.info(
f"高优先级流 {stream_id} 开始处理 | "
f"优先级: {priority_score:.3f} | "
f"未读消息: {len(unread_messages)}"
)
# 如果没有处理任务,创建一个
if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done():
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
# 高优先级流的额外日志
if priority_score > 0.7:
logger.info(
f"高优先级流 {stream_id} 开始处理 | "
f"优先级: {priority_score:.3f} | "
f"未读消息: {len(unread_messages)}"
)
except Exception as e:
logger.error(f"处理流 {stream_id} 的未读消息时发生错误: {e}")
continue
# 更新统计
self.stats.active_streams = active_stream_count
@@ -535,22 +668,33 @@ class MessageManager:
def _clear_all_unread_messages(self, stream_id: str):
"""清除指定上下文中的所有未读消息,防止意外情况导致消息一直未读"""
unread_messages = self.context_manager.get_unread_messages(stream_id)
if not unread_messages:
return
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"清除消息失败: 聊天流 {stream_id} 不存在")
return
logger.warning(f"正在清除 {len(unread_messages)} 条未读消息")
# 获取未读消息
unread_messages = chat_stream.context_manager.get_unread_messages()
if not unread_messages:
return
# 将所有未读消息标记为已读
context = self.context_manager.get_stream_context(stream_id)
if context:
for msg in unread_messages[:]: # 使用切片复制避免迭代时修改列表
try:
context.mark_message_as_read(msg.message_id)
self.stats.total_processed_messages += 1
logger.debug(f"强制清除消息 {msg.message_id},标记为已读")
except Exception as e:
logger.error(f"清除消息 {msg.message_id} 时出错: {e}")
logger.warning(f"正在清除 {len(unread_messages)} 条未读消息")
# 将所有未读消息标记为已读
message_ids = [msg.message_id for msg in unread_messages]
success = chat_stream.context_manager.mark_messages_as_read(message_ids)
if success:
self.stats.total_processed_messages += len(unread_messages)
logger.debug(f"强制清除 {len(unread_messages)} 条消息,标记为已读")
else:
logger.error("标记未读消息为已读失败")
except Exception as e:
logger.error(f"清除未读消息时发生错误: {e}")
# 创建全局消息管理器实例