diff --git a/src/chat/chat_loop/energy_manager.py b/src/chat/chat_loop/energy_manager.py index 5e5e5eea5..0507dc84f 100644 --- a/src/chat/chat_loop/energy_manager.py +++ b/src/chat/chat_loop/energy_manager.py @@ -102,6 +102,7 @@ class EnergyManager: self.context.sleep_pressure -= decay_per_10s self.context.sleep_pressure = max(self.context.sleep_pressure, 0) self._log_sleep_pressure_change("睡眠压力释放") + self.context.save_context_state() else: # 清醒时:处理能量衰减 is_group_chat = self.context.chat_stream.group_info is not None @@ -122,6 +123,7 @@ class EnergyManager: self.context.energy_value = max(self.context.energy_value, 0.3) self._log_energy_change("能量值衰减") + self.context.save_context_state() def _should_log_energy(self) -> bool: """ @@ -149,6 +151,7 @@ class EnergyManager: self.context.sleep_pressure += increment self.context.sleep_pressure = min(self.context.sleep_pressure, 100.0) # 设置一个100的上限 self._log_sleep_pressure_change("执行动作,睡眠压力累积") + self.context.save_context_state() def _log_energy_change(self, action: str, reason: str = ""): """ diff --git a/src/chat/chat_loop/hfc_context.py b/src/chat/chat_loop/hfc_context.py index 1920c5417..1f5c3e1df 100644 --- a/src/chat/chat_loop/hfc_context.py +++ b/src/chat/chat_loop/hfc_context.py @@ -2,7 +2,6 @@ from typing import List, Optional, TYPE_CHECKING import time from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.common.logger import get_logger -from src.manager.local_store_manager import local_storage from src.person_info.relationship_builder_manager import RelationshipBuilder from src.chat.express.expression_learner import ExpressionLearner from src.plugin_system.base.component_types import ChatMode @@ -42,8 +41,8 @@ class HfcContext: self.expression_learner: Optional[ExpressionLearner] = None self.loop_mode = ChatMode.NORMAL - self.energy_value = 5.0 - self.sleep_pressure = 0.0 + self.energy_value = self.chat_stream.energy_value + self.sleep_pressure = self.chat_stream.sleep_pressure self.was_sleeping = False # 用于检测睡眠状态的切换 self.last_message_time = time.time() @@ -61,30 +60,8 @@ class HfcContext: self.wakeup_manager: Optional['WakeUpManager'] = None self.energy_manager: Optional['EnergyManager'] = None - self._load_context_state() - - def _get_storage_key(self) -> str: - """获取当前聊天流的本地存储键""" - return f"hfc_context_state_{self.stream_id}" - - def _load_context_state(self): - """从本地存储加载状态""" - state = local_storage[self._get_storage_key()] - if state and isinstance(state, dict): - self.energy_value = state.get("energy_value", 5.0) - self.sleep_pressure = state.get("sleep_pressure", 0.0) - logger = get_logger("hfc_context") - logger.info(f"{self.log_prefix} 成功从本地存储加载HFC上下文状态: {state}") - else: - logger = get_logger("hfc_context") - logger.info(f"{self.log_prefix} 未找到本地HFC上下文状态,将使用默认值初始化。") - def save_context_state(self): - """将当前状态保存到本地存储""" - state = { - "energy_value": self.energy_value, - "sleep_pressure": self.sleep_pressure, - } - local_storage[self._get_storage_key()] = state - logger = get_logger("hfc_context") - logger.debug(f"{self.log_prefix} 已将HFC上下文状态保存到本地存储: {state}") \ No newline at end of file + """将当前状态保存到聊天流""" + if self.chat_stream: + self.chat_stream.energy_value = self.energy_value + self.chat_stream.sleep_pressure = self.sleep_pressure \ No newline at end of file diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index d7c103222..6989eca9a 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -77,6 +77,8 @@ class ChatStream: self.group_info = group_info self.create_time = data.get("create_time", time.time()) if data else time.time() self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time + self.energy_value = data.get("energy_value", 5.0) if data else 5.0 + self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0 self.saved = False self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息 @@ -89,6 +91,8 @@ class ChatStream: "group_info": self.group_info.to_dict() if self.group_info else None, "create_time": self.create_time, "last_active_time": self.last_active_time, + "energy_value": self.energy_value, + "sleep_pressure": self.sleep_pressure, } @classmethod @@ -249,7 +253,7 @@ class ChatManager: "user_cardname": model_instance.user_cardname or "", } group_info_data = None - if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息 + if model_instance.group_id: group_info_data = { "platform": model_instance.group_platform, "group_id": model_instance.group_id, @@ -263,6 +267,8 @@ class ChatManager: "group_info": group_info_data, "create_time": model_instance.create_time, "last_active_time": model_instance.last_active_time, + "energy_value": model_instance.energy_value, + "sleep_pressure": model_instance.sleep_pressure, } stream = ChatStream.from_dict(data_for_from_dict) # 更新用户信息和群组信息 @@ -346,6 +352,8 @@ class ChatManager: "group_platform": group_info_d["platform"] if group_info_d else "", "group_id": group_info_d["group_id"] if group_info_d else "", "group_name": group_info_d["group_name"] if group_info_d else "", + "energy_value": s_data_dict.get("energy_value", 5.0), + "sleep_pressure": s_data_dict.get("sleep_pressure", 0.0), } # 根据数据库类型选择插入语句 @@ -411,6 +419,8 @@ class ChatManager: "group_info": group_info_data, "create_time": model_instance.create_time, "last_active_time": model_instance.last_active_time, + "energy_value": model_instance.energy_value, + "sleep_pressure": model_instance.sleep_pressure, } loaded_streams_data.append(data_for_from_dict) session.commit() diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 779179ff9..31e77a2be 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -48,6 +48,8 @@ class ChatStreams(Base): user_id = Column(get_string_field(100), nullable=False, index=True) user_nickname = Column(Text, nullable=False) user_cardname = Column(Text, nullable=True) + energy_value = Column(Float, nullable=True, default=5.0) + sleep_pressure = Column(Float, nullable=True, default=0.0) __table_args__ = ( Index('idx_chatstreams_stream_id', 'stream_id'),