refactor(chat): 重构消息管理器以使用集中式上下文管理和能量系统
- 将流上下文管理从MessageManager迁移到专门的ContextManager - 使用统一的能量系统计算focus_energy和分发间隔 - 重构ChatStream的消息数据转换逻辑,支持更完整的数据字段 - 更新数据库模型,移除interest_degree字段,统一使用interest_value - 集成新的兴趣度管理系统替代原有的评分系统 - 添加消息存储的interest_value修复功能
This commit is contained in:
@@ -120,186 +120,209 @@ class ChatStream:
|
||||
"""设置聊天消息上下文"""
|
||||
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
import json
|
||||
|
||||
# 简化转换,实际可能需要更完整的转换逻辑
|
||||
# 安全获取message_info中的数据
|
||||
message_info = getattr(message, "message_info", {})
|
||||
user_info = getattr(message_info, "user_info", {})
|
||||
group_info = getattr(message_info, "group_info", {})
|
||||
|
||||
# 提取reply_to信息(从message_segment中查找reply类型的段)
|
||||
reply_to = None
|
||||
if hasattr(message, "message_segment") and message.message_segment:
|
||||
reply_to = self._extract_reply_from_segment(message.message_segment)
|
||||
|
||||
# 完整的数据转移逻辑
|
||||
db_message = DatabaseMessages(
|
||||
# 基础消息信息
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
time=getattr(message, "time", time.time()),
|
||||
chat_id=getattr(message, "chat_id", ""),
|
||||
user_id=str(getattr(message.message_info, "user_info", {}).user_id)
|
||||
if hasattr(message, "message_info") and hasattr(message.message_info, "user_info")
|
||||
else "",
|
||||
user_nickname=getattr(message.message_info, "user_info", {}).user_nickname
|
||||
if hasattr(message, "message_info") and hasattr(message.message_info, "user_info")
|
||||
else "",
|
||||
user_platform=getattr(message.message_info, "user_info", {}).platform
|
||||
if hasattr(message, "message_info") and hasattr(message.message_info, "user_info")
|
||||
else "",
|
||||
priority_mode=getattr(message, "priority_mode", None),
|
||||
priority_info=str(getattr(message, "priority_info", None))
|
||||
if hasattr(message, "priority_info") and message.priority_info
|
||||
chat_id=self._generate_chat_id(message_info),
|
||||
reply_to=reply_to,
|
||||
# 兴趣度相关
|
||||
interest_value=getattr(message, "interest_value", 0.0),
|
||||
# 关键词
|
||||
key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False)
|
||||
if getattr(message, "key_words", None)
|
||||
else None,
|
||||
additional_config=getattr(getattr(message, "message_info", {}), "additional_config", None),
|
||||
key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False)
|
||||
if getattr(message, "key_words_lite", None)
|
||||
else None,
|
||||
# 消息状态标记
|
||||
is_mentioned=getattr(message, "is_mentioned", None),
|
||||
is_at=getattr(message, "is_at", False),
|
||||
is_emoji=getattr(message, "is_emoji", False),
|
||||
is_picid=getattr(message, "is_picid", False),
|
||||
is_voice=getattr(message, "is_voice", False),
|
||||
is_video=getattr(message, "is_video", False),
|
||||
is_command=getattr(message, "is_command", False),
|
||||
is_notify=getattr(message, "is_notify", False),
|
||||
# 消息内容
|
||||
processed_plain_text=getattr(message, "processed_plain_text", ""),
|
||||
display_message=getattr(message, "processed_plain_text", ""), # 默认使用processed_plain_text
|
||||
# 优先级信息
|
||||
priority_mode=getattr(message, "priority_mode", None),
|
||||
priority_info=json.dumps(getattr(message, "priority_info", None))
|
||||
if getattr(message, "priority_info", None)
|
||||
else None,
|
||||
# 额外配置
|
||||
additional_config=getattr(message_info, "additional_config", None),
|
||||
# 用户信息
|
||||
user_id=str(getattr(user_info, "user_id", "")),
|
||||
user_nickname=getattr(user_info, "user_nickname", ""),
|
||||
user_cardname=getattr(user_info, "user_cardname", None),
|
||||
user_platform=getattr(user_info, "platform", ""),
|
||||
# 群组信息
|
||||
chat_info_group_id=getattr(group_info, "group_id", None),
|
||||
chat_info_group_name=getattr(group_info, "group_name", None),
|
||||
chat_info_group_platform=getattr(group_info, "platform", None),
|
||||
# 聊天流信息
|
||||
chat_info_user_id=str(getattr(user_info, "user_id", "")),
|
||||
chat_info_user_nickname=getattr(user_info, "user_nickname", ""),
|
||||
chat_info_user_cardname=getattr(user_info, "user_cardname", None),
|
||||
chat_info_user_platform=getattr(user_info, "platform", ""),
|
||||
chat_info_stream_id=self.stream_id,
|
||||
chat_info_platform=self.platform,
|
||||
chat_info_create_time=self.create_time,
|
||||
chat_info_last_active_time=self.last_active_time,
|
||||
# 新增兴趣度系统字段 - 添加安全处理
|
||||
actions=self._safe_get_actions(message),
|
||||
should_reply=getattr(message, "should_reply", False),
|
||||
)
|
||||
|
||||
self.stream_context.set_current_message(db_message)
|
||||
self.stream_context.priority_mode = getattr(message, "priority_mode", None)
|
||||
self.stream_context.priority_info = getattr(message, "priority_info", None)
|
||||
|
||||
# 调试日志:记录数据转移情况
|
||||
logger.debug(f"消息数据转移完成 - message_id: {db_message.message_id}, "
|
||||
f"chat_id: {db_message.chat_id}, "
|
||||
f"is_mentioned: {db_message.is_mentioned}, "
|
||||
f"is_emoji: {db_message.is_emoji}, "
|
||||
f"is_picid: {db_message.is_picid}, "
|
||||
f"interest_value: {db_message.interest_value}")
|
||||
|
||||
def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]:
|
||||
"""安全获取消息的actions字段"""
|
||||
try:
|
||||
actions = getattr(message, "actions", None)
|
||||
if actions is None:
|
||||
return None
|
||||
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
if isinstance(actions, str):
|
||||
try:
|
||||
import json
|
||||
actions = json.loads(actions)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"无法解析actions JSON字符串: {actions}")
|
||||
return None
|
||||
|
||||
# 确保返回列表类型
|
||||
if isinstance(actions, list):
|
||||
# 过滤掉空值和非字符串元素
|
||||
filtered_actions = [action for action in actions if action is not None and isinstance(action, str)]
|
||||
return filtered_actions if filtered_actions else None
|
||||
else:
|
||||
logger.warning(f"actions字段类型不支持: {type(actions)}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取actions字段失败: {e}")
|
||||
return None
|
||||
|
||||
def _extract_reply_from_segment(self, segment) -> Optional[str]:
|
||||
"""从消息段中提取reply_to信息"""
|
||||
try:
|
||||
if hasattr(segment, "type") and segment.type == "seglist":
|
||||
# 递归搜索seglist中的reply段
|
||||
if hasattr(segment, "data") and segment.data:
|
||||
for seg in segment.data:
|
||||
reply_id = self._extract_reply_from_segment(seg)
|
||||
if reply_id:
|
||||
return reply_id
|
||||
elif hasattr(segment, "type") and segment.type == "reply":
|
||||
# 找到reply段,返回message_id
|
||||
return str(segment.data) if segment.data else None
|
||||
except Exception as e:
|
||||
logger.warning(f"提取reply_to信息失败: {e}")
|
||||
return None
|
||||
|
||||
def _generate_chat_id(self, message_info) -> str:
|
||||
"""生成chat_id,基于群组或用户信息"""
|
||||
try:
|
||||
group_info = getattr(message_info, "group_info", None)
|
||||
user_info = getattr(message_info, "user_info", None)
|
||||
|
||||
if group_info and hasattr(group_info, "group_id") and group_info.group_id:
|
||||
# 群聊:使用群组ID
|
||||
return f"{self.platform}_{group_info.group_id}"
|
||||
elif user_info and hasattr(user_info, "user_id") and user_info.user_id:
|
||||
# 私聊:使用用户ID
|
||||
return f"{self.platform}_{user_info.user_id}_private"
|
||||
else:
|
||||
# 默认:使用stream_id
|
||||
return self.stream_id
|
||||
except Exception as e:
|
||||
logger.warning(f"生成chat_id失败: {e}")
|
||||
return self.stream_id
|
||||
|
||||
@property
|
||||
def focus_energy(self) -> float:
|
||||
"""动态计算的聊天流总体兴趣度,访问时自动更新"""
|
||||
self._focus_energy = self._calculate_dynamic_focus_energy()
|
||||
return self._focus_energy
|
||||
"""使用重构后的能量管理器计算focus_energy"""
|
||||
try:
|
||||
from src.chat.energy_system import energy_manager
|
||||
|
||||
# 获取所有消息
|
||||
history_messages = self.stream_context.get_history_messages(limit=global_config.chat.max_context_size)
|
||||
unread_messages = self.stream_context.get_unread_messages()
|
||||
all_messages = history_messages + unread_messages
|
||||
|
||||
# 获取用户ID
|
||||
user_id = None
|
||||
if self.user_info and hasattr(self.user_info, "user_id"):
|
||||
user_id = str(self.user_info.user_id)
|
||||
|
||||
# 使用能量管理器计算
|
||||
energy = energy_manager.calculate_focus_energy(
|
||||
stream_id=self.stream_id,
|
||||
messages=all_messages,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# 更新内部存储
|
||||
self._focus_energy = energy
|
||||
|
||||
logger.debug(f"聊天流 {self.stream_id} 能量: {energy:.3f}")
|
||||
return energy
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取focus_energy失败: {e}", exc_info=True)
|
||||
# 返回缓存的值或默认值
|
||||
if hasattr(self, '_focus_energy'):
|
||||
return self._focus_energy
|
||||
else:
|
||||
return 0.5
|
||||
|
||||
@focus_energy.setter
|
||||
def focus_energy(self, value: float):
|
||||
"""设置focus_energy值(主要用于初始化或特殊场景)"""
|
||||
self._focus_energy = max(0.0, min(1.0, value))
|
||||
|
||||
def _calculate_dynamic_focus_energy(self) -> float:
|
||||
"""动态计算聊天流的总体兴趣度,使用StreamContext历史消息"""
|
||||
try:
|
||||
# 从StreamContext获取历史消息计算统计数据
|
||||
history_messages = self.stream_context.get_history_messages(limit=global_config.chat.max_context_size)
|
||||
unread_messages = self.stream_context.get_unread_messages()
|
||||
all_messages = history_messages + unread_messages
|
||||
|
||||
# 计算基于历史消息的统计数据
|
||||
if all_messages:
|
||||
# 基础分:平均消息兴趣度
|
||||
message_interests = [msg.interest_degree for msg in all_messages if hasattr(msg, "interest_degree")]
|
||||
avg_message_interest = sum(message_interests) / len(message_interests) if message_interests else 0.3
|
||||
|
||||
# 动作参与度:有动作的消息比例
|
||||
messages_with_actions = [msg for msg in all_messages if hasattr(msg, "actions") and msg.actions]
|
||||
action_rate = len(messages_with_actions) / len(all_messages)
|
||||
|
||||
# 回复活跃度:应该回复且已回复的消息比例
|
||||
should_reply_messages = [
|
||||
msg for msg in all_messages if hasattr(msg, "should_reply") and msg.should_reply
|
||||
]
|
||||
replied_messages = [
|
||||
msg for msg in should_reply_messages if hasattr(msg, "actions") and "reply" in (msg.actions or [])
|
||||
]
|
||||
reply_rate = len(replied_messages) / len(should_reply_messages) if should_reply_messages else 0.0
|
||||
|
||||
# 获取最后交互时间
|
||||
if all_messages:
|
||||
self.last_interaction_time = max(msg.time for msg in all_messages)
|
||||
|
||||
# 连续无回复计算:从最近的未回复消息计数
|
||||
consecutive_no_reply = 0
|
||||
for msg in reversed(all_messages):
|
||||
if hasattr(msg, "should_reply") and msg.should_reply:
|
||||
if not (hasattr(msg, "actions") and "reply" in (msg.actions or [])):
|
||||
consecutive_no_reply += 1
|
||||
else:
|
||||
break
|
||||
else:
|
||||
# 没有历史消息时的默认值
|
||||
avg_message_interest = 0.3
|
||||
action_rate = 0.0
|
||||
reply_rate = 0.0
|
||||
consecutive_no_reply = 0
|
||||
self.last_interaction_time = time.time()
|
||||
|
||||
# 获取用户关系分(对于私聊,群聊无效)
|
||||
relationship_factor = self._get_user_relationship_score()
|
||||
|
||||
# 时间衰减因子:最近活跃度
|
||||
current_time = time.time()
|
||||
if not hasattr(self, "last_interaction_time") or not self.last_interaction_time:
|
||||
self.last_interaction_time = current_time
|
||||
time_since_interaction = current_time - self.last_interaction_time
|
||||
time_decay = max(0.3, 1.0 - min(time_since_interaction / (7 * 24 * 3600), 0.7)) # 7天衰减
|
||||
|
||||
# 连续无回复惩罚
|
||||
no_reply_penalty = max(0.1, 1.0 - consecutive_no_reply * 0.1)
|
||||
|
||||
# 获取AFC系统阈值,添加None值检查
|
||||
reply_threshold = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
|
||||
non_reply_threshold = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2)
|
||||
high_match_threshold = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8)
|
||||
|
||||
# 计算与不同阈值的差距比例
|
||||
reply_gap_ratio = max(0, (avg_message_interest - reply_threshold) / max(0.1, (1.0 - reply_threshold)))
|
||||
non_reply_gap_ratio = max(
|
||||
0, (avg_message_interest - non_reply_threshold) / max(0.1, (1.0 - non_reply_threshold))
|
||||
)
|
||||
high_match_gap_ratio = max(
|
||||
0, (avg_message_interest - high_match_threshold) / max(0.1, (1.0 - high_match_threshold))
|
||||
)
|
||||
|
||||
# 基于阈值差距比例的基础分计算
|
||||
threshold_based_score = (
|
||||
reply_gap_ratio * 0.6 # 回复阈值差距权重60%
|
||||
+ non_reply_gap_ratio * 0.2 # 非回复阈值差距权重20%
|
||||
+ high_match_gap_ratio * 0.2 # 高匹配阈值差距权重20%
|
||||
)
|
||||
|
||||
# 动态权重调整:根据平均兴趣度水平调整权重分配
|
||||
if avg_message_interest >= high_match_threshold:
|
||||
# 高兴趣度:更注重阈值差距
|
||||
threshold_weight = 0.7
|
||||
activity_weight = 0.2
|
||||
relationship_weight = 0.1
|
||||
elif avg_message_interest >= reply_threshold:
|
||||
# 中等兴趣度:平衡权重
|
||||
threshold_weight = 0.5
|
||||
activity_weight = 0.3
|
||||
relationship_weight = 0.2
|
||||
else:
|
||||
# 低兴趣度:更注重活跃度提升
|
||||
threshold_weight = 0.3
|
||||
activity_weight = 0.5
|
||||
relationship_weight = 0.2
|
||||
|
||||
# 计算活跃度得分
|
||||
activity_score = action_rate * 0.6 + reply_rate * 0.4
|
||||
|
||||
# 综合计算:基于阈值的动态加权
|
||||
focus_energy = (
|
||||
(
|
||||
threshold_based_score * threshold_weight # 阈值差距基础分
|
||||
+ activity_score * activity_weight # 活跃度得分
|
||||
+ relationship_factor * relationship_weight # 关系得分
|
||||
+ self.base_interest_energy * 0.05 # 基础兴趣微调
|
||||
)
|
||||
* time_decay
|
||||
* no_reply_penalty
|
||||
)
|
||||
|
||||
# 确保在合理范围内
|
||||
focus_energy = max(0.1, min(1.0, focus_energy))
|
||||
|
||||
# 应用非线性变换增强区分度
|
||||
if focus_energy >= 0.7:
|
||||
# 高兴趣度区域:指数增强,更敏感
|
||||
focus_energy = 0.7 + (focus_energy - 0.7) ** 0.8
|
||||
elif focus_energy >= 0.4:
|
||||
# 中等兴趣度区域:线性保持
|
||||
pass
|
||||
else:
|
||||
# 低兴趣度区域:对数压缩,减少区分度
|
||||
focus_energy = 0.4 * (focus_energy / 0.4) ** 1.2
|
||||
|
||||
return max(0.1, min(1.0, focus_energy))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算动态focus_energy失败: {e}")
|
||||
return self.base_interest_energy
|
||||
|
||||
def _get_user_relationship_score(self) -> float:
|
||||
"""从外部系统获取用户关系分"""
|
||||
"""从新的兴趣度管理系统获取用户关系分"""
|
||||
try:
|
||||
# 尝试从兴趣评分系统获取用户关系分
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
|
||||
chatter_interest_scoring_system,
|
||||
)
|
||||
# 使用新的兴趣度管理系统
|
||||
from src.chat.interest_system import interest_manager
|
||||
|
||||
if self.user_info and hasattr(self.user_info, "user_id"):
|
||||
return chatter_interest_scoring_system.get_user_relationship(str(self.user_info.user_id))
|
||||
user_id = str(self.user_info.user_id)
|
||||
# 获取用户交互历史作为关系分的基础
|
||||
interaction_calc = interest_manager.calculators.get(
|
||||
interest_manager.InterestSourceType.USER_INTERACTION
|
||||
)
|
||||
if interaction_calc:
|
||||
return interaction_calc.calculate({"user_id": user_id})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -378,12 +401,13 @@ class ChatStream:
|
||||
chat_info_platform=db_msg.chat_info_platform,
|
||||
chat_info_create_time=db_msg.chat_info_create_time,
|
||||
chat_info_last_active_time=db_msg.chat_info_last_active_time,
|
||||
# 新增的兴趣度系统字段
|
||||
interest_degree=getattr(db_msg, "interest_degree", 0.0) or 0.0,
|
||||
actions=actions,
|
||||
should_reply=getattr(db_msg, "should_reply", False) or False,
|
||||
)
|
||||
|
||||
# 添加调试日志:检查从数据库加载的interest_value
|
||||
logger.debug(f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}")
|
||||
|
||||
# 标记为已读并添加到历史消息
|
||||
db_message.is_read = True
|
||||
self.stream_context.history_messages.append(db_message)
|
||||
|
||||
Reference in New Issue
Block a user