重构ChatStream和StreamContext:移除context_manager引用

- 在ChatStream及相关类中,将所有context_manager的实例替换为直接上下文访问。
- 更新方法,利用新的上下文结构来管理聊天状态和消息。
- 增强的StreamContext,增加了用于消息处理、统计和历史管理的方法。
- 在重构过程中改进了错误处理和日志记录。
This commit is contained in:
Windpicker-owo
2025-11-25 12:01:26 +08:00
parent d30b0544b5
commit 1ebdc37b22
16 changed files with 487 additions and 753 deletions

View File

@@ -3,13 +3,11 @@
提供统一的消息管理、上下文管理和流循环调度功能
"""
from .context_manager import SingleStreamContextManager
from .distribution_manager import StreamLoopManager, stream_loop_manager
from .message_manager import MessageManager, message_manager
__all__ = [
"MessageManager",
"SingleStreamContextManager",
"StreamLoopManager",
"message_manager",
"stream_loop_manager",

View File

@@ -1,529 +0,0 @@
"""
重构后的聊天上下文管理器
提供统一、稳定的聊天上下文管理功能
每个 context_manager 实例只管理一个 stream 的上下文
"""
import asyncio
import time
from typing import TYPE_CHECKING, Any
from src.chat.energy_system import energy_manager
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.base.component_types import ChatType
if TYPE_CHECKING:
from src.common.data_models.message_manager_data_model import StreamContext
logger = get_logger("context_manager")
# 全局背景任务集合(用于异步初始化等后台任务)
_background_tasks = set()
# 三层记忆系统的延迟导入(避免循环依赖)
_unified_memory_manager = None
def _get_unified_memory_manager():
"""获取统一记忆管理器(延迟导入)"""
global _unified_memory_manager
if _unified_memory_manager is None:
try:
from src.memory_graph.manager_singleton import get_unified_memory_manager
_unified_memory_manager = get_unified_memory_manager()
except Exception as e:
logger.warning(f"获取统一记忆管理器失败(可能未启用): {e}")
_unified_memory_manager = False # 标记为禁用,避免重复尝试
return _unified_memory_manager if _unified_memory_manager is not False else None
class SingleStreamContextManager:
"""单流上下文管理器 - 每个实例只管理一个 stream 的上下文"""
def __init__(self, stream_id: str, context: "StreamContext", max_context_size: int | None = None):
self.stream_id = stream_id
self.context = context
# 配置参数
self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100)
# 元数据
self.created_time = time.time()
self.last_access_time = time.time()
self.access_count = 0
self.total_messages = 0
# 标记是否已初始化历史消息
self._history_initialized = False
logger.debug(f"单流上下文管理器初始化: {stream_id}")
# 异步初始化历史消息(不阻塞构造函数)
task = asyncio.create_task(self._initialize_history_from_db())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
def get_context(self) -> "StreamContext":
"""获取流上下文"""
self._update_access_stats()
return self.context
async def add_message(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
"""添加消息到上下文
Args:
message: 消息对象
skip_energy_update: 是否跳过能量更新(兼容参数,当前忽略)
Returns:
bool: 是否成功添加
"""
try:
# 检查并配置StreamContext的缓存系统
cache_enabled = global_config.chat.enable_message_cache
if cache_enabled and not self.context.is_cache_enabled:
self.context.enable_cache(True)
logger.debug(f"为StreamContext {self.stream_id} 启用缓存系统")
# 新消息默认占位兴趣值,延迟到 Chatter 批量处理阶段
if message.interest_value is None:
message.interest_value = 0.3
message.should_reply = False
message.should_act = False
message.interest_calculated = False
message.semantic_embedding = None
message.is_read = False
# 使用StreamContext的智能缓存功能
success = self.context.add_message_with_cache_check(message, force_direct=not cache_enabled)
if success:
# 自动检测和更新chat type
self._detect_chat_type(message)
self.total_messages += 1
self.last_access_time = time.time()
# 如果使用了缓存系统,输出调试信息
if cache_enabled and self.context.is_cache_enabled:
if self.context.is_chatter_processing:
logger.debug(f"消息已缓存到StreamContext等待处理完成: stream={self.stream_id}")
else:
logger.debug(f"消息直接添加到StreamContext未读列表: stream={self.stream_id}")
else:
logger.debug(f"消息添加到StreamContext缓存禁用: {self.stream_id}")
# 三层记忆系统集成:将消息添加到感知记忆层
try:
if global_config.memory and global_config.memory.enable:
unified_manager = _get_unified_memory_manager()
if unified_manager:
# 构建消息字典
message_dict = {
"message_id": str(message.message_id),
"sender_id": message.user_info.user_id,
"sender_name": message.user_info.user_nickname,
"content": message.processed_plain_text or message.display_message or "",
"timestamp": message.time,
"platform": message.chat_info.platform,
"stream_id": self.stream_id,
}
await unified_manager.add_message(message_dict)
logger.debug(f"消息已添加到三层记忆系统: {message.message_id}")
except Exception as e:
# 记忆系统错误不应影响主流程
logger.error(f"添加消息到三层记忆系统失败: {e}", exc_info=True)
return True
else:
logger.error(f"StreamContext消息添加失败: {self.stream_id}")
return False
except Exception as e:
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False
async def update_message(self, message_id: str, updates: dict[str, Any]) -> bool:
"""更新上下文中的消息
Args:
message_id: 消息ID
updates: 更新的属性
Returns:
bool: 是否成功更新
"""
try:
# 直接在未读消息中查找并更新(统一转字符串比较)
for message in self.context.unread_messages:
if str(message.message_id) == str(message_id):
if "interest_value" in updates:
message.interest_value = updates["interest_value"]
if "actions" in updates:
message.actions = updates["actions"]
if "should_reply" in updates:
message.should_reply = updates["should_reply"]
break
# 在历史消息中查找并更新(统一转字符串比较)
for message in self.context.history_messages:
if str(message.message_id) == str(message_id):
if "interest_value" in updates:
message.interest_value = updates["interest_value"]
if "actions" in updates:
message.actions = updates["actions"]
if "should_reply" in updates:
message.should_reply = updates["should_reply"]
break
logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}")
return True
except Exception as e:
logger.error(f"更新单流上下文消息失败 {self.stream_id}/{message_id}: {e}", exc_info=True)
return False
def get_messages(self, limit: int | None = None, include_unread: bool = True) -> list[DatabaseMessages]:
"""获取上下文消息
Args:
limit: 消息数量限制
include_unread: 是否包含未读消息
Returns:
List[DatabaseMessages]: 消息列表
"""
try:
messages = []
if include_unread:
messages.extend(self.context.get_unread_messages())
if limit:
messages.extend(self.context.get_history_messages(limit=limit))
else:
messages.extend(self.context.get_history_messages())
# 按时间排序
messages.sort(key=lambda msg: getattr(msg, "time", 0))
# 应用限制
if limit and len(messages) > limit:
messages = messages[-limit:]
return messages
except Exception as e:
logger.error(f"获取单流上下文消息失败 {self.stream_id}: {e}", exc_info=True)
return []
def get_unread_messages(self) -> list[DatabaseMessages]:
"""获取未读消息"""
try:
return self.context.get_unread_messages()
except Exception as e:
logger.error(f"获取单流未读消息失败 {self.stream_id}: {e}", exc_info=True)
return []
def mark_messages_as_read(self, message_ids: list[str]) -> bool:
"""标记消息为已读"""
try:
if not hasattr(self.context, "mark_message_as_read"):
logger.error(f"上下文对象缺少 mark_message_as_read 方法: {self.stream_id}")
return False
marked_count = 0
failed_ids = []
for message_id in message_ids:
try:
# 传递最大历史消息数量限制
self.context.mark_message_as_read(message_id, max_history_size=self.max_context_size)
marked_count += 1
except Exception as e:
failed_ids.append(str(message_id)[:8])
logger.warning(f"标记消息已读失败 {message_id}: {e}")
return marked_count > 0
except Exception as e:
logger.error(f"标记消息已读失败 {self.stream_id}: {e}", exc_info=True)
return False
async def clear_context(self) -> bool:
"""清空上下文"""
try:
if hasattr(self.context, "unread_messages"):
self.context.unread_messages.clear()
if hasattr(self.context, "history_messages"):
self.context.history_messages.clear()
reset_attrs = ["interruption_count", "afc_threshold_adjustment", "last_check_time"]
for attr in reset_attrs:
if hasattr(self.context, attr):
if attr in ["interruption_count", "afc_threshold_adjustment"]:
setattr(self.context, attr, 0)
else:
setattr(self.context, attr, time.time())
await self._update_stream_energy()
logger.debug(f"清空单流上下文: {self.stream_id}")
return True
except Exception as e:
logger.error(f"清空单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False
def get_statistics(self) -> dict[str, Any]:
"""获取流统计信息"""
try:
current_time = time.time()
uptime = current_time - self.created_time
unread_messages = getattr(self.context, "unread_messages", [])
history_messages = getattr(self.context, "history_messages", [])
stats = {
"stream_id": self.stream_id,
"context_type": type(self.context).__name__,
"total_messages": len(history_messages) + len(unread_messages),
"unread_messages": len(unread_messages),
"history_messages": len(history_messages),
"is_active": getattr(self.context, "is_active", True),
"last_check_time": getattr(self.context, "last_check_time", current_time),
"interruption_count": getattr(self.context, "interruption_count", 0),
"afc_threshold_adjustment": getattr(self.context, "afc_threshold_adjustment", 0.0),
"created_time": self.created_time,
"last_access_time": self.last_access_time,
"access_count": self.access_count,
"uptime_seconds": uptime,
"idle_seconds": current_time - self.last_access_time,
}
# 添加缓存统计信息
if hasattr(self.context, "get_cache_stats"):
stats["cache_stats"] = self.context.get_cache_stats()
return stats
except Exception as e:
logger.error(f"获取单流统计失败 {self.stream_id}: {e}", exc_info=True)
return {}
def flush_cached_messages(self) -> list[DatabaseMessages]:
"""
刷新StreamContext中的缓存消息到未读列表
Returns:
list[DatabaseMessages]: 刷新的消息列表
"""
try:
if hasattr(self.context, "flush_cached_messages"):
cached_messages = self.context.flush_cached_messages()
if cached_messages:
logger.debug(f"从StreamContext刷新缓存消息: stream={self.stream_id}, 数量={len(cached_messages)}")
return cached_messages
else:
logger.debug(f"StreamContext不支持缓存刷新: stream={self.stream_id}")
return []
except Exception as e:
logger.error(f"刷新StreamContext缓存失败: stream={self.stream_id}, error={e}")
return []
def get_cache_stats(self) -> dict[str, Any]:
"""获取StreamContext的缓存统计信息"""
try:
if hasattr(self.context, "get_cache_stats"):
return self.context.get_cache_stats()
else:
return {"error": "StreamContext不支持缓存统计"}
except Exception as e:
logger.error(f"获取StreamContext缓存统计失败: stream={self.stream_id}, error={e}")
return {"error": str(e)}
def validate_integrity(self) -> bool:
"""验证上下文完整性"""
try:
# 检查基本属性
required_attrs = ["stream_id", "unread_messages", "history_messages"]
for attr in required_attrs:
if not hasattr(self.context, attr):
logger.warning(f"上下文缺少必要属性: {attr}")
return False
# 检查消息ID唯一性
all_messages = getattr(self.context, "unread_messages", []) + getattr(self.context, "history_messages", [])
message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")]
if len(message_ids) != len(set(message_ids)):
logger.warning(f"上下文中存在重复消息ID: {self.stream_id}")
return False
return True
except Exception as e:
logger.error(f"验证单流上下文完整性失败 {self.stream_id}: {e}")
return False
def _update_access_stats(self):
"""更新访问统计"""
self.last_access_time = time.time()
self.access_count += 1
async def _initialize_history_from_db(self):
"""从数据库初始化历史消息到context中"""
if self._history_initialized:
logger.debug(f"历史消息已初始化,跳过: {self.stream_id}, 当前历史消息数: {len(self.context.history_messages)}")
return
# 立即设置标志,防止并发重复加载
logger.info(f"🔄 [历史加载] 开始从数据库加载历史消息: {self.stream_id}")
self._history_initialized = True
try:
logger.debug(f"开始从数据库加载历史消息: {self.stream_id}")
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
# 加载历史消息限制数量为max_context_size
db_messages = await get_raw_msg_before_timestamp_with_chat(
chat_id=self.stream_id,
timestamp=time.time(),
limit=self.max_context_size,
)
if db_messages:
logger.info(f"📥 [历史加载] 从数据库获取到 {len(db_messages)} 条消息")
# 将数据库消息转换为 DatabaseMessages 对象并添加到历史
loaded_count = 0
for msg_dict in db_messages:
try:
# 使用 ** 解包字典作为关键字参数
db_msg = DatabaseMessages(**msg_dict)
# 标记为已读
db_msg.is_read = True
# 添加到历史消息
self.context.history_messages.append(db_msg)
loaded_count += 1
except Exception as e:
logger.warning(f"转换历史消息失败 (message_id={msg_dict.get('message_id', 'unknown')}): {e}")
continue
# 应用历史消息长度限制
if len(self.context.history_messages) > self.max_context_size:
removed_count = len(self.context.history_messages) - self.max_context_size
self.context.history_messages = self.context.history_messages[-self.max_context_size:]
logger.debug(f"📝 [历史加载] 移除了 {removed_count} 条过旧的历史消息以保持上下文大小限制")
logger.info(f"✅ [历史加载] 成功加载 {loaded_count} 条历史消息到内存: {self.stream_id}")
else:
logger.debug(f"没有历史消息需要加载: {self.stream_id}")
except Exception as e:
logger.error(f"从数据库初始化历史消息失败: {self.stream_id}, {e}", exc_info=True)
# 加载失败时重置标志,允许重试
self._history_initialized = False
async def ensure_history_initialized(self):
"""确保历史消息已初始化(供外部调用)"""
if not self._history_initialized:
await self._initialize_history_from_db()
async def _calculate_message_interest(self, message: DatabaseMessages) -> float:
"""
在上下文管理器中计算消息的兴趣度
"""
try:
from src.chat.interest_system.interest_manager import get_interest_manager
interest_manager = get_interest_manager()
if interest_manager.has_calculator():
# 使用兴趣值计算组件计算
result = await interest_manager.calculate_interest(message)
if result.success:
# 更新消息对象的兴趣值相关字段
message.interest_value = result.interest_value
message.should_reply = result.should_reply
message.should_act = result.should_act
message.interest_calculated = True
logger.debug(
f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
f"should_reply: {result.should_reply}, should_act: {result.should_act}"
)
return result.interest_value
else:
logger.warning(f"消息 {message.message_id} 兴趣值计算失败: {result.error_message}")
message.interest_calculated = False
return 0.5
else:
logger.debug("未找到兴趣值计算器,使用默认兴趣值")
return 0.5
except Exception as e:
logger.error(f"计算消息兴趣度时发生错误: {e}", exc_info=True)
if hasattr(message, "interest_calculated"):
message.interest_calculated = False
return 0.5
def _detect_chat_type(self, message: DatabaseMessages):
"""根据消息内容自动检测聊天类型"""
# 只有在第一次添加消息时才检测聊天类型,避免后续消息改变类型
if len(self.context.unread_messages) == 1: # 只有这条消息
# 如果消息包含群组信息,则为群聊
if message.chat_info.group_info:
self.context.chat_type = ChatType.GROUP
else:
self.context.chat_type = ChatType.PRIVATE
async def clear_context_async(self) -> bool:
"""异步实现的 clear_context清空消息并 await 能量重算。"""
try:
if hasattr(self.context, "unread_messages"):
self.context.unread_messages.clear()
if hasattr(self.context, "history_messages"):
self.context.history_messages.clear()
reset_attrs = ["interruption_count", "afc_threshold_adjustment", "last_check_time"]
for attr in reset_attrs:
if hasattr(self.context, attr):
if attr in ["interruption_count", "afc_threshold_adjustment"]:
setattr(self.context, attr, 0)
else:
setattr(self.context, attr, time.time())
await self._update_stream_energy()
logger.info(f"清空单流上下文(异步): {self.stream_id}")
return True
except Exception as e:
logger.error(f"清空单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
return False
async def refresh_focus_energy_from_history(self) -> None:
"""基于历史消息刷新聚焦能量"""
await self._update_stream_energy(include_unread=False)
async def _update_stream_energy(self, include_unread: bool = False) -> None:
"""更新流能量"""
try:
history_messages = self.context.get_history_messages(limit=self.max_context_size)
messages: list[DatabaseMessages] = list(history_messages)
if include_unread:
messages.extend(self.get_unread_messages())
# 获取用户ID优先使用最新历史消息
user_id = None
if messages:
last_message = messages[-1]
if hasattr(last_message, "user_info") and last_message.user_info:
user_id = last_message.user_info.user_id
await energy_manager.calculate_focus_energy(
stream_id=self.stream_id,
messages=messages,
user_id=user_id,
)
except Exception as e:
logger.error(f"更新单流能量失败 {self.stream_id}: {e}")

View File

@@ -81,7 +81,7 @@ class StreamLoopManager:
# 创建任务列表以便并发取消
cancel_tasks = []
for chat_stream in all_streams.values():
context = chat_stream.context_manager.context
context = chat_stream.context
if context.stream_loop_task and not context.stream_loop_task.done():
context.stream_loop_task.cancel()
cancel_tasks.append((chat_stream.stream_id, context.stream_loop_task))
@@ -309,7 +309,7 @@ class StreamLoopManager:
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream:
return chat_stream.context_manager.context
return chat_stream.context
return None
except Exception as e:
logger.error(f"获取流上下文失败 {stream_id}: {e}")
@@ -463,7 +463,7 @@ class StreamLoopManager:
logger.debug(f"无法找到聊天流 {stream_id},跳过能量更新")
return
# 从 context_manager 获取消息(包括未读和历史消息)
# 从 context 获取消息(包括未读和历史消息)
# 合并未读消息和历史消息
all_messages = []
@@ -573,7 +573,7 @@ class StreamLoopManager:
if not chat_stream:
return False
unread = getattr(chat_stream.context_manager.context, "unread_messages", [])
unread = getattr(chat_stream.context, "unread_messages", [])
return len(unread) > self.force_dispatch_unread_threshold
except Exception as e:
logger.debug(f"检查流 {stream_id} 是否需要强制分发失败: {e}")
@@ -628,7 +628,7 @@ class StreamLoopManager:
logger.debug(f"刷新能量时未找到聊天流: {stream_id}")
return
await chat_stream.context_manager.refresh_focus_energy_from_history()
await chat_stream.context.refresh_focus_energy_from_history()
logger.debug(f"已刷新聊天流 {stream_id} 的聚焦能量")
except Exception as e:
logger.warning(f"刷新聊天流 {stream_id} 能量失败: {e}")

View File

@@ -41,7 +41,7 @@ class MessageManager:
self.action_manager = ChatterActionManager()
self.chatter_manager = ChatterManager(self.action_manager)
# 不再需要全局上下文管理器,直接通过 ChatManager 访问各个 ChatStream 的 context_manager
# 不再需要全局上下文管理器,直接通过 ChatManager 访问各个 ChatStream 的 context
# 全局Notice管理器
self.notice_manager = global_notice_manager
@@ -115,7 +115,7 @@ class MessageManager:
# 启动steam loop任务如果尚未启动
await stream_loop_manager.start_stream_loop(stream_id)
await self._check_and_handle_interruption(chat_stream, message)
await chat_stream.context_manager.add_message(message)
await chat_stream.context.add_message(message)
except Exception as e:
logger.error(f"添加消息到聊天流 {stream_id} 时发生错误: {e}")
@@ -143,7 +143,7 @@ class MessageManager:
if should_reply is not None:
updates["should_reply"] = should_reply
if updates:
success = await chat_stream.context_manager.update_message(message_id, updates)
success = await chat_stream.context.update_message(message_id, updates)
if success:
logger.debug(f"更新消息 {message_id} 成功")
else:
@@ -160,7 +160,7 @@ class MessageManager:
if not chat_stream:
logger.warning(f"MessageManager.add_action: 聊天流 {stream_id} 不存在")
return
success = await chat_stream.context_manager.update_message(message_id, {"actions": [action]})
success = await chat_stream.context.update_message(message_id, {"actions": [action]})
if success:
logger.debug(f"为消息 {message_id} 添加动作 {action} 成功")
else:
@@ -178,7 +178,7 @@ class MessageManager:
logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在")
return
context = chat_stream.context_manager.context
context = chat_stream.context
context.is_active = False
# 取消处理任务
@@ -200,7 +200,7 @@ class MessageManager:
logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在")
return
context = chat_stream.context_manager.context
context = chat_stream.context
context.is_active = True
logger.debug(f"激活聊天流: {stream_id}")
@@ -216,8 +216,8 @@ class MessageManager:
if not chat_stream:
return None
context = chat_stream.context_manager.context
unread_count = len(chat_stream.context_manager.get_unread_messages())
context = chat_stream.context
unread_count = len(chat_stream.context.get_unread_messages())
return StreamStats(
stream_id=stream_id,
@@ -265,7 +265,7 @@ class MessageManager:
logger.debug(f"聊天流 {stream_id} 在清理时已不存在,跳过")
continue
await chat_stream.context_manager.clear_context()
await chat_stream.context.clear_context()
# 安全删除流(若已被其他地方删除则捕获)
try:
@@ -289,7 +289,7 @@ class MessageManager:
return
# 检查是否正在回复,以及是否允许在回复时打断
if chat_stream.context_manager.context.is_replying:
if chat_stream.context.is_replying:
if not global_config.chat.allow_reply_interruption:
logger.debug(f"聊天流 {chat_stream.stream_id} 正在回复中,且配置不允许回复时打断,跳过打断检查")
return
@@ -302,7 +302,7 @@ class MessageManager:
return
# 检查上下文
context = chat_stream.context_manager.context
context = chat_stream.context
# 只有当 Chatter 真正在处理时才检查打断
if not context.is_chatter_processing:
@@ -379,7 +379,7 @@ class MessageManager:
await asyncio.sleep(0.1)
# 获取当前的stream context
context = chat_stream.context_manager.context
context = chat_stream.context
# 确保有未读消息需要处理
unread_messages = context.get_unread_messages()
@@ -411,7 +411,7 @@ class MessageManager:
return
# 获取未读消息
unread_messages = chat_stream.context_manager.get_unread_messages()
unread_messages = chat_stream.context.get_unread_messages()
if not unread_messages:
logger.info(f"🧹 [清除未读] stream={stream_id[:8]}, 无未读消息需要清除")
return
@@ -423,7 +423,7 @@ class MessageManager:
# 将所有未读消息标记为已读
message_ids = [msg.message_id for msg in unread_messages]
success = chat_stream.context_manager.mark_messages_as_read(message_ids)
success = chat_stream.context.mark_messages_as_read(message_ids)
if success:
self.stats.total_processed_messages += len(unread_messages)
@@ -443,7 +443,7 @@ class MessageManager:
logger.warning(f"clear_stream_unread_messages: 聊天流 {stream_id} 不存在")
return
context = chat_stream.context_manager.context
context = chat_stream.context
if hasattr(context, "unread_messages") and context.unread_messages:
unread_count = len(context.unread_messages)
@@ -453,7 +453,7 @@ class MessageManager:
message_ids = [msg.message_id for msg in context.unread_messages]
# 标记为已读(会移到历史消息)
success = chat_stream.context_manager.mark_messages_as_read(message_ids)
success = chat_stream.context.mark_messages_as_read(message_ids)
if success:
logger.debug(f"✅ stream={stream_id[:8]}, 成功标记 {unread_count} 条消息为已读")
@@ -481,8 +481,8 @@ class MessageManager:
try:
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
chat_stream.context_manager.context.is_chatter_processing = is_processing
if chat_stream and hasattr(chat_stream.context, "is_chatter_processing"):
chat_stream.context.is_chatter_processing = is_processing
logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}")
except Exception as e:
logger.debug(f"更新StreamContext状态失败: stream={stream_id}, error={e}")
@@ -517,8 +517,8 @@ class MessageManager:
try:
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
return chat_stream.context_manager.context.is_chatter_processing
if chat_stream and hasattr(chat_stream.context, "is_chatter_processing"):
return chat_stream.context.is_chatter_processing
except Exception:
pass
return False

View File

@@ -8,6 +8,8 @@ from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType
from src.common.database.api.crud import CRUDBase
from src.common.database.api.specialized import get_or_create_chat_stream
from src.common.database.compatibility import get_db_session
@@ -41,18 +43,10 @@ class ChatStream:
self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0
self.saved = False
# 创建单流上下文管理器(包含StreamContext
from src.chat.message_manager.context_manager import SingleStreamContextManager
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType
self.context_manager: SingleStreamContextManager = SingleStreamContextManager(
self.context: StreamContext = StreamContext(
stream_id=stream_id,
context=StreamContext(
stream_id=stream_id,
chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE,
chat_mode=ChatMode.FOCUS,
),
chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE,
chat_mode=ChatMode.FOCUS,
)
# 基础参数
@@ -73,11 +67,11 @@ class ChatStream:
"focus_energy": self.focus_energy,
# 基础兴趣度
"base_interest_energy": self.base_interest_energy,
# stream_context基本信息通过context_manager访问
"stream_context_chat_type": self.context_manager.context.chat_type.value,
"stream_context_chat_mode": self.context_manager.context.chat_mode.value,
# stream_context基本信息
"stream_context_chat_type": self.context.chat_type.value,
"stream_context_chat_mode": self.context.chat_mode.value,
# 统计信息
"interruption_count": self.context_manager.context.interruption_count,
"interruption_count": self.context.interruption_count,
}
@classmethod
@@ -94,19 +88,19 @@ class ChatStream:
data=data,
)
# 恢复stream_context信息通过context_manager访问
# 恢复stream_context信息
if "stream_context_chat_type" in data:
from src.plugin_system.base.component_types import ChatMode, ChatType
instance.context_manager.context.chat_type = ChatType(data["stream_context_chat_type"])
instance.context.chat_type = ChatType(data["stream_context_chat_type"])
if "stream_context_chat_mode" in data:
from src.plugin_system.base.component_types import ChatMode, ChatType
instance.context_manager.context.chat_mode = ChatMode(data["stream_context_chat_mode"])
instance.context.chat_mode = ChatMode(data["stream_context_chat_mode"])
# 恢复interruption_count信息
if "interruption_count" in data:
instance.context_manager.context.interruption_count = data["interruption_count"]
instance.context.interruption_count = data["interruption_count"]
return instance
@@ -131,15 +125,7 @@ class ChatStream:
message: DatabaseMessages 对象,直接使用不需要转换
"""
# 直接使用传入的 DatabaseMessages设置到上下文中
self.context_manager.context.set_current_message(message)
# 设置优先级信息(如果存在)
priority_mode = getattr(message, "priority_mode", None)
priority_info = getattr(message, "priority_info", None)
if priority_mode:
self.context_manager.context.priority_mode = priority_mode
if priority_info:
self.context_manager.context.priority_info = priority_info
self.context.set_current_message(message)
# 调试日志
logger.debug(
@@ -253,7 +239,7 @@ class ChatStream:
"""异步计算focus_energy"""
try:
# 使用单流上下文管理器获取消息
all_messages = self.context_manager.get_messages(limit=global_config.chat.max_context_size)
all_messages = self.context.get_messages(limit=global_config.chat.max_context_size)
# 获取用户ID
user_id = None
@@ -318,7 +304,6 @@ class ChatManager:
def __init__(self):
if not self._initialized:
from src.common.data_models.database_data_model import DatabaseMessages
self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: dict[str, DatabaseMessages] = {} # stream_id -> last_message
@@ -409,135 +394,87 @@ class ChatManager:
async def get_or_create_stream(
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
) -> ChatStream:
"""获取或创建聊天流 - 优化版本使用缓存管理器
Args:
platform: 平台标识
user_info: 用户信息
group_info: 群组信息(可选)
Returns:
ChatStream: 聊天流对象
"""
# 生成stream_id
"""获取或创建聊天流 - 优化版本使用缓存机制"""
try:
stream_id = self._generate_stream_id(platform, user_info, group_info)
# 检查内存中是否存在
if stream_id in self.streams:
stream = self.streams[stream_id]
# 更新用户信息和群组信息
stream.update_active_time()
if user_info.platform and user_info.user_id:
stream.user_info = user_info
if group_info:
stream.group_info = group_info
# 检查是否有最后一条消息(现在使用 DatabaseMessages
from src.common.data_models.database_data_model import DatabaseMessages
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
await stream.set_context(self.last_messages[stream_id])
else:
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息")
return stream
# 使用优化后的API查询带缓存
current_time = time.time()
model_instance, _ = await get_or_create_chat_stream(
stream_id=stream_id,
platform=platform,
defaults={
"create_time": current_time,
"last_active_time": current_time,
"user_platform": user_info.platform if user_info else platform,
"user_id": user_info.user_id if user_info else "",
"user_nickname": user_info.user_nickname if user_info else "",
"user_cardname": user_info.user_cardname if user_info else "",
"group_platform": group_info.platform if group_info else None,
"group_id": group_info.group_id if group_info else None,
"group_name": group_info.group_name if group_info else None,
}
)
if model_instance:
# 从 SQLAlchemy 模型转换回 ChatStream.from_dict 期望的格式
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if model_instance and getattr(model_instance, "group_id", None):
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
"energy_value": model_instance.energy_value,
"sleep_pressure": model_instance.sleep_pressure,
}
stream = ChatStream.from_dict(data_for_from_dict)
# 更新用户信息和群组信息
stream.user_info = user_info
if group_info:
stream.group_info = group_info
stream.update_active_time()
else:
# 创建新的聊天流
stream = ChatStream(
current_time = time.time()
model_instance, _ = await get_or_create_chat_stream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info,
defaults={
"create_time": current_time,
"last_active_time": current_time,
"user_platform": user_info.platform if user_info else platform,
"user_id": user_info.user_id if user_info else "",
"user_nickname": user_info.user_nickname if user_info else "",
"user_cardname": user_info.user_cardname if user_info else "",
"group_platform": group_info.platform if group_info else None,
"group_id": group_info.group_id if group_info else None,
"group_name": group_info.group_name if group_info else None,
},
)
if model_instance:
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if getattr(model_instance, "group_id", None):
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
"energy_value": model_instance.energy_value,
"sleep_pressure": model_instance.sleep_pressure,
}
stream = ChatStream.from_dict(data_for_from_dict)
stream.user_info = user_info
if group_info:
stream.group_info = group_info
stream.update_active_time()
else:
stream = ChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info,
)
except Exception as e:
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
raise e
from src.common.data_models.database_data_model import DatabaseMessages
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
await stream.set_context(self.last_messages[stream_id])
else:
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
# 确保 ChatStream 有自己的 context_manager
if not hasattr(stream, "context_manager") or stream.context_manager is None:
from src.chat.message_manager.context_manager import SingleStreamContextManager
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType
logger.info(f"为 stream {stream_id} 创建新的 context_manager")
stream.context_manager = SingleStreamContextManager(
stream_id=stream_id,
context=StreamContext(
stream_id=stream_id,
chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE,
chat_mode=ChatMode.FOCUS,
),
)
else:
logger.info(f"stream {stream_id} 已有 context_manager跳过创建")
# 保存到内存和数据库
self.streams[stream_id] = stream
await self._save_stream(stream)
return stream
async def get_stream(self, stream_id: str) -> ChatStream | None:
"""通过stream_id获取聊天流"""
from src.common.data_models.database_data_model import DatabaseMessages
stream = self.streams.get(stream_id)
if not stream:
return None
@@ -765,23 +702,6 @@ class ChatManager:
# if stream.stream_id in self.last_messages:
# await stream.set_context(self.last_messages[stream.stream_id])
# 确保 ChatStream 有自己的 context_manager
if not hasattr(stream, "context_manager") or stream.context_manager is None:
from src.chat.message_manager.context_manager import SingleStreamContextManager
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType
logger.debug(f"为加载的 stream {stream.stream_id} 创建新的 context_manager")
stream.context_manager = SingleStreamContextManager(
stream_id=stream.stream_id,
context=StreamContext(
stream_id=stream.stream_id,
chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE,
chat_mode=ChatMode.FOCUS,
),
)
else:
logger.debug(f"加载的 stream {stream.stream_id} 已有 context_manager")
except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True)

View File

@@ -103,8 +103,8 @@ class HeartFCSender:
try:
# 将MessageSending转换为DatabaseMessages
db_message = await self._convert_to_database_message(message)
if db_message and message.chat_stream.context_manager:
context = message.chat_stream.context_manager.context
if db_message and message.chat_stream.context:
context = message.chat_stream.context
# 应用历史消息长度限制
from src.config.config import global_config

View File

@@ -183,7 +183,7 @@ class ChatterActionManager:
}
# 设置正在回复的状态
chat_stream.context_manager.context.is_replying = True
chat_stream.context.is_replying = True
if action_name == "no_action":
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
@@ -342,7 +342,7 @@ class ChatterActionManager:
finally:
# 确保重置正在回复的状态
if chat_stream:
chat_stream.context_manager.context.is_replying = False
chat_stream.context.is_replying = False
async def _record_action_to_message(self, chat_stream, action_name, target_message, action_data):
"""
@@ -387,7 +387,7 @@ class ChatterActionManager:
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream:
context = chat_stream.context_manager
context = chat_stream.context
if context.context.interruption_count > 0:
old_count = context.context.interruption_count
# old_afc_adjustment = context.context.get_afc_threshold_adjustment()

View File

@@ -139,7 +139,7 @@ class ActionModifier:
if not self.chat_stream:
logger.error(f"{self.log_prefix} chat_stream 未初始化,无法执行第二阶段")
return
chat_context = self.chat_stream.context_manager.context
chat_context = self.chat_stream.context
current_actions_s2 = self.action_manager.get_using_actions()
type_mismatched_actions = self._check_action_associated_types(current_actions_s2, chat_context)

View File

@@ -396,7 +396,7 @@ class DefaultReplyer:
try:
# 设置正在回复的状态
self.chat_stream.context_manager.context.is_replying = True
self.chat_stream.context.is_replying = True
content, reasoning_content, model_name, tool_call = await self.llm_generate_content(prompt)
logger.debug(f"replyer生成内容: {content}")
llm_response = {
@@ -413,7 +413,7 @@ class DefaultReplyer:
return False, None, prompt # LLM 调用失败则无法生成回复
finally:
# 重置正在回复的状态
self.chat_stream.context_manager.context.is_replying = False
self.chat_stream.context.is_replying = False
# 触发 AFTER_LLM 事件
if not from_plugin:
@@ -910,7 +910,7 @@ class DefaultReplyer:
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(chat_id)
if chat_stream:
stream_context = chat_stream.context_manager
stream_context = chat_stream.context
# 确保历史消息已从数据库加载
await stream_context.ensure_history_initialized()
@@ -1140,7 +1140,7 @@ class DefaultReplyer:
chat_stream_obj = await chat_manager.get_stream(chat_id)
if chat_stream_obj:
unread_messages = chat_stream_obj.context_manager.get_unread_messages()
unread_messages = chat_stream_obj.context.get_unread_messages()
if unread_messages:
# 使用最后一条未读消息作为参考
last_msg = unread_messages[-1]
@@ -1262,12 +1262,12 @@ class DefaultReplyer:
if chat_stream_obj:
# 确保历史消息已初始化
await chat_stream_obj.context_manager.ensure_history_initialized()
await chat_stream_obj.context.ensure_history_initialized()
# 获取所有消息(历史+未读)
all_messages = (
chat_stream_obj.context_manager.context.history_messages +
chat_stream_obj.context_manager.get_unread_messages()
chat_stream_obj.context.history_messages +
chat_stream_obj.context.get_unread_messages()
)
# 转换为字典格式
@@ -1639,12 +1639,12 @@ class DefaultReplyer:
if chat_stream_obj:
# 确保历史消息已初始化
await chat_stream_obj.context_manager.ensure_history_initialized()
await chat_stream_obj.context.ensure_history_initialized()
# 获取所有消息(历史+未读)
all_messages = (
chat_stream_obj.context_manager.context.history_messages +
chat_stream_obj.context_manager.get_unread_messages()
chat_stream_obj.context.history_messages +
chat_stream_obj.context.get_unread_messages()
)
# 转换为字典格式,限制数量
@@ -2071,12 +2071,12 @@ class DefaultReplyer:
if chat_stream_obj:
# 确保历史消息已初始化
await chat_stream_obj.context_manager.ensure_history_initialized()
await chat_stream_obj.context.ensure_history_initialized()
# 获取所有消息(历史+未读)
all_messages = (
chat_stream_obj.context_manager.context.history_messages +
chat_stream_obj.context_manager.get_unread_messages()
chat_stream_obj.context.history_messages +
chat_stream_obj.context.get_unread_messages()
)
# 转换为字典格式,限制数量

View File

@@ -8,9 +8,10 @@ import time
from collections import deque
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.base.component_types import ChatMode, ChatType
from . import BaseDataModel
@@ -20,6 +21,23 @@ if TYPE_CHECKING:
logger = get_logger("stream_context")
_background_tasks: set[asyncio.Task] = set()
_unified_memory_manager = None
def _get_unified_memory_manager():
"""获取记忆体系单例"""
global _unified_memory_manager
if _unified_memory_manager is None:
try:
from src.memory_graph.manager_singleton import get_unified_memory_manager
_unified_memory_manager = get_unified_memory_manager()
except Exception as e:
logger.warning(f"获取统一记忆管理器失败,可能未实现: {e}")
_unified_memory_manager = False # <20><><EFBFBD>Ϊ<EFBFBD><CEAA><EFBFBD>ã<EFBFBD><C3A3><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ظ<EFBFBD><D8B8><EFBFBD><EFBFBD><EFBFBD>
return _unified_memory_manager if _unified_memory_manager is not False else None
class MessageStatus(Enum):
"""消息状态枚举"""
@@ -44,6 +62,7 @@ class StreamContext(BaseDataModel):
stream_id: str
chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊
chat_mode: ChatMode = ChatMode.FOCUS # 聊天模式,默认为专注模式
max_context_size: int = field(default_factory=lambda: getattr(global_config.chat, "max_context_size", 100))
unread_messages: list["DatabaseMessages"] = field(default_factory=list)
history_messages: list["DatabaseMessages"] = field(default_factory=list)
last_check_time: float = field(default_factory=time.time)
@@ -54,22 +73,15 @@ class StreamContext(BaseDataModel):
interruption_count: int = 0 # 打断计数器
last_interruption_time: float = 0.0 # 上次打断时间
# 独立分发周期字段
next_check_time: float = field(default_factory=time.time) # 下次检查时间
distribution_interval: float = 5.0 # 当前分发周期(秒)
# 新增字段以替代ChatMessageContext功能
current_message: Optional["DatabaseMessages"] = None
priority_mode: str | None = None
priority_info: dict | None = None
triggering_user_id: str | None = None # 触发当前聊天流的用户ID
is_replying: bool = False # 是否正在生成回复
triggering_user_id: str | None = None # 记录当前触发的用户ID
is_replying: bool = False # 是否正在进行回复
processing_message_id: str | None = None # 当前正在规划/处理的目标消息ID用于防止重复回复
decision_history: list["DecisionRecord"] = field(default_factory=list) # 决策历史
# 消息缓存系统相关字段
message_cache: deque["DatabaseMessages"] = field(default_factory=deque) # 消息缓存队列
is_cache_enabled: bool = False # 是否为此流启用缓存
is_cache_enabled: bool = False # 是否为当前用户启用缓存
cache_stats: dict = field(default_factory=lambda: {
"total_cached_messages": 0,
"total_flushed_messages": 0,
@@ -77,6 +89,117 @@ class StreamContext(BaseDataModel):
"cache_misses": 0
}) # 缓存统计信息
created_time: float = field(default_factory=time.time)
last_access_time: float = field(default_factory=time.time)
access_count: int = 0
total_messages: int = 0
_history_initialized: bool = field(default=False, init=False)
def __post_init__(self):
"""初始化历史消息异步加载"""
if not self.max_context_size or self.max_context_size <= 0:
self.max_context_size = getattr(global_config.chat, "max_context_size", 100)
try:
loop = asyncio.get_event_loop()
if loop.is_running():
task = asyncio.create_task(self._initialize_history_from_db())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
except RuntimeError:
# 事件循环未运行时await ensure_history_initialized 进行初始化
pass
def _update_access_stats(self):
"""更新访问统计信息,记录最后访问时间"""
self.last_access_time = time.time()
self.access_count += 1
async def add_message(self, message: "DatabaseMessages", skip_energy_update: bool = False) -> bool:
"""添加消息到上下文,支持跳过能量更新的选项"""
try:
cache_enabled = global_config.chat.enable_message_cache
if cache_enabled and not self.is_cache_enabled:
self.enable_cache(True)
logger.debug(f"为StreamContext {self.stream_id} 启用消息缓存系统")
if message.interest_value is None:
message.interest_value = 0.3
message.should_reply = False
message.should_act = False
message.interest_calculated = False
message.semantic_embedding = None
message.is_read = False
success = self.add_message_with_cache_check(message, force_direct=not cache_enabled)
if not success:
logger.error(f"StreamContext消息添加失败: {self.stream_id}")
return False
self._detect_chat_type(message)
self.total_messages += 1
self._update_access_stats()
if cache_enabled and self.is_cache_enabled:
if self.is_chatter_processing:
logger.debug(f"消息已缓存到StreamContext等待处理: stream={self.stream_id}")
else:
logger.debug(f"消息直接添加到StreamContext未处理列表: stream={self.stream_id}")
else:
logger.debug(f"消息添加到StreamContext成功: {self.stream_id}")
# ͬ<><CDAC><EFBFBD><EFBFBD><EFBFBD>ݵ<EFBFBD>ͳһ<CDB3><D2BB><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
try:
if global_config.memory and global_config.memory.enable:
unified_manager = _get_unified_memory_manager()
if unified_manager:
message_dict = {
"message_id": str(message.message_id),
"sender_id": message.user_info.user_id,
"sender_name": message.user_info.user_nickname,
"content": message.processed_plain_text or message.display_message or "",
"timestamp": message.time,
"platform": message.chat_info.platform,
"stream_id": self.stream_id,
}
await unified_manager.add_message(message_dict)
logger.debug(f"<EFBFBD><EFBFBD>Ϣ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ӵ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ϵͳ: {message.message_id}")
except Exception as e:
logger.error(f"<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ϣ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ϵͳʧ<EFBFBD><EFBFBD>: {e}", exc_info=True)
return True
except Exception as e:
logger.error(f"<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ϣ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ʧ<EFBFBD><EFBFBD> {self.stream_id}: {e}", exc_info=True)
return False
async def update_message(self, message_id: str, updates: dict[str, Any]) -> bool:
"""<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>е<EFBFBD><EFBFBD><EFBFBD>Ϣ"""
try:
for message in self.unread_messages:
if str(message.message_id) == str(message_id):
if "interest_value" in updates:
message.interest_value = updates["interest_value"]
if "actions" in updates:
message.actions = updates["actions"]
if "should_reply" in updates:
message.should_reply = updates["should_reply"]
break
for message in self.history_messages:
if str(message.message_id) == str(message_id):
if "interest_value" in updates:
message.interest_value = updates["interest_value"]
if "actions" in updates:
message.actions = updates["actions"]
if "should_reply" in updates:
message.should_reply = updates["should_reply"]
break
logger.debug(f"<EFBFBD><EFBFBD><EFBFBD>µ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ϣ: {self.stream_id}/{message_id}")
return True
except Exception as e:
logger.error(f"<EFBFBD><EFBFBD><EFBFBD>µ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ϣʧ<EFBFBD><EFBFBD> {self.stream_id}/{message_id}: {e}", exc_info=True)
return False
def add_action_to_message(self, message_id: str, action: str):
"""
向指定消息添加执行的动作
@@ -113,9 +236,7 @@ class StreamContext(BaseDataModel):
# 应用历史消息长度限制
if max_history_size is None:
# 从全局配置获取最大历史消息数量
from src.config.config import global_config
max_history_size = getattr(global_config.chat, "max_context_size", 40)
max_history_size = self.max_context_size
# 如果历史消息已达到最大长度,移除最旧的消息
if len(self.history_messages) >= max_history_size:
@@ -136,6 +257,44 @@ class StreamContext(BaseDataModel):
recent_history = self.history_messages[-limit:] if len(self.history_messages) > limit else self.history_messages
return recent_history
def get_messages(self, limit: int | None = None, include_unread: bool = True) -> list["DatabaseMessages"]:
"""获取上下文中的消息集合"""
try:
messages: list["DatabaseMessages"] = []
if include_unread:
messages.extend(self.get_unread_messages())
if limit:
messages.extend(self.get_history_messages(limit=limit))
else:
messages.extend(self.get_history_messages())
messages.sort(key=lambda msg: getattr(msg, "time", 0))
if limit and len(messages) > limit:
messages = messages[-limit:]
self._update_access_stats()
return messages
except Exception as e:
logger.error(f"获取上下文消息失败 {self.stream_id}: {e}", exc_info=True)
return []
def mark_messages_as_read(self, message_ids: list[str]) -> bool:
"""批量标记消息为已读"""
try:
marked_count = 0
for message_id in message_ids:
try:
self.mark_message_as_read(message_id, max_history_size=self.max_context_size)
marked_count += 1
except Exception as e:
logger.warning(f"标记消息已读失败 {message_id}: {e}")
return marked_count > 0
except Exception as e:
logger.error(f"批量标记消息已读失败 {self.stream_id}: {e}", exc_info=True)
return False
def calculate_interruption_probability(self, max_limit: int, min_probability: float = 0.1, probability_factor: float | None = None) -> float:
"""计算打断概率 - 使用反比例函数模型
@@ -175,6 +334,75 @@ class StreamContext(BaseDataModel):
probability = max(min_probability, probability)
return max(0.0, min(1.0, probability))
async def clear_context(self) -> bool:
"""清空上下文的未读与历史消息并重置状态"""
try:
self.unread_messages.clear()
self.history_messages.clear()
for attr in ["interruption_count", "afc_threshold_adjustment", "last_check_time"]:
if hasattr(self, attr):
if attr in ["interruption_count", "afc_threshold_adjustment"]:
setattr(self, attr, 0)
else:
setattr(self, attr, time.time())
await self._update_stream_energy()
logger.debug(f"清空上下文成功: {self.stream_id}")
return True
except Exception as e:
logger.error(f"清空上下文失败 {self.stream_id}: {e}", exc_info=True)
return False
def get_statistics(self) -> dict[str, Any]:
"""获取上下文统计信息"""
try:
current_time = time.time()
uptime = current_time - self.created_time
stats = {
"stream_id": self.stream_id,
"context_type": type(self).__name__,
"total_messages": len(self.history_messages) + len(self.unread_messages),
"unread_messages": len(self.unread_messages),
"history_messages": len(self.history_messages),
"is_active": self.is_active,
"last_check_time": self.last_check_time,
"interruption_count": self.interruption_count,
"afc_threshold_adjustment": getattr(self, "afc_threshold_adjustment", 0.0),
"created_time": self.created_time,
"last_access_time": self.last_access_time,
"access_count": self.access_count,
"uptime_seconds": uptime,
"idle_seconds": current_time - self.last_access_time,
}
stats["cache_stats"] = self.get_cache_stats()
return stats
except Exception as e:
logger.error(f"获取上下文统计失败 {self.stream_id}: {e}", exc_info=True)
return {}
def validate_integrity(self) -> bool:
"""校验上下文结构完整性"""
try:
required_attrs = ["stream_id", "unread_messages", "history_messages"]
for attr in required_attrs:
if not hasattr(self, attr):
logger.warning(f"上下文缺少必要属性: {attr}")
return False
all_messages = self.unread_messages + self.history_messages
message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")]
if len(message_ids) != len(set(message_ids)):
logger.warning(f"上下文中存在重复的消息ID: {self.stream_id}")
return False
return True
except Exception as e:
logger.error(f"校验上下文完整性失败 {self.stream_id}: {e}")
return False
async def increment_interruption_count(self):
"""增加打断计数"""
self.interruption_count += 1
@@ -239,6 +467,131 @@ class StreamContext(BaseDataModel):
return self.history_messages[-1]
return None
async def ensure_history_initialized(self):
"""初始化历史消息异步加载"""
if not self._history_initialized:
await self._initialize_history_from_db()
async def refresh_focus_energy_from_history(self) -> None:
"""根据历史消息刷新关注能量"""
await self._update_stream_energy(include_unread=False)
async def _update_stream_energy(self, include_unread: bool = False) -> None:
"""使用当前上下文消息更新关注能量"""
try:
history_messages = self.get_history_messages(limit=self.max_context_size)
messages: list["DatabaseMessages"] = list(history_messages)
if include_unread:
messages.extend(self.get_unread_messages())
user_id = None
if messages:
last_message = messages[-1]
if hasattr(last_message, "user_info") and last_message.user_info:
user_id = last_message.user_info.user_id
from src.chat.energy_system import energy_manager
await energy_manager.calculate_focus_energy(
stream_id=self.stream_id,
messages=messages,
user_id=user_id,
)
except Exception as e:
logger.error(f"更新能量体系失败 {self.stream_id}: {e}")
async def _initialize_history_from_db(self):
"""Load history messages from database into context."""
if self._history_initialized:
logger.debug(f"历史信息已初始化,stream={self.stream_id}, 当前条数={len(self.history_messages)}")
return
logger.info(f"?? [历史加载] 开始从数据库读取历史消息: {self.stream_id}")
self._history_initialized = True
try:
logger.debug(f"开始加载数据库历史消息: {self.stream_id}")
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
db_messages = await get_raw_msg_before_timestamp_with_chat(
chat_id=self.stream_id,
timestamp=time.time(),
limit=self.max_context_size,
)
if db_messages:
logger.info(f"[历史加载] 从数据库获取到 {len(db_messages)} 条历史消息")
loaded_count = 0
for msg_dict in db_messages:
try:
db_msg = DatabaseMessages(**msg_dict)
db_msg.is_read = True
self.history_messages.append(db_msg)
loaded_count += 1
except Exception as e:
logger.warning(f"转换历史消息失败 (message_id={msg_dict.get('message_id', 'unknown')}): {e}")
continue
if len(self.history_messages) > self.max_context_size:
removed_count = len(self.history_messages) - self.max_context_size
self.history_messages = self.history_messages[-self.max_context_size :]
logger.debug(f"[历史加载] 移除了 {removed_count} 条最早的消息以适配当前容量限制")
logger.info(f"[历史加载] 成功加载 {loaded_count} 条历史消息到内存: {self.stream_id}")
else:
logger.debug(f"无历史消息需要加载: {self.stream_id}")
except Exception as e:
logger.error(f"从数据库加载历史消息失败: {self.stream_id}, {e}", exc_info=True)
self._history_initialized = False
def _detect_chat_type(self, message: "DatabaseMessages"):
"""基于消息内容检测聊天类型"""
if len(self.unread_messages) == 1:
if message.chat_info.group_info:
self.chat_type = ChatType.GROUP
else:
self.chat_type = ChatType.PRIVATE
async def _calculate_message_interest(self, message: "DatabaseMessages") -> float:
"""调用兴趣系统计算消息兴趣值"""
try:
from src.chat.interest_system.interest_manager import get_interest_manager
interest_manager = get_interest_manager()
if interest_manager.has_calculator():
result = await interest_manager.calculate_interest(message)
if result.success:
message.interest_value = result.interest_value
message.should_reply = result.should_reply
message.should_act = result.should_act
message.interest_calculated = True
logger.debug(
f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
f"should_reply: {result.should_reply}, should_act: {result.should_act}"
)
return result.interest_value
else:
logger.warning(f"消息 {message.message_id} 兴趣值计算失败: {result.error_message}")
message.interest_calculated = False
return 0.5
else:
logger.debug("未找到兴趣计算器,使用默认兴趣值")
return 0.5
except Exception as e:
logger.error(f"计算消息兴趣时出现异常: {e}", exc_info=True)
if hasattr(message, "interest_calculated"):
message.interest_calculated = False
return 0.5
def check_types(self, types: list) -> bool:
"""
检查当前消息是否支持指定的类型
@@ -332,14 +685,6 @@ class StreamContext(BaseDataModel):
logger.debug("[check_types] ✅ 备用方案通过所有类型检查")
return True
def get_priority_mode(self) -> str | None:
"""获取优先级模式"""
return self.priority_mode
def get_priority_info(self) -> dict | None:
"""获取优先级信息"""
return self.priority_info
# ==================== 消息缓存系统方法 ====================
def enable_cache(self, enabled: bool = True):

View File

@@ -477,7 +477,7 @@ class ChatterPlanExecutor:
)
# 添加到chat_stream的已读消息中
chat_stream.context_manager.context.history_messages.append(bot_message)
chat_stream.context.history_messages.append(bot_message)
logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
except Exception as e:

View File

@@ -169,7 +169,7 @@ class ChatterPlanFilter:
logger.debug("尝试添加空的决策历史,已跳过")
return
context = chat_stream.context_manager.context
context = chat_stream.context
new_record = DecisionRecord(thought=thought, action=action)
# 添加新记录
@@ -204,7 +204,7 @@ class ChatterPlanFilter:
if not chat_stream:
return ""
context = chat_stream.context_manager.context
context = chat_stream.context
if not context.decision_history:
return ""
@@ -344,7 +344,7 @@ class ChatterPlanFilter:
logger.warning(f"[plan_filter] 聊天流 {plan.chat_id} 不存在")
return "最近没有聊天内容。", "没有未读消息。", []
stream_context = chat_stream.context_manager
stream_context = chat_stream.context
# 获取真正的已读和未读消息
read_messages = (

View File

@@ -599,7 +599,7 @@ class ChatterActionPlanner:
if chat_manager:
chat_stream = await chat_manager.get_stream(context.stream_id)
if chat_stream:
chat_stream.context_manager.context.chat_mode = context.chat_mode
chat_stream.context.chat_mode = context.chat_mode
chat_stream.saved = False # 标记需要保存
logger.debug(f"已同步chat_mode {context.chat_mode.value} 到ChatStream {context.stream_id}")
except Exception as e:

View File

@@ -564,7 +564,7 @@ async def execute_proactive_thinking(stream_id: str):
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and chat_stream.context_manager.context.is_chatter_processing:
if chat_stream and chat_stream.context.is_chatter_processing:
logger.warning(f"⚠️ 主动思考跳过:聊天流 {stream_id} 的 chatter 正在处理消息")
return
except Exception as e:

View File

@@ -61,7 +61,7 @@ class ReminderTask(AsyncTask):
logger.info(f"执行提醒任务: 给 {self.target_user_name} 发送关于 '{self.event_details}' 的提醒")
extra_info = f"现在是提醒时间,请你以一种符合你人设的、俏皮的方式提醒 {self.target_user_name}\n提醒内容: {self.event_details}\n设置提醒的人: {self.creator_name}"
last_message = self.chat_stream.context_manager.context.get_last_message()
last_message = self.chat_stream.context.get_last_message()
reply_message_dict = last_message.flatten() if last_message else None
success, reply_set, _ = await generator_api.generate_reply(
chat_stream=self.chat_stream,
@@ -523,7 +523,7 @@ class RemindAction(BaseAction):
# 4. 生成并发送确认消息
extra_info = f"你已经成功设置了一个提醒,请以一种符合你人设的、俏皮的方式回复用户。\n提醒时间: {target_time.strftime('%Y-%m-%d %H:%M:%S')}\n提醒对象: {user_name_to_remind}\n提醒内容: {event_details}"
last_message = self.chat_stream.context_manager.context.get_last_message()
last_message = self.chat_stream.context.get_last_message()
reply_message_dict = last_message.flatten() if last_message else None
success, reply_set, _ = await generator_api.generate_reply(
chat_stream=self.chat_stream,

View File

@@ -54,7 +54,7 @@ class TTSAction(BaseAction):
success, response_set, _ = await generate_reply(
chat_stream=self.chat_stream,
reply_message=self.chat_stream.context_manager.context.get_last_message(),
reply_message=self.chat_stream.context.get_last_message(),
enable_tool=global_config.tool.enable_tool,
request_type="chat.tts",
from_plugin=False,