refactor(chat): 优化消息管理与打断系统,添加打断计数与历史消息加载功能

This commit is contained in:
Windpicker-owo
2025-09-26 19:17:24 +08:00
parent 701e523823
commit 5962b44294
4 changed files with 334 additions and 76 deletions

View File

@@ -48,10 +48,9 @@ class ChatStream:
# 使用StreamContext替代ChatMessageContext
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatType, ChatMode
self.stream_context: StreamContext = StreamContext(
stream_id=stream_id,
chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE,
chat_mode=ChatMode.NORMAL
stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL
)
# 基础参数
@@ -59,6 +58,8 @@ class ChatStream:
self._focus_energy = 0.5 # 内部存储的focus_energy值
self.no_reply_consecutive = 0
# 自动加载历史消息
self._load_history_messages()
def to_dict(self) -> dict:
"""转换为字典格式"""
@@ -76,6 +77,8 @@ class ChatStream:
# 新增stream_context信息
"stream_context_chat_type": self.stream_context.chat_type.value,
"stream_context_chat_mode": self.stream_context.chat_mode.value,
# 新增interruption_count信息
"interruption_count": self.stream_context.interruption_count,
}
@classmethod
@@ -95,11 +98,17 @@ class ChatStream:
# 恢复stream_context信息
if "stream_context_chat_type" in data:
from src.plugin_system.base.component_types import ChatType, ChatMode
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
if "stream_context_chat_mode" in data:
from src.plugin_system.base.component_types import ChatType, ChatMode
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
# 恢复interruption_count信息
if "interruption_count" in data:
instance.stream_context.interruption_count = data["interruption_count"]
return instance
def update_active_time(self):
@@ -114,19 +123,28 @@ class ChatStream:
# 简化转换,实际可能需要更完整的转换逻辑
db_message = DatabaseMessages(
message_id=getattr(message, 'message_id', ''),
time=getattr(message, 'time', time.time()),
chat_id=getattr(message, 'chat_id', ''),
user_id=str(getattr(message.message_info, 'user_info', {}).user_id) if hasattr(message, 'message_info') and hasattr(message.message_info, 'user_info') else '',
user_nickname=getattr(message.message_info, 'user_info', {}).user_nickname if hasattr(message, 'message_info') and hasattr(message.message_info, 'user_info') else '',
user_platform=getattr(message.message_info, 'user_info', {}).platform if hasattr(message, 'message_info') and hasattr(message.message_info, 'user_info') else '',
priority_mode=getattr(message, 'priority_mode', None),
priority_info=str(getattr(message, 'priority_info', None)) if hasattr(message, 'priority_info') and message.priority_info else None,
message_id=getattr(message, "message_id", ""),
time=getattr(message, "time", time.time()),
chat_id=getattr(message, "chat_id", ""),
user_id=str(getattr(message.message_info, "user_info", {}).user_id)
if hasattr(message, "message_info") and hasattr(message.message_info, "user_info")
else "",
user_nickname=getattr(message.message_info, "user_info", {}).user_nickname
if hasattr(message, "message_info") and hasattr(message.message_info, "user_info")
else "",
user_platform=getattr(message.message_info, "user_info", {}).platform
if hasattr(message, "message_info") and hasattr(message.message_info, "user_info")
else "",
priority_mode=getattr(message, "priority_mode", None),
priority_info=str(getattr(message, "priority_info", None))
if hasattr(message, "priority_info") and message.priority_info
else None,
additional_config=getattr(getattr(message, "message_info", {}), "additional_config", None),
)
self.stream_context.set_current_message(db_message)
self.stream_context.priority_mode = getattr(message, 'priority_mode', None)
self.stream_context.priority_info = getattr(message, 'priority_info', None)
self.stream_context.priority_mode = getattr(message, "priority_mode", None)
self.stream_context.priority_info = getattr(message, "priority_info", None)
@property
def focus_energy(self) -> float:
@@ -150,16 +168,20 @@ class ChatStream:
# 计算基于历史消息的统计数据
if all_messages:
# 基础分:平均消息兴趣度
message_interests = [msg.interest_degree for msg in all_messages if hasattr(msg, 'interest_degree')]
message_interests = [msg.interest_degree for msg in all_messages if hasattr(msg, "interest_degree")]
avg_message_interest = sum(message_interests) / len(message_interests) if message_interests else 0.3
# 动作参与度:有动作的消息比例
messages_with_actions = [msg for msg in all_messages if hasattr(msg, 'actions') and msg.actions]
messages_with_actions = [msg for msg in all_messages if hasattr(msg, "actions") and msg.actions]
action_rate = len(messages_with_actions) / len(all_messages)
# 回复活跃度:应该回复且已回复的消息比例
should_reply_messages = [msg for msg in all_messages if hasattr(msg, 'should_reply') and msg.should_reply]
replied_messages = [msg for msg in should_reply_messages if hasattr(msg, 'actions') and 'reply' in (msg.actions or [])]
should_reply_messages = [
msg for msg in all_messages if hasattr(msg, "should_reply") and msg.should_reply
]
replied_messages = [
msg for msg in should_reply_messages if hasattr(msg, "actions") and "reply" in (msg.actions or [])
]
reply_rate = len(replied_messages) / len(should_reply_messages) if should_reply_messages else 0.0
# 获取最后交互时间
@@ -169,8 +191,8 @@ class ChatStream:
# 连续无回复计算:从最近的未回复消息计数
consecutive_no_reply = 0
for msg in reversed(all_messages):
if hasattr(msg, 'should_reply') and msg.should_reply:
if not (hasattr(msg, 'actions') and 'reply' in (msg.actions or [])):
if hasattr(msg, "should_reply") and msg.should_reply:
if not (hasattr(msg, "actions") and "reply" in (msg.actions or [])):
consecutive_no_reply += 1
else:
break
@@ -187,7 +209,7 @@ class ChatStream:
# 时间衰减因子:最近活跃度
current_time = time.time()
if not hasattr(self, 'last_interaction_time') or not self.last_interaction_time:
if not hasattr(self, "last_interaction_time") or not self.last_interaction_time:
self.last_interaction_time = current_time
time_since_interaction = current_time - self.last_interaction_time
time_decay = max(0.3, 1.0 - min(time_since_interaction / (7 * 24 * 3600), 0.7)) # 7天衰减
@@ -196,20 +218,24 @@ class ChatStream:
no_reply_penalty = max(0.1, 1.0 - consecutive_no_reply * 0.1)
# 获取AFC系统阈值添加None值检查
reply_threshold = getattr(global_config.affinity_flow, 'reply_action_interest_threshold', 0.4)
non_reply_threshold = getattr(global_config.affinity_flow, 'non_reply_action_interest_threshold', 0.2)
high_match_threshold = getattr(global_config.affinity_flow, 'high_match_interest_threshold', 0.8)
reply_threshold = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
non_reply_threshold = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2)
high_match_threshold = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8)
# 计算与不同阈值的差距比例
reply_gap_ratio = max(0, (avg_message_interest - reply_threshold) / max(0.1, (1.0 - reply_threshold)))
non_reply_gap_ratio = max(0, (avg_message_interest - non_reply_threshold) / max(0.1, (1.0 - non_reply_threshold)))
high_match_gap_ratio = max(0, (avg_message_interest - high_match_threshold) / max(0.1, (1.0 - high_match_threshold)))
non_reply_gap_ratio = max(
0, (avg_message_interest - non_reply_threshold) / max(0.1, (1.0 - non_reply_threshold))
)
high_match_gap_ratio = max(
0, (avg_message_interest - high_match_threshold) / max(0.1, (1.0 - high_match_threshold))
)
# 基于阈值差距比例的基础分计算
threshold_based_score = (
reply_gap_ratio * 0.6 + # 回复阈值差距权重60%
non_reply_gap_ratio * 0.2 + # 非回复阈值差距权重20%
high_match_gap_ratio * 0.2 # 高匹配阈值差距权重20%
reply_gap_ratio * 0.6 # 回复阈值差距权重60%
+ non_reply_gap_ratio * 0.2 # 非回复阈值差距权重20%
+ high_match_gap_ratio * 0.2 # 高匹配阈值差距权重20%
)
# 动态权重调整:根据平均兴趣度水平调整权重分配
@@ -230,15 +256,19 @@ class ChatStream:
relationship_weight = 0.2
# 计算活跃度得分
activity_score = (action_rate * 0.6 + reply_rate * 0.4)
activity_score = action_rate * 0.6 + reply_rate * 0.4
# 综合计算:基于阈值的动态加权
focus_energy = (
threshold_based_score * threshold_weight + # 阈值差距基础分
activity_score * activity_weight + # 活跃度得
relationship_factor * relationship_weight + # 关系得分
self.base_interest_energy * 0.05 # 基础兴趣微调
) * time_decay * no_reply_penalty
(
threshold_based_score * threshold_weight # 阈值差距基础
+ activity_score * activity_weight # 活跃度得分
+ relationship_factor * relationship_weight # 关系得分
+ self.base_interest_energy * 0.05 # 基础兴趣微调
)
* time_decay
* no_reply_penalty
)
# 确保在合理范围内
focus_energy = max(0.1, min(1.0, focus_energy))
@@ -268,7 +298,7 @@ class ChatStream:
chatter_interest_scoring_system,
)
if self.user_info and hasattr(self.user_info, 'user_id'):
if self.user_info and hasattr(self.user_info, "user_id"):
return chatter_interest_scoring_system.get_user_relationship(str(self.user_info.user_id))
except Exception:
pass
@@ -276,8 +306,102 @@ class ChatStream:
# 默认基础分
return 0.3
def _load_history_messages(self):
"""从数据库加载历史消息到StreamContext"""
try:
from src.common.database.sqlalchemy_models import Messages
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.data_models.database_data_model import DatabaseMessages
from sqlalchemy import select, desc
import asyncio
async def _load_messages():
def _db_query():
with get_db_session() as session:
# 查询该stream_id的最近20条消息
stmt = (
select(Messages)
.where(Messages.chat_info_stream_id == self.stream_id)
.order_by(desc(Messages.time))
.limit(global_config.chat.max_context_size)
)
results = session.execute(stmt).scalars().all()
return results
# 在线程中执行数据库查询
db_messages = await asyncio.to_thread(_db_query)
# 转换为DatabaseMessages对象并添加到StreamContext
for db_msg in db_messages:
try:
# 从SQLAlchemy模型转换为DatabaseMessages数据模型
import orjson
# 解析actions字段JSON格式
actions = None
if db_msg.actions:
try:
actions = orjson.loads(db_msg.actions)
except (orjson.JSONDecodeError, TypeError):
actions = None
db_message = DatabaseMessages(
message_id=db_msg.message_id,
time=db_msg.time,
chat_id=db_msg.chat_id,
reply_to=db_msg.reply_to,
interest_value=db_msg.interest_value,
key_words=db_msg.key_words,
key_words_lite=db_msg.key_words_lite,
is_mentioned=db_msg.is_mentioned,
processed_plain_text=db_msg.processed_plain_text,
display_message=db_msg.display_message,
priority_mode=db_msg.priority_mode,
priority_info=db_msg.priority_info,
additional_config=db_msg.additional_config,
is_emoji=db_msg.is_emoji,
is_picid=db_msg.is_picid,
is_command=db_msg.is_command,
is_notify=db_msg.is_notify,
user_id=db_msg.user_id,
user_nickname=db_msg.user_nickname,
user_cardname=db_msg.user_cardname,
user_platform=db_msg.user_platform,
chat_info_group_id=db_msg.chat_info_group_id,
chat_info_group_name=db_msg.chat_info_group_name,
chat_info_group_platform=db_msg.chat_info_group_platform,
chat_info_user_id=db_msg.chat_info_user_id,
chat_info_user_nickname=db_msg.chat_info_user_nickname,
chat_info_user_cardname=db_msg.chat_info_user_cardname,
chat_info_user_platform=db_msg.chat_info_user_platform,
chat_info_stream_id=db_msg.chat_info_stream_id,
chat_info_platform=db_msg.chat_info_platform,
chat_info_create_time=db_msg.chat_info_create_time,
chat_info_last_active_time=db_msg.chat_info_last_active_time,
# 新增的兴趣度系统字段
interest_degree=getattr(db_msg, "interest_degree", 0.0) or 0.0,
actions=actions,
should_reply=getattr(db_msg, "should_reply", False) or False,
)
# 标记为已读并添加到历史消息
db_message.is_read = True
self.stream_context.history_messages.append(db_message)
except Exception as e:
logger.warning(f"转换消息 {db_msg.message_id} 失败: {e}")
continue
if self.stream_context.history_messages:
logger.info(
f"已从数据库加载 {len(self.stream_context.history_messages)} 条历史消息到聊天流 {self.stream_id}"
)
# 创建任务来加载历史消息
asyncio.create_task(_load_messages())
except Exception as e:
logger.error(f"加载历史消息失败: {e}")
class ChatManager:
@@ -524,6 +648,7 @@ class ChatManager:
"reply_count": s_data_dict.get("reply_count", 0),
"last_interaction_time": s_data_dict.get("last_interaction_time", time.time()),
"consecutive_no_reply": s_data_dict.get("consecutive_no_reply", 0),
"interruption_count": s_data_dict.get("interruption_count", 0),
}
if global_config.database.database_type == "sqlite":
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
@@ -590,6 +715,7 @@ class ChatManager:
"last_interaction_time": getattr(model_instance, "last_interaction_time", time.time()),
"relationship_score": getattr(model_instance, "relationship_score", 0.3),
"consecutive_no_reply": getattr(model_instance, "consecutive_no_reply", 0),
"interruption_count": getattr(model_instance, "interruption_count", 0),
}
loaded_streams_data.append(data_for_from_dict)
await session.commit()