refactor(chat): 重构消息管理器以使用集中式上下文管理和能量系统

- 将流上下文管理从MessageManager迁移到专门的ContextManager
- 使用统一的能量系统计算focus_energy和分发间隔
- 重构ChatStream的消息数据转换逻辑,支持更完整的数据字段
- 更新数据库模型,移除interest_degree字段,统一使用interest_value
- 集成新的兴趣度管理系统替代原有的评分系统
- 添加消息存储的interest_value修复功能
This commit is contained in:
Windpicker-owo
2025-09-27 14:23:48 +08:00
parent 0478be7d2a
commit c49b3f3ac4
15 changed files with 3518 additions and 495 deletions

View File

@@ -120,186 +120,209 @@ class ChatStream:
"""设置聊天消息上下文"""
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
from src.common.data_models.database_data_model import DatabaseMessages
import json
# 简化转换,实际可能需要更完整的转换逻辑
# 安全获取message_info中的数据
message_info = getattr(message, "message_info", {})
user_info = getattr(message_info, "user_info", {})
group_info = getattr(message_info, "group_info", {})
# 提取reply_to信息从message_segment中查找reply类型的段
reply_to = None
if hasattr(message, "message_segment") and message.message_segment:
reply_to = self._extract_reply_from_segment(message.message_segment)
# 完整的数据转移逻辑
db_message = DatabaseMessages(
# 基础消息信息
message_id=getattr(message, "message_id", ""),
time=getattr(message, "time", time.time()),
chat_id=getattr(message, "chat_id", ""),
user_id=str(getattr(message.message_info, "user_info", {}).user_id)
if hasattr(message, "message_info") and hasattr(message.message_info, "user_info")
else "",
user_nickname=getattr(message.message_info, "user_info", {}).user_nickname
if hasattr(message, "message_info") and hasattr(message.message_info, "user_info")
else "",
user_platform=getattr(message.message_info, "user_info", {}).platform
if hasattr(message, "message_info") and hasattr(message.message_info, "user_info")
else "",
priority_mode=getattr(message, "priority_mode", None),
priority_info=str(getattr(message, "priority_info", None))
if hasattr(message, "priority_info") and message.priority_info
chat_id=self._generate_chat_id(message_info),
reply_to=reply_to,
# 兴趣度相关
interest_value=getattr(message, "interest_value", 0.0),
# 关键词
key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False)
if getattr(message, "key_words", None)
else None,
additional_config=getattr(getattr(message, "message_info", {}), "additional_config", None),
key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False)
if getattr(message, "key_words_lite", None)
else None,
# 消息状态标记
is_mentioned=getattr(message, "is_mentioned", None),
is_at=getattr(message, "is_at", False),
is_emoji=getattr(message, "is_emoji", False),
is_picid=getattr(message, "is_picid", False),
is_voice=getattr(message, "is_voice", False),
is_video=getattr(message, "is_video", False),
is_command=getattr(message, "is_command", False),
is_notify=getattr(message, "is_notify", False),
# 消息内容
processed_plain_text=getattr(message, "processed_plain_text", ""),
display_message=getattr(message, "processed_plain_text", ""), # 默认使用processed_plain_text
# 优先级信息
priority_mode=getattr(message, "priority_mode", None),
priority_info=json.dumps(getattr(message, "priority_info", None))
if getattr(message, "priority_info", None)
else None,
# 额外配置
additional_config=getattr(message_info, "additional_config", None),
# 用户信息
user_id=str(getattr(user_info, "user_id", "")),
user_nickname=getattr(user_info, "user_nickname", ""),
user_cardname=getattr(user_info, "user_cardname", None),
user_platform=getattr(user_info, "platform", ""),
# 群组信息
chat_info_group_id=getattr(group_info, "group_id", None),
chat_info_group_name=getattr(group_info, "group_name", None),
chat_info_group_platform=getattr(group_info, "platform", None),
# 聊天流信息
chat_info_user_id=str(getattr(user_info, "user_id", "")),
chat_info_user_nickname=getattr(user_info, "user_nickname", ""),
chat_info_user_cardname=getattr(user_info, "user_cardname", None),
chat_info_user_platform=getattr(user_info, "platform", ""),
chat_info_stream_id=self.stream_id,
chat_info_platform=self.platform,
chat_info_create_time=self.create_time,
chat_info_last_active_time=self.last_active_time,
# 新增兴趣度系统字段 - 添加安全处理
actions=self._safe_get_actions(message),
should_reply=getattr(message, "should_reply", False),
)
self.stream_context.set_current_message(db_message)
self.stream_context.priority_mode = getattr(message, "priority_mode", None)
self.stream_context.priority_info = getattr(message, "priority_info", None)
# 调试日志:记录数据转移情况
logger.debug(f"消息数据转移完成 - message_id: {db_message.message_id}, "
f"chat_id: {db_message.chat_id}, "
f"is_mentioned: {db_message.is_mentioned}, "
f"is_emoji: {db_message.is_emoji}, "
f"is_picid: {db_message.is_picid}, "
f"interest_value: {db_message.interest_value}")
def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]:
"""安全获取消息的actions字段"""
try:
actions = getattr(message, "actions", None)
if actions is None:
return None
# 如果是字符串尝试解析为JSON
if isinstance(actions, str):
try:
import json
actions = json.loads(actions)
except json.JSONDecodeError:
logger.warning(f"无法解析actions JSON字符串: {actions}")
return None
# 确保返回列表类型
if isinstance(actions, list):
# 过滤掉空值和非字符串元素
filtered_actions = [action for action in actions if action is not None and isinstance(action, str)]
return filtered_actions if filtered_actions else None
else:
logger.warning(f"actions字段类型不支持: {type(actions)}")
return None
except Exception as e:
logger.warning(f"获取actions字段失败: {e}")
return None
def _extract_reply_from_segment(self, segment) -> Optional[str]:
"""从消息段中提取reply_to信息"""
try:
if hasattr(segment, "type") and segment.type == "seglist":
# 递归搜索seglist中的reply段
if hasattr(segment, "data") and segment.data:
for seg in segment.data:
reply_id = self._extract_reply_from_segment(seg)
if reply_id:
return reply_id
elif hasattr(segment, "type") and segment.type == "reply":
# 找到reply段返回message_id
return str(segment.data) if segment.data else None
except Exception as e:
logger.warning(f"提取reply_to信息失败: {e}")
return None
def _generate_chat_id(self, message_info) -> str:
"""生成chat_id基于群组或用户信息"""
try:
group_info = getattr(message_info, "group_info", None)
user_info = getattr(message_info, "user_info", None)
if group_info and hasattr(group_info, "group_id") and group_info.group_id:
# 群聊使用群组ID
return f"{self.platform}_{group_info.group_id}"
elif user_info and hasattr(user_info, "user_id") and user_info.user_id:
# 私聊使用用户ID
return f"{self.platform}_{user_info.user_id}_private"
else:
# 默认使用stream_id
return self.stream_id
except Exception as e:
logger.warning(f"生成chat_id失败: {e}")
return self.stream_id
@property
def focus_energy(self) -> float:
"""动态计算的聊天流总体兴趣度,访问时自动更新"""
self._focus_energy = self._calculate_dynamic_focus_energy()
return self._focus_energy
"""使用重构后的能量管理器计算focus_energy"""
try:
from src.chat.energy_system import energy_manager
# 获取所有消息
history_messages = self.stream_context.get_history_messages(limit=global_config.chat.max_context_size)
unread_messages = self.stream_context.get_unread_messages()
all_messages = history_messages + unread_messages
# 获取用户ID
user_id = None
if self.user_info and hasattr(self.user_info, "user_id"):
user_id = str(self.user_info.user_id)
# 使用能量管理器计算
energy = energy_manager.calculate_focus_energy(
stream_id=self.stream_id,
messages=all_messages,
user_id=user_id
)
# 更新内部存储
self._focus_energy = energy
logger.debug(f"聊天流 {self.stream_id} 能量: {energy:.3f}")
return energy
except Exception as e:
logger.error(f"获取focus_energy失败: {e}", exc_info=True)
# 返回缓存的值或默认值
if hasattr(self, '_focus_energy'):
return self._focus_energy
else:
return 0.5
@focus_energy.setter
def focus_energy(self, value: float):
"""设置focus_energy值主要用于初始化或特殊场景"""
self._focus_energy = max(0.0, min(1.0, value))
def _calculate_dynamic_focus_energy(self) -> float:
"""动态计算聊天流的总体兴趣度使用StreamContext历史消息"""
try:
# 从StreamContext获取历史消息计算统计数据
history_messages = self.stream_context.get_history_messages(limit=global_config.chat.max_context_size)
unread_messages = self.stream_context.get_unread_messages()
all_messages = history_messages + unread_messages
# 计算基于历史消息的统计数据
if all_messages:
# 基础分:平均消息兴趣度
message_interests = [msg.interest_degree for msg in all_messages if hasattr(msg, "interest_degree")]
avg_message_interest = sum(message_interests) / len(message_interests) if message_interests else 0.3
# 动作参与度:有动作的消息比例
messages_with_actions = [msg for msg in all_messages if hasattr(msg, "actions") and msg.actions]
action_rate = len(messages_with_actions) / len(all_messages)
# 回复活跃度:应该回复且已回复的消息比例
should_reply_messages = [
msg for msg in all_messages if hasattr(msg, "should_reply") and msg.should_reply
]
replied_messages = [
msg for msg in should_reply_messages if hasattr(msg, "actions") and "reply" in (msg.actions or [])
]
reply_rate = len(replied_messages) / len(should_reply_messages) if should_reply_messages else 0.0
# 获取最后交互时间
if all_messages:
self.last_interaction_time = max(msg.time for msg in all_messages)
# 连续无回复计算:从最近的未回复消息计数
consecutive_no_reply = 0
for msg in reversed(all_messages):
if hasattr(msg, "should_reply") and msg.should_reply:
if not (hasattr(msg, "actions") and "reply" in (msg.actions or [])):
consecutive_no_reply += 1
else:
break
else:
# 没有历史消息时的默认值
avg_message_interest = 0.3
action_rate = 0.0
reply_rate = 0.0
consecutive_no_reply = 0
self.last_interaction_time = time.time()
# 获取用户关系分(对于私聊,群聊无效)
relationship_factor = self._get_user_relationship_score()
# 时间衰减因子:最近活跃度
current_time = time.time()
if not hasattr(self, "last_interaction_time") or not self.last_interaction_time:
self.last_interaction_time = current_time
time_since_interaction = current_time - self.last_interaction_time
time_decay = max(0.3, 1.0 - min(time_since_interaction / (7 * 24 * 3600), 0.7)) # 7天衰减
# 连续无回复惩罚
no_reply_penalty = max(0.1, 1.0 - consecutive_no_reply * 0.1)
# 获取AFC系统阈值添加None值检查
reply_threshold = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
non_reply_threshold = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2)
high_match_threshold = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8)
# 计算与不同阈值的差距比例
reply_gap_ratio = max(0, (avg_message_interest - reply_threshold) / max(0.1, (1.0 - reply_threshold)))
non_reply_gap_ratio = max(
0, (avg_message_interest - non_reply_threshold) / max(0.1, (1.0 - non_reply_threshold))
)
high_match_gap_ratio = max(
0, (avg_message_interest - high_match_threshold) / max(0.1, (1.0 - high_match_threshold))
)
# 基于阈值差距比例的基础分计算
threshold_based_score = (
reply_gap_ratio * 0.6 # 回复阈值差距权重60%
+ non_reply_gap_ratio * 0.2 # 非回复阈值差距权重20%
+ high_match_gap_ratio * 0.2 # 高匹配阈值差距权重20%
)
# 动态权重调整:根据平均兴趣度水平调整权重分配
if avg_message_interest >= high_match_threshold:
# 高兴趣度:更注重阈值差距
threshold_weight = 0.7
activity_weight = 0.2
relationship_weight = 0.1
elif avg_message_interest >= reply_threshold:
# 中等兴趣度:平衡权重
threshold_weight = 0.5
activity_weight = 0.3
relationship_weight = 0.2
else:
# 低兴趣度:更注重活跃度提升
threshold_weight = 0.3
activity_weight = 0.5
relationship_weight = 0.2
# 计算活跃度得分
activity_score = action_rate * 0.6 + reply_rate * 0.4
# 综合计算:基于阈值的动态加权
focus_energy = (
(
threshold_based_score * threshold_weight # 阈值差距基础分
+ activity_score * activity_weight # 活跃度得分
+ relationship_factor * relationship_weight # 关系得分
+ self.base_interest_energy * 0.05 # 基础兴趣微调
)
* time_decay
* no_reply_penalty
)
# 确保在合理范围内
focus_energy = max(0.1, min(1.0, focus_energy))
# 应用非线性变换增强区分度
if focus_energy >= 0.7:
# 高兴趣度区域:指数增强,更敏感
focus_energy = 0.7 + (focus_energy - 0.7) ** 0.8
elif focus_energy >= 0.4:
# 中等兴趣度区域:线性保持
pass
else:
# 低兴趣度区域:对数压缩,减少区分度
focus_energy = 0.4 * (focus_energy / 0.4) ** 1.2
return max(0.1, min(1.0, focus_energy))
except Exception as e:
logger.error(f"计算动态focus_energy失败: {e}")
return self.base_interest_energy
def _get_user_relationship_score(self) -> float:
"""外部系统获取用户关系分"""
"""新的兴趣度管理系统获取用户关系分"""
try:
# 尝试从兴趣评分系统获取用户关系分
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
chatter_interest_scoring_system,
)
# 使用新的兴趣度管理系统
from src.chat.interest_system import interest_manager
if self.user_info and hasattr(self.user_info, "user_id"):
return chatter_interest_scoring_system.get_user_relationship(str(self.user_info.user_id))
user_id = str(self.user_info.user_id)
# 获取用户交互历史作为关系分的基础
interaction_calc = interest_manager.calculators.get(
interest_manager.InterestSourceType.USER_INTERACTION
)
if interaction_calc:
return interaction_calc.calculate({"user_id": user_id})
except Exception:
pass
@@ -378,12 +401,13 @@ class ChatStream:
chat_info_platform=db_msg.chat_info_platform,
chat_info_create_time=db_msg.chat_info_create_time,
chat_info_last_active_time=db_msg.chat_info_last_active_time,
# 新增的兴趣度系统字段
interest_degree=getattr(db_msg, "interest_degree", 0.0) or 0.0,
actions=actions,
should_reply=getattr(db_msg, "should_reply", False) or False,
)
# 添加调试日志检查从数据库加载的interest_value
logger.debug(f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}")
# 标记为已读并添加到历史消息
db_message.is_read = True
self.stream_context.history_messages.append(db_message)