diff --git a/src/chat/message_manager/__init__.py b/src/chat/message_manager/__init__.py index c8bd18a08..7b67424f9 100644 --- a/src/chat/message_manager/__init__.py +++ b/src/chat/message_manager/__init__.py @@ -3,13 +3,11 @@ 提供统一的消息管理、上下文管理和流循环调度功能 """ -from .context_manager import SingleStreamContextManager from .distribution_manager import StreamLoopManager, stream_loop_manager from .message_manager import MessageManager, message_manager __all__ = [ "MessageManager", - "SingleStreamContextManager", "StreamLoopManager", "message_manager", "stream_loop_manager", diff --git a/src/chat/message_manager/context_manager.py b/src/chat/message_manager/context_manager.py deleted file mode 100644 index b26e660b4..000000000 --- a/src/chat/message_manager/context_manager.py +++ /dev/null @@ -1,529 +0,0 @@ -""" -重构后的聊天上下文管理器 -提供统一、稳定的聊天上下文管理功能 -每个 context_manager 实例只管理一个 stream 的上下文 -""" - -import asyncio -import time -from typing import TYPE_CHECKING, Any - -from src.chat.energy_system import energy_manager -from src.common.data_models.database_data_model import DatabaseMessages -from src.common.logger import get_logger -from src.config.config import global_config -from src.plugin_system.base.component_types import ChatType - -if TYPE_CHECKING: - from src.common.data_models.message_manager_data_model import StreamContext - -logger = get_logger("context_manager") - -# 全局背景任务集合(用于异步初始化等后台任务) -_background_tasks = set() - -# 三层记忆系统的延迟导入(避免循环依赖) -_unified_memory_manager = None - - -def _get_unified_memory_manager(): - """获取统一记忆管理器(延迟导入)""" - global _unified_memory_manager - if _unified_memory_manager is None: - try: - from src.memory_graph.manager_singleton import get_unified_memory_manager - - _unified_memory_manager = get_unified_memory_manager() - except Exception as e: - logger.warning(f"获取统一记忆管理器失败(可能未启用): {e}") - _unified_memory_manager = False # 标记为禁用,避免重复尝试 - return _unified_memory_manager if _unified_memory_manager is not False else None - - -class SingleStreamContextManager: - """单流上下文管理器 - 每个实例只管理一个 stream 的上下文""" - - def __init__(self, stream_id: str, context: "StreamContext", max_context_size: int | None = None): - self.stream_id = stream_id - self.context = context - - # 配置参数 - self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100) - - # 元数据 - self.created_time = time.time() - self.last_access_time = time.time() - self.access_count = 0 - self.total_messages = 0 - - # 标记是否已初始化历史消息 - self._history_initialized = False - - logger.debug(f"单流上下文管理器初始化: {stream_id}") - - # 异步初始化历史消息(不阻塞构造函数) - task = asyncio.create_task(self._initialize_history_from_db()) - _background_tasks.add(task) - task.add_done_callback(_background_tasks.discard) - - def get_context(self) -> "StreamContext": - """获取流上下文""" - self._update_access_stats() - return self.context - - async def add_message(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool: - """添加消息到上下文 - - Args: - message: 消息对象 - skip_energy_update: 是否跳过能量更新(兼容参数,当前忽略) - - Returns: - bool: 是否成功添加 - """ - try: - # 检查并配置StreamContext的缓存系统 - cache_enabled = global_config.chat.enable_message_cache - if cache_enabled and not self.context.is_cache_enabled: - self.context.enable_cache(True) - logger.debug(f"为StreamContext {self.stream_id} 启用缓存系统") - - # 新消息默认占位兴趣值,延迟到 Chatter 批量处理阶段 - if message.interest_value is None: - message.interest_value = 0.3 - message.should_reply = False - message.should_act = False - message.interest_calculated = False - message.semantic_embedding = None - message.is_read = False - - # 使用StreamContext的智能缓存功能 - success = self.context.add_message_with_cache_check(message, force_direct=not cache_enabled) - - if success: - # 自动检测和更新chat type - self._detect_chat_type(message) - - self.total_messages += 1 - self.last_access_time = time.time() - - # 如果使用了缓存系统,输出调试信息 - if cache_enabled and self.context.is_cache_enabled: - if self.context.is_chatter_processing: - logger.debug(f"消息已缓存到StreamContext,等待处理完成: stream={self.stream_id}") - else: - logger.debug(f"消息直接添加到StreamContext未读列表: stream={self.stream_id}") - else: - logger.debug(f"消息添加到StreamContext(缓存禁用): {self.stream_id}") - - # 三层记忆系统集成:将消息添加到感知记忆层 - try: - if global_config.memory and global_config.memory.enable: - unified_manager = _get_unified_memory_manager() - if unified_manager: - # 构建消息字典 - message_dict = { - "message_id": str(message.message_id), - "sender_id": message.user_info.user_id, - "sender_name": message.user_info.user_nickname, - "content": message.processed_plain_text or message.display_message or "", - "timestamp": message.time, - "platform": message.chat_info.platform, - "stream_id": self.stream_id, - } - await unified_manager.add_message(message_dict) - logger.debug(f"消息已添加到三层记忆系统: {message.message_id}") - except Exception as e: - # 记忆系统错误不应影响主流程 - logger.error(f"添加消息到三层记忆系统失败: {e}", exc_info=True) - - return True - else: - logger.error(f"StreamContext消息添加失败: {self.stream_id}") - return False - - except Exception as e: - logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True) - return False - - async def update_message(self, message_id: str, updates: dict[str, Any]) -> bool: - """更新上下文中的消息 - - Args: - message_id: 消息ID - updates: 更新的属性 - - Returns: - bool: 是否成功更新 - """ - try: - # 直接在未读消息中查找并更新(统一转字符串比较) - for message in self.context.unread_messages: - if str(message.message_id) == str(message_id): - if "interest_value" in updates: - message.interest_value = updates["interest_value"] - if "actions" in updates: - message.actions = updates["actions"] - if "should_reply" in updates: - message.should_reply = updates["should_reply"] - break - - # 在历史消息中查找并更新(统一转字符串比较) - for message in self.context.history_messages: - if str(message.message_id) == str(message_id): - if "interest_value" in updates: - message.interest_value = updates["interest_value"] - if "actions" in updates: - message.actions = updates["actions"] - if "should_reply" in updates: - message.should_reply = updates["should_reply"] - break - - logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}") - return True - except Exception as e: - logger.error(f"更新单流上下文消息失败 {self.stream_id}/{message_id}: {e}", exc_info=True) - return False - - def get_messages(self, limit: int | None = None, include_unread: bool = True) -> list[DatabaseMessages]: - """获取上下文消息 - - Args: - limit: 消息数量限制 - include_unread: 是否包含未读消息 - - Returns: - List[DatabaseMessages]: 消息列表 - """ - try: - messages = [] - if include_unread: - messages.extend(self.context.get_unread_messages()) - - if limit: - messages.extend(self.context.get_history_messages(limit=limit)) - else: - messages.extend(self.context.get_history_messages()) - - # 按时间排序 - messages.sort(key=lambda msg: getattr(msg, "time", 0)) - - # 应用限制 - if limit and len(messages) > limit: - messages = messages[-limit:] - - return messages - - except Exception as e: - logger.error(f"获取单流上下文消息失败 {self.stream_id}: {e}", exc_info=True) - return [] - - def get_unread_messages(self) -> list[DatabaseMessages]: - """获取未读消息""" - try: - return self.context.get_unread_messages() - except Exception as e: - logger.error(f"获取单流未读消息失败 {self.stream_id}: {e}", exc_info=True) - return [] - - def mark_messages_as_read(self, message_ids: list[str]) -> bool: - """标记消息为已读""" - try: - if not hasattr(self.context, "mark_message_as_read"): - logger.error(f"上下文对象缺少 mark_message_as_read 方法: {self.stream_id}") - return False - - marked_count = 0 - failed_ids = [] - for message_id in message_ids: - try: - # 传递最大历史消息数量限制 - self.context.mark_message_as_read(message_id, max_history_size=self.max_context_size) - marked_count += 1 - except Exception as e: - failed_ids.append(str(message_id)[:8]) - logger.warning(f"标记消息已读失败 {message_id}: {e}") - - return marked_count > 0 - - except Exception as e: - logger.error(f"标记消息已读失败 {self.stream_id}: {e}", exc_info=True) - return False - - async def clear_context(self) -> bool: - """清空上下文""" - try: - if hasattr(self.context, "unread_messages"): - self.context.unread_messages.clear() - if hasattr(self.context, "history_messages"): - self.context.history_messages.clear() - reset_attrs = ["interruption_count", "afc_threshold_adjustment", "last_check_time"] - for attr in reset_attrs: - if hasattr(self.context, attr): - if attr in ["interruption_count", "afc_threshold_adjustment"]: - setattr(self.context, attr, 0) - else: - setattr(self.context, attr, time.time()) - await self._update_stream_energy() - logger.debug(f"清空单流上下文: {self.stream_id}") - return True - except Exception as e: - logger.error(f"清空单流上下文失败 {self.stream_id}: {e}", exc_info=True) - return False - - def get_statistics(self) -> dict[str, Any]: - """获取流统计信息""" - try: - current_time = time.time() - uptime = current_time - self.created_time - - unread_messages = getattr(self.context, "unread_messages", []) - history_messages = getattr(self.context, "history_messages", []) - - stats = { - "stream_id": self.stream_id, - "context_type": type(self.context).__name__, - "total_messages": len(history_messages) + len(unread_messages), - "unread_messages": len(unread_messages), - "history_messages": len(history_messages), - "is_active": getattr(self.context, "is_active", True), - "last_check_time": getattr(self.context, "last_check_time", current_time), - "interruption_count": getattr(self.context, "interruption_count", 0), - "afc_threshold_adjustment": getattr(self.context, "afc_threshold_adjustment", 0.0), - "created_time": self.created_time, - "last_access_time": self.last_access_time, - "access_count": self.access_count, - "uptime_seconds": uptime, - "idle_seconds": current_time - self.last_access_time, - } - - # 添加缓存统计信息 - if hasattr(self.context, "get_cache_stats"): - stats["cache_stats"] = self.context.get_cache_stats() - - return stats - except Exception as e: - logger.error(f"获取单流统计失败 {self.stream_id}: {e}", exc_info=True) - return {} - - def flush_cached_messages(self) -> list[DatabaseMessages]: - """ - 刷新StreamContext中的缓存消息到未读列表 - - Returns: - list[DatabaseMessages]: 刷新的消息列表 - """ - try: - if hasattr(self.context, "flush_cached_messages"): - cached_messages = self.context.flush_cached_messages() - if cached_messages: - logger.debug(f"从StreamContext刷新缓存消息: stream={self.stream_id}, 数量={len(cached_messages)}") - return cached_messages - else: - logger.debug(f"StreamContext不支持缓存刷新: stream={self.stream_id}") - return [] - except Exception as e: - logger.error(f"刷新StreamContext缓存失败: stream={self.stream_id}, error={e}") - return [] - - def get_cache_stats(self) -> dict[str, Any]: - """获取StreamContext的缓存统计信息""" - try: - if hasattr(self.context, "get_cache_stats"): - return self.context.get_cache_stats() - else: - return {"error": "StreamContext不支持缓存统计"} - except Exception as e: - logger.error(f"获取StreamContext缓存统计失败: stream={self.stream_id}, error={e}") - return {"error": str(e)} - - def validate_integrity(self) -> bool: - """验证上下文完整性""" - try: - # 检查基本属性 - required_attrs = ["stream_id", "unread_messages", "history_messages"] - for attr in required_attrs: - if not hasattr(self.context, attr): - logger.warning(f"上下文缺少必要属性: {attr}") - return False - - # 检查消息ID唯一性 - all_messages = getattr(self.context, "unread_messages", []) + getattr(self.context, "history_messages", []) - message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")] - if len(message_ids) != len(set(message_ids)): - logger.warning(f"上下文中存在重复消息ID: {self.stream_id}") - return False - - return True - - except Exception as e: - logger.error(f"验证单流上下文完整性失败 {self.stream_id}: {e}") - return False - - def _update_access_stats(self): - """更新访问统计""" - self.last_access_time = time.time() - self.access_count += 1 - - async def _initialize_history_from_db(self): - """从数据库初始化历史消息到context中""" - if self._history_initialized: - logger.debug(f"历史消息已初始化,跳过: {self.stream_id}, 当前历史消息数: {len(self.context.history_messages)}") - return - - # 立即设置标志,防止并发重复加载 - logger.info(f"🔄 [历史加载] 开始从数据库加载历史消息: {self.stream_id}") - self._history_initialized = True - - try: - logger.debug(f"开始从数据库加载历史消息: {self.stream_id}") - - from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat - - # 加载历史消息(限制数量为max_context_size) - db_messages = await get_raw_msg_before_timestamp_with_chat( - chat_id=self.stream_id, - timestamp=time.time(), - limit=self.max_context_size, - ) - - if db_messages: - logger.info(f"📥 [历史加载] 从数据库获取到 {len(db_messages)} 条消息") - # 将数据库消息转换为 DatabaseMessages 对象并添加到历史 - loaded_count = 0 - for msg_dict in db_messages: - try: - # 使用 ** 解包字典作为关键字参数 - db_msg = DatabaseMessages(**msg_dict) - - # 标记为已读 - db_msg.is_read = True - - # 添加到历史消息 - self.context.history_messages.append(db_msg) - loaded_count += 1 - - except Exception as e: - logger.warning(f"转换历史消息失败 (message_id={msg_dict.get('message_id', 'unknown')}): {e}") - continue - - # 应用历史消息长度限制 - if len(self.context.history_messages) > self.max_context_size: - removed_count = len(self.context.history_messages) - self.max_context_size - self.context.history_messages = self.context.history_messages[-self.max_context_size:] - logger.debug(f"📝 [历史加载] 移除了 {removed_count} 条过旧的历史消息以保持上下文大小限制") - - logger.info(f"✅ [历史加载] 成功加载 {loaded_count} 条历史消息到内存: {self.stream_id}") - else: - logger.debug(f"没有历史消息需要加载: {self.stream_id}") - - except Exception as e: - logger.error(f"从数据库初始化历史消息失败: {self.stream_id}, {e}", exc_info=True) - # 加载失败时重置标志,允许重试 - self._history_initialized = False - - async def ensure_history_initialized(self): - """确保历史消息已初始化(供外部调用)""" - if not self._history_initialized: - await self._initialize_history_from_db() - - async def _calculate_message_interest(self, message: DatabaseMessages) -> float: - """ - 在上下文管理器中计算消息的兴趣度 - """ - try: - from src.chat.interest_system.interest_manager import get_interest_manager - - interest_manager = get_interest_manager() - - if interest_manager.has_calculator(): - # 使用兴趣值计算组件计算 - result = await interest_manager.calculate_interest(message) - - if result.success: - # 更新消息对象的兴趣值相关字段 - message.interest_value = result.interest_value - message.should_reply = result.should_reply - message.should_act = result.should_act - message.interest_calculated = True - - logger.debug( - f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, " - f"should_reply: {result.should_reply}, should_act: {result.should_act}" - ) - return result.interest_value - else: - logger.warning(f"消息 {message.message_id} 兴趣值计算失败: {result.error_message}") - message.interest_calculated = False - return 0.5 - else: - logger.debug("未找到兴趣值计算器,使用默认兴趣值") - return 0.5 - - except Exception as e: - logger.error(f"计算消息兴趣度时发生错误: {e}", exc_info=True) - if hasattr(message, "interest_calculated"): - message.interest_calculated = False - return 0.5 - - def _detect_chat_type(self, message: DatabaseMessages): - """根据消息内容自动检测聊天类型""" - # 只有在第一次添加消息时才检测聊天类型,避免后续消息改变类型 - if len(self.context.unread_messages) == 1: # 只有这条消息 - # 如果消息包含群组信息,则为群聊 - if message.chat_info.group_info: - self.context.chat_type = ChatType.GROUP - else: - self.context.chat_type = ChatType.PRIVATE - - async def clear_context_async(self) -> bool: - """异步实现的 clear_context:清空消息并 await 能量重算。""" - try: - if hasattr(self.context, "unread_messages"): - self.context.unread_messages.clear() - if hasattr(self.context, "history_messages"): - self.context.history_messages.clear() - - reset_attrs = ["interruption_count", "afc_threshold_adjustment", "last_check_time"] - for attr in reset_attrs: - if hasattr(self.context, attr): - if attr in ["interruption_count", "afc_threshold_adjustment"]: - setattr(self.context, attr, 0) - else: - setattr(self.context, attr, time.time()) - - await self._update_stream_energy() - logger.info(f"清空单流上下文(异步): {self.stream_id}") - return True - except Exception as e: - logger.error(f"清空单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True) - return False - - async def refresh_focus_energy_from_history(self) -> None: - """基于历史消息刷新聚焦能量""" - await self._update_stream_energy(include_unread=False) - - async def _update_stream_energy(self, include_unread: bool = False) -> None: - """更新流能量""" - try: - history_messages = self.context.get_history_messages(limit=self.max_context_size) - messages: list[DatabaseMessages] = list(history_messages) - - if include_unread: - messages.extend(self.get_unread_messages()) - - # 获取用户ID(优先使用最新历史消息) - user_id = None - if messages: - last_message = messages[-1] - if hasattr(last_message, "user_info") and last_message.user_info: - user_id = last_message.user_info.user_id - - await energy_manager.calculate_focus_energy( - stream_id=self.stream_id, - messages=messages, - user_id=user_id, - ) - - except Exception as e: - logger.error(f"更新单流能量失败 {self.stream_id}: {e}") diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index b8e940748..41654971c 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -81,7 +81,7 @@ class StreamLoopManager: # 创建任务列表以便并发取消 cancel_tasks = [] for chat_stream in all_streams.values(): - context = chat_stream.context_manager.context + context = chat_stream.context if context.stream_loop_task and not context.stream_loop_task.done(): context.stream_loop_task.cancel() cancel_tasks.append((chat_stream.stream_id, context.stream_loop_task)) @@ -309,7 +309,7 @@ class StreamLoopManager: chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) if chat_stream: - return chat_stream.context_manager.context + return chat_stream.context return None except Exception as e: logger.error(f"获取流上下文失败 {stream_id}: {e}") @@ -463,7 +463,7 @@ class StreamLoopManager: logger.debug(f"无法找到聊天流 {stream_id},跳过能量更新") return - # 从 context_manager 获取消息(包括未读和历史消息) + # 从 context 获取消息(包括未读和历史消息) # 合并未读消息和历史消息 all_messages = [] @@ -573,7 +573,7 @@ class StreamLoopManager: if not chat_stream: return False - unread = getattr(chat_stream.context_manager.context, "unread_messages", []) + unread = getattr(chat_stream.context, "unread_messages", []) return len(unread) > self.force_dispatch_unread_threshold except Exception as e: logger.debug(f"检查流 {stream_id} 是否需要强制分发失败: {e}") @@ -628,7 +628,7 @@ class StreamLoopManager: logger.debug(f"刷新能量时未找到聊天流: {stream_id}") return - await chat_stream.context_manager.refresh_focus_energy_from_history() + await chat_stream.context.refresh_focus_energy_from_history() logger.debug(f"已刷新聊天流 {stream_id} 的聚焦能量") except Exception as e: logger.warning(f"刷新聊天流 {stream_id} 能量失败: {e}") diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index 29da4f068..516c56456 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -41,7 +41,7 @@ class MessageManager: self.action_manager = ChatterActionManager() self.chatter_manager = ChatterManager(self.action_manager) - # 不再需要全局上下文管理器,直接通过 ChatManager 访问各个 ChatStream 的 context_manager + # 不再需要全局上下文管理器,直接通过 ChatManager 访问各个 ChatStream 的 context # 全局Notice管理器 self.notice_manager = global_notice_manager @@ -115,7 +115,7 @@ class MessageManager: # 启动steam loop任务(如果尚未启动) await stream_loop_manager.start_stream_loop(stream_id) await self._check_and_handle_interruption(chat_stream, message) - await chat_stream.context_manager.add_message(message) + await chat_stream.context.add_message(message) except Exception as e: logger.error(f"添加消息到聊天流 {stream_id} 时发生错误: {e}") @@ -143,7 +143,7 @@ class MessageManager: if should_reply is not None: updates["should_reply"] = should_reply if updates: - success = await chat_stream.context_manager.update_message(message_id, updates) + success = await chat_stream.context.update_message(message_id, updates) if success: logger.debug(f"更新消息 {message_id} 成功") else: @@ -160,7 +160,7 @@ class MessageManager: if not chat_stream: logger.warning(f"MessageManager.add_action: 聊天流 {stream_id} 不存在") return - success = await chat_stream.context_manager.update_message(message_id, {"actions": [action]}) + success = await chat_stream.context.update_message(message_id, {"actions": [action]}) if success: logger.debug(f"为消息 {message_id} 添加动作 {action} 成功") else: @@ -178,7 +178,7 @@ class MessageManager: logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在") return - context = chat_stream.context_manager.context + context = chat_stream.context context.is_active = False # 取消处理任务 @@ -200,7 +200,7 @@ class MessageManager: logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在") return - context = chat_stream.context_manager.context + context = chat_stream.context context.is_active = True logger.debug(f"激活聊天流: {stream_id}") @@ -216,8 +216,8 @@ class MessageManager: if not chat_stream: return None - context = chat_stream.context_manager.context - unread_count = len(chat_stream.context_manager.get_unread_messages()) + context = chat_stream.context + unread_count = len(chat_stream.context.get_unread_messages()) return StreamStats( stream_id=stream_id, @@ -265,7 +265,7 @@ class MessageManager: logger.debug(f"聊天流 {stream_id} 在清理时已不存在,跳过") continue - await chat_stream.context_manager.clear_context() + await chat_stream.context.clear_context() # 安全删除流(若已被其他地方删除则捕获) try: @@ -289,7 +289,7 @@ class MessageManager: return # 检查是否正在回复,以及是否允许在回复时打断 - if chat_stream.context_manager.context.is_replying: + if chat_stream.context.is_replying: if not global_config.chat.allow_reply_interruption: logger.debug(f"聊天流 {chat_stream.stream_id} 正在回复中,且配置不允许回复时打断,跳过打断检查") return @@ -302,7 +302,7 @@ class MessageManager: return # 检查上下文 - context = chat_stream.context_manager.context + context = chat_stream.context # 只有当 Chatter 真正在处理时才检查打断 if not context.is_chatter_processing: @@ -379,7 +379,7 @@ class MessageManager: await asyncio.sleep(0.1) # 获取当前的stream context - context = chat_stream.context_manager.context + context = chat_stream.context # 确保有未读消息需要处理 unread_messages = context.get_unread_messages() @@ -411,7 +411,7 @@ class MessageManager: return # 获取未读消息 - unread_messages = chat_stream.context_manager.get_unread_messages() + unread_messages = chat_stream.context.get_unread_messages() if not unread_messages: logger.info(f"🧹 [清除未读] stream={stream_id[:8]}, 无未读消息需要清除") return @@ -423,7 +423,7 @@ class MessageManager: # 将所有未读消息标记为已读 message_ids = [msg.message_id for msg in unread_messages] - success = chat_stream.context_manager.mark_messages_as_read(message_ids) + success = chat_stream.context.mark_messages_as_read(message_ids) if success: self.stats.total_processed_messages += len(unread_messages) @@ -443,7 +443,7 @@ class MessageManager: logger.warning(f"clear_stream_unread_messages: 聊天流 {stream_id} 不存在") return - context = chat_stream.context_manager.context + context = chat_stream.context if hasattr(context, "unread_messages") and context.unread_messages: unread_count = len(context.unread_messages) @@ -453,7 +453,7 @@ class MessageManager: message_ids = [msg.message_id for msg in context.unread_messages] # 标记为已读(会移到历史消息) - success = chat_stream.context_manager.mark_messages_as_read(message_ids) + success = chat_stream.context.mark_messages_as_read(message_ids) if success: logger.debug(f"✅ stream={stream_id[:8]}, 成功标记 {unread_count} 条消息为已读") @@ -481,8 +481,8 @@ class MessageManager: try: chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) - if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"): - chat_stream.context_manager.context.is_chatter_processing = is_processing + if chat_stream and hasattr(chat_stream.context, "is_chatter_processing"): + chat_stream.context.is_chatter_processing = is_processing logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}") except Exception as e: logger.debug(f"更新StreamContext状态失败: stream={stream_id}, error={e}") @@ -517,8 +517,8 @@ class MessageManager: try: chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) - if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"): - return chat_stream.context_manager.context.is_chatter_processing + if chat_stream and hasattr(chat_stream.context, "is_chatter_processing"): + return chat_stream.context.is_chatter_processing except Exception: pass return False diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 2dc1f5696..8777e852f 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -8,6 +8,8 @@ from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert from src.common.data_models.database_data_model import DatabaseMessages +from src.common.data_models.message_manager_data_model import StreamContext +from src.plugin_system.base.component_types import ChatMode, ChatType from src.common.database.api.crud import CRUDBase from src.common.database.api.specialized import get_or_create_chat_stream from src.common.database.compatibility import get_db_session @@ -41,18 +43,10 @@ class ChatStream: self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0 self.saved = False - # 创建单流上下文管理器(包含StreamContext) - from src.chat.message_manager.context_manager import SingleStreamContextManager - from src.common.data_models.message_manager_data_model import StreamContext - from src.plugin_system.base.component_types import ChatMode, ChatType - - self.context_manager: SingleStreamContextManager = SingleStreamContextManager( + self.context: StreamContext = StreamContext( stream_id=stream_id, - context=StreamContext( - stream_id=stream_id, - chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, - chat_mode=ChatMode.FOCUS, - ), + chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, + chat_mode=ChatMode.FOCUS, ) # 基础参数 @@ -73,11 +67,11 @@ class ChatStream: "focus_energy": self.focus_energy, # 基础兴趣度 "base_interest_energy": self.base_interest_energy, - # stream_context基本信息(通过context_manager访问) - "stream_context_chat_type": self.context_manager.context.chat_type.value, - "stream_context_chat_mode": self.context_manager.context.chat_mode.value, + # stream_context基本信息 + "stream_context_chat_type": self.context.chat_type.value, + "stream_context_chat_mode": self.context.chat_mode.value, # 统计信息 - "interruption_count": self.context_manager.context.interruption_count, + "interruption_count": self.context.interruption_count, } @classmethod @@ -94,19 +88,19 @@ class ChatStream: data=data, ) - # 恢复stream_context信息(通过context_manager访问) + # 恢复stream_context信息 if "stream_context_chat_type" in data: from src.plugin_system.base.component_types import ChatMode, ChatType - instance.context_manager.context.chat_type = ChatType(data["stream_context_chat_type"]) + instance.context.chat_type = ChatType(data["stream_context_chat_type"]) if "stream_context_chat_mode" in data: from src.plugin_system.base.component_types import ChatMode, ChatType - instance.context_manager.context.chat_mode = ChatMode(data["stream_context_chat_mode"]) + instance.context.chat_mode = ChatMode(data["stream_context_chat_mode"]) # 恢复interruption_count信息 if "interruption_count" in data: - instance.context_manager.context.interruption_count = data["interruption_count"] + instance.context.interruption_count = data["interruption_count"] return instance @@ -131,15 +125,7 @@ class ChatStream: message: DatabaseMessages 对象,直接使用不需要转换 """ # 直接使用传入的 DatabaseMessages,设置到上下文中 - self.context_manager.context.set_current_message(message) - - # 设置优先级信息(如果存在) - priority_mode = getattr(message, "priority_mode", None) - priority_info = getattr(message, "priority_info", None) - if priority_mode: - self.context_manager.context.priority_mode = priority_mode - if priority_info: - self.context_manager.context.priority_info = priority_info + self.context.set_current_message(message) # 调试日志 logger.debug( @@ -253,7 +239,7 @@ class ChatStream: """异步计算focus_energy""" try: # 使用单流上下文管理器获取消息 - all_messages = self.context_manager.get_messages(limit=global_config.chat.max_context_size) + all_messages = self.context.get_messages(limit=global_config.chat.max_context_size) # 获取用户ID user_id = None @@ -318,7 +304,6 @@ class ChatManager: def __init__(self): if not self._initialized: - from src.common.data_models.database_data_model import DatabaseMessages self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream self.last_messages: dict[str, DatabaseMessages] = {} # stream_id -> last_message @@ -409,135 +394,87 @@ class ChatManager: async def get_or_create_stream( self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None ) -> ChatStream: - """获取或创建聊天流 - 优化版本使用缓存管理器 - - Args: - platform: 平台标识 - user_info: 用户信息 - group_info: 群组信息(可选) - - Returns: - ChatStream: 聊天流对象 - """ - # 生成stream_id + """获取或创建聊天流 - 优化版本使用缓存机制""" try: stream_id = self._generate_stream_id(platform, user_info, group_info) - # 检查内存中是否存在 if stream_id in self.streams: stream = self.streams[stream_id] - - # 更新用户信息和群组信息 stream.update_active_time() if user_info.platform and user_info.user_id: stream.user_info = user_info if group_info: stream.group_info = group_info - - # 检查是否有最后一条消息(现在使用 DatabaseMessages) - from src.common.data_models.database_data_model import DatabaseMessages - if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages): - await stream.set_context(self.last_messages[stream_id]) - else: - logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息") - return stream - - # 使用优化后的API查询(带缓存) - current_time = time.time() - model_instance, _ = await get_or_create_chat_stream( - stream_id=stream_id, - platform=platform, - defaults={ - "create_time": current_time, - "last_active_time": current_time, - "user_platform": user_info.platform if user_info else platform, - "user_id": user_info.user_id if user_info else "", - "user_nickname": user_info.user_nickname if user_info else "", - "user_cardname": user_info.user_cardname if user_info else "", - "group_platform": group_info.platform if group_info else None, - "group_id": group_info.group_id if group_info else None, - "group_name": group_info.group_name if group_info else None, - } - ) - - if model_instance: - # 从 SQLAlchemy 模型转换回 ChatStream.from_dict 期望的格式 - user_info_data = { - "platform": model_instance.user_platform, - "user_id": model_instance.user_id, - "user_nickname": model_instance.user_nickname, - "user_cardname": model_instance.user_cardname or "", - } - group_info_data = None - if model_instance and getattr(model_instance, "group_id", None): - group_info_data = { - "platform": model_instance.group_platform, - "group_id": model_instance.group_id, - "group_name": model_instance.group_name, - } - - data_for_from_dict = { - "stream_id": model_instance.stream_id, - "platform": model_instance.platform, - "user_info": user_info_data, - "group_info": group_info_data, - "create_time": model_instance.create_time, - "last_active_time": model_instance.last_active_time, - "energy_value": model_instance.energy_value, - "sleep_pressure": model_instance.sleep_pressure, - } - stream = ChatStream.from_dict(data_for_from_dict) - # 更新用户信息和群组信息 - stream.user_info = user_info - if group_info: - stream.group_info = group_info - stream.update_active_time() else: - # 创建新的聊天流 - stream = ChatStream( + current_time = time.time() + model_instance, _ = await get_or_create_chat_stream( stream_id=stream_id, platform=platform, - user_info=user_info, - group_info=group_info, + defaults={ + "create_time": current_time, + "last_active_time": current_time, + "user_platform": user_info.platform if user_info else platform, + "user_id": user_info.user_id if user_info else "", + "user_nickname": user_info.user_nickname if user_info else "", + "user_cardname": user_info.user_cardname if user_info else "", + "group_platform": group_info.platform if group_info else None, + "group_id": group_info.group_id if group_info else None, + "group_name": group_info.group_name if group_info else None, + }, ) + + if model_instance: + user_info_data = { + "platform": model_instance.user_platform, + "user_id": model_instance.user_id, + "user_nickname": model_instance.user_nickname, + "user_cardname": model_instance.user_cardname or "", + } + group_info_data = None + if getattr(model_instance, "group_id", None): + group_info_data = { + "platform": model_instance.group_platform, + "group_id": model_instance.group_id, + "group_name": model_instance.group_name, + } + + data_for_from_dict = { + "stream_id": model_instance.stream_id, + "platform": model_instance.platform, + "user_info": user_info_data, + "group_info": group_info_data, + "create_time": model_instance.create_time, + "last_active_time": model_instance.last_active_time, + "energy_value": model_instance.energy_value, + "sleep_pressure": model_instance.sleep_pressure, + } + stream = ChatStream.from_dict(data_for_from_dict) + stream.user_info = user_info + if group_info: + stream.group_info = group_info + stream.update_active_time() + else: + stream = ChatStream( + stream_id=stream_id, + platform=platform, + user_info=user_info, + group_info=group_info, + ) except Exception as e: logger.error(f"获取或创建聊天流失败: {e}", exc_info=True) raise e - from src.common.data_models.database_data_model import DatabaseMessages - if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages): await stream.set_context(self.last_messages[stream_id]) else: logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的") - # 确保 ChatStream 有自己的 context_manager - if not hasattr(stream, "context_manager") or stream.context_manager is None: - from src.chat.message_manager.context_manager import SingleStreamContextManager - from src.common.data_models.message_manager_data_model import StreamContext - from src.plugin_system.base.component_types import ChatMode, ChatType - - logger.info(f"为 stream {stream_id} 创建新的 context_manager") - stream.context_manager = SingleStreamContextManager( - stream_id=stream_id, - context=StreamContext( - stream_id=stream_id, - chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE, - chat_mode=ChatMode.FOCUS, - ), - ) - else: - logger.info(f"stream {stream_id} 已有 context_manager,跳过创建") - - # 保存到内存和数据库 self.streams[stream_id] = stream await self._save_stream(stream) return stream async def get_stream(self, stream_id: str) -> ChatStream | None: """通过stream_id获取聊天流""" - from src.common.data_models.database_data_model import DatabaseMessages - stream = self.streams.get(stream_id) if not stream: return None @@ -765,23 +702,6 @@ class ChatManager: # if stream.stream_id in self.last_messages: # await stream.set_context(self.last_messages[stream.stream_id]) - # 确保 ChatStream 有自己的 context_manager - if not hasattr(stream, "context_manager") or stream.context_manager is None: - from src.chat.message_manager.context_manager import SingleStreamContextManager - from src.common.data_models.message_manager_data_model import StreamContext - from src.plugin_system.base.component_types import ChatMode, ChatType - - logger.debug(f"为加载的 stream {stream.stream_id} 创建新的 context_manager") - stream.context_manager = SingleStreamContextManager( - stream_id=stream.stream_id, - context=StreamContext( - stream_id=stream.stream_id, - chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE, - chat_mode=ChatMode.FOCUS, - ), - ) - else: - logger.debug(f"加载的 stream {stream.stream_id} 已有 context_manager") except Exception as e: logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True) diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index c6ff81f6f..073356d44 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -103,8 +103,8 @@ class HeartFCSender: try: # 将MessageSending转换为DatabaseMessages db_message = await self._convert_to_database_message(message) - if db_message and message.chat_stream.context_manager: - context = message.chat_stream.context_manager.context + if db_message and message.chat_stream.context: + context = message.chat_stream.context # 应用历史消息长度限制 from src.config.config import global_config diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index a0e72ed73..92f94ef64 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -183,7 +183,7 @@ class ChatterActionManager: } # 设置正在回复的状态 - chat_stream.context_manager.context.is_replying = True + chat_stream.context.is_replying = True if action_name == "no_action": return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""} @@ -342,7 +342,7 @@ class ChatterActionManager: finally: # 确保重置正在回复的状态 if chat_stream: - chat_stream.context_manager.context.is_replying = False + chat_stream.context.is_replying = False async def _record_action_to_message(self, chat_stream, action_name, target_message, action_data): """ @@ -387,7 +387,7 @@ class ChatterActionManager: chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) if chat_stream: - context = chat_stream.context_manager + context = chat_stream.context if context.context.interruption_count > 0: old_count = context.context.interruption_count # old_afc_adjustment = context.context.get_afc_threshold_adjustment() diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 4815d9c38..4cc2992f5 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -139,7 +139,7 @@ class ActionModifier: if not self.chat_stream: logger.error(f"{self.log_prefix} chat_stream 未初始化,无法执行第二阶段") return - chat_context = self.chat_stream.context_manager.context + chat_context = self.chat_stream.context current_actions_s2 = self.action_manager.get_using_actions() type_mismatched_actions = self._check_action_associated_types(current_actions_s2, chat_context) diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index a760e6025..a3524dc83 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -396,7 +396,7 @@ class DefaultReplyer: try: # 设置正在回复的状态 - self.chat_stream.context_manager.context.is_replying = True + self.chat_stream.context.is_replying = True content, reasoning_content, model_name, tool_call = await self.llm_generate_content(prompt) logger.debug(f"replyer生成内容: {content}") llm_response = { @@ -413,7 +413,7 @@ class DefaultReplyer: return False, None, prompt # LLM 调用失败则无法生成回复 finally: # 重置正在回复的状态 - self.chat_stream.context_manager.context.is_replying = False + self.chat_stream.context.is_replying = False # 触发 AFTER_LLM 事件 if not from_plugin: @@ -910,7 +910,7 @@ class DefaultReplyer: chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(chat_id) if chat_stream: - stream_context = chat_stream.context_manager + stream_context = chat_stream.context # 确保历史消息已从数据库加载 await stream_context.ensure_history_initialized() @@ -1140,7 +1140,7 @@ class DefaultReplyer: chat_stream_obj = await chat_manager.get_stream(chat_id) if chat_stream_obj: - unread_messages = chat_stream_obj.context_manager.get_unread_messages() + unread_messages = chat_stream_obj.context.get_unread_messages() if unread_messages: # 使用最后一条未读消息作为参考 last_msg = unread_messages[-1] @@ -1262,12 +1262,12 @@ class DefaultReplyer: if chat_stream_obj: # 确保历史消息已初始化 - await chat_stream_obj.context_manager.ensure_history_initialized() + await chat_stream_obj.context.ensure_history_initialized() # 获取所有消息(历史+未读) all_messages = ( - chat_stream_obj.context_manager.context.history_messages + - chat_stream_obj.context_manager.get_unread_messages() + chat_stream_obj.context.history_messages + + chat_stream_obj.context.get_unread_messages() ) # 转换为字典格式 @@ -1639,12 +1639,12 @@ class DefaultReplyer: if chat_stream_obj: # 确保历史消息已初始化 - await chat_stream_obj.context_manager.ensure_history_initialized() + await chat_stream_obj.context.ensure_history_initialized() # 获取所有消息(历史+未读) all_messages = ( - chat_stream_obj.context_manager.context.history_messages + - chat_stream_obj.context_manager.get_unread_messages() + chat_stream_obj.context.history_messages + + chat_stream_obj.context.get_unread_messages() ) # 转换为字典格式,限制数量 @@ -2071,12 +2071,12 @@ class DefaultReplyer: if chat_stream_obj: # 确保历史消息已初始化 - await chat_stream_obj.context_manager.ensure_history_initialized() + await chat_stream_obj.context.ensure_history_initialized() # 获取所有消息(历史+未读) all_messages = ( - chat_stream_obj.context_manager.context.history_messages + - chat_stream_obj.context_manager.get_unread_messages() + chat_stream_obj.context.history_messages + + chat_stream_obj.context.get_unread_messages() ) # 转换为字典格式,限制数量 diff --git a/src/common/data_models/message_manager_data_model.py b/src/common/data_models/message_manager_data_model.py index cf2a6445d..8957ac817 100644 --- a/src/common/data_models/message_manager_data_model.py +++ b/src/common/data_models/message_manager_data_model.py @@ -8,9 +8,10 @@ import time from collections import deque from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from src.common.logger import get_logger +from src.config.config import global_config from src.plugin_system.base.component_types import ChatMode, ChatType from . import BaseDataModel @@ -20,6 +21,23 @@ if TYPE_CHECKING: logger = get_logger("stream_context") +_background_tasks: set[asyncio.Task] = set() +_unified_memory_manager = None + + +def _get_unified_memory_manager(): + """获取记忆体系单例""" + global _unified_memory_manager + if _unified_memory_manager is None: + try: + from src.memory_graph.manager_singleton import get_unified_memory_manager + + _unified_memory_manager = get_unified_memory_manager() + except Exception as e: + logger.warning(f"获取统一记忆管理器失败,可能未实现: {e}") + _unified_memory_manager = False # ���Ϊ���ã������ظ����� + return _unified_memory_manager if _unified_memory_manager is not False else None + class MessageStatus(Enum): """消息状态枚举""" @@ -44,6 +62,7 @@ class StreamContext(BaseDataModel): stream_id: str chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊 chat_mode: ChatMode = ChatMode.FOCUS # 聊天模式,默认为专注模式 + max_context_size: int = field(default_factory=lambda: getattr(global_config.chat, "max_context_size", 100)) unread_messages: list["DatabaseMessages"] = field(default_factory=list) history_messages: list["DatabaseMessages"] = field(default_factory=list) last_check_time: float = field(default_factory=time.time) @@ -54,22 +73,15 @@ class StreamContext(BaseDataModel): interruption_count: int = 0 # 打断计数器 last_interruption_time: float = 0.0 # 上次打断时间 - # 独立分发周期字段 - next_check_time: float = field(default_factory=time.time) # 下次检查时间 - distribution_interval: float = 5.0 # 当前分发周期(秒) - - # 新增字段以替代ChatMessageContext功能 current_message: Optional["DatabaseMessages"] = None - priority_mode: str | None = None - priority_info: dict | None = None - triggering_user_id: str | None = None # 触发当前聊天流的用户ID - is_replying: bool = False # 是否正在生成回复 + triggering_user_id: str | None = None # 记录当前触发的用户ID + is_replying: bool = False # 是否正在进行回复 processing_message_id: str | None = None # 当前正在规划/处理的目标消息ID,用于防止重复回复 decision_history: list["DecisionRecord"] = field(default_factory=list) # 决策历史 # 消息缓存系统相关字段 message_cache: deque["DatabaseMessages"] = field(default_factory=deque) # 消息缓存队列 - is_cache_enabled: bool = False # 是否为此流启用缓存 + is_cache_enabled: bool = False # 是否为当前用户启用缓存 cache_stats: dict = field(default_factory=lambda: { "total_cached_messages": 0, "total_flushed_messages": 0, @@ -77,6 +89,117 @@ class StreamContext(BaseDataModel): "cache_misses": 0 }) # 缓存统计信息 + created_time: float = field(default_factory=time.time) + last_access_time: float = field(default_factory=time.time) + access_count: int = 0 + total_messages: int = 0 + _history_initialized: bool = field(default=False, init=False) + + def __post_init__(self): + """初始化历史消息异步加载""" + if not self.max_context_size or self.max_context_size <= 0: + self.max_context_size = getattr(global_config.chat, "max_context_size", 100) + + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + task = asyncio.create_task(self._initialize_history_from_db()) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + except RuntimeError: + # 事件循环未运行时,await ensure_history_initialized 进行初始化 + pass + + def _update_access_stats(self): + """更新访问统计信息,记录最后访问时间""" + self.last_access_time = time.time() + self.access_count += 1 + + async def add_message(self, message: "DatabaseMessages", skip_energy_update: bool = False) -> bool: + """添加消息到上下文,支持跳过能量更新的选项""" + try: + cache_enabled = global_config.chat.enable_message_cache + if cache_enabled and not self.is_cache_enabled: + self.enable_cache(True) + logger.debug(f"为StreamContext {self.stream_id} 启用消息缓存系统") + + if message.interest_value is None: + message.interest_value = 0.3 + message.should_reply = False + message.should_act = False + message.interest_calculated = False + message.semantic_embedding = None + message.is_read = False + + success = self.add_message_with_cache_check(message, force_direct=not cache_enabled) + if not success: + logger.error(f"StreamContext消息添加失败: {self.stream_id}") + return False + + self._detect_chat_type(message) + self.total_messages += 1 + self._update_access_stats() + + if cache_enabled and self.is_cache_enabled: + if self.is_chatter_processing: + logger.debug(f"消息已缓存到StreamContext等待处理: stream={self.stream_id}") + else: + logger.debug(f"消息直接添加到StreamContext未处理列表: stream={self.stream_id}") + else: + logger.debug(f"消息添加到StreamContext成功: {self.stream_id}") + # ͬ�����ݵ�ͳһ������� + try: + if global_config.memory and global_config.memory.enable: + unified_manager = _get_unified_memory_manager() + if unified_manager: + message_dict = { + "message_id": str(message.message_id), + "sender_id": message.user_info.user_id, + "sender_name": message.user_info.user_nickname, + "content": message.processed_plain_text or message.display_message or "", + "timestamp": message.time, + "platform": message.chat_info.platform, + "stream_id": self.stream_id, + } + await unified_manager.add_message(message_dict) + logger.debug(f"��Ϣ�����ӵ��������ϵͳ: {message.message_id}") + except Exception as e: + logger.error(f"������Ϣ���������ϵͳʧ��: {e}", exc_info=True) + + return True + except Exception as e: + logger.error(f"������Ϣ������������ʧ�� {self.stream_id}: {e}", exc_info=True) + return False + + async def update_message(self, message_id: str, updates: dict[str, Any]) -> bool: + """�����������е���Ϣ""" + try: + for message in self.unread_messages: + if str(message.message_id) == str(message_id): + if "interest_value" in updates: + message.interest_value = updates["interest_value"] + if "actions" in updates: + message.actions = updates["actions"] + if "should_reply" in updates: + message.should_reply = updates["should_reply"] + break + + for message in self.history_messages: + if str(message.message_id) == str(message_id): + if "interest_value" in updates: + message.interest_value = updates["interest_value"] + if "actions" in updates: + message.actions = updates["actions"] + if "should_reply" in updates: + message.should_reply = updates["should_reply"] + break + + logger.debug(f"���µ�����������Ϣ: {self.stream_id}/{message_id}") + return True + except Exception as e: + logger.error(f"���µ�����������Ϣʧ�� {self.stream_id}/{message_id}: {e}", exc_info=True) + return False + def add_action_to_message(self, message_id: str, action: str): """ 向指定消息添加执行的动作 @@ -113,9 +236,7 @@ class StreamContext(BaseDataModel): # 应用历史消息长度限制 if max_history_size is None: - # 从全局配置获取最大历史消息数量 - from src.config.config import global_config - max_history_size = getattr(global_config.chat, "max_context_size", 40) + max_history_size = self.max_context_size # 如果历史消息已达到最大长度,移除最旧的消息 if len(self.history_messages) >= max_history_size: @@ -136,6 +257,44 @@ class StreamContext(BaseDataModel): recent_history = self.history_messages[-limit:] if len(self.history_messages) > limit else self.history_messages return recent_history + def get_messages(self, limit: int | None = None, include_unread: bool = True) -> list["DatabaseMessages"]: + """获取上下文中的消息集合""" + try: + messages: list["DatabaseMessages"] = [] + if include_unread: + messages.extend(self.get_unread_messages()) + + if limit: + messages.extend(self.get_history_messages(limit=limit)) + else: + messages.extend(self.get_history_messages()) + + messages.sort(key=lambda msg: getattr(msg, "time", 0)) + + if limit and len(messages) > limit: + messages = messages[-limit:] + + self._update_access_stats() + return messages + except Exception as e: + logger.error(f"获取上下文消息失败 {self.stream_id}: {e}", exc_info=True) + return [] + + def mark_messages_as_read(self, message_ids: list[str]) -> bool: + """批量标记消息为已读""" + try: + marked_count = 0 + for message_id in message_ids: + try: + self.mark_message_as_read(message_id, max_history_size=self.max_context_size) + marked_count += 1 + except Exception as e: + logger.warning(f"标记消息已读失败 {message_id}: {e}") + return marked_count > 0 + except Exception as e: + logger.error(f"批量标记消息已读失败 {self.stream_id}: {e}", exc_info=True) + return False + def calculate_interruption_probability(self, max_limit: int, min_probability: float = 0.1, probability_factor: float | None = None) -> float: """计算打断概率 - 使用反比例函数模型 @@ -175,6 +334,75 @@ class StreamContext(BaseDataModel): probability = max(min_probability, probability) return max(0.0, min(1.0, probability)) + async def clear_context(self) -> bool: + """清空上下文的未读与历史消息并重置状态""" + try: + self.unread_messages.clear() + self.history_messages.clear() + for attr in ["interruption_count", "afc_threshold_adjustment", "last_check_time"]: + if hasattr(self, attr): + if attr in ["interruption_count", "afc_threshold_adjustment"]: + setattr(self, attr, 0) + else: + setattr(self, attr, time.time()) + await self._update_stream_energy() + logger.debug(f"清空上下文成功: {self.stream_id}") + return True + except Exception as e: + logger.error(f"清空上下文失败 {self.stream_id}: {e}", exc_info=True) + return False + + def get_statistics(self) -> dict[str, Any]: + """获取上下文统计信息""" + try: + current_time = time.time() + uptime = current_time - self.created_time + + stats = { + "stream_id": self.stream_id, + "context_type": type(self).__name__, + "total_messages": len(self.history_messages) + len(self.unread_messages), + "unread_messages": len(self.unread_messages), + "history_messages": len(self.history_messages), + "is_active": self.is_active, + "last_check_time": self.last_check_time, + "interruption_count": self.interruption_count, + "afc_threshold_adjustment": getattr(self, "afc_threshold_adjustment", 0.0), + "created_time": self.created_time, + "last_access_time": self.last_access_time, + "access_count": self.access_count, + "uptime_seconds": uptime, + "idle_seconds": current_time - self.last_access_time, + } + + stats["cache_stats"] = self.get_cache_stats() + return stats + except Exception as e: + logger.error(f"获取上下文统计失败 {self.stream_id}: {e}", exc_info=True) + return {} + + def validate_integrity(self) -> bool: + """校验上下文结构完整性""" + try: + required_attrs = ["stream_id", "unread_messages", "history_messages"] + for attr in required_attrs: + if not hasattr(self, attr): + logger.warning(f"上下文缺少必要属性: {attr}") + return False + + all_messages = self.unread_messages + self.history_messages + message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")] + if len(message_ids) != len(set(message_ids)): + logger.warning(f"上下文中存在重复的消息ID: {self.stream_id}") + return False + + return True + + except Exception as e: + logger.error(f"校验上下文完整性失败 {self.stream_id}: {e}") + return False + + async def increment_interruption_count(self): """增加打断计数""" self.interruption_count += 1 @@ -239,6 +467,131 @@ class StreamContext(BaseDataModel): return self.history_messages[-1] return None + async def ensure_history_initialized(self): + """初始化历史消息异步加载""" + if not self._history_initialized: + await self._initialize_history_from_db() + + async def refresh_focus_energy_from_history(self) -> None: + """根据历史消息刷新关注能量""" + await self._update_stream_energy(include_unread=False) + + async def _update_stream_energy(self, include_unread: bool = False) -> None: + """使用当前上下文消息更新关注能量""" + try: + history_messages = self.get_history_messages(limit=self.max_context_size) + messages: list["DatabaseMessages"] = list(history_messages) + + if include_unread: + messages.extend(self.get_unread_messages()) + + user_id = None + if messages: + last_message = messages[-1] + if hasattr(last_message, "user_info") and last_message.user_info: + user_id = last_message.user_info.user_id + + from src.chat.energy_system import energy_manager + + await energy_manager.calculate_focus_energy( + stream_id=self.stream_id, + messages=messages, + user_id=user_id, + ) + + except Exception as e: + logger.error(f"更新能量体系失败 {self.stream_id}: {e}") + + async def _initialize_history_from_db(self): + """Load history messages from database into context.""" + if self._history_initialized: + logger.debug(f"历史信息已初始化,stream={self.stream_id}, 当前条数={len(self.history_messages)}") + return + + logger.info(f"?? [历史加载] 开始从数据库读取历史消息: {self.stream_id}") + self._history_initialized = True + + try: + logger.debug(f"开始加载数据库历史消息: {self.stream_id}") + + from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat + + db_messages = await get_raw_msg_before_timestamp_with_chat( + chat_id=self.stream_id, + timestamp=time.time(), + limit=self.max_context_size, + ) + + if db_messages: + logger.info(f"[历史加载] 从数据库获取到 {len(db_messages)} 条历史消息") + loaded_count = 0 + for msg_dict in db_messages: + try: + db_msg = DatabaseMessages(**msg_dict) + db_msg.is_read = True + self.history_messages.append(db_msg) + loaded_count += 1 + + except Exception as e: + logger.warning(f"转换历史消息失败 (message_id={msg_dict.get('message_id', 'unknown')}): {e}") + continue + + if len(self.history_messages) > self.max_context_size: + removed_count = len(self.history_messages) - self.max_context_size + self.history_messages = self.history_messages[-self.max_context_size :] + logger.debug(f"[历史加载] 移除了 {removed_count} 条最早的消息以适配当前容量限制") + + logger.info(f"[历史加载] 成功加载 {loaded_count} 条历史消息到内存: {self.stream_id}") + else: + logger.debug(f"无历史消息需要加载: {self.stream_id}") + + except Exception as e: + logger.error(f"从数据库加载历史消息失败: {self.stream_id}, {e}", exc_info=True) + self._history_initialized = False + + def _detect_chat_type(self, message: "DatabaseMessages"): + """基于消息内容检测聊天类型""" + if len(self.unread_messages) == 1: + if message.chat_info.group_info: + self.chat_type = ChatType.GROUP + else: + self.chat_type = ChatType.PRIVATE + + async def _calculate_message_interest(self, message: "DatabaseMessages") -> float: + """调用兴趣系统计算消息兴趣值""" + try: + from src.chat.interest_system.interest_manager import get_interest_manager + + interest_manager = get_interest_manager() + + if interest_manager.has_calculator(): + result = await interest_manager.calculate_interest(message) + + if result.success: + message.interest_value = result.interest_value + message.should_reply = result.should_reply + message.should_act = result.should_act + message.interest_calculated = True + + logger.debug( + f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, " + f"should_reply: {result.should_reply}, should_act: {result.should_act}" + ) + return result.interest_value + else: + logger.warning(f"消息 {message.message_id} 兴趣值计算失败: {result.error_message}") + message.interest_calculated = False + return 0.5 + else: + logger.debug("未找到兴趣计算器,使用默认兴趣值") + return 0.5 + + except Exception as e: + logger.error(f"计算消息兴趣时出现异常: {e}", exc_info=True) + if hasattr(message, "interest_calculated"): + message.interest_calculated = False + return 0.5 + def check_types(self, types: list) -> bool: """ 检查当前消息是否支持指定的类型 @@ -332,14 +685,6 @@ class StreamContext(BaseDataModel): logger.debug("[check_types] ✅ 备用方案通过所有类型检查") return True - def get_priority_mode(self) -> str | None: - """获取优先级模式""" - return self.priority_mode - - def get_priority_info(self) -> dict | None: - """获取优先级信息""" - return self.priority_info - # ==================== 消息缓存系统方法 ==================== def enable_cache(self, enabled: bool = True): diff --git a/src/plugins/built_in/affinity_flow_chatter/planner/plan_executor.py b/src/plugins/built_in/affinity_flow_chatter/planner/plan_executor.py index cc03c0656..1032d5271 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner/plan_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner/plan_executor.py @@ -477,7 +477,7 @@ class ChatterPlanExecutor: ) # 添加到chat_stream的已读消息中 - chat_stream.context_manager.context.history_messages.append(bot_message) + chat_stream.context.history_messages.append(bot_message) logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...") except Exception as e: diff --git a/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py index 61892c1ed..06f5990b3 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner/plan_filter.py @@ -169,7 +169,7 @@ class ChatterPlanFilter: logger.debug("尝试添加空的决策历史,已跳过") return - context = chat_stream.context_manager.context + context = chat_stream.context new_record = DecisionRecord(thought=thought, action=action) # 添加新记录 @@ -204,7 +204,7 @@ class ChatterPlanFilter: if not chat_stream: return "" - context = chat_stream.context_manager.context + context = chat_stream.context if not context.decision_history: return "" @@ -344,7 +344,7 @@ class ChatterPlanFilter: logger.warning(f"[plan_filter] 聊天流 {plan.chat_id} 不存在") return "最近没有聊天内容。", "没有未读消息。", [] - stream_context = chat_stream.context_manager + stream_context = chat_stream.context # 获取真正的已读和未读消息 read_messages = ( diff --git a/src/plugins/built_in/affinity_flow_chatter/planner/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner/planner.py index cebb32a66..c12bf71ab 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner/planner.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner/planner.py @@ -599,7 +599,7 @@ class ChatterActionPlanner: if chat_manager: chat_stream = await chat_manager.get_stream(context.stream_id) if chat_stream: - chat_stream.context_manager.context.chat_mode = context.chat_mode + chat_stream.context.chat_mode = context.chat_mode chat_stream.saved = False # 标记需要保存 logger.debug(f"已同步chat_mode {context.chat_mode.value} 到ChatStream {context.stream_id}") except Exception as e: diff --git a/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py b/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py index 23d19cc23..14a52473a 100644 --- a/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/proactive/proactive_thinking_executor.py @@ -564,7 +564,7 @@ async def execute_proactive_thinking(stream_id: str): chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) - if chat_stream and chat_stream.context_manager.context.is_chatter_processing: + if chat_stream and chat_stream.context.is_chatter_processing: logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 的 chatter 正在处理消息") return except Exception as e: diff --git a/src/plugins/built_in/social_toolkit_plugin/plugin.py b/src/plugins/built_in/social_toolkit_plugin/plugin.py index a4cd165c4..69ee3d5c7 100644 --- a/src/plugins/built_in/social_toolkit_plugin/plugin.py +++ b/src/plugins/built_in/social_toolkit_plugin/plugin.py @@ -61,7 +61,7 @@ class ReminderTask(AsyncTask): logger.info(f"执行提醒任务: 给 {self.target_user_name} 发送关于 '{self.event_details}' 的提醒") extra_info = f"现在是提醒时间,请你以一种符合你人设的、俏皮的方式提醒 {self.target_user_name}。\n提醒内容: {self.event_details}\n设置提醒的人: {self.creator_name}" - last_message = self.chat_stream.context_manager.context.get_last_message() + last_message = self.chat_stream.context.get_last_message() reply_message_dict = last_message.flatten() if last_message else None success, reply_set, _ = await generator_api.generate_reply( chat_stream=self.chat_stream, @@ -523,7 +523,7 @@ class RemindAction(BaseAction): # 4. 生成并发送确认消息 extra_info = f"你已经成功设置了一个提醒,请以一种符合你人设的、俏皮的方式回复用户。\n提醒时间: {target_time.strftime('%Y-%m-%d %H:%M:%S')}\n提醒对象: {user_name_to_remind}\n提醒内容: {event_details}" - last_message = self.chat_stream.context_manager.context.get_last_message() + last_message = self.chat_stream.context.get_last_message() reply_message_dict = last_message.flatten() if last_message else None success, reply_set, _ = await generator_api.generate_reply( chat_stream=self.chat_stream, diff --git a/src/plugins/built_in/tts_plugin/plugin.py b/src/plugins/built_in/tts_plugin/plugin.py index 2fd272dfa..a056964c9 100644 --- a/src/plugins/built_in/tts_plugin/plugin.py +++ b/src/plugins/built_in/tts_plugin/plugin.py @@ -54,7 +54,7 @@ class TTSAction(BaseAction): success, response_set, _ = await generate_reply( chat_stream=self.chat_stream, - reply_message=self.chat_stream.context_manager.context.get_last_message(), + reply_message=self.chat_stream.context.get_last_message(), enable_tool=global_config.tool.enable_tool, request_type="chat.tts", from_plugin=False,