refactor(chat): 优化消息管理与打断系统,添加打断计数与历史消息加载功能
This commit is contained in:
@@ -89,7 +89,14 @@ class MessageManager:
|
||||
|
||||
logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}")
|
||||
|
||||
def update_message_and_refresh_energy(self, stream_id: str, message_id: str, interest_degree: float = None, actions: list = None, should_reply: bool = None):
|
||||
def update_message_and_refresh_energy(
|
||||
self,
|
||||
stream_id: str,
|
||||
message_id: str,
|
||||
interest_degree: float = None,
|
||||
actions: list = None,
|
||||
should_reply: bool = None,
|
||||
):
|
||||
"""更新消息信息"""
|
||||
if stream_id in self.stream_contexts:
|
||||
context = self.stream_contexts[stream_id]
|
||||
@@ -288,6 +295,13 @@ class MessageManager:
|
||||
global_config.chat.interruption_max_limit, global_config.chat.interruption_probability_factor
|
||||
)
|
||||
|
||||
# 检查是否已达到最大打断次数
|
||||
if context.interruption_count >= global_config.chat.interruption_max_limit:
|
||||
logger.debug(
|
||||
f"聊天流 {stream_id} 已达到最大打断次数 {context.interruption_count}/{global_config.chat.interruption_max_limit},跳过打断检查"
|
||||
)
|
||||
return
|
||||
|
||||
# 根据概率决定是否打断
|
||||
if random.random() < interruption_probability:
|
||||
logger.info(f"聊天流 {stream_id} 触发消息打断,打断概率: {interruption_probability:.2f}")
|
||||
@@ -302,9 +316,16 @@ class MessageManager:
|
||||
# 增加打断计数并应用afc阈值降低
|
||||
context.increment_interruption_count()
|
||||
context.apply_interruption_afc_reduction(global_config.chat.interruption_afc_reduction)
|
||||
logger.info(
|
||||
f"聊天流 {stream_id} 已打断,当前打断次数: {context.interruption_count}/{global_config.chat.interruption_max_limit}, afc阈值调整: {context.get_afc_threshold_adjustment()}"
|
||||
)
|
||||
|
||||
# 检查是否已达到最大次数
|
||||
if context.interruption_count >= global_config.chat.interruption_max_limit:
|
||||
logger.warning(
|
||||
f"聊天流 {stream_id} 已达到最大打断次数 {context.interruption_count}/{global_config.chat.interruption_max_limit},后续消息将不再打断"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"聊天流 {stream_id} 已打断,当前打断次数: {context.interruption_count}/{global_config.chat.interruption_max_limit}, afc阈值调整: {context.get_afc_threshold_adjustment()}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"聊天流 {stream_id} 未触发打断,打断概率: {interruption_probability:.2f}")
|
||||
|
||||
@@ -312,9 +333,10 @@ class MessageManager:
|
||||
"""计算单个聊天流的分发周期 - 基于阈值感知的focus_energy"""
|
||||
if not global_config.chat.dynamic_distribution_enabled:
|
||||
return self.check_interval # 使用固定间隔
|
||||
|
||||
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||
# 获取该流的focus_energy(新的阈值感知版本)
|
||||
focus_energy = 0.5 # 默认值
|
||||
avg_message_interest = 0.5 # 默认平均兴趣度
|
||||
@@ -327,13 +349,13 @@ class MessageManager:
|
||||
all_messages = history_messages + unread_messages
|
||||
|
||||
if all_messages:
|
||||
message_interests = [msg.interest_degree for msg in all_messages if hasattr(msg, 'interest_degree')]
|
||||
message_interests = [msg.interest_degree for msg in all_messages if hasattr(msg, "interest_degree")]
|
||||
avg_message_interest = sum(message_interests) / len(message_interests) if message_interests else 0.5
|
||||
|
||||
# 获取AFC阈值用于参考,添加None值检查
|
||||
reply_threshold = getattr(global_config.affinity_flow, 'reply_action_interest_threshold', 0.4)
|
||||
non_reply_threshold = getattr(global_config.affinity_flow, 'non_reply_action_interest_threshold', 0.2)
|
||||
high_match_threshold = getattr(global_config.affinity_flow, 'high_match_interest_threshold', 0.8)
|
||||
reply_threshold = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
|
||||
non_reply_threshold = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2)
|
||||
high_match_threshold = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8)
|
||||
|
||||
# 使用配置参数
|
||||
base_interval = global_config.chat.dynamic_distribution_base_interval
|
||||
@@ -364,6 +386,7 @@ class MessageManager:
|
||||
|
||||
# 添加随机扰动避免同步
|
||||
import random
|
||||
|
||||
jitter = random.uniform(1.0 - jitter_factor, 1.0 + jitter_factor)
|
||||
final_interval = interval * jitter
|
||||
|
||||
@@ -395,7 +418,7 @@ class MessageManager:
|
||||
def _calculate_next_manager_delay(self) -> float:
|
||||
"""计算管理器下次检查的延迟时间"""
|
||||
current_time = time.time()
|
||||
min_delay = float('inf')
|
||||
min_delay = float("inf")
|
||||
|
||||
# 找到最近需要检查的流
|
||||
for context in self.stream_contexts.values():
|
||||
@@ -410,7 +433,7 @@ class MessageManager:
|
||||
break
|
||||
|
||||
# 如果没有活跃流,使用默认间隔
|
||||
if min_delay == float('inf'):
|
||||
if min_delay == float("inf"):
|
||||
return self.check_interval
|
||||
|
||||
# 确保最小延迟
|
||||
@@ -448,7 +471,8 @@ class MessageManager:
|
||||
# 如果没有处理任务,创建一个
|
||||
if not context.processing_task or context.processing_task.done():
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||
focus_energy = chat_stream.focus_energy if chat_stream else 0.5
|
||||
|
||||
# 根据优先级记录日志
|
||||
@@ -473,10 +497,7 @@ class MessageManager:
|
||||
self.stats.active_streams = active_count
|
||||
|
||||
if processed_streams > 0:
|
||||
logger.debug(
|
||||
f"本次循环处理了 {processed_streams} 个流 | "
|
||||
f"活跃流总数: {active_count}"
|
||||
)
|
||||
logger.debug(f"本次循环处理了 {processed_streams} 个流 | 活跃流总数: {active_count}")
|
||||
|
||||
async def _check_all_streams_with_priority(self):
|
||||
"""按优先级检查所有聊天流,高focus_energy的流优先处理"""
|
||||
@@ -491,7 +512,8 @@ class MessageManager:
|
||||
|
||||
# 获取focus_energy,如果不存在则使用默认值
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||
focus_energy = 0.5
|
||||
if chat_stream:
|
||||
focus_energy = chat_stream.focus_energy
|
||||
@@ -534,7 +556,8 @@ class MessageManager:
|
||||
def _calculate_stream_priority(self, context: StreamContext, focus_energy: float) -> float:
|
||||
"""计算聊天流的优先级分数"""
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||
# 基础优先级:focus_energy
|
||||
base_priority = focus_energy
|
||||
|
||||
@@ -553,8 +576,8 @@ class MessageManager:
|
||||
consecutive_no_reply = 0
|
||||
all_messages = context.get_history_messages(limit=50) + context.get_unread_messages()
|
||||
for msg in reversed(all_messages):
|
||||
if hasattr(msg, 'should_reply') and msg.should_reply:
|
||||
if not (hasattr(msg, 'actions') and 'reply' in (msg.actions or [])):
|
||||
if hasattr(msg, "should_reply") and msg.should_reply:
|
||||
if not (hasattr(msg, "actions") and "reply" in (msg.actions or [])):
|
||||
consecutive_no_reply += 1
|
||||
else:
|
||||
break
|
||||
@@ -564,10 +587,10 @@ class MessageManager:
|
||||
|
||||
# 综合优先级计算
|
||||
final_priority = (
|
||||
base_priority * 0.6 + # 基础兴趣度权重60%
|
||||
message_count_bonus * 0.2 + # 消息数量权重20%
|
||||
time_penalty * 0.1 + # 时间权重10%
|
||||
no_reply_penalty * 0.1 # 回复状态权重10%
|
||||
base_priority * 0.6 # 基础兴趣度权重60%
|
||||
+ message_count_bonus * 0.2 # 消息数量权重20%
|
||||
+ time_penalty * 0.1 # 时间权重10%
|
||||
+ no_reply_penalty * 0.1 # 回复状态权重10%
|
||||
)
|
||||
|
||||
return max(0.0, min(1.0, final_priority))
|
||||
|
||||
@@ -48,10 +48,9 @@ class ChatStream:
|
||||
# 使用StreamContext替代ChatMessageContext
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
|
||||
self.stream_context: StreamContext = StreamContext(
|
||||
stream_id=stream_id,
|
||||
chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE,
|
||||
chat_mode=ChatMode.NORMAL
|
||||
stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL
|
||||
)
|
||||
|
||||
# 基础参数
|
||||
@@ -59,6 +58,8 @@ class ChatStream:
|
||||
self._focus_energy = 0.5 # 内部存储的focus_energy值
|
||||
self.no_reply_consecutive = 0
|
||||
|
||||
# 自动加载历史消息
|
||||
self._load_history_messages()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
@@ -76,6 +77,8 @@ class ChatStream:
|
||||
# 新增stream_context信息
|
||||
"stream_context_chat_type": self.stream_context.chat_type.value,
|
||||
"stream_context_chat_mode": self.stream_context.chat_mode.value,
|
||||
# 新增interruption_count信息
|
||||
"interruption_count": self.stream_context.interruption_count,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -95,11 +98,17 @@ class ChatStream:
|
||||
# 恢复stream_context信息
|
||||
if "stream_context_chat_type" in data:
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
|
||||
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
|
||||
if "stream_context_chat_mode" in data:
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
|
||||
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
||||
|
||||
# 恢复interruption_count信息
|
||||
if "interruption_count" in data:
|
||||
instance.stream_context.interruption_count = data["interruption_count"]
|
||||
|
||||
return instance
|
||||
|
||||
def update_active_time(self):
|
||||
@@ -114,19 +123,28 @@ class ChatStream:
|
||||
|
||||
# 简化转换,实际可能需要更完整的转换逻辑
|
||||
db_message = DatabaseMessages(
|
||||
message_id=getattr(message, 'message_id', ''),
|
||||
time=getattr(message, 'time', time.time()),
|
||||
chat_id=getattr(message, 'chat_id', ''),
|
||||
user_id=str(getattr(message.message_info, 'user_info', {}).user_id) if hasattr(message, 'message_info') and hasattr(message.message_info, 'user_info') else '',
|
||||
user_nickname=getattr(message.message_info, 'user_info', {}).user_nickname if hasattr(message, 'message_info') and hasattr(message.message_info, 'user_info') else '',
|
||||
user_platform=getattr(message.message_info, 'user_info', {}).platform if hasattr(message, 'message_info') and hasattr(message.message_info, 'user_info') else '',
|
||||
priority_mode=getattr(message, 'priority_mode', None),
|
||||
priority_info=str(getattr(message, 'priority_info', None)) if hasattr(message, 'priority_info') and message.priority_info else None,
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
time=getattr(message, "time", time.time()),
|
||||
chat_id=getattr(message, "chat_id", ""),
|
||||
user_id=str(getattr(message.message_info, "user_info", {}).user_id)
|
||||
if hasattr(message, "message_info") and hasattr(message.message_info, "user_info")
|
||||
else "",
|
||||
user_nickname=getattr(message.message_info, "user_info", {}).user_nickname
|
||||
if hasattr(message, "message_info") and hasattr(message.message_info, "user_info")
|
||||
else "",
|
||||
user_platform=getattr(message.message_info, "user_info", {}).platform
|
||||
if hasattr(message, "message_info") and hasattr(message.message_info, "user_info")
|
||||
else "",
|
||||
priority_mode=getattr(message, "priority_mode", None),
|
||||
priority_info=str(getattr(message, "priority_info", None))
|
||||
if hasattr(message, "priority_info") and message.priority_info
|
||||
else None,
|
||||
additional_config=getattr(getattr(message, "message_info", {}), "additional_config", None),
|
||||
)
|
||||
|
||||
self.stream_context.set_current_message(db_message)
|
||||
self.stream_context.priority_mode = getattr(message, 'priority_mode', None)
|
||||
self.stream_context.priority_info = getattr(message, 'priority_info', None)
|
||||
self.stream_context.priority_mode = getattr(message, "priority_mode", None)
|
||||
self.stream_context.priority_info = getattr(message, "priority_info", None)
|
||||
|
||||
@property
|
||||
def focus_energy(self) -> float:
|
||||
@@ -150,16 +168,20 @@ class ChatStream:
|
||||
# 计算基于历史消息的统计数据
|
||||
if all_messages:
|
||||
# 基础分:平均消息兴趣度
|
||||
message_interests = [msg.interest_degree for msg in all_messages if hasattr(msg, 'interest_degree')]
|
||||
message_interests = [msg.interest_degree for msg in all_messages if hasattr(msg, "interest_degree")]
|
||||
avg_message_interest = sum(message_interests) / len(message_interests) if message_interests else 0.3
|
||||
|
||||
# 动作参与度:有动作的消息比例
|
||||
messages_with_actions = [msg for msg in all_messages if hasattr(msg, 'actions') and msg.actions]
|
||||
messages_with_actions = [msg for msg in all_messages if hasattr(msg, "actions") and msg.actions]
|
||||
action_rate = len(messages_with_actions) / len(all_messages)
|
||||
|
||||
# 回复活跃度:应该回复且已回复的消息比例
|
||||
should_reply_messages = [msg for msg in all_messages if hasattr(msg, 'should_reply') and msg.should_reply]
|
||||
replied_messages = [msg for msg in should_reply_messages if hasattr(msg, 'actions') and 'reply' in (msg.actions or [])]
|
||||
should_reply_messages = [
|
||||
msg for msg in all_messages if hasattr(msg, "should_reply") and msg.should_reply
|
||||
]
|
||||
replied_messages = [
|
||||
msg for msg in should_reply_messages if hasattr(msg, "actions") and "reply" in (msg.actions or [])
|
||||
]
|
||||
reply_rate = len(replied_messages) / len(should_reply_messages) if should_reply_messages else 0.0
|
||||
|
||||
# 获取最后交互时间
|
||||
@@ -169,8 +191,8 @@ class ChatStream:
|
||||
# 连续无回复计算:从最近的未回复消息计数
|
||||
consecutive_no_reply = 0
|
||||
for msg in reversed(all_messages):
|
||||
if hasattr(msg, 'should_reply') and msg.should_reply:
|
||||
if not (hasattr(msg, 'actions') and 'reply' in (msg.actions or [])):
|
||||
if hasattr(msg, "should_reply") and msg.should_reply:
|
||||
if not (hasattr(msg, "actions") and "reply" in (msg.actions or [])):
|
||||
consecutive_no_reply += 1
|
||||
else:
|
||||
break
|
||||
@@ -187,7 +209,7 @@ class ChatStream:
|
||||
|
||||
# 时间衰减因子:最近活跃度
|
||||
current_time = time.time()
|
||||
if not hasattr(self, 'last_interaction_time') or not self.last_interaction_time:
|
||||
if not hasattr(self, "last_interaction_time") or not self.last_interaction_time:
|
||||
self.last_interaction_time = current_time
|
||||
time_since_interaction = current_time - self.last_interaction_time
|
||||
time_decay = max(0.3, 1.0 - min(time_since_interaction / (7 * 24 * 3600), 0.7)) # 7天衰减
|
||||
@@ -196,20 +218,24 @@ class ChatStream:
|
||||
no_reply_penalty = max(0.1, 1.0 - consecutive_no_reply * 0.1)
|
||||
|
||||
# 获取AFC系统阈值,添加None值检查
|
||||
reply_threshold = getattr(global_config.affinity_flow, 'reply_action_interest_threshold', 0.4)
|
||||
non_reply_threshold = getattr(global_config.affinity_flow, 'non_reply_action_interest_threshold', 0.2)
|
||||
high_match_threshold = getattr(global_config.affinity_flow, 'high_match_interest_threshold', 0.8)
|
||||
reply_threshold = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
|
||||
non_reply_threshold = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2)
|
||||
high_match_threshold = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8)
|
||||
|
||||
# 计算与不同阈值的差距比例
|
||||
reply_gap_ratio = max(0, (avg_message_interest - reply_threshold) / max(0.1, (1.0 - reply_threshold)))
|
||||
non_reply_gap_ratio = max(0, (avg_message_interest - non_reply_threshold) / max(0.1, (1.0 - non_reply_threshold)))
|
||||
high_match_gap_ratio = max(0, (avg_message_interest - high_match_threshold) / max(0.1, (1.0 - high_match_threshold)))
|
||||
non_reply_gap_ratio = max(
|
||||
0, (avg_message_interest - non_reply_threshold) / max(0.1, (1.0 - non_reply_threshold))
|
||||
)
|
||||
high_match_gap_ratio = max(
|
||||
0, (avg_message_interest - high_match_threshold) / max(0.1, (1.0 - high_match_threshold))
|
||||
)
|
||||
|
||||
# 基于阈值差距比例的基础分计算
|
||||
threshold_based_score = (
|
||||
reply_gap_ratio * 0.6 + # 回复阈值差距权重60%
|
||||
non_reply_gap_ratio * 0.2 + # 非回复阈值差距权重20%
|
||||
high_match_gap_ratio * 0.2 # 高匹配阈值差距权重20%
|
||||
reply_gap_ratio * 0.6 # 回复阈值差距权重60%
|
||||
+ non_reply_gap_ratio * 0.2 # 非回复阈值差距权重20%
|
||||
+ high_match_gap_ratio * 0.2 # 高匹配阈值差距权重20%
|
||||
)
|
||||
|
||||
# 动态权重调整:根据平均兴趣度水平调整权重分配
|
||||
@@ -230,15 +256,19 @@ class ChatStream:
|
||||
relationship_weight = 0.2
|
||||
|
||||
# 计算活跃度得分
|
||||
activity_score = (action_rate * 0.6 + reply_rate * 0.4)
|
||||
activity_score = action_rate * 0.6 + reply_rate * 0.4
|
||||
|
||||
# 综合计算:基于阈值的动态加权
|
||||
focus_energy = (
|
||||
threshold_based_score * threshold_weight + # 阈值差距基础分
|
||||
activity_score * activity_weight + # 活跃度得分
|
||||
relationship_factor * relationship_weight + # 关系得分
|
||||
self.base_interest_energy * 0.05 # 基础兴趣微调
|
||||
) * time_decay * no_reply_penalty
|
||||
(
|
||||
threshold_based_score * threshold_weight # 阈值差距基础分
|
||||
+ activity_score * activity_weight # 活跃度得分
|
||||
+ relationship_factor * relationship_weight # 关系得分
|
||||
+ self.base_interest_energy * 0.05 # 基础兴趣微调
|
||||
)
|
||||
* time_decay
|
||||
* no_reply_penalty
|
||||
)
|
||||
|
||||
# 确保在合理范围内
|
||||
focus_energy = max(0.1, min(1.0, focus_energy))
|
||||
@@ -268,7 +298,7 @@ class ChatStream:
|
||||
chatter_interest_scoring_system,
|
||||
)
|
||||
|
||||
if self.user_info and hasattr(self.user_info, 'user_id'):
|
||||
if self.user_info and hasattr(self.user_info, "user_id"):
|
||||
return chatter_interest_scoring_system.get_user_relationship(str(self.user_info.user_id))
|
||||
except Exception:
|
||||
pass
|
||||
@@ -276,8 +306,102 @@ class ChatStream:
|
||||
# 默认基础分
|
||||
return 0.3
|
||||
|
||||
def _load_history_messages(self):
|
||||
"""从数据库加载历史消息到StreamContext"""
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import Messages
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from sqlalchemy import select, desc
|
||||
import asyncio
|
||||
|
||||
async def _load_messages():
|
||||
def _db_query():
|
||||
with get_db_session() as session:
|
||||
# 查询该stream_id的最近20条消息
|
||||
stmt = (
|
||||
select(Messages)
|
||||
.where(Messages.chat_info_stream_id == self.stream_id)
|
||||
.order_by(desc(Messages.time))
|
||||
.limit(global_config.chat.max_context_size)
|
||||
)
|
||||
results = session.execute(stmt).scalars().all()
|
||||
return results
|
||||
|
||||
# 在线程中执行数据库查询
|
||||
db_messages = await asyncio.to_thread(_db_query)
|
||||
|
||||
# 转换为DatabaseMessages对象并添加到StreamContext
|
||||
for db_msg in db_messages:
|
||||
try:
|
||||
# 从SQLAlchemy模型转换为DatabaseMessages数据模型
|
||||
import orjson
|
||||
|
||||
# 解析actions字段(JSON格式)
|
||||
actions = None
|
||||
if db_msg.actions:
|
||||
try:
|
||||
actions = orjson.loads(db_msg.actions)
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
actions = None
|
||||
|
||||
db_message = DatabaseMessages(
|
||||
message_id=db_msg.message_id,
|
||||
time=db_msg.time,
|
||||
chat_id=db_msg.chat_id,
|
||||
reply_to=db_msg.reply_to,
|
||||
interest_value=db_msg.interest_value,
|
||||
key_words=db_msg.key_words,
|
||||
key_words_lite=db_msg.key_words_lite,
|
||||
is_mentioned=db_msg.is_mentioned,
|
||||
processed_plain_text=db_msg.processed_plain_text,
|
||||
display_message=db_msg.display_message,
|
||||
priority_mode=db_msg.priority_mode,
|
||||
priority_info=db_msg.priority_info,
|
||||
additional_config=db_msg.additional_config,
|
||||
is_emoji=db_msg.is_emoji,
|
||||
is_picid=db_msg.is_picid,
|
||||
is_command=db_msg.is_command,
|
||||
is_notify=db_msg.is_notify,
|
||||
user_id=db_msg.user_id,
|
||||
user_nickname=db_msg.user_nickname,
|
||||
user_cardname=db_msg.user_cardname,
|
||||
user_platform=db_msg.user_platform,
|
||||
chat_info_group_id=db_msg.chat_info_group_id,
|
||||
chat_info_group_name=db_msg.chat_info_group_name,
|
||||
chat_info_group_platform=db_msg.chat_info_group_platform,
|
||||
chat_info_user_id=db_msg.chat_info_user_id,
|
||||
chat_info_user_nickname=db_msg.chat_info_user_nickname,
|
||||
chat_info_user_cardname=db_msg.chat_info_user_cardname,
|
||||
chat_info_user_platform=db_msg.chat_info_user_platform,
|
||||
chat_info_stream_id=db_msg.chat_info_stream_id,
|
||||
chat_info_platform=db_msg.chat_info_platform,
|
||||
chat_info_create_time=db_msg.chat_info_create_time,
|
||||
chat_info_last_active_time=db_msg.chat_info_last_active_time,
|
||||
# 新增的兴趣度系统字段
|
||||
interest_degree=getattr(db_msg, "interest_degree", 0.0) or 0.0,
|
||||
actions=actions,
|
||||
should_reply=getattr(db_msg, "should_reply", False) or False,
|
||||
)
|
||||
|
||||
# 标记为已读并添加到历史消息
|
||||
db_message.is_read = True
|
||||
self.stream_context.history_messages.append(db_message)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"转换消息 {db_msg.message_id} 失败: {e}")
|
||||
continue
|
||||
|
||||
if self.stream_context.history_messages:
|
||||
logger.info(
|
||||
f"已从数据库加载 {len(self.stream_context.history_messages)} 条历史消息到聊天流 {self.stream_id}"
|
||||
)
|
||||
|
||||
# 创建任务来加载历史消息
|
||||
asyncio.create_task(_load_messages())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载历史消息失败: {e}")
|
||||
|
||||
|
||||
class ChatManager:
|
||||
@@ -524,6 +648,7 @@ class ChatManager:
|
||||
"reply_count": s_data_dict.get("reply_count", 0),
|
||||
"last_interaction_time": s_data_dict.get("last_interaction_time", time.time()),
|
||||
"consecutive_no_reply": s_data_dict.get("consecutive_no_reply", 0),
|
||||
"interruption_count": s_data_dict.get("interruption_count", 0),
|
||||
}
|
||||
|
||||
# 根据数据库类型选择插入语句
|
||||
@@ -595,6 +720,7 @@ class ChatManager:
|
||||
"last_interaction_time": getattr(model_instance, "last_interaction_time", time.time()),
|
||||
"relationship_score": getattr(model_instance, "relationship_score", 0.3),
|
||||
"consecutive_no_reply": getattr(model_instance, "consecutive_no_reply", 0),
|
||||
"interruption_count": getattr(model_instance, "interruption_count", 0),
|
||||
}
|
||||
loaded_streams_data.append(data_for_from_dict)
|
||||
session.commit()
|
||||
|
||||
@@ -60,7 +60,9 @@ class StreamContext(BaseDataModel):
|
||||
# 自动检测和更新chat type
|
||||
self._detect_chat_type(message)
|
||||
|
||||
def update_message_info(self, message_id: str, interest_degree: float = None, actions: list = None, should_reply: bool = None):
|
||||
def update_message_info(
|
||||
self, message_id: str, interest_degree: float = None, actions: list = None, should_reply: bool = None
|
||||
):
|
||||
"""
|
||||
更新消息信息
|
||||
|
||||
@@ -166,11 +168,15 @@ class StreamContext(BaseDataModel):
|
||||
# 计算打断比例
|
||||
interruption_ratio = self.interruption_count / max_limit
|
||||
|
||||
# 如果已达到或超过最大次数,完全禁止打断
|
||||
if self.interruption_count >= max_limit:
|
||||
return 0.0
|
||||
|
||||
# 如果超过概率因子,概率下降
|
||||
if interruption_ratio > probability_factor:
|
||||
# 使用指数衰减,超过限制越多,概率越低
|
||||
excess_ratio = interruption_ratio - probability_factor
|
||||
probability = 1.0 * (0.5**excess_ratio) # 基础概率0.5,指数衰减
|
||||
probability = 0.8 * (0.5**excess_ratio) # 基础概率0.8,指数衰减
|
||||
else:
|
||||
# 在限制内,保持较高概率
|
||||
probability = 0.8
|
||||
@@ -182,12 +188,18 @@ class StreamContext(BaseDataModel):
|
||||
self.interruption_count += 1
|
||||
self.last_interruption_time = time.time()
|
||||
|
||||
# 同步打断计数到ChatStream
|
||||
self._sync_interruption_count_to_stream()
|
||||
|
||||
def reset_interruption_count(self):
|
||||
"""重置打断计数和afc阈值调整"""
|
||||
self.interruption_count = 0
|
||||
self.last_interruption_time = 0.0
|
||||
self.afc_threshold_adjustment = 0.0
|
||||
|
||||
# 同步打断计数到ChatStream
|
||||
self._sync_interruption_count_to_stream()
|
||||
|
||||
def apply_interruption_afc_reduction(self, reduction_value: float):
|
||||
"""应用打断导致的afc阈值降低"""
|
||||
self.afc_threshold_adjustment += reduction_value
|
||||
@@ -197,18 +209,40 @@ class StreamContext(BaseDataModel):
|
||||
"""获取当前的afc阈值调整量"""
|
||||
return self.afc_threshold_adjustment
|
||||
|
||||
def _sync_interruption_count_to_stream(self):
|
||||
"""同步打断计数到ChatStream"""
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
if chat_manager:
|
||||
chat_stream = chat_manager.get_stream(self.stream_id)
|
||||
if chat_stream and hasattr(chat_stream, "interruption_count"):
|
||||
# 在这里我们只是标记需要保存,实际的保存会在下次save时进行
|
||||
chat_stream.saved = False
|
||||
logger.debug(
|
||||
f"已同步StreamContext {self.stream_id} 的打断计数 {self.interruption_count} 到ChatStream"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"同步打断计数到ChatStream失败: {e}")
|
||||
|
||||
def set_current_message(self, message: "DatabaseMessages"):
|
||||
"""设置当前消息"""
|
||||
self.current_message = message
|
||||
|
||||
def get_template_name(self) -> Optional[str]:
|
||||
"""获取模板名称"""
|
||||
if self.current_message and hasattr(self.current_message, 'additional_config') and self.current_message.additional_config:
|
||||
if (
|
||||
self.current_message
|
||||
and hasattr(self.current_message, "additional_config")
|
||||
and self.current_message.additional_config
|
||||
):
|
||||
try:
|
||||
import json
|
||||
|
||||
config = json.loads(self.current_message.additional_config)
|
||||
if config.get('template_info') and not config.get('template_default', True):
|
||||
return config.get('template_name')
|
||||
if config.get("template_info") and not config.get("template_default", True):
|
||||
return config.get("template_name")
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass
|
||||
return None
|
||||
@@ -224,25 +258,83 @@ class StreamContext(BaseDataModel):
|
||||
return None
|
||||
|
||||
def check_types(self, types: list) -> bool:
|
||||
"""检查消息类型"""
|
||||
"""
|
||||
检查当前消息是否支持指定的类型
|
||||
|
||||
Args:
|
||||
types: 需要检查的消息类型列表,如 ["text", "image", "emoji"]
|
||||
|
||||
Returns:
|
||||
bool: 如果消息支持所有指定的类型则返回True,否则返回False
|
||||
"""
|
||||
if not self.current_message:
|
||||
return False
|
||||
|
||||
# 检查消息是否支持指定的类型
|
||||
# 这里简化处理,实际应该根据消息的格式信息检查
|
||||
if hasattr(self.current_message, 'additional_config') and self.current_message.additional_config:
|
||||
if not types:
|
||||
# 如果没有指定类型要求,默认为支持
|
||||
return True
|
||||
|
||||
# 优先从additional_config中获取format_info
|
||||
if hasattr(self.current_message, "additional_config") and self.current_message.additional_config:
|
||||
try:
|
||||
import json
|
||||
config = json.loads(self.current_message.additional_config)
|
||||
if 'format_info' in config and 'accept_format' in config['format_info']:
|
||||
accept_format = config['format_info']['accept_format']
|
||||
for t in types:
|
||||
if t not in accept_format:
|
||||
return False
|
||||
return True
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass
|
||||
return False
|
||||
import orjson
|
||||
|
||||
config = orjson.loads(self.current_message.additional_config)
|
||||
|
||||
# 检查format_info结构
|
||||
if "format_info" in config:
|
||||
format_info = config["format_info"]
|
||||
|
||||
# 方法1: 直接检查accept_format字段
|
||||
if "accept_format" in format_info:
|
||||
accept_format = format_info["accept_format"]
|
||||
# 确保accept_format是列表类型
|
||||
if isinstance(accept_format, str):
|
||||
accept_format = [accept_format]
|
||||
elif isinstance(accept_format, list):
|
||||
pass
|
||||
else:
|
||||
# 如果accept_format不是字符串或列表,尝试转换为列表
|
||||
accept_format = list(accept_format) if hasattr(accept_format, "__iter__") else []
|
||||
|
||||
# 检查所有请求的类型是否都被支持
|
||||
for requested_type in types:
|
||||
if requested_type not in accept_format:
|
||||
logger.debug(f"消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
|
||||
return False
|
||||
return True
|
||||
|
||||
# 方法2: 检查content_format字段(向后兼容)
|
||||
elif "content_format" in format_info:
|
||||
content_format = format_info["content_format"]
|
||||
# 确保content_format是列表类型
|
||||
if isinstance(content_format, str):
|
||||
content_format = [content_format]
|
||||
elif isinstance(content_format, list):
|
||||
pass
|
||||
else:
|
||||
content_format = list(content_format) if hasattr(content_format, "__iter__") else []
|
||||
|
||||
# 检查所有请求的类型是否都被支持
|
||||
for requested_type in types:
|
||||
if requested_type not in content_format:
|
||||
logger.debug(f"消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
|
||||
return False
|
||||
return True
|
||||
|
||||
except (orjson.JSONDecodeError, AttributeError, TypeError) as e:
|
||||
logger.debug(f"解析消息格式信息失败: {e}")
|
||||
|
||||
# 备用方案:如果无法从additional_config获取格式信息,使用默认支持的类型
|
||||
# 大多数消息至少支持text类型
|
||||
default_supported_types = ["text", "emoji"]
|
||||
for requested_type in types:
|
||||
if requested_type not in default_supported_types:
|
||||
logger.debug(f"使用默认类型检查,消息可能不支持类型 '{requested_type}'")
|
||||
# 对于非基础类型,返回False以避免错误
|
||||
if requested_type not in ["text", "emoji", "reply"]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_priority_mode(self) -> Optional[str]:
|
||||
"""获取优先级模式"""
|
||||
|
||||
@@ -62,6 +62,8 @@ class ChatStreams(Base):
|
||||
reply_count = Column(Integer, nullable=True, default=0)
|
||||
last_interaction_time = Column(Float, nullable=True, default=None)
|
||||
consecutive_no_reply = Column(Integer, nullable=True, default=0)
|
||||
# 消息打断系统字段
|
||||
interruption_count = Column(Integer, nullable=True, default=0)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_chatstreams_stream_id", "stream_id"),
|
||||
@@ -171,11 +173,18 @@ class Messages(Base):
|
||||
is_command = Column(Boolean, nullable=False, default=False)
|
||||
is_notify = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
# 兴趣度系统字段
|
||||
interest_degree = Column(Float, nullable=True, default=0.0)
|
||||
actions = Column(Text, nullable=True) # JSON格式存储动作列表
|
||||
should_reply = Column(Boolean, nullable=True, default=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_messages_message_id", "message_id"),
|
||||
Index("idx_messages_chat_id", "chat_id"),
|
||||
Index("idx_messages_time", "time"),
|
||||
Index("idx_messages_user_id", "user_id"),
|
||||
Index("idx_messages_interest_degree", "interest_degree"),
|
||||
Index("idx_messages_should_reply", "should_reply"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user