refactor(chat): 优化消息管理与打断系统,添加打断计数与历史消息加载功能
This commit is contained in:
@@ -60,7 +60,9 @@ class StreamContext(BaseDataModel):
|
||||
# 自动检测和更新chat type
|
||||
self._detect_chat_type(message)
|
||||
|
||||
def update_message_info(self, message_id: str, interest_degree: float = None, actions: list = None, should_reply: bool = None):
|
||||
def update_message_info(
|
||||
self, message_id: str, interest_degree: float = None, actions: list = None, should_reply: bool = None
|
||||
):
|
||||
"""
|
||||
更新消息信息
|
||||
|
||||
@@ -166,11 +168,15 @@ class StreamContext(BaseDataModel):
|
||||
# 计算打断比例
|
||||
interruption_ratio = self.interruption_count / max_limit
|
||||
|
||||
# 如果已达到或超过最大次数,完全禁止打断
|
||||
if self.interruption_count >= max_limit:
|
||||
return 0.0
|
||||
|
||||
# 如果超过概率因子,概率下降
|
||||
if interruption_ratio > probability_factor:
|
||||
# 使用指数衰减,超过限制越多,概率越低
|
||||
excess_ratio = interruption_ratio - probability_factor
|
||||
probability = 1.0 * (0.5**excess_ratio) # 基础概率0.5,指数衰减
|
||||
probability = 0.8 * (0.5**excess_ratio) # 基础概率0.8,指数衰减
|
||||
else:
|
||||
# 在限制内,保持较高概率
|
||||
probability = 0.8
|
||||
@@ -182,12 +188,18 @@ class StreamContext(BaseDataModel):
|
||||
self.interruption_count += 1
|
||||
self.last_interruption_time = time.time()
|
||||
|
||||
# 同步打断计数到ChatStream
|
||||
self._sync_interruption_count_to_stream()
|
||||
|
||||
def reset_interruption_count(self):
|
||||
"""重置打断计数和afc阈值调整"""
|
||||
self.interruption_count = 0
|
||||
self.last_interruption_time = 0.0
|
||||
self.afc_threshold_adjustment = 0.0
|
||||
|
||||
# 同步打断计数到ChatStream
|
||||
self._sync_interruption_count_to_stream()
|
||||
|
||||
def apply_interruption_afc_reduction(self, reduction_value: float):
|
||||
"""应用打断导致的afc阈值降低"""
|
||||
self.afc_threshold_adjustment += reduction_value
|
||||
@@ -197,18 +209,40 @@ class StreamContext(BaseDataModel):
|
||||
"""获取当前的afc阈值调整量"""
|
||||
return self.afc_threshold_adjustment
|
||||
|
||||
def _sync_interruption_count_to_stream(self):
|
||||
"""同步打断计数到ChatStream"""
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
if chat_manager:
|
||||
chat_stream = chat_manager.get_stream(self.stream_id)
|
||||
if chat_stream and hasattr(chat_stream, "interruption_count"):
|
||||
# 在这里我们只是标记需要保存,实际的保存会在下次save时进行
|
||||
chat_stream.saved = False
|
||||
logger.debug(
|
||||
f"已同步StreamContext {self.stream_id} 的打断计数 {self.interruption_count} 到ChatStream"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"同步打断计数到ChatStream失败: {e}")
|
||||
|
||||
def set_current_message(self, message: "DatabaseMessages"):
|
||||
"""设置当前消息"""
|
||||
self.current_message = message
|
||||
|
||||
def get_template_name(self) -> Optional[str]:
|
||||
"""获取模板名称"""
|
||||
if self.current_message and hasattr(self.current_message, 'additional_config') and self.current_message.additional_config:
|
||||
if (
|
||||
self.current_message
|
||||
and hasattr(self.current_message, "additional_config")
|
||||
and self.current_message.additional_config
|
||||
):
|
||||
try:
|
||||
import json
|
||||
|
||||
config = json.loads(self.current_message.additional_config)
|
||||
if config.get('template_info') and not config.get('template_default', True):
|
||||
return config.get('template_name')
|
||||
if config.get("template_info") and not config.get("template_default", True):
|
||||
return config.get("template_name")
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass
|
||||
return None
|
||||
@@ -224,25 +258,83 @@ class StreamContext(BaseDataModel):
|
||||
return None
|
||||
|
||||
def check_types(self, types: list) -> bool:
|
||||
"""检查消息类型"""
|
||||
"""
|
||||
检查当前消息是否支持指定的类型
|
||||
|
||||
Args:
|
||||
types: 需要检查的消息类型列表,如 ["text", "image", "emoji"]
|
||||
|
||||
Returns:
|
||||
bool: 如果消息支持所有指定的类型则返回True,否则返回False
|
||||
"""
|
||||
if not self.current_message:
|
||||
return False
|
||||
|
||||
# 检查消息是否支持指定的类型
|
||||
# 这里简化处理,实际应该根据消息的格式信息检查
|
||||
if hasattr(self.current_message, 'additional_config') and self.current_message.additional_config:
|
||||
if not types:
|
||||
# 如果没有指定类型要求,默认为支持
|
||||
return True
|
||||
|
||||
# 优先从additional_config中获取format_info
|
||||
if hasattr(self.current_message, "additional_config") and self.current_message.additional_config:
|
||||
try:
|
||||
import json
|
||||
config = json.loads(self.current_message.additional_config)
|
||||
if 'format_info' in config and 'accept_format' in config['format_info']:
|
||||
accept_format = config['format_info']['accept_format']
|
||||
for t in types:
|
||||
if t not in accept_format:
|
||||
return False
|
||||
return True
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass
|
||||
return False
|
||||
import orjson
|
||||
|
||||
config = orjson.loads(self.current_message.additional_config)
|
||||
|
||||
# 检查format_info结构
|
||||
if "format_info" in config:
|
||||
format_info = config["format_info"]
|
||||
|
||||
# 方法1: 直接检查accept_format字段
|
||||
if "accept_format" in format_info:
|
||||
accept_format = format_info["accept_format"]
|
||||
# 确保accept_format是列表类型
|
||||
if isinstance(accept_format, str):
|
||||
accept_format = [accept_format]
|
||||
elif isinstance(accept_format, list):
|
||||
pass
|
||||
else:
|
||||
# 如果accept_format不是字符串或列表,尝试转换为列表
|
||||
accept_format = list(accept_format) if hasattr(accept_format, "__iter__") else []
|
||||
|
||||
# 检查所有请求的类型是否都被支持
|
||||
for requested_type in types:
|
||||
if requested_type not in accept_format:
|
||||
logger.debug(f"消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
|
||||
return False
|
||||
return True
|
||||
|
||||
# 方法2: 检查content_format字段(向后兼容)
|
||||
elif "content_format" in format_info:
|
||||
content_format = format_info["content_format"]
|
||||
# 确保content_format是列表类型
|
||||
if isinstance(content_format, str):
|
||||
content_format = [content_format]
|
||||
elif isinstance(content_format, list):
|
||||
pass
|
||||
else:
|
||||
content_format = list(content_format) if hasattr(content_format, "__iter__") else []
|
||||
|
||||
# 检查所有请求的类型是否都被支持
|
||||
for requested_type in types:
|
||||
if requested_type not in content_format:
|
||||
logger.debug(f"消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
|
||||
return False
|
||||
return True
|
||||
|
||||
except (orjson.JSONDecodeError, AttributeError, TypeError) as e:
|
||||
logger.debug(f"解析消息格式信息失败: {e}")
|
||||
|
||||
# 备用方案:如果无法从additional_config获取格式信息,使用默认支持的类型
|
||||
# 大多数消息至少支持text类型
|
||||
default_supported_types = ["text", "emoji"]
|
||||
for requested_type in types:
|
||||
if requested_type not in default_supported_types:
|
||||
logger.debug(f"使用默认类型检查,消息可能不支持类型 '{requested_type}'")
|
||||
# 对于非基础类型,返回False以避免错误
|
||||
if requested_type not in ["text", "emoji", "reply"]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_priority_mode(self) -> Optional[str]:
|
||||
"""获取优先级模式"""
|
||||
|
||||
Reference in New Issue
Block a user