From 5962b442949c49c948749e7f5bc65c1f13fa150b Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Fri, 26 Sep 2025 19:17:24 +0800 Subject: [PATCH] =?UTF-8?q?refactor(chat):=20=E4=BC=98=E5=8C=96=E6=B6=88?= =?UTF-8?q?=E6=81=AF=E7=AE=A1=E7=90=86=E4=B8=8E=E6=89=93=E6=96=AD=E7=B3=BB?= =?UTF-8?q?=E7=BB=9F=EF=BC=8C=E6=B7=BB=E5=8A=A0=E6=89=93=E6=96=AD=E8=AE=A1?= =?UTF-8?q?=E6=95=B0=E4=B8=8E=E5=8E=86=E5=8F=B2=E6=B6=88=E6=81=AF=E5=8A=A0?= =?UTF-8?q?=E8=BD=BD=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_manager/message_manager.py | 73 +++++-- src/chat/message_receive/chat_stream.py | 196 ++++++++++++++---- .../data_models/message_manager_data_model.py | 132 ++++++++++-- src/common/database/sqlalchemy_models.py | 9 + 4 files changed, 334 insertions(+), 76 deletions(-) diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index 773f22f16..7671e55f1 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -88,7 +88,14 @@ class MessageManager: logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}") - def update_message_and_refresh_energy(self, stream_id: str, message_id: str, interest_degree: float = None, actions: list = None, should_reply: bool = None): + def update_message_and_refresh_energy( + self, + stream_id: str, + message_id: str, + interest_degree: float = None, + actions: list = None, + should_reply: bool = None, + ): """更新消息信息""" if stream_id in self.stream_contexts: context = self.stream_contexts[stream_id] @@ -287,6 +294,13 @@ class MessageManager: global_config.chat.interruption_max_limit, global_config.chat.interruption_probability_factor ) + # 检查是否已达到最大打断次数 + if context.interruption_count >= global_config.chat.interruption_max_limit: + logger.debug( + f"聊天流 {stream_id} 已达到最大打断次数 {context.interruption_count}/{global_config.chat.interruption_max_limit},跳过打断检查" + ) + return + # 根据概率决定是否打断 if random.random() < interruption_probability: logger.info(f"聊天流 {stream_id} 触发消息打断,打断概率: {interruption_probability:.2f}") @@ -301,9 +315,16 @@ class MessageManager: # 增加打断计数并应用afc阈值降低 context.increment_interruption_count() context.apply_interruption_afc_reduction(global_config.chat.interruption_afc_reduction) - logger.info( - f"聊天流 {stream_id} 已打断,当前打断次数: {context.interruption_count}/{global_config.chat.interruption_max_limit}, afc阈值调整: {context.get_afc_threshold_adjustment()}" - ) + + # 检查是否已达到最大次数 + if context.interruption_count >= global_config.chat.interruption_max_limit: + logger.warning( + f"聊天流 {stream_id} 已达到最大打断次数 {context.interruption_count}/{global_config.chat.interruption_max_limit},后续消息将不再打断" + ) + else: + logger.info( + f"聊天流 {stream_id} 已打断,当前打断次数: {context.interruption_count}/{global_config.chat.interruption_max_limit}, afc阈值调整: {context.get_afc_threshold_adjustment()}" + ) else: logger.debug(f"聊天流 {stream_id} 未触发打断,打断概率: {interruption_probability:.2f}") @@ -312,6 +333,9 @@ class MessageManager: if not global_config.chat.dynamic_distribution_enabled: return self.check_interval # 使用固定间隔 + from src.plugin_system.apis.chat_api import get_chat_manager + + chat_stream = get_chat_manager().get_stream(context.stream_id) # 获取该流的focus_energy(新的阈值感知版本) focus_energy = 0.5 # 默认值 avg_message_interest = 0.5 # 默认平均兴趣度 @@ -324,13 +348,13 @@ class MessageManager: all_messages = history_messages + unread_messages if all_messages: - message_interests = [msg.interest_degree for msg in all_messages if hasattr(msg, 'interest_degree')] + message_interests = [msg.interest_degree for msg in all_messages if hasattr(msg, "interest_degree")] avg_message_interest = sum(message_interests) / len(message_interests) if message_interests else 0.5 # 获取AFC阈值用于参考,添加None值检查 - reply_threshold = getattr(global_config.affinity_flow, 'reply_action_interest_threshold', 0.4) - non_reply_threshold = getattr(global_config.affinity_flow, 'non_reply_action_interest_threshold', 0.2) - high_match_threshold = getattr(global_config.affinity_flow, 'high_match_interest_threshold', 0.8) + reply_threshold = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4) + non_reply_threshold = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2) + high_match_threshold = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8) # 使用配置参数 base_interval = global_config.chat.dynamic_distribution_base_interval @@ -361,6 +385,7 @@ class MessageManager: # 添加随机扰动避免同步 import random + jitter = random.uniform(1.0 - jitter_factor, 1.0 + jitter_factor) final_interval = interval * jitter @@ -392,7 +417,7 @@ class MessageManager: def _calculate_next_manager_delay(self) -> float: """计算管理器下次检查的延迟时间""" current_time = time.time() - min_delay = float('inf') + min_delay = float("inf") # 找到最近需要检查的流 for context in self.stream_contexts.values(): @@ -407,7 +432,7 @@ class MessageManager: break # 如果没有活跃流,使用默认间隔 - if min_delay == float('inf'): + if min_delay == float("inf"): return self.check_interval # 确保最小延迟 @@ -444,7 +469,10 @@ class MessageManager: # 如果没有处理任务,创建一个 if not context.processing_task or context.processing_task.done(): - focus_energy = context.chat_stream.focus_energy if hasattr(context, 'chat_stream') and context.chat_stream else 0.5 + from src.plugin_system.apis.chat_api import get_chat_manager + + chat_stream = get_chat_manager().get_stream(context.stream_id) + focus_energy = chat_stream.focus_energy if chat_stream else 0.5 # 根据优先级记录日志 if focus_energy >= 0.7: @@ -468,10 +496,7 @@ class MessageManager: self.stats.active_streams = active_count if processed_streams > 0: - logger.debug( - f"本次循环处理了 {processed_streams} 个流 | " - f"活跃流总数: {active_count}" - ) + logger.debug(f"本次循环处理了 {processed_streams} 个流 | 活跃流总数: {active_count}") async def _check_all_streams_with_priority(self): """按优先级检查所有聊天流,高focus_energy的流优先处理""" @@ -485,6 +510,9 @@ class MessageManager: continue # 获取focus_energy,如果不存在则使用默认值 + from src.plugin_system.apis.chat_api import get_chat_manager + + chat_stream = get_chat_manager().get_stream(context.stream_id) focus_energy = 0.5 if hasattr(context, 'chat_stream') and context.chat_stream: focus_energy = context.chat_stream.focus_energy @@ -526,6 +554,9 @@ class MessageManager: def _calculate_stream_priority(self, context: StreamContext, focus_energy: float) -> float: """计算聊天流的优先级分数""" + from src.plugin_system.apis.chat_api import get_chat_manager + + chat_stream = get_chat_manager().get_stream(context.stream_id) # 基础优先级:focus_energy base_priority = focus_energy @@ -544,8 +575,8 @@ class MessageManager: consecutive_no_reply = 0 all_messages = context.get_history_messages(limit=50) + context.get_unread_messages() for msg in reversed(all_messages): - if hasattr(msg, 'should_reply') and msg.should_reply: - if not (hasattr(msg, 'actions') and 'reply' in (msg.actions or [])): + if hasattr(msg, "should_reply") and msg.should_reply: + if not (hasattr(msg, "actions") and "reply" in (msg.actions or [])): consecutive_no_reply += 1 else: break @@ -555,10 +586,10 @@ class MessageManager: # 综合优先级计算 final_priority = ( - base_priority * 0.6 + # 基础兴趣度权重60% - message_count_bonus * 0.2 + # 消息数量权重20% - time_penalty * 0.1 + # 时间权重10% - no_reply_penalty * 0.1 # 回复状态权重10% + base_priority * 0.6 # 基础兴趣度权重60% + + message_count_bonus * 0.2 # 消息数量权重20% + + time_penalty * 0.1 # 时间权重10% + + no_reply_penalty * 0.1 # 回复状态权重10% ) return max(0.0, min(1.0, final_priority)) diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index ac23f0d8c..8ef9772c2 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -48,10 +48,9 @@ class ChatStream: # 使用StreamContext替代ChatMessageContext from src.common.data_models.message_manager_data_model import StreamContext from src.plugin_system.base.component_types import ChatType, ChatMode + self.stream_context: StreamContext = StreamContext( - stream_id=stream_id, - chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, - chat_mode=ChatMode.NORMAL + stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL ) # 基础参数 @@ -59,6 +58,8 @@ class ChatStream: self._focus_energy = 0.5 # 内部存储的focus_energy值 self.no_reply_consecutive = 0 + # 自动加载历史消息 + self._load_history_messages() def to_dict(self) -> dict: """转换为字典格式""" @@ -76,6 +77,8 @@ class ChatStream: # 新增stream_context信息 "stream_context_chat_type": self.stream_context.chat_type.value, "stream_context_chat_mode": self.stream_context.chat_mode.value, + # 新增interruption_count信息 + "interruption_count": self.stream_context.interruption_count, } @classmethod @@ -95,11 +98,17 @@ class ChatStream: # 恢复stream_context信息 if "stream_context_chat_type" in data: from src.plugin_system.base.component_types import ChatType, ChatMode + instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"]) if "stream_context_chat_mode" in data: from src.plugin_system.base.component_types import ChatType, ChatMode + instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"]) + # 恢复interruption_count信息 + if "interruption_count" in data: + instance.stream_context.interruption_count = data["interruption_count"] + return instance def update_active_time(self): @@ -114,19 +123,28 @@ class ChatStream: # 简化转换,实际可能需要更完整的转换逻辑 db_message = DatabaseMessages( - message_id=getattr(message, 'message_id', ''), - time=getattr(message, 'time', time.time()), - chat_id=getattr(message, 'chat_id', ''), - user_id=str(getattr(message.message_info, 'user_info', {}).user_id) if hasattr(message, 'message_info') and hasattr(message.message_info, 'user_info') else '', - user_nickname=getattr(message.message_info, 'user_info', {}).user_nickname if hasattr(message, 'message_info') and hasattr(message.message_info, 'user_info') else '', - user_platform=getattr(message.message_info, 'user_info', {}).platform if hasattr(message, 'message_info') and hasattr(message.message_info, 'user_info') else '', - priority_mode=getattr(message, 'priority_mode', None), - priority_info=str(getattr(message, 'priority_info', None)) if hasattr(message, 'priority_info') and message.priority_info else None, + message_id=getattr(message, "message_id", ""), + time=getattr(message, "time", time.time()), + chat_id=getattr(message, "chat_id", ""), + user_id=str(getattr(message.message_info, "user_info", {}).user_id) + if hasattr(message, "message_info") and hasattr(message.message_info, "user_info") + else "", + user_nickname=getattr(message.message_info, "user_info", {}).user_nickname + if hasattr(message, "message_info") and hasattr(message.message_info, "user_info") + else "", + user_platform=getattr(message.message_info, "user_info", {}).platform + if hasattr(message, "message_info") and hasattr(message.message_info, "user_info") + else "", + priority_mode=getattr(message, "priority_mode", None), + priority_info=str(getattr(message, "priority_info", None)) + if hasattr(message, "priority_info") and message.priority_info + else None, + additional_config=getattr(getattr(message, "message_info", {}), "additional_config", None), ) self.stream_context.set_current_message(db_message) - self.stream_context.priority_mode = getattr(message, 'priority_mode', None) - self.stream_context.priority_info = getattr(message, 'priority_info', None) + self.stream_context.priority_mode = getattr(message, "priority_mode", None) + self.stream_context.priority_info = getattr(message, "priority_info", None) @property def focus_energy(self) -> float: @@ -150,16 +168,20 @@ class ChatStream: # 计算基于历史消息的统计数据 if all_messages: # 基础分:平均消息兴趣度 - message_interests = [msg.interest_degree for msg in all_messages if hasattr(msg, 'interest_degree')] + message_interests = [msg.interest_degree for msg in all_messages if hasattr(msg, "interest_degree")] avg_message_interest = sum(message_interests) / len(message_interests) if message_interests else 0.3 # 动作参与度:有动作的消息比例 - messages_with_actions = [msg for msg in all_messages if hasattr(msg, 'actions') and msg.actions] + messages_with_actions = [msg for msg in all_messages if hasattr(msg, "actions") and msg.actions] action_rate = len(messages_with_actions) / len(all_messages) # 回复活跃度:应该回复且已回复的消息比例 - should_reply_messages = [msg for msg in all_messages if hasattr(msg, 'should_reply') and msg.should_reply] - replied_messages = [msg for msg in should_reply_messages if hasattr(msg, 'actions') and 'reply' in (msg.actions or [])] + should_reply_messages = [ + msg for msg in all_messages if hasattr(msg, "should_reply") and msg.should_reply + ] + replied_messages = [ + msg for msg in should_reply_messages if hasattr(msg, "actions") and "reply" in (msg.actions or []) + ] reply_rate = len(replied_messages) / len(should_reply_messages) if should_reply_messages else 0.0 # 获取最后交互时间 @@ -169,8 +191,8 @@ class ChatStream: # 连续无回复计算:从最近的未回复消息计数 consecutive_no_reply = 0 for msg in reversed(all_messages): - if hasattr(msg, 'should_reply') and msg.should_reply: - if not (hasattr(msg, 'actions') and 'reply' in (msg.actions or [])): + if hasattr(msg, "should_reply") and msg.should_reply: + if not (hasattr(msg, "actions") and "reply" in (msg.actions or [])): consecutive_no_reply += 1 else: break @@ -187,7 +209,7 @@ class ChatStream: # 时间衰减因子:最近活跃度 current_time = time.time() - if not hasattr(self, 'last_interaction_time') or not self.last_interaction_time: + if not hasattr(self, "last_interaction_time") or not self.last_interaction_time: self.last_interaction_time = current_time time_since_interaction = current_time - self.last_interaction_time time_decay = max(0.3, 1.0 - min(time_since_interaction / (7 * 24 * 3600), 0.7)) # 7天衰减 @@ -196,20 +218,24 @@ class ChatStream: no_reply_penalty = max(0.1, 1.0 - consecutive_no_reply * 0.1) # 获取AFC系统阈值,添加None值检查 - reply_threshold = getattr(global_config.affinity_flow, 'reply_action_interest_threshold', 0.4) - non_reply_threshold = getattr(global_config.affinity_flow, 'non_reply_action_interest_threshold', 0.2) - high_match_threshold = getattr(global_config.affinity_flow, 'high_match_interest_threshold', 0.8) + reply_threshold = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4) + non_reply_threshold = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2) + high_match_threshold = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8) # 计算与不同阈值的差距比例 reply_gap_ratio = max(0, (avg_message_interest - reply_threshold) / max(0.1, (1.0 - reply_threshold))) - non_reply_gap_ratio = max(0, (avg_message_interest - non_reply_threshold) / max(0.1, (1.0 - non_reply_threshold))) - high_match_gap_ratio = max(0, (avg_message_interest - high_match_threshold) / max(0.1, (1.0 - high_match_threshold))) + non_reply_gap_ratio = max( + 0, (avg_message_interest - non_reply_threshold) / max(0.1, (1.0 - non_reply_threshold)) + ) + high_match_gap_ratio = max( + 0, (avg_message_interest - high_match_threshold) / max(0.1, (1.0 - high_match_threshold)) + ) # 基于阈值差距比例的基础分计算 threshold_based_score = ( - reply_gap_ratio * 0.6 + # 回复阈值差距权重60% - non_reply_gap_ratio * 0.2 + # 非回复阈值差距权重20% - high_match_gap_ratio * 0.2 # 高匹配阈值差距权重20% + reply_gap_ratio * 0.6 # 回复阈值差距权重60% + + non_reply_gap_ratio * 0.2 # 非回复阈值差距权重20% + + high_match_gap_ratio * 0.2 # 高匹配阈值差距权重20% ) # 动态权重调整:根据平均兴趣度水平调整权重分配 @@ -230,15 +256,19 @@ class ChatStream: relationship_weight = 0.2 # 计算活跃度得分 - activity_score = (action_rate * 0.6 + reply_rate * 0.4) + activity_score = action_rate * 0.6 + reply_rate * 0.4 # 综合计算:基于阈值的动态加权 focus_energy = ( - threshold_based_score * threshold_weight + # 阈值差距基础分 - activity_score * activity_weight + # 活跃度得分 - relationship_factor * relationship_weight + # 关系得分 - self.base_interest_energy * 0.05 # 基础兴趣微调 - ) * time_decay * no_reply_penalty + ( + threshold_based_score * threshold_weight # 阈值差距基础分 + + activity_score * activity_weight # 活跃度得分 + + relationship_factor * relationship_weight # 关系得分 + + self.base_interest_energy * 0.05 # 基础兴趣微调 + ) + * time_decay + * no_reply_penalty + ) # 确保在合理范围内 focus_energy = max(0.1, min(1.0, focus_energy)) @@ -268,7 +298,7 @@ class ChatStream: chatter_interest_scoring_system, ) - if self.user_info and hasattr(self.user_info, 'user_id'): + if self.user_info and hasattr(self.user_info, "user_id"): return chatter_interest_scoring_system.get_user_relationship(str(self.user_info.user_id)) except Exception: pass @@ -276,8 +306,102 @@ class ChatStream: # 默认基础分 return 0.3 + def _load_history_messages(self): + """从数据库加载历史消息到StreamContext""" + try: + from src.common.database.sqlalchemy_models import Messages + from src.common.database.sqlalchemy_database_api import get_db_session + from src.common.data_models.database_data_model import DatabaseMessages + from sqlalchemy import select, desc + import asyncio + async def _load_messages(): + def _db_query(): + with get_db_session() as session: + # 查询该stream_id的最近20条消息 + stmt = ( + select(Messages) + .where(Messages.chat_info_stream_id == self.stream_id) + .order_by(desc(Messages.time)) + .limit(global_config.chat.max_context_size) + ) + results = session.execute(stmt).scalars().all() + return results + # 在线程中执行数据库查询 + db_messages = await asyncio.to_thread(_db_query) + + # 转换为DatabaseMessages对象并添加到StreamContext + for db_msg in db_messages: + try: + # 从SQLAlchemy模型转换为DatabaseMessages数据模型 + import orjson + + # 解析actions字段(JSON格式) + actions = None + if db_msg.actions: + try: + actions = orjson.loads(db_msg.actions) + except (orjson.JSONDecodeError, TypeError): + actions = None + + db_message = DatabaseMessages( + message_id=db_msg.message_id, + time=db_msg.time, + chat_id=db_msg.chat_id, + reply_to=db_msg.reply_to, + interest_value=db_msg.interest_value, + key_words=db_msg.key_words, + key_words_lite=db_msg.key_words_lite, + is_mentioned=db_msg.is_mentioned, + processed_plain_text=db_msg.processed_plain_text, + display_message=db_msg.display_message, + priority_mode=db_msg.priority_mode, + priority_info=db_msg.priority_info, + additional_config=db_msg.additional_config, + is_emoji=db_msg.is_emoji, + is_picid=db_msg.is_picid, + is_command=db_msg.is_command, + is_notify=db_msg.is_notify, + user_id=db_msg.user_id, + user_nickname=db_msg.user_nickname, + user_cardname=db_msg.user_cardname, + user_platform=db_msg.user_platform, + chat_info_group_id=db_msg.chat_info_group_id, + chat_info_group_name=db_msg.chat_info_group_name, + chat_info_group_platform=db_msg.chat_info_group_platform, + chat_info_user_id=db_msg.chat_info_user_id, + chat_info_user_nickname=db_msg.chat_info_user_nickname, + chat_info_user_cardname=db_msg.chat_info_user_cardname, + chat_info_user_platform=db_msg.chat_info_user_platform, + chat_info_stream_id=db_msg.chat_info_stream_id, + chat_info_platform=db_msg.chat_info_platform, + chat_info_create_time=db_msg.chat_info_create_time, + chat_info_last_active_time=db_msg.chat_info_last_active_time, + # 新增的兴趣度系统字段 + interest_degree=getattr(db_msg, "interest_degree", 0.0) or 0.0, + actions=actions, + should_reply=getattr(db_msg, "should_reply", False) or False, + ) + + # 标记为已读并添加到历史消息 + db_message.is_read = True + self.stream_context.history_messages.append(db_message) + + except Exception as e: + logger.warning(f"转换消息 {db_msg.message_id} 失败: {e}") + continue + + if self.stream_context.history_messages: + logger.info( + f"已从数据库加载 {len(self.stream_context.history_messages)} 条历史消息到聊天流 {self.stream_id}" + ) + + # 创建任务来加载历史消息 + asyncio.create_task(_load_messages()) + + except Exception as e: + logger.error(f"加载历史消息失败: {e}") class ChatManager: @@ -524,6 +648,7 @@ class ChatManager: "reply_count": s_data_dict.get("reply_count", 0), "last_interaction_time": s_data_dict.get("last_interaction_time", time.time()), "consecutive_no_reply": s_data_dict.get("consecutive_no_reply", 0), + "interruption_count": s_data_dict.get("interruption_count", 0), } if global_config.database.database_type == "sqlite": stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) @@ -590,6 +715,7 @@ class ChatManager: "last_interaction_time": getattr(model_instance, "last_interaction_time", time.time()), "relationship_score": getattr(model_instance, "relationship_score", 0.3), "consecutive_no_reply": getattr(model_instance, "consecutive_no_reply", 0), + "interruption_count": getattr(model_instance, "interruption_count", 0), } loaded_streams_data.append(data_for_from_dict) await session.commit() diff --git a/src/common/data_models/message_manager_data_model.py b/src/common/data_models/message_manager_data_model.py index 9bc4002b5..268328c77 100644 --- a/src/common/data_models/message_manager_data_model.py +++ b/src/common/data_models/message_manager_data_model.py @@ -60,7 +60,9 @@ class StreamContext(BaseDataModel): # 自动检测和更新chat type self._detect_chat_type(message) - def update_message_info(self, message_id: str, interest_degree: float = None, actions: list = None, should_reply: bool = None): + def update_message_info( + self, message_id: str, interest_degree: float = None, actions: list = None, should_reply: bool = None + ): """ 更新消息信息 @@ -166,11 +168,15 @@ class StreamContext(BaseDataModel): # 计算打断比例 interruption_ratio = self.interruption_count / max_limit + # 如果已达到或超过最大次数,完全禁止打断 + if self.interruption_count >= max_limit: + return 0.0 + # 如果超过概率因子,概率下降 if interruption_ratio > probability_factor: # 使用指数衰减,超过限制越多,概率越低 excess_ratio = interruption_ratio - probability_factor - probability = 1.0 * (0.5**excess_ratio) # 基础概率0.5,指数衰减 + probability = 0.8 * (0.5**excess_ratio) # 基础概率0.8,指数衰减 else: # 在限制内,保持较高概率 probability = 0.8 @@ -182,12 +188,18 @@ class StreamContext(BaseDataModel): self.interruption_count += 1 self.last_interruption_time = time.time() + # 同步打断计数到ChatStream + self._sync_interruption_count_to_stream() + def reset_interruption_count(self): """重置打断计数和afc阈值调整""" self.interruption_count = 0 self.last_interruption_time = 0.0 self.afc_threshold_adjustment = 0.0 + # 同步打断计数到ChatStream + self._sync_interruption_count_to_stream() + def apply_interruption_afc_reduction(self, reduction_value: float): """应用打断导致的afc阈值降低""" self.afc_threshold_adjustment += reduction_value @@ -197,18 +209,40 @@ class StreamContext(BaseDataModel): """获取当前的afc阈值调整量""" return self.afc_threshold_adjustment + def _sync_interruption_count_to_stream(self): + """同步打断计数到ChatStream""" + try: + from src.chat.message_receive.chat_stream import get_chat_manager + + chat_manager = get_chat_manager() + if chat_manager: + chat_stream = chat_manager.get_stream(self.stream_id) + if chat_stream and hasattr(chat_stream, "interruption_count"): + # 在这里我们只是标记需要保存,实际的保存会在下次save时进行 + chat_stream.saved = False + logger.debug( + f"已同步StreamContext {self.stream_id} 的打断计数 {self.interruption_count} 到ChatStream" + ) + except Exception as e: + logger.warning(f"同步打断计数到ChatStream失败: {e}") + def set_current_message(self, message: "DatabaseMessages"): """设置当前消息""" self.current_message = message def get_template_name(self) -> Optional[str]: """获取模板名称""" - if self.current_message and hasattr(self.current_message, 'additional_config') and self.current_message.additional_config: + if ( + self.current_message + and hasattr(self.current_message, "additional_config") + and self.current_message.additional_config + ): try: import json + config = json.loads(self.current_message.additional_config) - if config.get('template_info') and not config.get('template_default', True): - return config.get('template_name') + if config.get("template_info") and not config.get("template_default", True): + return config.get("template_name") except (json.JSONDecodeError, AttributeError): pass return None @@ -224,25 +258,83 @@ class StreamContext(BaseDataModel): return None def check_types(self, types: list) -> bool: - """检查消息类型""" + """ + 检查当前消息是否支持指定的类型 + + Args: + types: 需要检查的消息类型列表,如 ["text", "image", "emoji"] + + Returns: + bool: 如果消息支持所有指定的类型则返回True,否则返回False + """ if not self.current_message: return False - # 检查消息是否支持指定的类型 - # 这里简化处理,实际应该根据消息的格式信息检查 - if hasattr(self.current_message, 'additional_config') and self.current_message.additional_config: + if not types: + # 如果没有指定类型要求,默认为支持 + return True + + # 优先从additional_config中获取format_info + if hasattr(self.current_message, "additional_config") and self.current_message.additional_config: try: - import json - config = json.loads(self.current_message.additional_config) - if 'format_info' in config and 'accept_format' in config['format_info']: - accept_format = config['format_info']['accept_format'] - for t in types: - if t not in accept_format: - return False - return True - except (json.JSONDecodeError, AttributeError): - pass - return False + import orjson + + config = orjson.loads(self.current_message.additional_config) + + # 检查format_info结构 + if "format_info" in config: + format_info = config["format_info"] + + # 方法1: 直接检查accept_format字段 + if "accept_format" in format_info: + accept_format = format_info["accept_format"] + # 确保accept_format是列表类型 + if isinstance(accept_format, str): + accept_format = [accept_format] + elif isinstance(accept_format, list): + pass + else: + # 如果accept_format不是字符串或列表,尝试转换为列表 + accept_format = list(accept_format) if hasattr(accept_format, "__iter__") else [] + + # 检查所有请求的类型是否都被支持 + for requested_type in types: + if requested_type not in accept_format: + logger.debug(f"消息不支持类型 '{requested_type}',支持的类型: {accept_format}") + return False + return True + + # 方法2: 检查content_format字段(向后兼容) + elif "content_format" in format_info: + content_format = format_info["content_format"] + # 确保content_format是列表类型 + if isinstance(content_format, str): + content_format = [content_format] + elif isinstance(content_format, list): + pass + else: + content_format = list(content_format) if hasattr(content_format, "__iter__") else [] + + # 检查所有请求的类型是否都被支持 + for requested_type in types: + if requested_type not in content_format: + logger.debug(f"消息不支持类型 '{requested_type}',支持的内容格式: {content_format}") + return False + return True + + except (orjson.JSONDecodeError, AttributeError, TypeError) as e: + logger.debug(f"解析消息格式信息失败: {e}") + + # 备用方案:如果无法从additional_config获取格式信息,使用默认支持的类型 + # 大多数消息至少支持text类型 + default_supported_types = ["text", "emoji"] + for requested_type in types: + if requested_type not in default_supported_types: + logger.debug(f"使用默认类型检查,消息可能不支持类型 '{requested_type}'") + # 对于非基础类型,返回False以避免错误 + if requested_type not in ["text", "emoji", "reply"]: + return False + return True def get_priority_mode(self) -> Optional[str]: """获取优先级模式""" diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 53505a4ee..726da7b5e 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -64,6 +64,8 @@ class ChatStreams(Base): reply_count = Column(Integer, nullable=True, default=0) last_interaction_time = Column(Float, nullable=True, default=None) consecutive_no_reply = Column(Integer, nullable=True, default=0) + # 消息打断系统字段 + interruption_count = Column(Integer, nullable=True, default=0) __table_args__ = ( Index("idx_chatstreams_stream_id", "stream_id"), @@ -173,11 +175,18 @@ class Messages(Base): is_command = Column(Boolean, nullable=False, default=False) is_notify = Column(Boolean, nullable=False, default=False) + # 兴趣度系统字段 + interest_degree = Column(Float, nullable=True, default=0.0) + actions = Column(Text, nullable=True) # JSON格式存储动作列表 + should_reply = Column(Boolean, nullable=True, default=False) + __table_args__ = ( Index("idx_messages_message_id", "message_id"), Index("idx_messages_chat_id", "chat_id"), Index("idx_messages_time", "time"), Index("idx_messages_user_id", "user_id"), + Index("idx_messages_interest_degree", "interest_degree"), + Index("idx_messages_should_reply", "should_reply"), )