refactor(chat): 优化消息管理与打断系统,添加打断计数与历史消息加载功能

This commit is contained in:
Windpicker-owo
2025-09-26 19:17:24 +08:00
parent 7718a9b956
commit 0478be7d2a
4 changed files with 330 additions and 80 deletions

View File

@@ -60,7 +60,9 @@ class StreamContext(BaseDataModel):
# 自动检测和更新chat type
self._detect_chat_type(message)
def update_message_info(self, message_id: str, interest_degree: float = None, actions: list = None, should_reply: bool = None):
def update_message_info(
self, message_id: str, interest_degree: float = None, actions: list = None, should_reply: bool = None
):
"""
更新消息信息
@@ -166,11 +168,15 @@ class StreamContext(BaseDataModel):
# 计算打断比例
interruption_ratio = self.interruption_count / max_limit
# 如果已达到或超过最大次数,完全禁止打断
if self.interruption_count >= max_limit:
return 0.0
# 如果超过概率因子,概率下降
if interruption_ratio > probability_factor:
# 使用指数衰减,超过限制越多,概率越低
excess_ratio = interruption_ratio - probability_factor
probability = 1.0 * (0.5**excess_ratio) # 基础概率0.5,指数衰减
probability = 0.8 * (0.5**excess_ratio) # 基础概率0.8,指数衰减
else:
# 在限制内,保持较高概率
probability = 0.8
@@ -182,12 +188,18 @@ class StreamContext(BaseDataModel):
self.interruption_count += 1
self.last_interruption_time = time.time()
# 同步打断计数到ChatStream
self._sync_interruption_count_to_stream()
def reset_interruption_count(self):
"""重置打断计数和afc阈值调整"""
self.interruption_count = 0
self.last_interruption_time = 0.0
self.afc_threshold_adjustment = 0.0
# 同步打断计数到ChatStream
self._sync_interruption_count_to_stream()
def apply_interruption_afc_reduction(self, reduction_value: float):
"""应用打断导致的afc阈值降低"""
self.afc_threshold_adjustment += reduction_value
@@ -197,18 +209,40 @@ class StreamContext(BaseDataModel):
"""获取当前的afc阈值调整量"""
return self.afc_threshold_adjustment
def _sync_interruption_count_to_stream(self):
"""同步打断计数到ChatStream"""
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
if chat_manager:
chat_stream = chat_manager.get_stream(self.stream_id)
if chat_stream and hasattr(chat_stream, "interruption_count"):
# 在这里我们只是标记需要保存实际的保存会在下次save时进行
chat_stream.saved = False
logger.debug(
f"已同步StreamContext {self.stream_id} 的打断计数 {self.interruption_count} 到ChatStream"
)
except Exception as e:
logger.warning(f"同步打断计数到ChatStream失败: {e}")
def set_current_message(self, message: "DatabaseMessages"):
"""设置当前消息"""
self.current_message = message
def get_template_name(self) -> Optional[str]:
"""获取模板名称"""
if self.current_message and hasattr(self.current_message, 'additional_config') and self.current_message.additional_config:
if (
self.current_message
and hasattr(self.current_message, "additional_config")
and self.current_message.additional_config
):
try:
import json
config = json.loads(self.current_message.additional_config)
if config.get('template_info') and not config.get('template_default', True):
return config.get('template_name')
if config.get("template_info") and not config.get("template_default", True):
return config.get("template_name")
except (json.JSONDecodeError, AttributeError):
pass
return None
@@ -224,25 +258,83 @@ class StreamContext(BaseDataModel):
return None
def check_types(self, types: list) -> bool:
"""检查消息类型"""
"""
检查当前消息是否支持指定的类型
Args:
types: 需要检查的消息类型列表,如 ["text", "image", "emoji"]
Returns:
bool: 如果消息支持所有指定的类型则返回True否则返回False
"""
if not self.current_message:
return False
# 检查消息是否支持指定的类型
# 这里简化处理,实际应该根据消息的格式信息检查
if hasattr(self.current_message, 'additional_config') and self.current_message.additional_config:
if not types:
# 如果没有指定类型要求,默认为支持
return True
# 优先从additional_config中获取format_info
if hasattr(self.current_message, "additional_config") and self.current_message.additional_config:
try:
import json
config = json.loads(self.current_message.additional_config)
if 'format_info' in config and 'accept_format' in config['format_info']:
accept_format = config['format_info']['accept_format']
for t in types:
if t not in accept_format:
return False
return True
except (json.JSONDecodeError, AttributeError):
pass
return False
import orjson
config = orjson.loads(self.current_message.additional_config)
# 检查format_info结构
if "format_info" in config:
format_info = config["format_info"]
# 方法1: 直接检查accept_format字段
if "accept_format" in format_info:
accept_format = format_info["accept_format"]
# 确保accept_format是列表类型
if isinstance(accept_format, str):
accept_format = [accept_format]
elif isinstance(accept_format, list):
pass
else:
# 如果accept_format不是字符串或列表尝试转换为列表
accept_format = list(accept_format) if hasattr(accept_format, "__iter__") else []
# 检查所有请求的类型是否都被支持
for requested_type in types:
if requested_type not in accept_format:
logger.debug(f"消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
return False
return True
# 方法2: 检查content_format字段向后兼容
elif "content_format" in format_info:
content_format = format_info["content_format"]
# 确保content_format是列表类型
if isinstance(content_format, str):
content_format = [content_format]
elif isinstance(content_format, list):
pass
else:
content_format = list(content_format) if hasattr(content_format, "__iter__") else []
# 检查所有请求的类型是否都被支持
for requested_type in types:
if requested_type not in content_format:
logger.debug(f"消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
return False
return True
except (orjson.JSONDecodeError, AttributeError, TypeError) as e:
logger.debug(f"解析消息格式信息失败: {e}")
# 备用方案如果无法从additional_config获取格式信息使用默认支持的类型
# 大多数消息至少支持text类型
default_supported_types = ["text", "emoji"]
for requested_type in types:
if requested_type not in default_supported_types:
logger.debug(f"使用默认类型检查,消息可能不支持类型 '{requested_type}'")
# 对于非基础类型返回False以避免错误
if requested_type not in ["text", "emoji", "reply"]:
return False
return True
def get_priority_mode(self) -> Optional[str]:
"""获取优先级模式"""