重构ChatStream和StreamContext:移除context_manager引用
- 在ChatStream及相关类中,将所有context_manager的实例替换为直接上下文访问。 - 更新方法,利用新的上下文结构来管理聊天状态和消息。 - 增强的StreamContext,增加了用于消息处理、统计和历史管理的方法。 - 在重构过程中改进了错误处理和日志记录。
This commit is contained in:
@@ -8,9 +8,10 @@ import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
from . import BaseDataModel
|
||||
@@ -20,6 +21,23 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = get_logger("stream_context")
|
||||
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
_unified_memory_manager = None
|
||||
|
||||
|
||||
def _get_unified_memory_manager():
|
||||
"""获取记忆体系单例"""
|
||||
global _unified_memory_manager
|
||||
if _unified_memory_manager is None:
|
||||
try:
|
||||
from src.memory_graph.manager_singleton import get_unified_memory_manager
|
||||
|
||||
_unified_memory_manager = get_unified_memory_manager()
|
||||
except Exception as e:
|
||||
logger.warning(f"获取统一记忆管理器失败,可能未实现: {e}")
|
||||
_unified_memory_manager = False # <20><><EFBFBD>Ϊ<EFBFBD><CEAA><EFBFBD>ã<EFBFBD><C3A3><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ظ<EFBFBD><D8B8><EFBFBD><EFBFBD><EFBFBD>
|
||||
return _unified_memory_manager if _unified_memory_manager is not False else None
|
||||
|
||||
|
||||
class MessageStatus(Enum):
|
||||
"""消息状态枚举"""
|
||||
@@ -44,6 +62,7 @@ class StreamContext(BaseDataModel):
|
||||
stream_id: str
|
||||
chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊
|
||||
chat_mode: ChatMode = ChatMode.FOCUS # 聊天模式,默认为专注模式
|
||||
max_context_size: int = field(default_factory=lambda: getattr(global_config.chat, "max_context_size", 100))
|
||||
unread_messages: list["DatabaseMessages"] = field(default_factory=list)
|
||||
history_messages: list["DatabaseMessages"] = field(default_factory=list)
|
||||
last_check_time: float = field(default_factory=time.time)
|
||||
@@ -54,22 +73,15 @@ class StreamContext(BaseDataModel):
|
||||
interruption_count: int = 0 # 打断计数器
|
||||
last_interruption_time: float = 0.0 # 上次打断时间
|
||||
|
||||
# 独立分发周期字段
|
||||
next_check_time: float = field(default_factory=time.time) # 下次检查时间
|
||||
distribution_interval: float = 5.0 # 当前分发周期(秒)
|
||||
|
||||
# 新增字段以替代ChatMessageContext功能
|
||||
current_message: Optional["DatabaseMessages"] = None
|
||||
priority_mode: str | None = None
|
||||
priority_info: dict | None = None
|
||||
triggering_user_id: str | None = None # 触发当前聊天流的用户ID
|
||||
is_replying: bool = False # 是否正在生成回复
|
||||
triggering_user_id: str | None = None # 记录当前触发的用户ID
|
||||
is_replying: bool = False # 是否正在进行回复
|
||||
processing_message_id: str | None = None # 当前正在规划/处理的目标消息ID,用于防止重复回复
|
||||
decision_history: list["DecisionRecord"] = field(default_factory=list) # 决策历史
|
||||
|
||||
# 消息缓存系统相关字段
|
||||
message_cache: deque["DatabaseMessages"] = field(default_factory=deque) # 消息缓存队列
|
||||
is_cache_enabled: bool = False # 是否为此流启用缓存
|
||||
is_cache_enabled: bool = False # 是否为当前用户启用缓存
|
||||
cache_stats: dict = field(default_factory=lambda: {
|
||||
"total_cached_messages": 0,
|
||||
"total_flushed_messages": 0,
|
||||
@@ -77,6 +89,117 @@ class StreamContext(BaseDataModel):
|
||||
"cache_misses": 0
|
||||
}) # 缓存统计信息
|
||||
|
||||
created_time: float = field(default_factory=time.time)
|
||||
last_access_time: float = field(default_factory=time.time)
|
||||
access_count: int = 0
|
||||
total_messages: int = 0
|
||||
_history_initialized: bool = field(default=False, init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化历史消息异步加载"""
|
||||
if not self.max_context_size or self.max_context_size <= 0:
|
||||
self.max_context_size = getattr(global_config.chat, "max_context_size", 100)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
task = asyncio.create_task(self._initialize_history_from_db())
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
except RuntimeError:
|
||||
# 事件循环未运行时,await ensure_history_initialized 进行初始化
|
||||
pass
|
||||
|
||||
def _update_access_stats(self):
|
||||
"""更新访问统计信息,记录最后访问时间"""
|
||||
self.last_access_time = time.time()
|
||||
self.access_count += 1
|
||||
|
||||
async def add_message(self, message: "DatabaseMessages", skip_energy_update: bool = False) -> bool:
|
||||
"""添加消息到上下文,支持跳过能量更新的选项"""
|
||||
try:
|
||||
cache_enabled = global_config.chat.enable_message_cache
|
||||
if cache_enabled and not self.is_cache_enabled:
|
||||
self.enable_cache(True)
|
||||
logger.debug(f"为StreamContext {self.stream_id} 启用消息缓存系统")
|
||||
|
||||
if message.interest_value is None:
|
||||
message.interest_value = 0.3
|
||||
message.should_reply = False
|
||||
message.should_act = False
|
||||
message.interest_calculated = False
|
||||
message.semantic_embedding = None
|
||||
message.is_read = False
|
||||
|
||||
success = self.add_message_with_cache_check(message, force_direct=not cache_enabled)
|
||||
if not success:
|
||||
logger.error(f"StreamContext消息添加失败: {self.stream_id}")
|
||||
return False
|
||||
|
||||
self._detect_chat_type(message)
|
||||
self.total_messages += 1
|
||||
self._update_access_stats()
|
||||
|
||||
if cache_enabled and self.is_cache_enabled:
|
||||
if self.is_chatter_processing:
|
||||
logger.debug(f"消息已缓存到StreamContext等待处理: stream={self.stream_id}")
|
||||
else:
|
||||
logger.debug(f"消息直接添加到StreamContext未处理列表: stream={self.stream_id}")
|
||||
else:
|
||||
logger.debug(f"消息添加到StreamContext成功: {self.stream_id}")
|
||||
# ͬ<><CDAC><EFBFBD><EFBFBD><EFBFBD>ݵ<EFBFBD>ͳһ<CDB3><D2BB><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
|
||||
try:
|
||||
if global_config.memory and global_config.memory.enable:
|
||||
unified_manager = _get_unified_memory_manager()
|
||||
if unified_manager:
|
||||
message_dict = {
|
||||
"message_id": str(message.message_id),
|
||||
"sender_id": message.user_info.user_id,
|
||||
"sender_name": message.user_info.user_nickname,
|
||||
"content": message.processed_plain_text or message.display_message or "",
|
||||
"timestamp": message.time,
|
||||
"platform": message.chat_info.platform,
|
||||
"stream_id": self.stream_id,
|
||||
}
|
||||
await unified_manager.add_message(message_dict)
|
||||
logger.debug(f"<EFBFBD><EFBFBD>Ϣ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ӵ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ϵͳ: {message.message_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ϣ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ϵͳʧ<EFBFBD><EFBFBD>: {e}", exc_info=True)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ϣ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ʧ<EFBFBD><EFBFBD> {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def update_message(self, message_id: str, updates: dict[str, Any]) -> bool:
|
||||
"""<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>е<EFBFBD><EFBFBD><EFBFBD>Ϣ"""
|
||||
try:
|
||||
for message in self.unread_messages:
|
||||
if str(message.message_id) == str(message_id):
|
||||
if "interest_value" in updates:
|
||||
message.interest_value = updates["interest_value"]
|
||||
if "actions" in updates:
|
||||
message.actions = updates["actions"]
|
||||
if "should_reply" in updates:
|
||||
message.should_reply = updates["should_reply"]
|
||||
break
|
||||
|
||||
for message in self.history_messages:
|
||||
if str(message.message_id) == str(message_id):
|
||||
if "interest_value" in updates:
|
||||
message.interest_value = updates["interest_value"]
|
||||
if "actions" in updates:
|
||||
message.actions = updates["actions"]
|
||||
if "should_reply" in updates:
|
||||
message.should_reply = updates["should_reply"]
|
||||
break
|
||||
|
||||
logger.debug(f"<EFBFBD><EFBFBD><EFBFBD>µ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ϣ: {self.stream_id}/{message_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"<EFBFBD><EFBFBD><EFBFBD>µ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ϣʧ<EFBFBD><EFBFBD> {self.stream_id}/{message_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def add_action_to_message(self, message_id: str, action: str):
|
||||
"""
|
||||
向指定消息添加执行的动作
|
||||
@@ -113,9 +236,7 @@ class StreamContext(BaseDataModel):
|
||||
|
||||
# 应用历史消息长度限制
|
||||
if max_history_size is None:
|
||||
# 从全局配置获取最大历史消息数量
|
||||
from src.config.config import global_config
|
||||
max_history_size = getattr(global_config.chat, "max_context_size", 40)
|
||||
max_history_size = self.max_context_size
|
||||
|
||||
# 如果历史消息已达到最大长度,移除最旧的消息
|
||||
if len(self.history_messages) >= max_history_size:
|
||||
@@ -136,6 +257,44 @@ class StreamContext(BaseDataModel):
|
||||
recent_history = self.history_messages[-limit:] if len(self.history_messages) > limit else self.history_messages
|
||||
return recent_history
|
||||
|
||||
def get_messages(self, limit: int | None = None, include_unread: bool = True) -> list["DatabaseMessages"]:
|
||||
"""获取上下文中的消息集合"""
|
||||
try:
|
||||
messages: list["DatabaseMessages"] = []
|
||||
if include_unread:
|
||||
messages.extend(self.get_unread_messages())
|
||||
|
||||
if limit:
|
||||
messages.extend(self.get_history_messages(limit=limit))
|
||||
else:
|
||||
messages.extend(self.get_history_messages())
|
||||
|
||||
messages.sort(key=lambda msg: getattr(msg, "time", 0))
|
||||
|
||||
if limit and len(messages) > limit:
|
||||
messages = messages[-limit:]
|
||||
|
||||
self._update_access_stats()
|
||||
return messages
|
||||
except Exception as e:
|
||||
logger.error(f"获取上下文消息失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def mark_messages_as_read(self, message_ids: list[str]) -> bool:
|
||||
"""批量标记消息为已读"""
|
||||
try:
|
||||
marked_count = 0
|
||||
for message_id in message_ids:
|
||||
try:
|
||||
self.mark_message_as_read(message_id, max_history_size=self.max_context_size)
|
||||
marked_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"标记消息已读失败 {message_id}: {e}")
|
||||
return marked_count > 0
|
||||
except Exception as e:
|
||||
logger.error(f"批量标记消息已读失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def calculate_interruption_probability(self, max_limit: int, min_probability: float = 0.1, probability_factor: float | None = None) -> float:
|
||||
"""计算打断概率 - 使用反比例函数模型
|
||||
|
||||
@@ -175,6 +334,75 @@ class StreamContext(BaseDataModel):
|
||||
probability = max(min_probability, probability)
|
||||
return max(0.0, min(1.0, probability))
|
||||
|
||||
async def clear_context(self) -> bool:
|
||||
"""清空上下文的未读与历史消息并重置状态"""
|
||||
try:
|
||||
self.unread_messages.clear()
|
||||
self.history_messages.clear()
|
||||
for attr in ["interruption_count", "afc_threshold_adjustment", "last_check_time"]:
|
||||
if hasattr(self, attr):
|
||||
if attr in ["interruption_count", "afc_threshold_adjustment"]:
|
||||
setattr(self, attr, 0)
|
||||
else:
|
||||
setattr(self, attr, time.time())
|
||||
await self._update_stream_energy()
|
||||
logger.debug(f"清空上下文成功: {self.stream_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"清空上下文失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def get_statistics(self) -> dict[str, Any]:
|
||||
"""获取上下文统计信息"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
uptime = current_time - self.created_time
|
||||
|
||||
stats = {
|
||||
"stream_id": self.stream_id,
|
||||
"context_type": type(self).__name__,
|
||||
"total_messages": len(self.history_messages) + len(self.unread_messages),
|
||||
"unread_messages": len(self.unread_messages),
|
||||
"history_messages": len(self.history_messages),
|
||||
"is_active": self.is_active,
|
||||
"last_check_time": self.last_check_time,
|
||||
"interruption_count": self.interruption_count,
|
||||
"afc_threshold_adjustment": getattr(self, "afc_threshold_adjustment", 0.0),
|
||||
"created_time": self.created_time,
|
||||
"last_access_time": self.last_access_time,
|
||||
"access_count": self.access_count,
|
||||
"uptime_seconds": uptime,
|
||||
"idle_seconds": current_time - self.last_access_time,
|
||||
}
|
||||
|
||||
stats["cache_stats"] = self.get_cache_stats()
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.error(f"获取上下文统计失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return {}
|
||||
|
||||
def validate_integrity(self) -> bool:
|
||||
"""校验上下文结构完整性"""
|
||||
try:
|
||||
required_attrs = ["stream_id", "unread_messages", "history_messages"]
|
||||
for attr in required_attrs:
|
||||
if not hasattr(self, attr):
|
||||
logger.warning(f"上下文缺少必要属性: {attr}")
|
||||
return False
|
||||
|
||||
all_messages = self.unread_messages + self.history_messages
|
||||
message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")]
|
||||
if len(message_ids) != len(set(message_ids)):
|
||||
logger.warning(f"上下文中存在重复的消息ID: {self.stream_id}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"校验上下文完整性失败 {self.stream_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def increment_interruption_count(self):
|
||||
"""增加打断计数"""
|
||||
self.interruption_count += 1
|
||||
@@ -239,6 +467,131 @@ class StreamContext(BaseDataModel):
|
||||
return self.history_messages[-1]
|
||||
return None
|
||||
|
||||
async def ensure_history_initialized(self):
|
||||
"""初始化历史消息异步加载"""
|
||||
if not self._history_initialized:
|
||||
await self._initialize_history_from_db()
|
||||
|
||||
async def refresh_focus_energy_from_history(self) -> None:
|
||||
"""根据历史消息刷新关注能量"""
|
||||
await self._update_stream_energy(include_unread=False)
|
||||
|
||||
async def _update_stream_energy(self, include_unread: bool = False) -> None:
|
||||
"""使用当前上下文消息更新关注能量"""
|
||||
try:
|
||||
history_messages = self.get_history_messages(limit=self.max_context_size)
|
||||
messages: list["DatabaseMessages"] = list(history_messages)
|
||||
|
||||
if include_unread:
|
||||
messages.extend(self.get_unread_messages())
|
||||
|
||||
user_id = None
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
if hasattr(last_message, "user_info") and last_message.user_info:
|
||||
user_id = last_message.user_info.user_id
|
||||
|
||||
from src.chat.energy_system import energy_manager
|
||||
|
||||
await energy_manager.calculate_focus_energy(
|
||||
stream_id=self.stream_id,
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新能量体系失败 {self.stream_id}: {e}")
|
||||
|
||||
async def _initialize_history_from_db(self):
|
||||
"""Load history messages from database into context."""
|
||||
if self._history_initialized:
|
||||
logger.debug(f"历史信息已初始化,stream={self.stream_id}, 当前条数={len(self.history_messages)}")
|
||||
return
|
||||
|
||||
logger.info(f"?? [历史加载] 开始从数据库读取历史消息: {self.stream_id}")
|
||||
self._history_initialized = True
|
||||
|
||||
try:
|
||||
logger.debug(f"开始加载数据库历史消息: {self.stream_id}")
|
||||
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||
|
||||
db_messages = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=self.max_context_size,
|
||||
)
|
||||
|
||||
if db_messages:
|
||||
logger.info(f"[历史加载] 从数据库获取到 {len(db_messages)} 条历史消息")
|
||||
loaded_count = 0
|
||||
for msg_dict in db_messages:
|
||||
try:
|
||||
db_msg = DatabaseMessages(**msg_dict)
|
||||
db_msg.is_read = True
|
||||
self.history_messages.append(db_msg)
|
||||
loaded_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"转换历史消息失败 (message_id={msg_dict.get('message_id', 'unknown')}): {e}")
|
||||
continue
|
||||
|
||||
if len(self.history_messages) > self.max_context_size:
|
||||
removed_count = len(self.history_messages) - self.max_context_size
|
||||
self.history_messages = self.history_messages[-self.max_context_size :]
|
||||
logger.debug(f"[历史加载] 移除了 {removed_count} 条最早的消息以适配当前容量限制")
|
||||
|
||||
logger.info(f"[历史加载] 成功加载 {loaded_count} 条历史消息到内存: {self.stream_id}")
|
||||
else:
|
||||
logger.debug(f"无历史消息需要加载: {self.stream_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载历史消息失败: {self.stream_id}, {e}", exc_info=True)
|
||||
self._history_initialized = False
|
||||
|
||||
def _detect_chat_type(self, message: "DatabaseMessages"):
|
||||
"""基于消息内容检测聊天类型"""
|
||||
if len(self.unread_messages) == 1:
|
||||
if message.chat_info.group_info:
|
||||
self.chat_type = ChatType.GROUP
|
||||
else:
|
||||
self.chat_type = ChatType.PRIVATE
|
||||
|
||||
async def _calculate_message_interest(self, message: "DatabaseMessages") -> float:
|
||||
"""调用兴趣系统计算消息兴趣值"""
|
||||
try:
|
||||
from src.chat.interest_system.interest_manager import get_interest_manager
|
||||
|
||||
interest_manager = get_interest_manager()
|
||||
|
||||
if interest_manager.has_calculator():
|
||||
result = await interest_manager.calculate_interest(message)
|
||||
|
||||
if result.success:
|
||||
message.interest_value = result.interest_value
|
||||
message.should_reply = result.should_reply
|
||||
message.should_act = result.should_act
|
||||
message.interest_calculated = True
|
||||
|
||||
logger.debug(
|
||||
f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
|
||||
f"should_reply: {result.should_reply}, should_act: {result.should_act}"
|
||||
)
|
||||
return result.interest_value
|
||||
else:
|
||||
logger.warning(f"消息 {message.message_id} 兴趣值计算失败: {result.error_message}")
|
||||
message.interest_calculated = False
|
||||
return 0.5
|
||||
else:
|
||||
logger.debug("未找到兴趣计算器,使用默认兴趣值")
|
||||
return 0.5
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算消息兴趣时出现异常: {e}", exc_info=True)
|
||||
if hasattr(message, "interest_calculated"):
|
||||
message.interest_calculated = False
|
||||
return 0.5
|
||||
|
||||
def check_types(self, types: list) -> bool:
|
||||
"""
|
||||
检查当前消息是否支持指定的类型
|
||||
@@ -332,14 +685,6 @@ class StreamContext(BaseDataModel):
|
||||
logger.debug("[check_types] ✅ 备用方案通过所有类型检查")
|
||||
return True
|
||||
|
||||
def get_priority_mode(self) -> str | None:
|
||||
"""获取优先级模式"""
|
||||
return self.priority_mode
|
||||
|
||||
def get_priority_info(self) -> dict | None:
|
||||
"""获取优先级信息"""
|
||||
return self.priority_info
|
||||
|
||||
# ==================== 消息缓存系统方法 ====================
|
||||
|
||||
def enable_cache(self, enabled: bool = True):
|
||||
|
||||
Reference in New Issue
Block a user