diff --git a/src/chat/energy_system/__init__.py b/src/chat/energy_system/__init__.py new file mode 100644 index 000000000..0addfd070 --- /dev/null +++ b/src/chat/energy_system/__init__.py @@ -0,0 +1,28 @@ +""" +能量系统模块 +提供稳定、高效的聊天流能量计算和管理功能 +""" + +from .energy_manager import ( + EnergyManager, + EnergyLevel, + EnergyComponent, + EnergyCalculator, + InterestEnergyCalculator, + ActivityEnergyCalculator, + RecencyEnergyCalculator, + RelationshipEnergyCalculator, + energy_manager +) + +__all__ = [ + "EnergyManager", + "EnergyLevel", + "EnergyComponent", + "EnergyCalculator", + "InterestEnergyCalculator", + "ActivityEnergyCalculator", + "RecencyEnergyCalculator", + "RelationshipEnergyCalculator", + "energy_manager" +] \ No newline at end of file diff --git a/src/chat/energy_system/energy_manager.py b/src/chat/energy_system/energy_manager.py new file mode 100644 index 000000000..099d31c2c --- /dev/null +++ b/src/chat/energy_system/energy_manager.py @@ -0,0 +1,480 @@ +""" +重构后的 focus_energy 管理系统 +提供稳定、高效的聊天流能量计算和管理功能 +""" + +import time +from typing import Dict, List, Optional, Tuple, Any, Union, TypedDict +from dataclasses import dataclass, field +from enum import Enum +from abc import ABC, abstractmethod + +from src.common.logger import get_logger +from src.config.config import global_config + +logger = get_logger("energy_system") + + +class EnergyLevel(Enum): + """能量等级""" + VERY_LOW = 0.1 # 非常低 + LOW = 0.3 # 低 + NORMAL = 0.5 # 正常 + HIGH = 0.7 # 高 + VERY_HIGH = 0.9 # 非常高 + + +@dataclass +class EnergyComponent: + """能量组件""" + name: str + value: float + weight: float = 1.0 + decay_rate: float = 0.05 # 衰减率 + last_updated: float = field(default_factory=time.time) + + def get_current_value(self) -> float: + """获取当前值(考虑时间衰减)""" + age = time.time() - self.last_updated + decay_factor = max(0.1, 1.0 - (age * self.decay_rate / (24 * 3600))) # 按天衰减 + return self.value * decay_factor + + def update_value(self, new_value: float) -> None: + """更新值""" + self.value = max(0.0, min(1.0, new_value)) + self.last_updated = time.time() + + +class EnergyContext(TypedDict): + """能量计算上下文""" + stream_id: str + messages: List[Any] + user_id: Optional[str] + + +class EnergyResult(TypedDict): + """能量计算结果""" + energy: float + level: EnergyLevel + distribution_interval: float + component_scores: Dict[str, float] + cached: bool + + +class EnergyCalculator(ABC): + """能量计算器抽象基类""" + + @abstractmethod + def calculate(self, context: Dict[str, Any]) -> float: + """计算能量值""" + pass + + @abstractmethod + def get_weight(self) -> float: + """获取权重""" + pass + + +class InterestEnergyCalculator(EnergyCalculator): + """兴趣度能量计算器""" + + def calculate(self, context: Dict[str, Any]) -> float: + """基于消息兴趣度计算能量""" + messages = context.get("messages", []) + if not messages: + return 0.3 + + # 计算平均兴趣度 + total_interest = 0.0 + valid_messages = 0 + + for msg in messages: + interest_value = getattr(msg, "interest_value", None) + if interest_value is not None: + try: + interest_float = float(interest_value) + if 0.0 <= interest_float <= 1.0: + total_interest += interest_float + valid_messages += 1 + except (ValueError, TypeError): + continue + + if valid_messages > 0: + avg_interest = total_interest / valid_messages + logger.debug(f"平均消息兴趣度: {avg_interest:.3f} (基于 {valid_messages} 条消息)") + return avg_interest + else: + return 0.3 + + def get_weight(self) -> float: + return 0.5 + + +class ActivityEnergyCalculator(EnergyCalculator): + """活跃度能量计算器""" + + def __init__(self): + self.action_weights = { + "reply": 0.4, + "react": 0.3, + "mention": 0.2, + "other": 0.1 + } + + def calculate(self, context: Dict[str, Any]) -> float: + """基于活跃度计算能量""" + messages = context.get("messages", []) + if not messages: + return 0.2 + + total_score = 0.0 + max_possible_score = len(messages) * 0.4 # 最高可能分数 + + for msg in messages: + actions = getattr(msg, "actions", []) + if isinstance(actions, list) and actions: + for action in actions: + weight = self.action_weights.get(action, self.action_weights["other"]) + total_score += weight + + if max_possible_score > 0: + activity_score = min(1.0, total_score / max_possible_score) + logger.debug(f"活跃度分数: {activity_score:.3f}") + return activity_score + else: + return 0.2 + + def get_weight(self) -> float: + return 0.3 + + +class RecencyEnergyCalculator(EnergyCalculator): + """最近性能量计算器""" + + def calculate(self, context: Dict[str, Any]) -> float: + """基于最近性计算能量""" + messages = context.get("messages", []) + if not messages: + return 0.1 + + # 获取最新消息时间 + latest_time = 0.0 + for msg in messages: + msg_time = getattr(msg, "time", None) + if msg_time and msg_time > latest_time: + latest_time = msg_time + + if latest_time == 0.0: + return 0.1 + + # 计算时间衰减 + current_time = time.time() + age = current_time - latest_time + + # 时间衰减策略: + # 1小时内:1.0 + # 1-6小时:0.8 + # 6-24小时:0.5 + # 1-7天:0.3 + # 7天以上:0.1 + if age < 3600: # 1小时内 + recency_score = 1.0 + elif age < 6 * 3600: # 6小时内 + recency_score = 0.8 + elif age < 24 * 3600: # 24小时内 + recency_score = 0.5 + elif age < 7 * 24 * 3600: # 7天内 + recency_score = 0.3 + else: + recency_score = 0.1 + + logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age/3600:.1f}小时)") + return recency_score + + def get_weight(self) -> float: + return 0.2 + + +class RelationshipEnergyCalculator(EnergyCalculator): + """关系能量计算器""" + + def calculate(self, context: Dict[str, Any]) -> float: + """基于关系计算能量""" + user_id = context.get("user_id") + if not user_id: + return 0.3 + + try: + # 使用新的兴趣度管理系统获取用户关系分 + from src.chat.interest_system import interest_manager + + # 获取用户交互历史作为关系分的基础 + interaction_calc = interest_manager.calculators.get( + interest_manager.InterestSourceType.USER_INTERACTION + ) + if interaction_calc: + relationship_score = interaction_calc.calculate({"user_id": user_id}) + logger.debug(f"用户关系分数: {relationship_score:.3f}") + return max(0.0, min(1.0, relationship_score)) + else: + # 默认基础分 + return 0.3 + except Exception: + # 默认基础分 + return 0.3 + + def get_weight(self) -> float: + return 0.1 + + +class EnergyManager: + """能量管理器 - 统一管理所有能量计算""" + + def __init__(self) -> None: + self.calculators: List[EnergyCalculator] = [ + InterestEnergyCalculator(), + ActivityEnergyCalculator(), + RecencyEnergyCalculator(), + RelationshipEnergyCalculator(), + ] + + # 能量缓存 + self.energy_cache: Dict[str, Tuple[float, float]] = {} # stream_id -> (energy, timestamp) + self.cache_ttl: int = 60 # 1分钟缓存 + + # AFC阈值配置 + self.thresholds: Dict[str, float] = { + "high_match": 0.8, + "reply": 0.4, + "non_reply": 0.2 + } + + # 统计信息 + self.stats: Dict[str, Union[int, float, str]] = { + "total_calculations": 0, + "cache_hits": 0, + "cache_misses": 0, + "average_calculation_time": 0.0, + "last_threshold_update": time.time(), + } + + # 从配置加载阈值 + self._load_thresholds_from_config() + + logger.info("能量管理器初始化完成") + + def _load_thresholds_from_config(self) -> None: + """从配置加载AFC阈值""" + try: + if hasattr(global_config, "affinity_flow") and global_config.affinity_flow is not None: + self.thresholds["high_match"] = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8) + self.thresholds["reply"] = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4) + self.thresholds["non_reply"] = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2) + + # 确保阈值关系合理 + self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1) + self.thresholds["reply"] = max(self.thresholds["reply"], self.thresholds["non_reply"] + 0.1) + + self.stats["last_threshold_update"] = time.time() + logger.info(f"加载AFC阈值: {self.thresholds}") + except Exception as e: + logger.warning(f"加载AFC阈值失败,使用默认值: {e}") + + def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float: + """计算聊天流的focus_energy""" + start_time = time.time() + + # 更新统计 + self.stats["total_calculations"] += 1 + + # 检查缓存 + if stream_id in self.energy_cache: + cached_energy, cached_time = self.energy_cache[stream_id] + if time.time() - cached_time < self.cache_ttl: + self.stats["cache_hits"] += 1 + logger.debug(f"使用缓存能量: {stream_id} = {cached_energy:.3f}") + return cached_energy + else: + self.stats["cache_misses"] += 1 + + # 构建计算上下文 + context: EnergyContext = { + "stream_id": stream_id, + "messages": messages, + "user_id": user_id, + } + + # 计算各组件能量 + component_scores: Dict[str, float] = {} + total_weight = 0.0 + + for calculator in self.calculators: + try: + score = calculator.calculate(context) + weight = calculator.get_weight() + + component_scores[calculator.__class__.__name__] = score + total_weight += weight + + logger.debug(f"{calculator.__class__.__name__} 能量: {score:.3f} (权重: {weight:.3f})") + + except Exception as e: + logger.warning(f"计算 {calculator.__class__.__name__} 能量失败: {e}") + + # 加权计算总能量 + if total_weight > 0: + total_energy = 0.0 + for calculator in self.calculators: + if calculator.__class__.__name__ in component_scores: + score = component_scores[calculator.__class__.__name__] + weight = calculator.get_weight() + total_energy += score * (weight / total_weight) + else: + total_energy = 0.5 + + # 应用阈值调整和变换 + final_energy = self._apply_threshold_adjustment(total_energy) + + # 缓存结果 + self.energy_cache[stream_id] = (final_energy, time.time()) + + # 清理过期缓存 + self._cleanup_cache() + + # 更新平均计算时间 + calculation_time = time.time() - start_time + total_calculations = self.stats["total_calculations"] + self.stats["average_calculation_time"] = ( + (self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time) + / total_calculations + ) + + logger.info(f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)") + return final_energy + + def _apply_threshold_adjustment(self, energy: float) -> float: + """应用阈值调整和变换""" + # 获取参考阈值 + high_threshold = self.thresholds["high_match"] + reply_threshold = self.thresholds["reply"] + + # 计算与阈值的相对位置 + if energy >= high_threshold: + # 高能量区域:指数增强 + adjusted = 0.7 + (energy - 0.7) ** 0.8 + elif energy >= reply_threshold: + # 中等能量区域:线性保持 + adjusted = energy + else: + # 低能量区域:对数压缩 + adjusted = 0.4 * (energy / 0.4) ** 1.2 + + # 确保在合理范围内 + return max(0.1, min(1.0, adjusted)) + + def get_energy_level(self, energy: float) -> EnergyLevel: + """获取能量等级""" + if energy >= EnergyLevel.VERY_HIGH.value: + return EnergyLevel.VERY_HIGH + elif energy >= EnergyLevel.HIGH.value: + return EnergyLevel.HIGH + elif energy >= EnergyLevel.NORMAL.value: + return EnergyLevel.NORMAL + elif energy >= EnergyLevel.LOW.value: + return EnergyLevel.LOW + else: + return EnergyLevel.VERY_LOW + + def get_distribution_interval(self, energy: float) -> float: + """基于能量等级获取分发周期""" + energy_level = self.get_energy_level(energy) + + # 根据能量等级确定基础分发周期 + if energy_level == EnergyLevel.VERY_HIGH: + base_interval = 1.0 # 1秒 + elif energy_level == EnergyLevel.HIGH: + base_interval = 3.0 # 3秒 + elif energy_level == EnergyLevel.NORMAL: + base_interval = 8.0 # 8秒 + elif energy_level == EnergyLevel.LOW: + base_interval = 15.0 # 15秒 + else: + base_interval = 30.0 # 30秒 + + # 添加随机扰动避免同步 + import random + jitter = random.uniform(0.8, 1.2) + final_interval = base_interval * jitter + + # 确保在配置范围内 + min_interval = getattr(global_config.chat, "dynamic_distribution_min_interval", 1.0) + max_interval = getattr(global_config.chat, "dynamic_distribution_max_interval", 60.0) + + return max(min_interval, min(max_interval, final_interval)) + + def invalidate_cache(self, stream_id: str) -> None: + """失效指定流的缓存""" + if stream_id in self.energy_cache: + del self.energy_cache[stream_id] + logger.debug(f"已清除聊天流 {stream_id} 的能量缓存") + + def _cleanup_cache(self) -> None: + """清理过期缓存""" + current_time = time.time() + expired_keys = [ + stream_id for stream_id, (_, timestamp) in self.energy_cache.items() + if current_time - timestamp > self.cache_ttl + ] + + for key in expired_keys: + del self.energy_cache[key] + + if expired_keys: + logger.debug(f"清理了 {len(expired_keys)} 个过期能量缓存") + + def get_statistics(self) -> Dict[str, Any]: + """获取统计信息""" + return { + "cache_size": len(self.energy_cache), + "calculators": [calc.__class__.__name__ for calc in self.calculators], + "thresholds": self.thresholds, + "performance_stats": self.stats.copy(), + } + + def update_thresholds(self, new_thresholds: Dict[str, float]) -> None: + """更新阈值""" + self.thresholds.update(new_thresholds) + + # 确保阈值关系合理 + self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1) + self.thresholds["reply"] = max(self.thresholds["reply"], self.thresholds["non_reply"] + 0.1) + + self.stats["last_threshold_update"] = time.time() + logger.info(f"更新AFC阈值: {self.thresholds}") + + def add_calculator(self, calculator: EnergyCalculator) -> None: + """添加计算器""" + self.calculators.append(calculator) + logger.info(f"添加能量计算器: {calculator.__class__.__name__}") + + def remove_calculator(self, calculator: EnergyCalculator) -> None: + """移除计算器""" + if calculator in self.calculators: + self.calculators.remove(calculator) + logger.info(f"移除能量计算器: {calculator.__class__.__name__}") + + def clear_cache(self) -> None: + """清空缓存""" + self.energy_cache.clear() + logger.info("清空能量缓存") + + def get_cache_hit_rate(self) -> float: + """获取缓存命中率""" + total_requests = self.stats.get("cache_hits", 0) + self.stats.get("cache_misses", 0) + if total_requests == 0: + return 0.0 + return self.stats["cache_hits"] / total_requests + + +# 全局能量管理器实例 +energy_manager = EnergyManager() \ No newline at end of file diff --git a/src/chat/interest_system/__init__.py b/src/chat/interest_system/__init__.py index e64f25a2f..378f8b683 100644 --- a/src/chat/interest_system/__init__.py +++ b/src/chat/interest_system/__init__.py @@ -1,12 +1,30 @@ """ -机器人兴趣标签系统 -基于人设生成兴趣标签,使用embedding计算匹配度 +兴趣度系统模块 +提供统一、稳定的消息兴趣度计算和管理功能 """ +from .interest_manager import ( + InterestManager, + InterestSourceType, + InterestFactor, + InterestCalculator, + MessageContentInterestCalculator, + TopicInterestCalculator, + UserInteractionInterestCalculator, + interest_manager +) from .bot_interest_manager import BotInterestManager, bot_interest_manager from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult __all__ = [ + "InterestManager", + "InterestSourceType", + "InterestFactor", + "InterestCalculator", + "MessageContentInterestCalculator", + "TopicInterestCalculator", + "UserInteractionInterestCalculator", + "interest_manager", "BotInterestManager", "bot_interest_manager", "BotInterestTag", diff --git a/src/chat/interest_system/interest_manager.py b/src/chat/interest_system/interest_manager.py new file mode 100644 index 000000000..e25c9e96d --- /dev/null +++ b/src/chat/interest_system/interest_manager.py @@ -0,0 +1,430 @@ +""" +重构后的消息兴趣值计算系统 +提供稳定、可靠的消息兴趣度计算和管理功能 +""" + +import time +from typing import Dict, List, Optional, Tuple, Any, Union, TypedDict +from dataclasses import dataclass, field +from enum import Enum +from abc import ABC, abstractmethod + +from src.common.logger import get_logger + +logger = get_logger("interest_system") + + +class InterestSourceType(Enum): + """兴趣度来源类型""" + MESSAGE_CONTENT = "message_content" # 消息内容 + USER_INTERACTION = "user_interaction" # 用户交互 + TOPIC_RELEVANCE = "topic_relevance" # 话题相关性 + RELATIONSHIP_SCORE = "relationship_score" # 关系分数 + HISTORICAL_PATTERN = "historical_pattern" # 历史模式 + + +@dataclass +class InterestFactor: + """兴趣度因子""" + source_type: InterestSourceType + value: float + weight: float = 1.0 + decay_rate: float = 0.1 # 衰减率 + last_updated: float = field(default_factory=time.time) + + def get_current_value(self) -> float: + """获取当前值(考虑时间衰减)""" + age = time.time() - self.last_updated + decay_factor = max(0.1, 1.0 - (age * self.decay_rate / (24 * 3600))) # 按天衰减 + return self.value * decay_factor + + def update_value(self, new_value: float) -> None: + """更新值""" + self.value = max(0.0, min(1.0, new_value)) + self.last_updated = time.time() + + +class InterestCalculator(ABC): + """兴趣度计算器抽象基类""" + + @abstractmethod + def calculate(self, context: Dict[str, Any]) -> float: + """计算兴趣度""" + pass + + @abstractmethod + def get_confidence(self) -> float: + """获取计算置信度""" + pass + + +class MessageData(TypedDict): + """消息数据类型定义""" + message_id: str + processed_plain_text: str + is_emoji: bool + is_picid: bool + is_mentioned: bool + is_command: bool + key_words: str + user_id: str + time: float + + +class InterestContext(TypedDict): + """兴趣度计算上下文""" + stream_id: str + user_id: Optional[str] + message: MessageData + + +class InterestResult(TypedDict): + """兴趣度计算结果""" + value: float + confidence: float + source_scores: Dict[InterestSourceType, float] + cached: bool + + +class MessageContentInterestCalculator(InterestCalculator): + """消息内容兴趣度计算器""" + + def calculate(self, context: Dict[str, Any]) -> float: + """基于消息内容计算兴趣度""" + message = context.get("message", {}) + if not message: + return 0.3 # 默认值 + + # 提取消息特征 + text_length = len(message.get("processed_plain_text", "")) + has_emoji = message.get("is_emoji", False) + has_image = message.get("is_picid", False) + is_mentioned = message.get("is_mentioned", False) + is_command = message.get("is_command", False) + + # 基础分数 + base_score = 0.3 + + # 文本长度加权 + if text_length > 0: + text_score = min(0.3, text_length / 200) # 200字符为满分 + base_score += text_score * 0.3 + + # 多媒体内容加权 + if has_emoji: + base_score += 0.1 + if has_image: + base_score += 0.2 + + # 交互特征加权 + if is_mentioned: + base_score += 0.2 + if is_command: + base_score += 0.1 + + return min(1.0, base_score) + + def get_confidence(self) -> float: + return 0.8 + + +class TopicInterestCalculator(InterestCalculator): + """话题兴趣度计算器""" + + def __init__(self): + self.topic_interests: Dict[str, float] = {} + self.topic_decay_rate = 0.05 # 话题兴趣度衰减率 + + def update_topic_interest(self, topic: str, interest_value: float): + """更新话题兴趣度""" + current_interest = self.topic_interests.get(topic, 0.3) + # 平滑更新 + new_interest = current_interest * 0.7 + interest_value * 0.3 + self.topic_interests[topic] = max(0.0, min(1.0, new_interest)) + + logger.debug(f"更新话题 '{topic}' 兴趣度: {current_interest:.3f} -> {new_interest:.3f}") + + def calculate(self, context: Dict[str, Any]) -> float: + """基于话题相关性计算兴趣度""" + message = context.get("message", {}) + keywords = message.get("key_words", "[]") + + try: + import json + keyword_list = json.loads(keywords) if keywords else [] + except (json.JSONDecodeError, TypeError): + keyword_list = [] + + if not keyword_list: + return 0.4 # 无关键词时的默认值 + + # 计算相关话题的平均兴趣度 + total_interest = 0.0 + relevant_topics = 0 + + for keyword in keyword_list[:5]: # 最多取前5个关键词 + # 查找相关话题 + for topic, interest in self.topic_interests.items(): + if keyword.lower() in topic.lower() or topic.lower() in keyword.lower(): + total_interest += interest + relevant_topics += 1 + break + + if relevant_topics > 0: + return min(1.0, total_interest / relevant_topics) + else: + # 新话题,给予基础兴趣度 + for keyword in keyword_list[:3]: + self.topic_interests[keyword] = 0.5 + return 0.5 + + def get_confidence(self) -> float: + return 0.7 + + +class UserInteractionInterestCalculator(InterestCalculator): + """用户交互兴趣度计算器""" + + def __init__(self): + self.interaction_history: List[Dict] = [] + self.max_history_size = 100 + + def add_interaction(self, user_id: str, interaction_type: str, value: float): + """添加交互记录""" + self.interaction_history.append({ + "user_id": user_id, + "type": interaction_type, + "value": value, + "timestamp": time.time() + }) + + # 保持历史记录大小 + if len(self.interaction_history) > self.max_history_size: + self.interaction_history = self.interaction_history[-self.max_history_size:] + + def calculate(self, context: Dict[str, Any]) -> float: + """基于用户交互历史计算兴趣度""" + user_id = context.get("user_id") + if not user_id: + return 0.3 + + # 获取该用户的最近交互记录 + user_interactions = [ + interaction for interaction in self.interaction_history + if interaction["user_id"] == user_id + ] + + if not user_interactions: + return 0.3 + + # 计算加权平均(最近的交互权重更高) + total_weight = 0.0 + weighted_sum = 0.0 + + for interaction in user_interactions[-20:]: # 最近20次交互 + age = time.time() - interaction["timestamp"] + weight = max(0.1, 1.0 - age / (7 * 24 * 3600)) # 7天内衰减 + + weighted_sum += interaction["value"] * weight + total_weight += weight + + if total_weight > 0: + return min(1.0, weighted_sum / total_weight) + else: + return 0.3 + + def get_confidence(self) -> float: + return 0.6 + + +class InterestManager: + """兴趣度管理器 - 统一管理所有兴趣度计算""" + + def __init__(self) -> None: + self.calculators: Dict[InterestSourceType, InterestCalculator] = { + InterestSourceType.MESSAGE_CONTENT: MessageContentInterestCalculator(), + InterestSourceType.TOPIC_RELEVANCE: TopicInterestCalculator(), + InterestSourceType.USER_INTERACTION: UserInteractionInterestCalculator(), + } + + # 权重配置 + self.source_weights: Dict[InterestSourceType, float] = { + InterestSourceType.MESSAGE_CONTENT: 0.4, + InterestSourceType.TOPIC_RELEVANCE: 0.3, + InterestSourceType.USER_INTERACTION: 0.3, + } + + # 兴趣度缓存 + self.interest_cache: Dict[str, Tuple[float, float]] = {} # message_id -> (value, timestamp) + self.cache_ttl: int = 300 # 5分钟缓存 + + # 统计信息 + self.stats: Dict[str, Union[int, float, List[str]]] = { + "total_calculations": 0, + "cache_hits": 0, + "cache_misses": 0, + "average_calculation_time": 0.0, + "calculator_usage": {calc_type.value: 0 for calc_type in InterestSourceType} + } + + logger.info("兴趣度管理器初始化完成") + + def calculate_message_interest(self, message: Dict[str, Any], context: Dict[str, Any]) -> float: + """计算消息兴趣度""" + start_time = time.time() + message_id = message.get("message_id", "") + + # 更新统计 + self.stats["total_calculations"] += 1 + + # 检查缓存 + if message_id in self.interest_cache: + cached_value, cached_time = self.interest_cache[message_id] + if time.time() - cached_time < self.cache_ttl: + self.stats["cache_hits"] += 1 + logger.debug(f"使用缓存兴趣度: {message_id} = {cached_value:.3f}") + return cached_value + else: + self.stats["cache_misses"] += 1 + + # 构建计算上下文 + calc_context: Dict[str, Any] = { + "message": message, + "user_id": message.get("user_id"), + **context + } + + # 计算各来源的兴趣度 + source_scores: Dict[InterestSourceType, float] = {} + total_confidence = 0.0 + + for source_type, calculator in self.calculators.items(): + try: + score = calculator.calculate(calc_context) + confidence = calculator.get_confidence() + + source_scores[source_type] = score + total_confidence += confidence + + # 更新计算器使用统计 + self.stats["calculator_usage"][source_type.value] += 1 + + logger.debug(f"{source_type.value} 兴趣度: {score:.3f} (置信度: {confidence:.3f})") + + except Exception as e: + logger.warning(f"计算 {source_type.value} 兴趣度失败: {e}") + source_scores[source_type] = 0.3 + + # 加权计算最终兴趣度 + final_interest = 0.0 + total_weight = 0.0 + + for source_type, score in source_scores.items(): + weight = self.source_weights.get(source_type, 0.0) + final_interest += score * weight + total_weight += weight + + if total_weight > 0: + final_interest /= total_weight + + # 确保在合理范围内 + final_interest = max(0.0, min(1.0, final_interest)) + + # 缓存结果 + self.interest_cache[message_id] = (final_interest, time.time()) + + # 清理过期缓存 + self._cleanup_cache() + + # 更新平均计算时间 + calculation_time = time.time() - start_time + total_calculations = self.stats["total_calculations"] + self.stats["average_calculation_time"] = ( + (self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time) + / total_calculations + ) + + logger.info(f"消息 {message_id} 最终兴趣度: {final_interest:.3f} (耗时: {calculation_time:.3f}s)") + return final_interest + + def update_topic_interest(self, message: Dict[str, Any], interest_value: float) -> None: + """更新话题兴趣度""" + topic_calc = self.calculators.get(InterestSourceType.TOPIC_RELEVANCE) + if isinstance(topic_calc, TopicInterestCalculator): + # 提取关键词作为话题 + keywords = message.get("key_words", "[]") + try: + import json + keyword_list: List[str] = json.loads(keywords) if keywords else [] + for keyword in keyword_list[:3]: # 更新前3个关键词 + topic_calc.update_topic_interest(keyword, interest_value) + except (json.JSONDecodeError, TypeError): + pass + + def add_user_interaction(self, user_id: str, interaction_type: str, value: float) -> None: + """添加用户交互记录""" + interaction_calc = self.calculators.get(InterestSourceType.USER_INTERACTION) + if isinstance(interaction_calc, UserInteractionInterestCalculator): + interaction_calc.add_interaction(user_id, interaction_type, value) + + def get_topic_interests(self) -> Dict[str, float]: + """获取所有话题兴趣度""" + topic_calc = self.calculators.get(InterestSourceType.TOPIC_RELEVANCE) + if isinstance(topic_calc, TopicInterestCalculator): + return topic_calc.topic_interests.copy() + return {} + + def _cleanup_cache(self) -> None: + """清理过期缓存""" + current_time = time.time() + expired_keys = [ + message_id for message_id, (_, timestamp) in self.interest_cache.items() + if current_time - timestamp > self.cache_ttl + ] + + for key in expired_keys: + del self.interest_cache[key] + + if expired_keys: + logger.debug(f"清理了 {len(expired_keys)} 个过期兴趣度缓存") + + def get_statistics(self) -> Dict[str, Any]: + """获取统计信息""" + return { + "cache_size": len(self.interest_cache), + "topic_count": len(self.get_topic_interests()), + "calculators": list(self.calculators.keys()), + "performance_stats": self.stats.copy(), + } + + def add_calculator(self, source_type: InterestSourceType, calculator: InterestCalculator) -> None: + """添加自定义计算器""" + self.calculators[source_type] = calculator + logger.info(f"添加计算器: {source_type.value}") + + def remove_calculator(self, source_type: InterestSourceType) -> None: + """移除计算器""" + if source_type in self.calculators: + del self.calculators[source_type] + logger.info(f"移除计算器: {source_type.value}") + + def set_source_weight(self, source_type: InterestSourceType, weight: float) -> None: + """设置来源权重""" + self.source_weights[source_type] = max(0.0, min(1.0, weight)) + logger.info(f"设置 {source_type.value} 权重: {weight}") + + def clear_cache(self) -> None: + """清空缓存""" + self.interest_cache.clear() + logger.info("清空兴趣度缓存") + + def get_cache_hit_rate(self) -> float: + """获取缓存命中率""" + total_requests = self.stats.get("cache_hits", 0) + self.stats.get("cache_misses", 0) + if total_requests == 0: + return 0.0 + return self.stats["cache_hits"] / total_requests + + +# 全局兴趣度管理器实例 +interest_manager = InterestManager() \ No newline at end of file diff --git a/src/chat/message_manager/__init__.py b/src/chat/message_manager/__init__.py index f909f720a..2f623fbd0 100644 --- a/src/chat/message_manager/__init__.py +++ b/src/chat/message_manager/__init__.py @@ -1,14 +1,26 @@ """ -消息管理模块 -管理每个聊天流的上下文信息,包含历史记录和未读消息,定期检查并处理新消息 +消息管理器模块 +提供统一的消息管理、上下文管理和分发调度功能 """ from .message_manager import MessageManager, message_manager -from src.common.data_models.message_manager_data_model import ( - StreamContext, - MessageStatus, - MessageManagerStats, - StreamStats, +from .context_manager import StreamContextManager, context_manager +from .distribution_manager import ( + DistributionManager, + DistributionPriority, + DistributionTask, + StreamDistributionState, + distribution_manager ) -__all__ = ["MessageManager", "message_manager", "StreamContext", "MessageStatus", "MessageManagerStats", "StreamStats"] +__all__ = [ + "MessageManager", + "message_manager", + "StreamContextManager", + "context_manager", + "DistributionManager", + "DistributionPriority", + "DistributionTask", + "StreamDistributionState", + "distribution_manager" +] \ No newline at end of file diff --git a/src/chat/message_manager/context_manager.py b/src/chat/message_manager/context_manager.py new file mode 100644 index 000000000..80e65a33a --- /dev/null +++ b/src/chat/message_manager/context_manager.py @@ -0,0 +1,1072 @@ +""" +重构后的聊天上下文管理器 +提供统一、稳定的聊天上下文管理功能 +""" + +import asyncio +import time +from typing import Dict, List, Optional, Any, Callable, Union, Tuple +from dataclasses import dataclass, field +from enum import Enum +from abc import ABC, abstractmethod + +from src.common.logger import get_logger +from src.config.config import global_config +from src.chat.interest_system import interest_manager +from src.chat.energy_system import energy_manager +from . import distribution_manager + +logger = get_logger("context_manager") + + +class ContextEventType(Enum): + """上下文事件类型""" + MESSAGE_ADDED = "message_added" + MESSAGE_UPDATED = "message_updated" + ENERGY_CHANGED = "energy_changed" + STREAM_ACTIVATED = "stream_activated" + STREAM_DEACTIVATED = "stream_deactivated" + CONTEXT_CLEARED = "context_cleared" + VALIDATION_FAILED = "validation_failed" + CLEANUP_COMPLETED = "cleanup_completed" + INTEGRITY_CHECK = "integrity_check" + + def __str__(self) -> str: + return self.value + + def __repr__(self) -> str: + return f"ContextEventType.{self.name}" + + +@dataclass +class ContextEvent: + """上下文事件""" + event_type: ContextEventType + stream_id: str + data: Dict[str, Any] = field(default_factory=dict) + timestamp: float = field(default_factory=time.time) + event_id: str = field(default_factory=lambda: f"event_{time.time()}_{id(object())}") + priority: int = 0 # 事件优先级,数字越大优先级越高 + source: str = "system" # 事件来源 + + def __str__(self) -> str: + return f"ContextEvent({self.event_type}, {self.stream_id}, ts={self.timestamp:.3f})" + + def __repr__(self) -> str: + return f"ContextEvent(event_type={self.event_type}, stream_id={self.stream_id}, timestamp={self.timestamp}, event_id={self.event_id})" + + def get_age(self) -> float: + """获取事件年龄(秒)""" + return time.time() - self.timestamp + + def is_expired(self, max_age: float = 3600.0) -> bool: + """检查事件是否已过期 + + Args: + max_age: 最大年龄(秒) + + Returns: + bool: 是否已过期 + """ + return self.get_age() > max_age + + +class ContextValidator(ABC): + """上下文验证器抽象基类""" + + @abstractmethod + def validate_context(self, stream_id: str, context: Any) -> Tuple[bool, Optional[str]]: + """验证上下文 + + Args: + stream_id: 流ID + context: 上下文对象 + + Returns: + Tuple[bool, Optional[str]]: (是否有效, 错误信息) + """ + pass + + +class DefaultContextValidator(ContextValidator): + """默认上下文验证器""" + + def validate_context(self, stream_id: str, context: Any) -> Tuple[bool, Optional[str]]: + """验证上下文基本完整性""" + if not hasattr(context, 'stream_id'): + return False, "缺少 stream_id 属性" + if not hasattr(context, 'unread_messages'): + return False, "缺少 unread_messages 属性" + if not hasattr(context, 'history_messages'): + return False, "缺少 history_messages 属性" + return True, None + + +class StreamContextManager: + """流上下文管理器 - 统一管理所有聊天流上下文""" + + def __init__(self, max_context_size: Optional[int] = None, context_ttl: Optional[int] = None): + # 上下文存储 + self.stream_contexts: Dict[str, Any] = {} + self.context_metadata: Dict[str, Dict[str, Any]] = {} + + # 事件监听器 + self.event_listeners: Dict[ContextEventType, List[Callable]] = {} + self.event_history: List[ContextEvent] = [] + self.max_event_history = 1000 + + # 验证器 + self.validators: List[ContextValidator] = [DefaultContextValidator()] + + # 统计信息 + self.stats: Dict[str, Union[int, float, str, Dict]] = { + "total_messages": 0, + "total_streams": 0, + "active_streams": 0, + "inactive_streams": 0, + "last_activity": time.time(), + "creation_time": time.time(), + "validation_stats": { + "total_validations": 0, + "validation_failures": 0, + "last_validation_time": 0.0, + }, + "event_stats": { + "total_events": 0, + "events_by_type": {}, + "last_event_time": 0.0, + }, + } + + # 配置参数 + self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100) + self.context_ttl = context_ttl or getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时 + self.cleanup_interval = getattr(global_config.chat, "context_cleanup_interval", 3600) # 1小时 + self.auto_cleanup = getattr(global_config.chat, "auto_cleanup_contexts", True) + self.enable_validation = getattr(global_config.chat, "enable_context_validation", True) + + # 清理任务 + self.cleanup_task: Optional[Any] = None + self.is_running = False + + logger.info(f"上下文管理器初始化完成 (最大上下文: {self.max_context_size}, TTL: {self.context_ttl}s)") + + def add_stream_context(self, stream_id: str, context: Any, metadata: Optional[Dict[str, Any]] = None) -> bool: + """添加流上下文 + + Args: + stream_id: 流ID + context: 上下文对象 + metadata: 上下文元数据 + + Returns: + bool: 是否成功添加 + """ + if stream_id in self.stream_contexts: + logger.warning(f"流上下文已存在: {stream_id}") + return False + + # 验证上下文 + if self.enable_validation: + is_valid, error_msg = self._validate_context(stream_id, context) + if not is_valid: + logger.error(f"上下文验证失败: {stream_id} - {error_msg}") + self._emit_event(ContextEventType.VALIDATION_FAILED, stream_id, { + "error": error_msg, + "context_type": type(context).__name__ + }) + return False + + # 添加上下文 + self.stream_contexts[stream_id] = context + + # 初始化元数据 + self.context_metadata[stream_id] = { + "created_time": time.time(), + "last_access_time": time.time(), + "access_count": 0, + "validation_errors": 0, + "last_validation_time": 0.0, + "custom_metadata": metadata or {}, + } + + # 更新统计 + self.stats["total_streams"] += 1 + self.stats["active_streams"] += 1 + self.stats["last_activity"] = time.time() + + # 触发事件 + self._emit_event(ContextEventType.STREAM_ACTIVATED, stream_id, { + "context": context, + "context_type": type(context).__name__, + "metadata": metadata + }) + + logger.debug(f"添加流上下文: {stream_id} (类型: {type(context).__name__})") + return True + + def remove_stream_context(self, stream_id: str) -> bool: + """移除流上下文 + + Args: + stream_id: 流ID + + Returns: + bool: 是否成功移除 + """ + if stream_id in self.stream_contexts: + context = self.stream_contexts[stream_id] + metadata = self.context_metadata.get(stream_id, {}) + + del self.stream_contexts[stream_id] + if stream_id in self.context_metadata: + del self.context_metadata[stream_id] + + self.stats["active_streams"] = max(0, self.stats["active_streams"] - 1) + self.stats["inactive_streams"] += 1 + self.stats["last_activity"] = time.time() + + # 触发事件 + self._emit_event(ContextEventType.STREAM_DEACTIVATED, stream_id, { + "context": context, + "context_type": type(context).__name__, + "metadata": metadata, + "uptime": time.time() - metadata.get("created_time", time.time()) + }) + + logger.debug(f"移除流上下文: {stream_id} (类型: {type(context).__name__})") + return True + return False + + def get_stream_context(self, stream_id: str, update_access: bool = True) -> Optional[Any]: + """获取流上下文 + + Args: + stream_id: 流ID + update_access: 是否更新访问统计 + + Returns: + Optional[Any]: 上下文对象 + """ + context = self.stream_contexts.get(stream_id) + if context and update_access: + # 更新访问统计 + if stream_id in self.context_metadata: + metadata = self.context_metadata[stream_id] + metadata["last_access_time"] = time.time() + metadata["access_count"] = metadata.get("access_count", 0) + 1 + return context + + def get_context_metadata(self, stream_id: str) -> Optional[Dict[str, Any]]: + """获取上下文元数据 + + Args: + stream_id: 流ID + + Returns: + Optional[Dict[str, Any]]: 元数据 + """ + return self.context_metadata.get(stream_id) + + def update_context_metadata(self, stream_id: str, updates: Dict[str, Any]) -> bool: + """更新上下文元数据 + + Args: + stream_id: 流ID + updates: 更新的元数据 + + Returns: + bool: 是否成功更新 + """ + if stream_id not in self.context_metadata: + return False + + self.context_metadata[stream_id].update(updates) + return True + + def add_message_to_context(self, stream_id: str, message: Any, skip_energy_update: bool = False) -> bool: + """添加消息到上下文 + + Args: + stream_id: 流ID + message: 消息对象 + skip_energy_update: 是否跳过能量更新 + + Returns: + bool: 是否成功添加 + """ + context = self.get_stream_context(stream_id) + if not context: + logger.warning(f"流上下文不存在: {stream_id}") + return False + + try: + # 添加消息到上下文 + if hasattr(context, 'add_message'): + context.add_message(message) + else: + logger.error(f"上下文对象缺少 add_message 方法: {stream_id}") + return False + + # 计算消息兴趣度 + interest_value = self._calculate_message_interest(message) + if hasattr(message, 'interest_value'): + message.interest_value = interest_value + + # 更新统计 + self.stats["total_messages"] += 1 + self.stats["last_activity"] = time.time() + + # 触发事件 + event_data = { + "message": message, + "interest_value": interest_value, + "message_type": type(message).__name__, + "message_id": getattr(message, "message_id", None), + } + self._emit_event(ContextEventType.MESSAGE_ADDED, stream_id, event_data) + + # 更新能量和分发 + if not skip_energy_update: + self._update_stream_energy(stream_id) + distribution_manager.add_stream_message(stream_id, 1) + + logger.debug(f"添加消息到上下文: {stream_id} (兴趣度: {interest_value:.3f})") + return True + + except Exception as e: + logger.error(f"添加消息到上下文失败 {stream_id}: {e}", exc_info=True) + return False + + def update_message_in_context(self, stream_id: str, message_id: str, updates: Dict[str, Any]) -> bool: + """更新上下文中的消息 + + Args: + stream_id: 流ID + message_id: 消息ID + updates: 更新的属性 + + Returns: + bool: 是否成功更新 + """ + context = self.get_stream_context(stream_id) + if not context: + logger.warning(f"流上下文不存在: {stream_id}") + return False + + try: + # 更新消息信息 + if hasattr(context, 'update_message_info'): + context.update_message_info(message_id, **updates) + else: + logger.error(f"上下文对象缺少 update_message_info 方法: {stream_id}") + return False + + # 触发事件 + self._emit_event(ContextEventType.MESSAGE_UPDATED, stream_id, { + "message_id": message_id, + "updates": updates, + "update_time": time.time(), + }) + + # 如果更新了兴趣度,重新计算能量 + if "interest_value" in updates: + self._update_stream_energy(stream_id) + + logger.debug(f"更新上下文消息: {stream_id}/{message_id}") + return True + + except Exception as e: + logger.error(f"更新上下文消息失败 {stream_id}/{message_id}: {e}", exc_info=True) + return False + + def get_context_messages(self, stream_id: str, limit: Optional[int] = None, include_unread: bool = True) -> List[Any]: + """获取上下文消息 + + Args: + stream_id: 流ID + limit: 消息数量限制 + include_unread: 是否包含未读消息 + + Returns: + List[Any]: 消息列表 + """ + context = self.get_stream_context(stream_id) + if not context: + return [] + + try: + messages = [] + if include_unread and hasattr(context, 'get_unread_messages'): + messages.extend(context.get_unread_messages()) + + if hasattr(context, 'get_history_messages'): + if limit: + messages.extend(context.get_history_messages(limit=limit)) + else: + messages.extend(context.get_history_messages()) + + # 按时间排序 + messages.sort(key=lambda msg: getattr(msg, 'time', 0)) + + # 应用限制 + if limit and len(messages) > limit: + messages = messages[-limit:] + + return messages + + except Exception as e: + logger.error(f"获取上下文消息失败 {stream_id}: {e}", exc_info=True) + return [] + + def get_unread_messages(self, stream_id: str) -> List[Any]: + """获取未读消息 + + Args: + stream_id: 流ID + + Returns: + List[Any]: 未读消息列表 + """ + context = self.get_stream_context(stream_id) + if not context: + return [] + + try: + if hasattr(context, 'get_unread_messages'): + return context.get_unread_messages() + else: + logger.warning(f"上下文对象缺少 get_unread_messages 方法: {stream_id}") + return [] + except Exception as e: + logger.error(f"获取未读消息失败 {stream_id}: {e}", exc_info=True) + return [] + + def mark_messages_as_read(self, stream_id: str, message_ids: List[str]) -> bool: + """标记消息为已读 + + Args: + stream_id: 流ID + message_ids: 消息ID列表 + + Returns: + bool: 是否成功标记 + """ + context = self.get_stream_context(stream_id) + if not context: + logger.warning(f"流上下文不存在: {stream_id}") + return False + + try: + if not hasattr(context, 'mark_message_as_read'): + logger.error(f"上下文对象缺少 mark_message_as_read 方法: {stream_id}") + return False + + marked_count = 0 + for message_id in message_ids: + try: + context.mark_message_as_read(message_id) + marked_count += 1 + except Exception as e: + logger.warning(f"标记消息已读失败 {message_id}: {e}") + + logger.debug(f"标记消息为已读: {stream_id} ({marked_count}/{len(message_ids)}条)") + return marked_count > 0 + + except Exception as e: + logger.error(f"标记消息已读失败 {stream_id}: {e}", exc_info=True) + return False + + def clear_context(self, stream_id: str) -> bool: + """清空上下文 + + Args: + stream_id: 流ID + + Returns: + bool: 是否成功清空 + """ + context = self.get_stream_context(stream_id) + if not context: + logger.warning(f"流上下文不存在: {stream_id}") + return False + + try: + # 清空消息 + if hasattr(context, 'unread_messages'): + context.unread_messages.clear() + if hasattr(context, 'history_messages'): + context.history_messages.clear() + + # 重置状态 + reset_attrs = ['interruption_count', 'afc_threshold_adjustment', 'last_check_time'] + for attr in reset_attrs: + if hasattr(context, attr): + if attr in ['interruption_count', 'afc_threshold_adjustment']: + setattr(context, attr, 0) + else: + setattr(context, attr, time.time()) + + # 触发事件 + self._emit_event(ContextEventType.CONTEXT_CLEARED, stream_id, { + "clear_time": time.time(), + "reset_attributes": reset_attrs, + }) + + # 重新计算能量 + self._update_stream_energy(stream_id) + + logger.info(f"清空上下文: {stream_id}") + return True + + except Exception as e: + logger.error(f"清空上下文失败 {stream_id}: {e}", exc_info=True) + return False + + def _calculate_message_interest(self, message: Any) -> float: + """计算消息兴趣度""" + try: + # 将消息转换为字典格式 + message_dict = self._message_to_dict(message) + + # 使用兴趣度管理器计算 + context = { + "stream_id": getattr(message, 'chat_info_stream_id', ''), + "user_id": getattr(message, 'user_id', ''), + } + + interest_value = interest_manager.calculate_message_interest(message_dict, context) + + # 更新话题兴趣度 + interest_manager.update_topic_interest(message_dict, interest_value) + + return interest_value + + except Exception as e: + logger.error(f"计算消息兴趣度失败: {e}") + return 0.5 + + def _message_to_dict(self, message: Any) -> Dict[str, Any]: + """将消息对象转换为字典""" + try: + return { + "message_id": getattr(message, "message_id", ""), + "processed_plain_text": getattr(message, "processed_plain_text", ""), + "is_emoji": getattr(message, "is_emoji", False), + "is_picid": getattr(message, "is_picid", False), + "is_mentioned": getattr(message, "is_mentioned", False), + "is_command": getattr(message, "is_command", False), + "key_words": getattr(message, "key_words", "[]"), + "user_id": getattr(message, "user_id", ""), + "time": getattr(message, "time", time.time()), + } + except Exception as e: + logger.error(f"转换消息为字典失败: {e}") + return {} + + def _update_stream_energy(self, stream_id: str): + """更新流能量""" + try: + # 获取所有消息 + all_messages = self.get_context_messages(stream_id, self.max_context_size) + unread_messages = self.get_unread_messages(stream_id) + combined_messages = all_messages + unread_messages + + # 获取用户ID + user_id = None + if combined_messages: + last_message = combined_messages[-1] + user_id = getattr(last_message, "user_id", None) + + # 计算能量 + energy = energy_manager.calculate_focus_energy( + stream_id=stream_id, + messages=combined_messages, + user_id=user_id + ) + + # 更新分发管理器 + distribution_manager.update_stream_energy(stream_id, energy) + + # 触发事件 + self._emit_event(ContextEventType.ENERGY_CHANGED, stream_id, { + "energy": energy, + "message_count": len(combined_messages), + }) + + except Exception as e: + logger.error(f"更新流能量失败 {stream_id}: {e}") + + def add_event_listener(self, event_type: ContextEventType, listener: Callable[[ContextEvent], None]) -> bool: + """添加事件监听器 + + Args: + event_type: 事件类型 + listener: 监听器函数 + + Returns: + bool: 是否成功添加 + """ + if not callable(listener): + logger.error(f"监听器必须是可调用对象: {type(listener)}") + return False + + if event_type not in self.event_listeners: + self.event_listeners[event_type] = [] + + if listener not in self.event_listeners[event_type]: + self.event_listeners[event_type].append(listener) + logger.debug(f"添加事件监听器: {event_type} -> {getattr(listener, '__name__', 'anonymous')}") + return True + return False + + def remove_event_listener(self, event_type: ContextEventType, listener: Callable[[ContextEvent], None]) -> bool: + """移除事件监听器 + + Args: + event_type: 事件类型 + listener: 监听器函数 + + Returns: + bool: 是否成功移除 + """ + if event_type in self.event_listeners: + try: + self.event_listeners[event_type].remove(listener) + logger.debug(f"移除事件监听器: {event_type}") + return True + except ValueError: + pass + return False + + def _emit_event(self, event_type: ContextEventType, stream_id: str, data: Optional[Dict] = None, priority: int = 0) -> None: + """触发事件 + + Args: + event_type: 事件类型 + stream_id: 流ID + data: 事件数据 + priority: 事件优先级 + """ + if data is None: + data = {} + + event = ContextEvent(event_type, stream_id, data, priority=priority) + + # 添加到事件历史 + self.event_history.append(event) + if len(self.event_history) > self.max_event_history: + self.event_history = self.event_history[-self.max_event_history:] + + # 更新事件统计 + event_stats = self.stats["event_stats"] + event_stats["total_events"] += 1 + event_stats["last_event_time"] = time.time() + event_type_str = str(event_type) + event_stats["events_by_type"][event_type_str] = event_stats["events_by_type"].get(event_type_str, 0) + 1 + + # 通知监听器 + if event_type in self.event_listeners: + for listener in self.event_listeners[event_type]: + try: + listener(event) + except Exception as e: + logger.error(f"事件监听器执行失败: {e}", exc_info=True) + + def get_stream_statistics(self, stream_id: str) -> Optional[Dict[str, Any]]: + """获取流统计信息 + + Args: + stream_id: 流ID + + Returns: + Optional[Dict[str, Any]]: 统计信息 + """ + context = self.get_stream_context(stream_id, update_access=False) + if not context: + return None + + try: + metadata = self.context_metadata.get(stream_id, {}) + current_time = time.time() + created_time = metadata.get("created_time", current_time) + last_access_time = metadata.get("last_access_time", current_time) + access_count = metadata.get("access_count", 0) + + unread_messages = getattr(context, "unread_messages", []) + history_messages = getattr(context, "history_messages", []) + + return { + "stream_id": stream_id, + "context_type": type(context).__name__, + "total_messages": len(history_messages) + len(unread_messages), + "unread_messages": len(unread_messages), + "history_messages": len(history_messages), + "is_active": getattr(context, "is_active", True), + "last_check_time": getattr(context, "last_check_time", current_time), + "interruption_count": getattr(context, "interruption_count", 0), + "afc_threshold_adjustment": getattr(context, "afc_threshold_adjustment", 0.0), + "created_time": created_time, + "last_access_time": last_access_time, + "access_count": access_count, + "uptime_seconds": current_time - created_time, + "idle_seconds": current_time - last_access_time, + "validation_errors": metadata.get("validation_errors", 0), + } + except Exception as e: + logger.error(f"获取流统计失败 {stream_id}: {e}", exc_info=True) + return None + + def get_manager_statistics(self) -> Dict[str, Any]: + """获取管理器统计信息 + + Returns: + Dict[str, Any]: 管理器统计信息 + """ + current_time = time.time() + uptime = current_time - self.stats.get("creation_time", current_time) + + # 计算验证统计 + validation_stats = self.stats["validation_stats"] + validation_success_rate = ( + (validation_stats.get("total_validations", 0) - validation_stats.get("validation_failures", 0)) / + max(1, validation_stats.get("total_validations", 1)) + ) + + # 计算事件统计 + event_stats = self.stats["event_stats"] + events_by_type = event_stats.get("events_by_type", {}) + + return { + **self.stats, + "uptime_hours": uptime / 3600, + "stream_count": len(self.stream_contexts), + "metadata_count": len(self.context_metadata), + "event_history_size": len(self.event_history), + "validators_count": len(self.validators), + "event_listeners": { + str(event_type): len(listeners) + for event_type, listeners in self.event_listeners.items() + }, + "validation_success_rate": validation_success_rate, + "event_distribution": events_by_type, + "max_event_history": self.max_event_history, + "auto_cleanup_enabled": self.auto_cleanup, + "cleanup_interval": self.cleanup_interval, + } + + def cleanup_inactive_contexts(self, max_inactive_hours: int = 24) -> int: + """清理不活跃的上下文 + + Args: + max_inactive_hours: 最大不活跃小时数 + + Returns: + int: 清理的上下文数量 + """ + current_time = time.time() + max_inactive_seconds = max_inactive_hours * 3600 + + inactive_streams = [] + for stream_id, context in self.stream_contexts.items(): + try: + # 获取最后活动时间 + metadata = self.context_metadata.get(stream_id, {}) + last_activity = metadata.get("last_access_time", metadata.get("created_time", 0)) + context_last_activity = getattr(context, "last_check_time", 0) + actual_last_activity = max(last_activity, context_last_activity) + + # 检查是否不活跃 + unread_count = len(getattr(context, "unread_messages", [])) + history_count = len(getattr(context, "history_messages", [])) + total_messages = unread_count + history_count + + if (current_time - actual_last_activity > max_inactive_seconds and + total_messages == 0): + inactive_streams.append(stream_id) + except Exception as e: + logger.warning(f"检查上下文活跃状态失败 {stream_id}: {e}") + continue + + # 清理不活跃上下文 + cleaned_count = 0 + for stream_id in inactive_streams: + if self.remove_stream_context(stream_id): + cleaned_count += 1 + + if cleaned_count > 0: + logger.info(f"清理了 {cleaned_count} 个不活跃上下文") + + return cleaned_count + + def validate_context_integrity(self, stream_id: str) -> bool: + """验证上下文完整性 + + Args: + stream_id: 流ID + + Returns: + bool: 是否完整 + """ + context = self.get_stream_context(stream_id) + if not context: + return False + + try: + # 检查基本属性 + required_attrs = ["stream_id", "unread_messages", "history_messages"] + for attr in required_attrs: + if not hasattr(context, attr): + logger.warning(f"上下文缺少必要属性: {attr}") + return False + + # 检查消息ID唯一性 + all_messages = getattr(context, "unread_messages", []) + getattr(context, "history_messages", []) + message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")] + if len(message_ids) != len(set(message_ids)): + logger.warning(f"上下文中存在重复消息ID: {stream_id}") + return False + + return True + + except Exception as e: + logger.error(f"验证上下文完整性失败 {stream_id}: {e}") + return False + + def _validate_context(self, stream_id: str, context: Any) -> Tuple[bool, Optional[str]]: + """验证上下文完整性 + + Args: + stream_id: 流ID + context: 上下文对象 + + Returns: + Tuple[bool, Optional[str]]: (是否有效, 错误信息) + """ + validation_stats = self.stats["validation_stats"] + validation_stats["total_validations"] += 1 + validation_stats["last_validation_time"] = time.time() + + for validator in self.validators: + try: + is_valid, error_msg = validator.validate_context(stream_id, context) + if not is_valid: + validation_stats["validation_failures"] += 1 + return False, error_msg + except Exception as e: + validation_stats["validation_failures"] += 1 + return False, f"验证器执行失败: {e}" + return True, None + + async def start_auto_cleanup(self, interval: Optional[float] = None) -> None: + """启动自动清理 + + Args: + interval: 清理间隔(秒) + """ + if not self.auto_cleanup: + logger.info("自动清理已禁用") + return + + if self.is_running: + logger.warning("自动清理已在运行") + return + + self.is_running = True + cleanup_interval = interval or self.cleanup_interval + logger.info(f"启动自动清理(间隔: {cleanup_interval}s)") + + import asyncio + self.cleanup_task = asyncio.create_task(self._cleanup_loop(cleanup_interval)) + + async def stop_auto_cleanup(self) -> None: + """停止自动清理""" + self.is_running = False + if self.cleanup_task and not self.cleanup_task.done(): + self.cleanup_task.cancel() + try: + await self.cleanup_task + except Exception: + pass + logger.info("自动清理已停止") + + async def _cleanup_loop(self, interval: float) -> None: + """清理循环 + + Args: + interval: 清理间隔 + """ + while self.is_running: + try: + await asyncio.sleep(interval) + self.cleanup_inactive_contexts() + self._cleanup_event_history() + self._cleanup_expired_contexts() + logger.debug("自动清理完成") + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"清理循环出错: {e}", exc_info=True) + await asyncio.sleep(interval) + + def _cleanup_event_history(self) -> None: + """清理事件历史""" + max_age = 24 * 3600 # 24小时 + + # 清理过期事件 + self.event_history = [ + event for event in self.event_history + if not event.is_expired(max_age) + ] + + # 保持历史大小限制 + if len(self.event_history) > self.max_event_history: + self.event_history = self.event_history[-self.max_event_history:] + + def _cleanup_expired_contexts(self) -> None: + """清理过期上下文""" + current_time = time.time() + expired_contexts = [] + + for stream_id, metadata in self.context_metadata.items(): + created_time = metadata.get("created_time", current_time) + if current_time - created_time > self.context_ttl: + expired_contexts.append(stream_id) + + for stream_id in expired_contexts: + self.remove_stream_context(stream_id) + + if expired_contexts: + logger.info(f"清理了 {len(expired_contexts)} 个过期上下文") + + def get_event_history(self, limit: int = 100, event_type: Optional[ContextEventType] = None) -> List[ContextEvent]: + """获取事件历史 + + Args: + limit: 返回数量限制 + event_type: 过滤事件类型 + + Returns: + List[ContextEvent]: 事件列表 + """ + events = self.event_history + if event_type: + events = [event for event in events if event.event_type == event_type] + return events[-limit:] + + def get_active_streams(self) -> List[str]: + """获取活跃流列表 + + Returns: + List[str]: 活跃流ID列表 + """ + return list(self.stream_contexts.keys()) + + def get_context_summary(self) -> Dict[str, Any]: + """获取上下文摘要 + + Returns: + Dict[str, Any]: 上下文摘要信息 + """ + current_time = time.time() + uptime = current_time - self.stats.get("creation_time", current_time) + + # 计算平均访问次数 + total_access = sum(meta.get("access_count", 0) for meta in self.context_metadata.values()) + avg_access = total_access / max(1, len(self.context_metadata)) + + # 计算验证成功率 + validation_stats = self.stats["validation_stats"] + total_validations = validation_stats.get("total_validations", 0) + validation_success_rate = ( + (total_validations - validation_stats.get("validation_failures", 0)) / + max(1, total_validations) + ) if total_validations > 0 else 1.0 + + return { + "total_streams": len(self.stream_contexts), + "active_streams": len(self.stream_contexts), + "total_messages": self.stats.get("total_messages", 0), + "uptime_hours": uptime / 3600, + "average_access_count": avg_access, + "validation_success_rate": validation_success_rate, + "event_history_size": len(self.event_history), + "validators_count": len(self.validators), + "auto_cleanup_enabled": self.auto_cleanup, + "cleanup_interval": self.cleanup_interval, + "last_activity": self.stats.get("last_activity", 0), + } + + def force_validation(self, stream_id: str) -> Tuple[bool, Optional[str]]: + """强制验证上下文 + + Args: + stream_id: 流ID + + Returns: + Tuple[bool, Optional[str]]: (是否有效, 错误信息) + """ + context = self.get_stream_context(stream_id) + if not context: + return False, "上下文不存在" + + return self._validate_context(stream_id, context) + + def reset_statistics(self) -> None: + """重置统计信息""" + # 重置基本统计 + self.stats.update({ + "total_messages": 0, + "total_streams": len(self.stream_contexts), + "active_streams": len(self.stream_contexts), + "inactive_streams": 0, + "last_activity": time.time(), + "creation_time": time.time(), + }) + + # 重置验证统计 + self.stats["validation_stats"].update({ + "total_validations": 0, + "validation_failures": 0, + "last_validation_time": 0.0, + }) + + # 重置事件统计 + self.stats["event_stats"].update({ + "total_events": 0, + "events_by_type": {}, + "last_event_time": 0.0, + }) + + logger.info("上下文管理器统计信息已重置") + + def export_context_data(self, stream_id: str) -> Optional[Dict[str, Any]]: + """导出上下文数据 + + Args: + stream_id: 流ID + + Returns: + Optional[Dict[str, Any]]: 导出的数据 + """ + context = self.get_stream_context(stream_id, update_access=False) + if not context: + return None + + try: + return { + "stream_id": stream_id, + "context_type": type(context).__name__, + "metadata": self.context_metadata.get(stream_id, {}), + "statistics": self.get_stream_statistics(stream_id), + "export_time": time.time(), + "unread_message_count": len(getattr(context, "unread_messages", [])), + "history_message_count": len(getattr(context, "history_messages", [])), + } + except Exception as e: + logger.error(f"导出上下文数据失败 {stream_id}: {e}") + return None + + +# 全局上下文管理器实例 +context_manager = StreamContextManager() \ No newline at end of file diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py new file mode 100644 index 000000000..ab3579589 --- /dev/null +++ b/src/chat/message_manager/distribution_manager.py @@ -0,0 +1,1004 @@ +""" +重构后的动态消息分发管理器 +提供高效、智能的消息分发调度功能 +""" + +import asyncio +import time +from typing import Dict, List, Optional, Set, Any, Callable +from dataclasses import dataclass, field +from enum import Enum +from heapq import heappush, heappop +from abc import ABC, abstractmethod + +from src.common.logger import get_logger +from src.config.config import global_config +from src.chat.energy_system import energy_manager + +logger = get_logger("distribution_manager") + + +class DistributionPriority(Enum): + """分发优先级""" + CRITICAL = 0 # 关键(立即处理) + HIGH = 1 # 高优先级 + NORMAL = 2 # 正常优先级 + LOW = 3 # 低优先级 + BACKGROUND = 4 # 后台优先级 + + def __lt__(self, other: 'DistributionPriority') -> bool: + """用于优先级比较""" + return self.value < other.value + + +@dataclass +class DistributionTask: + """分发任务""" + stream_id: str + priority: DistributionPriority + energy: float + message_count: int + created_time: float = field(default_factory=time.time) + retry_count: int = 0 + max_retries: int = 3 + task_id: str = field(default_factory=lambda: f"task_{time.time()}_{id(object())}") + metadata: Dict[str, Any] = field(default_factory=dict) + + def __lt__(self, other: 'DistributionTask') -> bool: + """用于优先队列排序""" + # 首先按优先级排序 + if self.priority.value != other.priority.value: + return self.priority.value < other.priority.value + + # 相同优先级按能量排序(能量高的优先) + if abs(self.energy - other.energy) > 0.01: + return self.energy > other.energy + + # 最后按创建时间排序(先创建的优先) + return self.created_time < other.created_time + + def can_retry(self) -> bool: + """检查是否可以重试""" + return self.retry_count < self.max_retries + + def get_retry_delay(self, base_delay: float = 5.0) -> float: + """获取重试延迟""" + return base_delay * (2 ** min(self.retry_count, 3)) + + +@dataclass +class StreamDistributionState: + """流分发状态""" + stream_id: str + energy: float + last_distribution_time: float + next_distribution_time: float + message_count: int + consecutive_failures: int = 0 + is_active: bool = True + total_distributions: int = 0 + total_failures: int = 0 + average_distribution_time: float = 0.0 + metadata: Dict[str, Any] = field(default_factory=dict) + + def should_distribute(self, current_time: float) -> bool: + """检查是否应该分发""" + return (self.is_active and + current_time >= self.next_distribution_time and + self.message_count > 0) + + def update_distribution_stats(self, distribution_time: float, success: bool) -> None: + """更新分发统计""" + if success: + self.total_distributions += 1 + self.consecutive_failures = 0 + else: + self.total_failures += 1 + self.consecutive_failures += 1 + + # 更新平均分发时间 + total_attempts = self.total_distributions + self.total_failures + if total_attempts > 0: + self.average_distribution_time = ( + (self.average_distribution_time * (total_attempts - 1) + distribution_time) + / total_attempts + ) + + +class DistributionExecutor(ABC): + """分发执行器抽象基类""" + + @abstractmethod + async def execute(self, stream_id: str, context: Dict[str, Any]) -> bool: + """执行分发 + + Args: + stream_id: 流ID + context: 分发上下文 + + Returns: + bool: 是否执行成功 + """ + pass + + @abstractmethod + def get_priority(self, stream_id: str) -> DistributionPriority: + """获取流优先级 + + Args: + stream_id: 流ID + + Returns: + DistributionPriority: 优先级 + """ + pass + + +class DistributionManager: + """分发管理器 - 统一管理消息分发调度""" + + def __init__(self, max_concurrent_tasks: Optional[int] = None, retry_delay: Optional[float] = None): + # 流状态管理 + self.stream_states: Dict[str, StreamDistributionState] = {} + + # 任务队列 + self.task_queue: List[DistributionTask] = [] + self.processing_tasks: Set[str] = set() # 正在处理的stream_id + self.completed_tasks: List[DistributionTask] = [] + self.failed_tasks: List[DistributionTask] = [] + + # 统计信息 + self.stats: Dict[str, Any] = { + "total_distributed": 0, + "total_failed": 0, + "avg_distribution_time": 0.0, + "current_queue_size": 0, + "total_created_tasks": 0, + "total_completed_tasks": 0, + "total_failed_tasks": 0, + "total_retry_attempts": 0, + "peak_queue_size": 0, + "start_time": time.time(), + "last_activity_time": time.time(), + } + + # 配置参数 + self.max_concurrent_tasks = ( + max_concurrent_tasks or + getattr(global_config.chat, "max_concurrent_distributions", 3) + ) + self.retry_delay = ( + retry_delay or + getattr(global_config.chat, "distribution_retry_delay", 5.0) + ) + self.max_queue_size = getattr(global_config.chat, "max_distribution_queue_size", 1000) + self.max_history_size = getattr(global_config.chat, "max_task_history_size", 100) + + # 分发执行器 + self.executor: Optional[DistributionExecutor] = None + self.executor_callbacks: Dict[str, Callable] = {} + + # 事件循环 + self.is_running = False + self.distribution_task: Optional[asyncio.Task] = None + self.cleanup_task: Optional[asyncio.Task] = None + + # 性能监控 + self.performance_metrics: Dict[str, List[float]] = { + "distribution_times": [], + "queue_sizes": [], + "processing_counts": [], + } + self.max_metrics_size = 1000 + + logger.info(f"分发管理器初始化完成 (并发: {self.max_concurrent_tasks}, 重试延迟: {self.retry_delay}s)") + + async def start(self, cleanup_interval: float = 3600.0) -> None: + """启动分发管理器 + + Args: + cleanup_interval: 清理间隔(秒) + """ + if self.is_running: + logger.warning("分发管理器已经在运行") + return + + self.is_running = True + self.distribution_task = asyncio.create_task(self._distribution_loop()) + self.cleanup_task = asyncio.create_task(self._cleanup_loop(cleanup_interval)) + + logger.info("分发管理器已启动") + + async def stop(self) -> None: + """停止分发管理器""" + if not self.is_running: + return + + self.is_running = False + + # 取消分发任务 + if self.distribution_task and not self.distribution_task.done(): + self.distribution_task.cancel() + try: + await self.distribution_task + except asyncio.CancelledError: + pass + + # 取消清理任务 + if self.cleanup_task and not self.cleanup_task.done(): + self.cleanup_task.cancel() + try: + await self.cleanup_task + except asyncio.CancelledError: + pass + + # 取消所有处理中的任务 + for stream_id in list(self.processing_tasks): + self._cancel_stream_processing(stream_id) + + logger.info("分发管理器已停止") + + def add_stream_message(self, stream_id: str, message_count: int = 1, + priority: Optional[DistributionPriority] = None) -> bool: + """添加流消息 + + Args: + stream_id: 流ID + message_count: 消息数量 + priority: 指定优先级(可选) + + Returns: + bool: 是否成功添加 + """ + current_time = time.time() + self.stats["last_activity_time"] = current_time + + # 检查队列大小限制 + if len(self.task_queue) >= self.max_queue_size: + logger.warning(f"分发队列已满,拒绝添加: {stream_id}") + return False + + # 获取或创建流状态 + if stream_id not in self.stream_states: + self.stream_states[stream_id] = StreamDistributionState( + stream_id=stream_id, + energy=0.5, # 默认能量 + last_distribution_time=current_time, + next_distribution_time=current_time, + message_count=0, + ) + + # 更新流状态 + state = self.stream_states[stream_id] + state.message_count += message_count + + # 计算优先级 + if priority is None: + priority = self._calculate_priority(state) + + # 创建分发任务 + task = DistributionTask( + stream_id=stream_id, + priority=priority, + energy=state.energy, + message_count=state.message_count, + ) + + # 添加到任务队列 + heappush(self.task_queue, task) + self.stats["current_queue_size"] = len(self.task_queue) + self.stats["peak_queue_size"] = max(self.stats["peak_queue_size"], len(self.task_queue)) + self.stats["total_created_tasks"] += 1 + + # 记录性能指标 + self._record_performance_metric("queue_sizes", len(self.task_queue)) + + logger.debug(f"添加分发任务: {stream_id} (优先级: {priority.name}, 消息数: {message_count})") + return True + + def update_stream_energy(self, stream_id: str, energy: float) -> None: + """更新流能量 + + Args: + stream_id: 流ID + energy: 新的能量值 + """ + if stream_id in self.stream_states: + self.stream_states[stream_id].energy = max(0.0, min(1.0, energy)) + + # 失效能量管理器缓存 + energy_manager.invalidate_cache(stream_id) + + logger.debug(f"更新流能量: {stream_id} = {energy:.3f}") + + def _calculate_priority(self, state: StreamDistributionState) -> DistributionPriority: + """计算分发优先级 + + Args: + state: 流状态 + + Returns: + DistributionPriority: 优先级 + """ + energy = state.energy + message_count = state.message_count + consecutive_failures = state.consecutive_failures + total_distributions = state.total_distributions + + # 使用执行器获取优先级(如果设置) + if self.executor: + try: + return self.executor.get_priority(state.stream_id) + except Exception as e: + logger.warning(f"获取执行器优先级失败: {e}") + + # 失败次数过多,降低优先级 + if consecutive_failures >= 3: + return DistributionPriority.BACKGROUND + + # 高分发次数降低优先级 + if total_distributions > 50 and message_count < 2: + return DistributionPriority.LOW + + # 基于能量和消息数计算优先级 + if energy >= 0.8 and message_count >= 3: + return DistributionPriority.CRITICAL + elif energy >= 0.6 or message_count >= 5: + return DistributionPriority.HIGH + elif energy >= 0.3 or message_count >= 2: + return DistributionPriority.NORMAL + else: + return DistributionPriority.LOW + + async def _distribution_loop(self): + """分发主循环""" + while self.is_running: + try: + # 处理任务队列 + await self._process_task_queue() + + # 更新统计信息 + self._update_statistics() + + # 记录性能指标 + self._record_performance_metric("processing_counts", len(self.processing_tasks)) + + # 动态调整循环间隔 + queue_size = len(self.task_queue) + processing_count = len(self.processing_tasks) + sleep_time = 0.05 if queue_size > 10 or processing_count > 0 else 0.2 + + # 短暂休眠 + await asyncio.sleep(sleep_time) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"分发循环出错: {e}", exc_info=True) + await asyncio.sleep(1.0) + + async def _process_task_queue(self): + """处理任务队列""" + current_time = time.time() + + # 检查是否有可用的处理槽位 + available_slots = self.max_concurrent_tasks - len(self.processing_tasks) + if available_slots <= 0: + return + + # 处理队列中的任务 + processed_count = 0 + while (self.task_queue and + processed_count < available_slots and + len(self.processing_tasks) < self.max_concurrent_tasks): + + task = heappop(self.task_queue) + self.stats["current_queue_size"] = len(self.task_queue) + + # 检查任务是否仍然有效 + if not self._is_task_valid(task, current_time): + self._handle_invalid_task(task) + continue + + # 开始处理任务 + await self._start_task_processing(task) + processed_count += 1 + + # 记录处理统计 + if processed_count > 0: + logger.debug(f"处理了 {processed_count} 个分发任务") + + def _is_task_valid(self, task: DistributionTask, current_time: float) -> bool: + """检查任务是否有效 + + Args: + task: 分发任务 + current_time: 当前时间 + + Returns: + bool: 任务是否有效 + """ + state = self.stream_states.get(task.stream_id) + if not state or not state.is_active: + return False + + # 检查任务是否已过期 + if current_time - task.created_time > 3600: # 1小时 + return False + + # 检查是否达到了分发时间 + return state.should_distribute(current_time) + + def _handle_invalid_task(self, task: DistributionTask) -> None: + """处理无效任务 + + Args: + task: 无效的任务 + """ + logger.debug(f"任务无效,丢弃: {task.stream_id} (创建时间: {task.created_time})") + # 可以添加到历史记录中用于分析 + if len(self.failed_tasks) < self.max_history_size: + self.failed_tasks.append(task) + + async def _start_task_processing(self, task: DistributionTask) -> None: + """开始处理任务 + + Args: + task: 分发任务 + """ + stream_id = task.stream_id + state = self.stream_states[stream_id] + current_time = time.time() + + # 标记为处理中 + self.processing_tasks.add(stream_id) + state.last_distribution_time = current_time + + # 计算下次分发时间 + interval = energy_manager.get_distribution_interval(state.energy) + state.next_distribution_time = current_time + interval + + # 记录开始处理 + logger.info(f"开始处理分发任务: {stream_id} " + f"(能量: {state.energy:.3f}, " + f"消息数: {state.message_count}, " + f"周期: {interval:.1f}s, " + f"重试次数: {task.retry_count})") + + # 创建处理任务 + asyncio.create_task(self._process_distribution_task(task)) + + async def _process_distribution_task(self, task: DistributionTask) -> None: + """处理分发任务 + + Args: + task: 分发任务 + """ + stream_id = task.stream_id + start_time = time.time() + + try: + # 调用外部处理函数 + success = await self._execute_distribution(stream_id) + + if success: + # 处理成功 + self._handle_task_success(task, start_time) + else: + # 处理失败 + await self._handle_task_failure(task) + + except Exception as e: + logger.error(f"处理分发任务失败 {stream_id}: {e}", exc_info=True) + await self._handle_task_failure(task) + + finally: + # 清理处理状态 + self.processing_tasks.discard(stream_id) + self.stats["last_activity_time"] = time.time() + + async def _execute_distribution(self, stream_id: str) -> bool: + """执行分发(需要外部实现) + + Args: + stream_id: 流ID + + Returns: + bool: 是否执行成功 + """ + # 使用执行器处理分发 + if self.executor: + try: + state = self.stream_states.get(stream_id) + context = { + "stream_id": stream_id, + "energy": state.energy if state else 0.5, + "message_count": state.message_count if state else 0, + "task_metadata": {}, + } + return await self.executor.execute(stream_id, context) + except Exception as e: + logger.error(f"执行器分发失败 {stream_id}: {e}") + return False + + # 回退到回调函数 + callback = self.executor_callbacks.get(stream_id) + if callback: + try: + result = callback(stream_id) + if asyncio.iscoroutine(result): + return await result + return bool(result) + except Exception as e: + logger.error(f"回调分发失败 {stream_id}: {e}") + return False + + # 默认处理 + logger.debug(f"执行分发: {stream_id}") + return True + + def _handle_task_success(self, task: DistributionTask, start_time: float) -> None: + """处理任务成功 + + Args: + task: 成功的任务 + start_time: 开始时间 + """ + stream_id = task.stream_id + state = self.stream_states.get(stream_id) + distribution_time = time.time() - start_time + + if state: + # 更新流状态 + state.update_distribution_stats(distribution_time, True) + state.message_count = 0 # 清空消息计数 + + # 更新全局统计 + self.stats["total_distributed"] += 1 + self.stats["total_completed_tasks"] += 1 + + # 更新平均分发时间 + if self.stats["total_distributed"] > 0: + self.stats["avg_distribution_time"] = ( + (self.stats["avg_distribution_time"] * (self.stats["total_distributed"] - 1) + distribution_time) + / self.stats["total_distributed"] + ) + + # 记录性能指标 + self._record_performance_metric("distribution_times", distribution_time) + + # 添加到成功任务历史 + if len(self.completed_tasks) < self.max_history_size: + self.completed_tasks.append(task) + + logger.info(f"分发任务成功: {stream_id} (耗时: {distribution_time:.2f}s, 重试: {task.retry_count})") + + async def _handle_task_failure(self, task: DistributionTask) -> None: + """处理任务失败 + + Args: + task: 失败的任务 + """ + stream_id = task.stream_id + state = self.stream_states.get(stream_id) + distribution_time = time.time() - task.created_time + + if state: + # 更新流状态 + state.update_distribution_stats(distribution_time, False) + + # 增加失败计数 + state.consecutive_failures += 1 + + # 计算重试延迟 + retry_delay = task.get_retry_delay(self.retry_delay) + task.retry_count += 1 + self.stats["total_retry_attempts"] += 1 + + # 如果还有重试机会,重新添加到队列 + if task.can_retry(): + # 等待重试延迟 + await asyncio.sleep(retry_delay) + + # 重新计算优先级(失败后降低优先级) + task.priority = DistributionPriority.LOW + + # 重新添加到队列 + heappush(self.task_queue, task) + self.stats["current_queue_size"] = len(self.task_queue) + + logger.warning(f"分发任务失败,准备重试: {stream_id} " + f"(重试次数: {task.retry_count}/{task.max_retries}, " + f"延迟: {retry_delay:.1f}s)") + else: + # 超过重试次数,标记为不活跃 + state.is_active = False + self.stats["total_failed"] += 1 + self.stats["total_failed_tasks"] += 1 + + # 添加到失败任务历史 + if len(self.failed_tasks) < self.max_history_size: + self.failed_tasks.append(task) + + logger.error(f"分发任务最终失败: {stream_id} (重试次数: {task.retry_count})") + + def _cancel_stream_processing(self, stream_id: str) -> None: + """取消流处理 + + Args: + stream_id: 流ID + """ + # 从处理集合中移除 + self.processing_tasks.discard(stream_id) + + # 更新流状态 + if stream_id in self.stream_states: + self.stream_states[stream_id].is_active = False + + logger.info(f"取消流处理: {stream_id}") + + def _update_statistics(self) -> None: + """更新统计信息""" + # 更新当前队列大小 + self.stats["current_queue_size"] = len(self.task_queue) + + # 更新运行时间 + if self.is_running: + self.stats["uptime"] = time.time() - self.stats["start_time"] + + # 更新性能统计 + self.stats["avg_queue_size"] = ( + sum(self.performance_metrics["queue_sizes"]) / + max(1, len(self.performance_metrics["queue_sizes"])) + ) + + self.stats["avg_processing_count"] = ( + sum(self.performance_metrics["processing_counts"]) / + max(1, len(self.performance_metrics["processing_counts"])) + ) + + def _record_performance_metric(self, metric_name: str, value: float) -> None: + """记录性能指标 + + Args: + metric_name: 指标名称 + value: 指标值 + """ + if metric_name in self.performance_metrics: + metrics = self.performance_metrics[metric_name] + metrics.append(value) + # 保持大小限制 + if len(metrics) > self.max_metrics_size: + metrics.pop(0) + + async def _cleanup_loop(self, interval: float) -> None: + """清理循环 + + Args: + interval: 清理间隔 + """ + while self.is_running: + try: + await asyncio.sleep(interval) + self._cleanup_expired_data() + logger.debug(f"清理完成,保留 {len(self.completed_tasks)} 个成功任务,{len(self.failed_tasks)} 个失败任务") + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"清理循环出错: {e}") + + def _cleanup_expired_data(self) -> None: + """清理过期数据""" + current_time = time.time() + max_age = 24 * 3600 # 24小时 + + # 清理过期的成功任务 + self.completed_tasks = [ + task for task in self.completed_tasks + if current_time - task.created_time < max_age + ] + + # 清理过期的失败任务 + self.failed_tasks = [ + task for task in self.failed_tasks + if current_time - task.created_time < max_age + ] + + # 清理性能指标 + for metric_name in self.performance_metrics: + if len(self.performance_metrics[metric_name]) > self.max_metrics_size: + self.performance_metrics[metric_name] = ( + self.performance_metrics[metric_name][-self.max_metrics_size:] + ) + + def get_stream_status(self, stream_id: str) -> Optional[Dict[str, Any]]: + """获取流状态 + + Args: + stream_id: 流ID + + Returns: + Optional[Dict[str, Any]]: 流状态信息 + """ + if stream_id not in self.stream_states: + return None + + state = self.stream_states[stream_id] + current_time = time.time() + time_until_next = max(0, state.next_distribution_time - current_time) + + return { + "stream_id": state.stream_id, + "energy": state.energy, + "message_count": state.message_count, + "last_distribution_time": state.last_distribution_time, + "next_distribution_time": state.next_distribution_time, + "time_until_next_distribution": time_until_next, + "consecutive_failures": state.consecutive_failures, + "total_distributions": state.total_distributions, + "total_failures": state.total_failures, + "average_distribution_time": state.average_distribution_time, + "is_active": state.is_active, + "is_processing": stream_id in self.processing_tasks, + "uptime": current_time - state.last_distribution_time, + } + + def get_queue_status(self) -> Dict[str, Any]: + """获取队列状态 + + Returns: + Dict[str, Any]: 队列状态信息 + """ + current_time = time.time() + uptime = current_time - self.stats["start_time"] if self.is_running else 0 + + # 分析任务优先级分布 + priority_counts = {} + for task in self.task_queue: + priority_name = task.priority.name + priority_counts[priority_name] = priority_counts.get(priority_name, 0) + 1 + + return { + "queue_size": len(self.task_queue), + "processing_count": len(self.processing_tasks), + "max_concurrent": self.max_concurrent_tasks, + "max_queue_size": self.max_queue_size, + "is_running": self.is_running, + "uptime": uptime, + "priority_distribution": priority_counts, + "stats": self.stats.copy(), + "performance_metrics": { + name: { + "count": len(metrics), + "avg": sum(metrics) / max(1, len(metrics)), + "min": min(metrics) if metrics else 0, + "max": max(metrics) if metrics else 0, + } + for name, metrics in self.performance_metrics.items() + }, + } + + def deactivate_stream(self, stream_id: str) -> bool: + """停用流 + + Args: + stream_id: 流ID + + Returns: + bool: 是否成功停用 + """ + if stream_id in self.stream_states: + self.stream_states[stream_id].is_active = False + # 取消正在处理的任务 + if stream_id in self.processing_tasks: + self._cancel_stream_processing(stream_id) + logger.info(f"停用流: {stream_id}") + return True + return False + + def activate_stream(self, stream_id: str) -> bool: + """激活流 + + Args: + stream_id: 流ID + + Returns: + bool: 是否成功激活 + """ + if stream_id in self.stream_states: + self.stream_states[stream_id].is_active = True + self.stream_states[stream_id].consecutive_failures = 0 + self.stream_states[stream_id].next_distribution_time = time.time() + logger.info(f"激活流: {stream_id}") + return True + return False + + def cleanup_inactive_streams(self, max_inactive_hours: int = 24) -> int: + """清理不活跃的流 + + Args: + max_inactive_hours: 最大不活跃小时数 + + Returns: + int: 清理的流数量 + """ + current_time = time.time() + max_inactive_seconds = max_inactive_hours * 3600 + + inactive_streams = [] + for stream_id, state in self.stream_states.items(): + if (not state.is_active and + current_time - state.last_distribution_time > max_inactive_seconds and + state.message_count == 0): + inactive_streams.append(stream_id) + + for stream_id in inactive_streams: + del self.stream_states[stream_id] + # 同时清理处理中的任务 + self.processing_tasks.discard(stream_id) + logger.debug(f"清理不活跃流: {stream_id}") + + if inactive_streams: + logger.info(f"清理了 {len(inactive_streams)} 个不活跃流") + + return len(inactive_streams) + + def set_executor(self, executor: DistributionExecutor) -> None: + """设置分发执行器 + + Args: + executor: 分发执行器实例 + """ + self.executor = executor + logger.info(f"设置分发执行器: {executor.__class__.__name__}") + + def register_callback(self, stream_id: str, callback: Callable) -> None: + """注册分发回调 + + Args: + stream_id: 流ID + callback: 回调函数 + """ + self.executor_callbacks[stream_id] = callback + logger.debug(f"注册分发回调: {stream_id}") + + def unregister_callback(self, stream_id: str) -> bool: + """注销分发回调 + + Args: + stream_id: 流ID + + Returns: + bool: 是否成功注销 + """ + if stream_id in self.executor_callbacks: + del self.executor_callbacks[stream_id] + logger.debug(f"注销分发回调: {stream_id}") + return True + return False + + def get_task_history(self, limit: int = 50) -> Dict[str, List[Dict[str, Any]]]: + """获取任务历史 + + Args: + limit: 返回数量限制 + + Returns: + Dict[str, List[Dict[str, Any]]]: 任务历史 + """ + def task_to_dict(task: DistributionTask) -> Dict[str, Any]: + return { + "task_id": task.task_id, + "stream_id": task.stream_id, + "priority": task.priority.name, + "energy": task.energy, + "message_count": task.message_count, + "created_time": task.created_time, + "retry_count": task.retry_count, + "max_retries": task.max_retries, + "metadata": task.metadata, + } + + return { + "completed_tasks": [task_to_dict(task) for task in self.completed_tasks[-limit:]], + "failed_tasks": [task_to_dict(task) for task in self.failed_tasks[-limit:]], + } + + def get_performance_summary(self) -> Dict[str, Any]: + """获取性能摘要 + + Returns: + Dict[str, Any]: 性能摘要 + """ + current_time = time.time() + uptime = current_time - self.stats["start_time"] + + # 计算成功率 + total_attempts = self.stats["total_completed_tasks"] + self.stats["total_failed_tasks"] + success_rate = ( + self.stats["total_completed_tasks"] / max(1, total_attempts) + ) if total_attempts > 0 else 0.0 + + # 计算吞吐量 + throughput = ( + self.stats["total_completed_tasks"] / max(1, uptime / 3600) + ) # 每小时完成任务数 + + return { + "uptime_hours": uptime / 3600, + "success_rate": success_rate, + "throughput_per_hour": throughput, + "avg_distribution_time": self.stats["avg_distribution_time"], + "total_retry_attempts": self.stats["total_retry_attempts"], + "peak_queue_size": self.stats["peak_queue_size"], + "active_streams": len(self.stream_states), + "processing_tasks": len(self.processing_tasks), + } + + def reset_statistics(self) -> None: + """重置统计信息""" + self.stats.update({ + "total_distributed": 0, + "total_failed": 0, + "avg_distribution_time": 0.0, + "current_queue_size": len(self.task_queue), + "total_created_tasks": 0, + "total_completed_tasks": 0, + "total_failed_tasks": 0, + "total_retry_attempts": 0, + "peak_queue_size": 0, + "start_time": time.time(), + "last_activity_time": time.time(), + }) + + # 清空性能指标 + for metrics in self.performance_metrics.values(): + metrics.clear() + + logger.info("分发管理器统计信息已重置") + + def get_all_stream_states(self) -> Dict[str, Dict[str, Any]]: + """获取所有流状态 + + Returns: + Dict[str, Dict[str, Any]]: 所有流状态 + """ + return { + stream_id: self.get_stream_status(stream_id) + for stream_id in self.stream_states.keys() + } + + def force_process_stream(self, stream_id: str) -> bool: + """强制处理指定流 + + Args: + stream_id: 流ID + + Returns: + bool: 是否成功触发处理 + """ + if stream_id not in self.stream_states: + return False + + state = self.stream_states[stream_id] + if not state.is_active: + return False + + # 创建高优先级任务 + task = DistributionTask( + stream_id=stream_id, + priority=DistributionPriority.CRITICAL, + energy=state.energy, + message_count=state.message_count, + ) + + # 添加到队列 + heappush(self.task_queue, task) + self.stats["current_queue_size"] = len(self.task_queue) + + logger.info(f"强制处理流: {stream_id}") + return True + + +# 全局分发管理器实例 +distribution_manager = DistributionManager() \ No newline at end of file diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index 8df7f4bca..c2b519392 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -18,6 +18,7 @@ from src.plugin_system.base.component_types import ChatMode from .sleep_manager.sleep_manager import SleepManager from .sleep_manager.wakeup_manager import WakeUpManager from src.config.config import global_config +from . import context_manager if TYPE_CHECKING: from src.common.data_models.message_manager_data_model import StreamContext @@ -29,7 +30,6 @@ class MessageManager: """消息管理器""" def __init__(self, check_interval: float = 5.0): - self.stream_contexts: Dict[str, StreamContext] = {} self.check_interval = check_interval # 检查间隔(秒) self.is_running = False self.manager_task: Optional[asyncio.Task] = None @@ -45,6 +45,9 @@ class MessageManager: self.sleep_manager = SleepManager() self.wakeup_manager = WakeUpManager(self.sleep_manager) + # 初始化上下文管理器 + self.context_manager = context_manager.context_manager + async def start(self): """启动消息管理器""" if self.is_running: @@ -54,6 +57,7 @@ class MessageManager: self.is_running = True self.manager_task = asyncio.create_task(self._manager_loop()) await self.wakeup_manager.start() + await self.context_manager.start() logger.info("消息管理器已启动") async def stop(self): @@ -64,48 +68,44 @@ class MessageManager: self.is_running = False # 停止所有流处理任务 - for context in self.stream_contexts.values(): - if context.processing_task and not context.processing_task.done(): - context.processing_task.cancel() - - # 停止管理器任务 + # 注意:context_manager 会自己清理任务 if self.manager_task and not self.manager_task.done(): self.manager_task.cancel() await self.wakeup_manager.stop() + await self.context_manager.stop() logger.info("消息管理器已停止") def add_message(self, stream_id: str, message: DatabaseMessages): """添加消息到指定聊天流""" - # 获取或创建流上下文 - if stream_id not in self.stream_contexts: - self.stream_contexts[stream_id] = StreamContext(stream_id=stream_id) - self.stats.total_streams += 1 + # 使用 context_manager 添加消息 + success = self.context_manager.add_message_to_context(stream_id, message) - context = self.stream_contexts[stream_id] - context.set_chat_mode(ChatMode.FOCUS) - context.add_message(message) - - logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}") + if success: + logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}") + else: + logger.warning(f"添加消息到聊天流 {stream_id} 失败") def update_message_and_refresh_energy( self, stream_id: str, message_id: str, - interest_degree: float = None, + interest_value: float = None, actions: list = None, should_reply: bool = None, ): """更新消息信息""" - if stream_id in self.stream_contexts: - context = self.stream_contexts[stream_id] - context.update_message_info(message_id, interest_degree, actions, should_reply) + # 使用 context_manager 更新消息信息 + context = self.context_manager.get_stream_context(stream_id) + if context: + context.update_message_info(message_id, interest_value, actions, should_reply) def add_action_and_refresh_energy(self, stream_id: str, message_id: str, action: str): """添加动作到消息""" - if stream_id in self.stream_contexts: - context = self.stream_contexts[stream_id] + # 使用 context_manager 添加动作到消息 + context = self.context_manager.get_stream_context(stream_id) + if context: context.add_action_to_message(message_id, action) async def _manager_loop(self): @@ -136,19 +136,23 @@ class MessageManager: active_streams = 0 total_unread = 0 - for stream_id, context in self.stream_contexts.items(): - if not context.is_active: + # 使用 context_manager 获取活跃的流 + active_stream_ids = self.context_manager.get_active_streams() + + for stream_id in active_stream_ids: + context = self.context_manager.get_stream_context(stream_id) + if not context: continue active_streams += 1 # 检查是否有未读消息 - unread_messages = context.get_unread_messages() + unread_messages = self.context_manager.get_unread_messages(stream_id) if unread_messages: total_unread += len(unread_messages) # 如果没有处理任务,创建一个 - if not context.processing_task or context.processing_task.done(): + if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done(): context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id)) # 更新统计 @@ -157,14 +161,13 @@ class MessageManager: async def _process_stream_messages(self, stream_id: str): """处理指定聊天流的消息""" - if stream_id not in self.stream_contexts: + context = self.context_manager.get_stream_context(stream_id) + if not context: return - context = self.stream_contexts[stream_id] - try: # 获取未读消息 - unread_messages = context.get_unread_messages() + unread_messages = self.context_manager.get_unread_messages(stream_id) if not unread_messages: return @@ -205,7 +208,7 @@ class MessageManager: # 处理结果,标记消息为已读 if results.get("success", False): - self._clear_all_unread_messages(context) + self._clear_all_unread_messages(stream_id) logger.debug(f"聊天流 {stream_id} 处理成功,清除了 {len(unread_messages)} 条未读消息") else: logger.warning(f"聊天流 {stream_id} 处理失败: {results.get('error_message', '未知错误')}") @@ -213,7 +216,7 @@ class MessageManager: except Exception as e: logger.error(f"处理聊天流 {stream_id} 时发生异常,将清除所有未读消息: {e}") # 出现异常时也清除未读消息,避免重复处理 - self._clear_all_unread_messages(context) + self._clear_all_unread_messages(stream_id) raise logger.debug(f"聊天流 {stream_id} 消息处理完成") @@ -226,35 +229,36 @@ class MessageManager: def deactivate_stream(self, stream_id: str): """停用聊天流""" - if stream_id in self.stream_contexts: - context = self.stream_contexts[stream_id] + context = self.context_manager.get_stream_context(stream_id) + if context: context.is_active = False # 取消处理任务 - if context.processing_task and not context.processing_task.done(): + if hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done(): context.processing_task.cancel() logger.info(f"停用聊天流: {stream_id}") def activate_stream(self, stream_id: str): """激活聊天流""" - if stream_id in self.stream_contexts: - self.stream_contexts[stream_id].is_active = True + context = self.context_manager.get_stream_context(stream_id) + if context: + context.is_active = True logger.info(f"激活聊天流: {stream_id}") def get_stream_stats(self, stream_id: str) -> Optional[StreamStats]: """获取聊天流统计""" - if stream_id not in self.stream_contexts: + context = self.context_manager.get_stream_context(stream_id) + if not context: return None - context = self.stream_contexts[stream_id] return StreamStats( stream_id=stream_id, is_active=context.is_active, - unread_count=len(context.get_unread_messages()), + unread_count=len(self.context_manager.get_unread_messages(stream_id)), history_count=len(context.history_messages), last_check_time=context.last_check_time, - has_active_task=bool(context.processing_task and not context.processing_task.done()), + has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()), ) def get_manager_stats(self) -> Dict[str, Any]: @@ -270,18 +274,9 @@ class MessageManager: def cleanup_inactive_streams(self, max_inactive_hours: int = 24): """清理不活跃的聊天流""" - current_time = time.time() - max_inactive_seconds = max_inactive_hours * 3600 - - inactive_streams = [] - for stream_id, context in self.stream_contexts.items(): - if current_time - context.last_check_time > max_inactive_seconds and not context.get_unread_messages(): - inactive_streams.append(stream_id) - - for stream_id in inactive_streams: - self.deactivate_stream(stream_id) - del self.stream_contexts[stream_id] - logger.info(f"清理不活跃聊天流: {stream_id}") + # 使用 context_manager 的自动清理功能 + self.context_manager.cleanup_inactive_contexts(max_inactive_hours * 3600) + logger.info("已启动不活跃聊天流清理") async def _check_and_handle_interruption(self, context: StreamContext, stream_id: str): """检查并处理消息打断""" @@ -330,90 +325,29 @@ class MessageManager: logger.debug(f"聊天流 {stream_id} 未触发打断,打断概率: {interruption_probability:.2f}") def _calculate_stream_distribution_interval(self, context: StreamContext) -> float: - """计算单个聊天流的分发周期 - 基于阈值感知的focus_energy""" + """计算单个聊天流的分发周期 - 使用重构后的能量管理器""" if not global_config.chat.dynamic_distribution_enabled: return self.check_interval # 使用固定间隔 - from src.plugin_system.apis.chat_api import get_chat_manager + try: + from src.chat.energy_system import energy_manager + from src.plugin_system.apis.chat_api import get_chat_manager - chat_stream = get_chat_manager().get_stream(context.stream_id) - # 获取该流的focus_energy(新的阈值感知版本) - focus_energy = 0.5 # 默认值 - avg_message_interest = 0.5 # 默认平均兴趣度 + # 获取聊天流和能量 + chat_stream = get_chat_manager().get_stream(context.stream_id) + if chat_stream: + focus_energy = chat_stream.focus_energy + # 使用能量管理器获取分发周期 + interval = energy_manager.get_distribution_interval(focus_energy) + logger.debug(f"流 {context.stream_id} 分发周期: {interval:.2f}s (能量: {focus_energy:.3f})") + return interval + else: + # 默认间隔 + return self.check_interval - if chat_stream: - focus_energy = chat_stream.focus_energy - # 获取平均消息兴趣度用于更精确的计算 - 从StreamContext获取 - history_messages = context.get_history_messages(limit=100) - unread_messages = context.get_unread_messages() - all_messages = history_messages + unread_messages - - if all_messages: - message_interests = [msg.interest_degree for msg in all_messages if hasattr(msg, "interest_degree")] - avg_message_interest = sum(message_interests) / len(message_interests) if message_interests else 0.5 - - # 获取AFC阈值用于参考,添加None值检查 - reply_threshold = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4) - non_reply_threshold = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2) - high_match_threshold = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8) - - # 使用配置参数 - base_interval = global_config.chat.dynamic_distribution_base_interval - min_interval = global_config.chat.dynamic_distribution_min_interval - max_interval = global_config.chat.dynamic_distribution_max_interval - jitter_factor = global_config.chat.dynamic_distribution_jitter_factor - - # 基于阈值感知的智能分发周期计算 - if avg_message_interest >= high_match_threshold: - # 超高兴趣度:极快响应 (1-2秒) - interval_multiplier = 0.3 + (focus_energy - 0.7) * 2.0 - elif avg_message_interest >= reply_threshold: - # 高兴趣度:快速响应 (2-6秒) - gap_from_reply = (avg_message_interest - reply_threshold) / (high_match_threshold - reply_threshold) - interval_multiplier = 0.6 + gap_from_reply * 0.4 - elif avg_message_interest >= non_reply_threshold: - # 中等兴趣度:正常响应 (6-15秒) - gap_from_non_reply = (avg_message_interest - non_reply_threshold) / (reply_threshold - non_reply_threshold) - interval_multiplier = 1.2 + gap_from_non_reply * 1.8 - else: - # 低兴趣度:缓慢响应 (15-30秒) - gap_ratio = max(0, avg_message_interest / non_reply_threshold) - interval_multiplier = 3.0 + (1.0 - gap_ratio) * 3.0 - - # 应用focus_energy微调 - energy_adjustment = 1.0 + (focus_energy - 0.5) * 0.5 - interval = base_interval * interval_multiplier * energy_adjustment - - # 添加随机扰动避免同步 - import random - - jitter = random.uniform(1.0 - jitter_factor, 1.0 + jitter_factor) - final_interval = interval * jitter - - # 限制在合理范围内 - final_interval = max(min_interval, min(max_interval, final_interval)) - - # 根据兴趣度级别调整日志级别 - if avg_message_interest >= high_match_threshold: - log_level = "info" - elif avg_message_interest >= reply_threshold: - log_level = "info" - else: - log_level = "debug" - - log_msg = ( - f"流 {context.stream_id} 分发周期: {final_interval:.2f}s | " - f"focus_energy: {focus_energy:.3f} | " - f"avg_interest: {avg_message_interest:.3f} | " - f"阈值参考: {non_reply_threshold:.2f}/{reply_threshold:.2f}/{high_match_threshold:.2f}" - ) - - if log_level == "info": - logger.info(log_msg) - else: - logger.debug(log_msg) - - return final_interval + except Exception as e: + logger.error(f"计算分发周期失败: {e}") + return self.check_interval def _calculate_next_manager_delay(self) -> float: """计算管理器下次检查的延迟时间""" @@ -421,8 +355,10 @@ class MessageManager: min_delay = float("inf") # 找到最近需要检查的流 - for context in self.stream_contexts.values(): - if not context.is_active: + active_stream_ids = self.context_manager.get_active_streams() + for stream_id in active_stream_ids: + context = self.context_manager.get_stream_context(stream_id) + if not context or not context.is_active: continue time_until_check = context.next_check_time - current_time @@ -444,8 +380,12 @@ class MessageManager: current_time = time.time() processed_streams = 0 - for stream_id, context in self.stream_contexts.items(): - if not context.is_active: + # 使用 context_manager 获取活跃的流 + active_stream_ids = self.context_manager.get_active_streams() + + for stream_id in active_stream_ids: + context = self.context_manager.get_stream_context(stream_id) + if not context or not context.is_active: continue # 检查是否达到检查时间 @@ -463,7 +403,7 @@ class MessageManager: context.next_check_time = current_time + context.distribution_interval # 检查未读消息 - unread_messages = context.get_unread_messages() + unread_messages = self.context_manager.get_unread_messages(stream_id) if unread_messages: processed_streams += 1 self.stats.total_unread_messages = len(unread_messages) @@ -493,7 +433,7 @@ class MessageManager: context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id)) # 更新活跃流计数 - active_count = sum(1 for ctx in self.stream_contexts.values() if ctx.is_active) + active_count = len(self.context_manager.get_active_streams()) self.stats.active_streams = active_count if processed_streams > 0: @@ -501,13 +441,16 @@ class MessageManager: async def _check_all_streams_with_priority(self): """按优先级检查所有聊天流,高focus_energy的流优先处理""" - if not self.stream_contexts: + if not self.context_manager.get_active_streams(): return # 获取活跃的聊天流并按focus_energy排序 active_streams = [] - for stream_id, context in self.stream_contexts.items(): - if not context.is_active: + active_stream_ids = self.context_manager.get_active_streams() + + for stream_id in active_stream_ids: + context = self.context_manager.get_stream_context(stream_id) + if not context or not context.is_active: continue # 获取focus_energy,如果不存在则使用默认值 @@ -533,12 +476,12 @@ class MessageManager: active_stream_count += 1 # 检查是否有未读消息 - unread_messages = context.get_unread_messages() + unread_messages = self.context_manager.get_unread_messages(stream_id) if unread_messages: total_unread += len(unread_messages) # 如果没有处理任务,创建一个 - if not context.processing_task or context.processing_task.done(): + if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done(): context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id)) # 高优先级流的额外日志 @@ -554,63 +497,40 @@ class MessageManager: self.stats.total_unread_messages = total_unread def _calculate_stream_priority(self, context: StreamContext, focus_energy: float) -> float: - """计算聊天流的优先级分数""" - from src.plugin_system.apis.chat_api import get_chat_manager - - chat_stream = get_chat_manager().get_stream(context.stream_id) - # 基础优先级:focus_energy + """计算聊天流的优先级分数 - 简化版本,主要使用focus_energy""" + # 使用重构后的能量管理器,主要依赖focus_energy base_priority = focus_energy - # 未读消息数量加权 + # 简单的未读消息加权 unread_count = len(context.get_unread_messages()) - message_count_bonus = min(unread_count * 0.1, 0.3) # 最多30%加成 + message_bonus = min(unread_count * 0.05, 0.2) # 最多20%加成 - # 时间加权:最近活跃的流优先级更高 + # 简单的时间加权 current_time = time.time() time_since_active = current_time - context.last_check_time - time_penalty = max(0, 1.0 - time_since_active / 3600.0) # 1小时内无惩罚 - - # 连续无回复惩罚 - 从StreamContext历史消息计算 - if chat_stream: - # 计算连续无回复次数 - consecutive_no_reply = 0 - all_messages = context.get_history_messages(limit=50) + context.get_unread_messages() - for msg in reversed(all_messages): - if hasattr(msg, "should_reply") and msg.should_reply: - if not (hasattr(msg, "actions") and "reply" in (msg.actions or [])): - consecutive_no_reply += 1 - else: - break - no_reply_penalty = max(0, 1.0 - consecutive_no_reply * 0.05) # 每次无回复降低5% - else: - no_reply_penalty = 1.0 - - # 综合优先级计算 - final_priority = ( - base_priority * 0.6 # 基础兴趣度权重60% - + message_count_bonus * 0.2 # 消息数量权重20% - + time_penalty * 0.1 # 时间权重10% - + no_reply_penalty * 0.1 # 回复状态权重10% - ) + time_bonus = max(0, 1.0 - time_since_active / 7200.0) * 0.1 # 2小时内衰减 + final_priority = base_priority + message_bonus + time_bonus return max(0.0, min(1.0, final_priority)) - def _clear_all_unread_messages(self, context: StreamContext): + def _clear_all_unread_messages(self, stream_id: str): """清除指定上下文中的所有未读消息,防止意外情况导致消息一直未读""" - unread_messages = context.get_unread_messages() + unread_messages = self.context_manager.get_unread_messages(stream_id) if not unread_messages: return logger.warning(f"正在清除 {len(unread_messages)} 条未读消息") - # 将所有未读消息标记为已读并移动到历史记录 - for msg in unread_messages[:]: # 使用切片复制避免迭代时修改列表 - try: - context.mark_message_as_read(msg.message_id) - self.stats.total_processed_messages += 1 - logger.debug(f"强制清除消息 {msg.message_id},标记为已读") - except Exception as e: - logger.error(f"清除消息 {msg.message_id} 时出错: {e}") + # 将所有未读消息标记为已读 + context = self.context_manager.get_stream_context(stream_id) + if context: + for msg in unread_messages[:]: # 使用切片复制避免迭代时修改列表 + try: + context.mark_message_as_read(msg.message_id) + self.stats.total_processed_messages += 1 + logger.debug(f"强制清除消息 {msg.message_id},标记为已读") + except Exception as e: + logger.error(f"清除消息 {msg.message_id} 时出错: {e}") # 创建全局消息管理器实例 diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 084fc0292..bfd170e6c 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -120,186 +120,209 @@ class ChatStream: """设置聊天消息上下文""" # 将MessageRecv转换为DatabaseMessages并设置到stream_context from src.common.data_models.database_data_model import DatabaseMessages + import json - # 简化转换,实际可能需要更完整的转换逻辑 + # 安全获取message_info中的数据 + message_info = getattr(message, "message_info", {}) + user_info = getattr(message_info, "user_info", {}) + group_info = getattr(message_info, "group_info", {}) + + # 提取reply_to信息(从message_segment中查找reply类型的段) + reply_to = None + if hasattr(message, "message_segment") and message.message_segment: + reply_to = self._extract_reply_from_segment(message.message_segment) + + # 完整的数据转移逻辑 db_message = DatabaseMessages( + # 基础消息信息 message_id=getattr(message, "message_id", ""), time=getattr(message, "time", time.time()), - chat_id=getattr(message, "chat_id", ""), - user_id=str(getattr(message.message_info, "user_info", {}).user_id) - if hasattr(message, "message_info") and hasattr(message.message_info, "user_info") - else "", - user_nickname=getattr(message.message_info, "user_info", {}).user_nickname - if hasattr(message, "message_info") and hasattr(message.message_info, "user_info") - else "", - user_platform=getattr(message.message_info, "user_info", {}).platform - if hasattr(message, "message_info") and hasattr(message.message_info, "user_info") - else "", - priority_mode=getattr(message, "priority_mode", None), - priority_info=str(getattr(message, "priority_info", None)) - if hasattr(message, "priority_info") and message.priority_info + chat_id=self._generate_chat_id(message_info), + reply_to=reply_to, + # 兴趣度相关 + interest_value=getattr(message, "interest_value", 0.0), + # 关键词 + key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False) + if getattr(message, "key_words", None) else None, - additional_config=getattr(getattr(message, "message_info", {}), "additional_config", None), + key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False) + if getattr(message, "key_words_lite", None) + else None, + # 消息状态标记 + is_mentioned=getattr(message, "is_mentioned", None), + is_at=getattr(message, "is_at", False), + is_emoji=getattr(message, "is_emoji", False), + is_picid=getattr(message, "is_picid", False), + is_voice=getattr(message, "is_voice", False), + is_video=getattr(message, "is_video", False), + is_command=getattr(message, "is_command", False), + is_notify=getattr(message, "is_notify", False), + # 消息内容 + processed_plain_text=getattr(message, "processed_plain_text", ""), + display_message=getattr(message, "processed_plain_text", ""), # 默认使用processed_plain_text + # 优先级信息 + priority_mode=getattr(message, "priority_mode", None), + priority_info=json.dumps(getattr(message, "priority_info", None)) + if getattr(message, "priority_info", None) + else None, + # 额外配置 + additional_config=getattr(message_info, "additional_config", None), + # 用户信息 + user_id=str(getattr(user_info, "user_id", "")), + user_nickname=getattr(user_info, "user_nickname", ""), + user_cardname=getattr(user_info, "user_cardname", None), + user_platform=getattr(user_info, "platform", ""), + # 群组信息 + chat_info_group_id=getattr(group_info, "group_id", None), + chat_info_group_name=getattr(group_info, "group_name", None), + chat_info_group_platform=getattr(group_info, "platform", None), + # 聊天流信息 + chat_info_user_id=str(getattr(user_info, "user_id", "")), + chat_info_user_nickname=getattr(user_info, "user_nickname", ""), + chat_info_user_cardname=getattr(user_info, "user_cardname", None), + chat_info_user_platform=getattr(user_info, "platform", ""), + chat_info_stream_id=self.stream_id, + chat_info_platform=self.platform, + chat_info_create_time=self.create_time, + chat_info_last_active_time=self.last_active_time, + # 新增兴趣度系统字段 - 添加安全处理 + actions=self._safe_get_actions(message), + should_reply=getattr(message, "should_reply", False), ) self.stream_context.set_current_message(db_message) self.stream_context.priority_mode = getattr(message, "priority_mode", None) self.stream_context.priority_info = getattr(message, "priority_info", None) + # 调试日志:记录数据转移情况 + logger.debug(f"消息数据转移完成 - message_id: {db_message.message_id}, " + f"chat_id: {db_message.chat_id}, " + f"is_mentioned: {db_message.is_mentioned}, " + f"is_emoji: {db_message.is_emoji}, " + f"is_picid: {db_message.is_picid}, " + f"interest_value: {db_message.interest_value}") + + def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]: + """安全获取消息的actions字段""" + try: + actions = getattr(message, "actions", None) + if actions is None: + return None + + # 如果是字符串,尝试解析为JSON + if isinstance(actions, str): + try: + import json + actions = json.loads(actions) + except json.JSONDecodeError: + logger.warning(f"无法解析actions JSON字符串: {actions}") + return None + + # 确保返回列表类型 + if isinstance(actions, list): + # 过滤掉空值和非字符串元素 + filtered_actions = [action for action in actions if action is not None and isinstance(action, str)] + return filtered_actions if filtered_actions else None + else: + logger.warning(f"actions字段类型不支持: {type(actions)}") + return None + + except Exception as e: + logger.warning(f"获取actions字段失败: {e}") + return None + + def _extract_reply_from_segment(self, segment) -> Optional[str]: + """从消息段中提取reply_to信息""" + try: + if hasattr(segment, "type") and segment.type == "seglist": + # 递归搜索seglist中的reply段 + if hasattr(segment, "data") and segment.data: + for seg in segment.data: + reply_id = self._extract_reply_from_segment(seg) + if reply_id: + return reply_id + elif hasattr(segment, "type") and segment.type == "reply": + # 找到reply段,返回message_id + return str(segment.data) if segment.data else None + except Exception as e: + logger.warning(f"提取reply_to信息失败: {e}") + return None + + def _generate_chat_id(self, message_info) -> str: + """生成chat_id,基于群组或用户信息""" + try: + group_info = getattr(message_info, "group_info", None) + user_info = getattr(message_info, "user_info", None) + + if group_info and hasattr(group_info, "group_id") and group_info.group_id: + # 群聊:使用群组ID + return f"{self.platform}_{group_info.group_id}" + elif user_info and hasattr(user_info, "user_id") and user_info.user_id: + # 私聊:使用用户ID + return f"{self.platform}_{user_info.user_id}_private" + else: + # 默认:使用stream_id + return self.stream_id + except Exception as e: + logger.warning(f"生成chat_id失败: {e}") + return self.stream_id + @property def focus_energy(self) -> float: - """动态计算的聊天流总体兴趣度,访问时自动更新""" - self._focus_energy = self._calculate_dynamic_focus_energy() - return self._focus_energy + """使用重构后的能量管理器计算focus_energy""" + try: + from src.chat.energy_system import energy_manager + + # 获取所有消息 + history_messages = self.stream_context.get_history_messages(limit=global_config.chat.max_context_size) + unread_messages = self.stream_context.get_unread_messages() + all_messages = history_messages + unread_messages + + # 获取用户ID + user_id = None + if self.user_info and hasattr(self.user_info, "user_id"): + user_id = str(self.user_info.user_id) + + # 使用能量管理器计算 + energy = energy_manager.calculate_focus_energy( + stream_id=self.stream_id, + messages=all_messages, + user_id=user_id + ) + + # 更新内部存储 + self._focus_energy = energy + + logger.debug(f"聊天流 {self.stream_id} 能量: {energy:.3f}") + return energy + + except Exception as e: + logger.error(f"获取focus_energy失败: {e}", exc_info=True) + # 返回缓存的值或默认值 + if hasattr(self, '_focus_energy'): + return self._focus_energy + else: + return 0.5 @focus_energy.setter def focus_energy(self, value: float): """设置focus_energy值(主要用于初始化或特殊场景)""" self._focus_energy = max(0.0, min(1.0, value)) - def _calculate_dynamic_focus_energy(self) -> float: - """动态计算聊天流的总体兴趣度,使用StreamContext历史消息""" - try: - # 从StreamContext获取历史消息计算统计数据 - history_messages = self.stream_context.get_history_messages(limit=global_config.chat.max_context_size) - unread_messages = self.stream_context.get_unread_messages() - all_messages = history_messages + unread_messages - - # 计算基于历史消息的统计数据 - if all_messages: - # 基础分:平均消息兴趣度 - message_interests = [msg.interest_degree for msg in all_messages if hasattr(msg, "interest_degree")] - avg_message_interest = sum(message_interests) / len(message_interests) if message_interests else 0.3 - - # 动作参与度:有动作的消息比例 - messages_with_actions = [msg for msg in all_messages if hasattr(msg, "actions") and msg.actions] - action_rate = len(messages_with_actions) / len(all_messages) - - # 回复活跃度:应该回复且已回复的消息比例 - should_reply_messages = [ - msg for msg in all_messages if hasattr(msg, "should_reply") and msg.should_reply - ] - replied_messages = [ - msg for msg in should_reply_messages if hasattr(msg, "actions") and "reply" in (msg.actions or []) - ] - reply_rate = len(replied_messages) / len(should_reply_messages) if should_reply_messages else 0.0 - - # 获取最后交互时间 - if all_messages: - self.last_interaction_time = max(msg.time for msg in all_messages) - - # 连续无回复计算:从最近的未回复消息计数 - consecutive_no_reply = 0 - for msg in reversed(all_messages): - if hasattr(msg, "should_reply") and msg.should_reply: - if not (hasattr(msg, "actions") and "reply" in (msg.actions or [])): - consecutive_no_reply += 1 - else: - break - else: - # 没有历史消息时的默认值 - avg_message_interest = 0.3 - action_rate = 0.0 - reply_rate = 0.0 - consecutive_no_reply = 0 - self.last_interaction_time = time.time() - - # 获取用户关系分(对于私聊,群聊无效) - relationship_factor = self._get_user_relationship_score() - - # 时间衰减因子:最近活跃度 - current_time = time.time() - if not hasattr(self, "last_interaction_time") or not self.last_interaction_time: - self.last_interaction_time = current_time - time_since_interaction = current_time - self.last_interaction_time - time_decay = max(0.3, 1.0 - min(time_since_interaction / (7 * 24 * 3600), 0.7)) # 7天衰减 - - # 连续无回复惩罚 - no_reply_penalty = max(0.1, 1.0 - consecutive_no_reply * 0.1) - - # 获取AFC系统阈值,添加None值检查 - reply_threshold = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4) - non_reply_threshold = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2) - high_match_threshold = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8) - - # 计算与不同阈值的差距比例 - reply_gap_ratio = max(0, (avg_message_interest - reply_threshold) / max(0.1, (1.0 - reply_threshold))) - non_reply_gap_ratio = max( - 0, (avg_message_interest - non_reply_threshold) / max(0.1, (1.0 - non_reply_threshold)) - ) - high_match_gap_ratio = max( - 0, (avg_message_interest - high_match_threshold) / max(0.1, (1.0 - high_match_threshold)) - ) - - # 基于阈值差距比例的基础分计算 - threshold_based_score = ( - reply_gap_ratio * 0.6 # 回复阈值差距权重60% - + non_reply_gap_ratio * 0.2 # 非回复阈值差距权重20% - + high_match_gap_ratio * 0.2 # 高匹配阈值差距权重20% - ) - - # 动态权重调整:根据平均兴趣度水平调整权重分配 - if avg_message_interest >= high_match_threshold: - # 高兴趣度:更注重阈值差距 - threshold_weight = 0.7 - activity_weight = 0.2 - relationship_weight = 0.1 - elif avg_message_interest >= reply_threshold: - # 中等兴趣度:平衡权重 - threshold_weight = 0.5 - activity_weight = 0.3 - relationship_weight = 0.2 - else: - # 低兴趣度:更注重活跃度提升 - threshold_weight = 0.3 - activity_weight = 0.5 - relationship_weight = 0.2 - - # 计算活跃度得分 - activity_score = action_rate * 0.6 + reply_rate * 0.4 - - # 综合计算:基于阈值的动态加权 - focus_energy = ( - ( - threshold_based_score * threshold_weight # 阈值差距基础分 - + activity_score * activity_weight # 活跃度得分 - + relationship_factor * relationship_weight # 关系得分 - + self.base_interest_energy * 0.05 # 基础兴趣微调 - ) - * time_decay - * no_reply_penalty - ) - - # 确保在合理范围内 - focus_energy = max(0.1, min(1.0, focus_energy)) - - # 应用非线性变换增强区分度 - if focus_energy >= 0.7: - # 高兴趣度区域:指数增强,更敏感 - focus_energy = 0.7 + (focus_energy - 0.7) ** 0.8 - elif focus_energy >= 0.4: - # 中等兴趣度区域:线性保持 - pass - else: - # 低兴趣度区域:对数压缩,减少区分度 - focus_energy = 0.4 * (focus_energy / 0.4) ** 1.2 - - return max(0.1, min(1.0, focus_energy)) - - except Exception as e: - logger.error(f"计算动态focus_energy失败: {e}") - return self.base_interest_energy - def _get_user_relationship_score(self) -> float: - """从外部系统获取用户关系分""" + """从新的兴趣度管理系统获取用户关系分""" try: - # 尝试从兴趣评分系统获取用户关系分 - from src.plugins.built_in.affinity_flow_chatter.interest_scoring import ( - chatter_interest_scoring_system, - ) + # 使用新的兴趣度管理系统 + from src.chat.interest_system import interest_manager if self.user_info and hasattr(self.user_info, "user_id"): - return chatter_interest_scoring_system.get_user_relationship(str(self.user_info.user_id)) + user_id = str(self.user_info.user_id) + # 获取用户交互历史作为关系分的基础 + interaction_calc = interest_manager.calculators.get( + interest_manager.InterestSourceType.USER_INTERACTION + ) + if interaction_calc: + return interaction_calc.calculate({"user_id": user_id}) except Exception: pass @@ -378,12 +401,13 @@ class ChatStream: chat_info_platform=db_msg.chat_info_platform, chat_info_create_time=db_msg.chat_info_create_time, chat_info_last_active_time=db_msg.chat_info_last_active_time, - # 新增的兴趣度系统字段 - interest_degree=getattr(db_msg, "interest_degree", 0.0) or 0.0, actions=actions, should_reply=getattr(db_msg, "should_reply", False) or False, ) + # 添加调试日志:检查从数据库加载的interest_value + logger.debug(f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}") + # 标记为已读并添加到历史消息 db_message.is_read = True self.stream_context.history_messages.append(db_message) diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 35c395ed4..b37301f47 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -218,4 +218,93 @@ class MessageStorage: except Exception: return match.group(0) - return re.sub(r"\[图片:([^\]]+)\]", replace_match, text) + @staticmethod + def update_message_interest_value(message_id: str, interest_value: float) -> None: + """ + 更新数据库中消息的interest_value字段 + + Args: + message_id: 消息ID + interest_value: 兴趣度值 + """ + try: + with get_db_session() as session: + # 更新消息的interest_value字段 + stmt = update(Messages).where(Messages.message_id == message_id).values(interest_value=interest_value) + result = session.execute(stmt) + session.commit() + + if result.rowcount > 0: + logger.debug(f"成功更新消息 {message_id} 的interest_value为 {interest_value}") + else: + logger.warning(f"未找到消息 {message_id},无法更新interest_value") + + except Exception as e: + logger.error(f"更新消息 {message_id} 的interest_value失败: {e}") + raise + + @staticmethod + def fix_zero_interest_values(chat_id: str, since_time: float) -> int: + """ + 修复指定聊天中interest_value为0或null的历史消息记录 + + Args: + chat_id: 聊天ID + since_time: 从指定时间开始修复(时间戳) + + Returns: + 修复的记录数量 + """ + try: + with get_db_session() as session: + from sqlalchemy import select, update + from src.common.database.sqlalchemy_models import Messages + + # 查找需要修复的记录:interest_value为0、null或很小的值 + query = select(Messages).where( + (Messages.chat_id == chat_id) & + (Messages.time >= since_time) & + ( + (Messages.interest_value == 0) | + (Messages.interest_value.is_(None)) | + (Messages.interest_value < 0.1) + ) + ).limit(50) # 限制每次修复的数量,避免性能问题 + + messages_to_fix = session.execute(query).scalars().all() + fixed_count = 0 + + for msg in messages_to_fix: + # 为这些消息设置一个合理的默认兴趣度 + # 可以基于消息长度、内容或其他因素计算 + default_interest = 0.3 # 默认中等兴趣度 + + # 如果消息内容较长,可能是重要消息,兴趣度稍高 + if hasattr(msg, 'processed_plain_text') and msg.processed_plain_text: + text_length = len(msg.processed_plain_text) + if text_length > 50: # 长消息 + default_interest = 0.4 + elif text_length > 20: # 中等长度消息 + default_interest = 0.35 + + # 如果是被@的消息,兴趣度更高 + if getattr(msg, 'is_mentioned', False): + default_interest = min(default_interest + 0.2, 0.8) + + # 执行更新 + update_stmt = update(Messages).where( + Messages.message_id == msg.message_id + ).values(interest_value=default_interest) + + result = session.execute(update_stmt) + if result.rowcount > 0: + fixed_count += 1 + logger.debug(f"修复消息 {msg.message_id} 的interest_value为 {default_interest}") + + session.commit() + logger.info(f"共修复了 {fixed_count} 条历史消息的interest_value值") + return fixed_count + + except Exception as e: + logger.error(f"修复历史消息interest_value失败: {e}") + return 0 diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index 11468a814..4578d1481 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -96,7 +96,6 @@ class DatabaseMessages(BaseDataModel): chat_info_create_time: float = 0.0, chat_info_last_active_time: float = 0.0, # 新增字段 - interest_degree: float = 0.0, actions: Optional[list] = None, should_reply: bool = False, **kwargs: Any, @@ -108,7 +107,6 @@ class DatabaseMessages(BaseDataModel): self.interest_value = interest_value # 新增字段 - self.interest_degree = interest_degree self.actions = actions self.should_reply = should_reply @@ -201,7 +199,6 @@ class DatabaseMessages(BaseDataModel): "selected_expressions": self.selected_expressions, "is_read": self.is_read, # 新增字段 - "interest_degree": self.interest_degree, "actions": self.actions, "should_reply": self.should_reply, "user_id": self.user_info.user_id, @@ -221,17 +218,17 @@ class DatabaseMessages(BaseDataModel): "chat_info_user_cardname": self.chat_info.user_info.user_cardname, } - def update_message_info(self, interest_degree: float = None, actions: list = None, should_reply: bool = None): + def update_message_info(self, interest_value: float = None, actions: list = None, should_reply: bool = None): """ 更新消息信息 Args: - interest_degree: 兴趣度值 + interest_value: 兴趣度值 actions: 执行的动作列表 should_reply: 是否应该回复 """ - if interest_degree is not None: - self.interest_degree = interest_degree + if interest_value is not None: + self.interest_value = interest_value if actions is not None: self.actions = actions if should_reply is not None: @@ -268,7 +265,7 @@ class DatabaseMessages(BaseDataModel): return { "message_id": self.message_id, "time": self.time, - "interest_degree": self.interest_degree, + "interest_value": self.interest_value, "actions": self.actions, "should_reply": self.should_reply, "user_nickname": self.user_info.user_nickname, diff --git a/src/common/data_models/message_manager_data_model.py b/src/common/data_models/message_manager_data_model.py index 268328c77..f35c53573 100644 --- a/src/common/data_models/message_manager_data_model.py +++ b/src/common/data_models/message_manager_data_model.py @@ -61,27 +61,27 @@ class StreamContext(BaseDataModel): self._detect_chat_type(message) def update_message_info( - self, message_id: str, interest_degree: float = None, actions: list = None, should_reply: bool = None + self, message_id: str, interest_value: float = None, actions: list = None, should_reply: bool = None ): """ 更新消息信息 Args: message_id: 消息ID - interest_degree: 兴趣度值 + interest_value: 兴趣度值 actions: 执行的动作列表 should_reply: 是否应该回复 """ # 在未读消息中查找并更新 for message in self.unread_messages: if message.message_id == message_id: - message.update_message_info(interest_degree, actions, should_reply) + message.update_message_info(interest_value, actions, should_reply) break # 在历史消息中查找并更新 for message in self.history_messages: if message.message_id == message_id: - message.update_message_info(interest_degree, actions, should_reply) + message.update_message_info(interest_value, actions, should_reply) break def add_action_to_message(self, message_id: str, action: str): diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 84ad10ea9..5d57bb73d 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -174,7 +174,6 @@ class Messages(Base): is_notify = Column(Boolean, nullable=False, default=False) # 兴趣度系统字段 - interest_degree = Column(Float, nullable=True, default=0.0) actions = Column(Text, nullable=True) # JSON格式存储动作列表 should_reply = Column(Boolean, nullable=True, default=False) @@ -183,7 +182,6 @@ class Messages(Base): Index("idx_messages_chat_id", "chat_id"), Index("idx_messages_time", "time"), Index("idx_messages_user_id", "user_id"), - Index("idx_messages_interest_degree", "interest_degree"), Index("idx_messages_should_reply", "should_reply"), ) diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py index 4793e2835..f6a3c3653 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py @@ -368,41 +368,30 @@ class ChatterPlanFilter: interest_scores = {} try: - from src.plugins.built_in.affinity_flow_chatter.interest_scoring import ( - chatter_interest_scoring_system as interest_scoring_system, - ) - from src.common.data_models.database_data_model import DatabaseMessages + from src.chat.interest_system import interest_manager - # 转换消息格式 - db_messages = [] + # 使用新的兴趣度管理系统计算评分 for msg_dict in messages: try: - db_msg = DatabaseMessages( - message_id=msg_dict.get("message_id", ""), - time=msg_dict.get("time", time.time()), - chat_id=msg_dict.get("chat_id", ""), - processed_plain_text=msg_dict.get("processed_plain_text", ""), - user_id=msg_dict.get("user_id", ""), - user_nickname=msg_dict.get("user_nickname", ""), - user_platform=msg_dict.get("platform", "qq"), - chat_info_group_id=msg_dict.get("group_id", ""), - chat_info_group_name=msg_dict.get("group_name", ""), - chat_info_group_platform=msg_dict.get("platform", "qq"), + # 构建计算上下文 + calc_context = { + "stream_id": msg_dict.get("chat_id", ""), + "user_id": msg_dict.get("user_id"), + } + + # 计算消息兴趣度 + interest_score = interest_manager.calculate_message_interest( + message=msg_dict, + context=calc_context ) - db_messages.append(db_msg) + + # 构建兴趣度字典 + interest_scores[msg_dict.get("message_id", "")] = interest_score + except Exception as e: - logger.warning(f"转换消息格式失败: {e}") + logger.warning(f"计算消息兴趣度失败: {e}") continue - # 计算兴趣度评分 - if db_messages: - bot_nickname = global_config.bot.nickname or "麦麦" - scores = await interest_scoring_system.calculate_interest_scores(db_messages, bot_nickname) - - # 构建兴趣度字典 - for score in scores: - interest_scores[score.message_id] = score.total_score - except Exception as e: logger.warning(f"获取兴趣度评分失败: {e}") diff --git a/src/plugins/built_in/affinity_flow_chatter/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner.py index 36d3d300f..f0d09a5e6 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner.py @@ -4,13 +4,13 @@ """ from dataclasses import asdict +import time from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator -from src.plugins.built_in.affinity_flow_chatter.interest_scoring import ChatterInterestScoringSystem -from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker +from src.chat.interest_system import interest_manager from src.mood.mood_manager import mood_manager @@ -52,14 +52,7 @@ class ChatterActionPlanner: self.generator = ChatterPlanGenerator(chat_id) self.executor = ChatterPlanExecutor(action_manager) - # 初始化兴趣度评分系统 - self.interest_scoring = ChatterInterestScoringSystem() - - # 创建新的关系追踪器 - self.relationship_tracker = ChatterRelationshipTracker(self.interest_scoring) - - # 设置执行器的关系追踪器 - self.executor.set_relationship_tracker(self.relationship_tracker) + # 使用新的统一兴趣度管理系统 # 规划器统计 self.planner_stats = { @@ -107,43 +100,39 @@ class ChatterActionPlanner: initial_plan.available_actions = self.action_manager.get_using_actions() unread_messages = context.get_unread_messages() if context else [] - # 2. 兴趣度评分 - 只对未读消息进行评分 + # 2. 使用新的兴趣度管理系统进行评分 + score = 0.0 + should_reply = False + reply_not_available = False + if unread_messages: - bot_nickname = global_config.bot.nickname - interest_scores = await self.interest_scoring.calculate_interest_scores(unread_messages, bot_nickname) + # 获取用户ID + user_id = None + if unread_messages[0].user_id: + user_id = unread_messages[0].user_id - # 3. 根据兴趣度调整可用动作 - if interest_scores: - latest_score = max(interest_scores, key=lambda s: s.total_score) - latest_message = next( - (msg for msg in unread_messages if msg.message_id == latest_score.message_id), None - ) - should_reply, score = self.interest_scoring.should_reply(latest_score, latest_message) + # 构建计算上下文 + calc_context = { + "stream_id": self.chat_id, + "user_id": user_id, + } - reply_not_available = False - if not should_reply and "reply" in initial_plan.available_actions: - logger.info(f"兴趣度不足 ({latest_score.total_score:.2f}),移除回复") - reply_not_available = True - - # 更新ChatStream的兴趣度数据 - from src.plugin_system.apis.chat_api import get_chat_manager - chat_stream = get_chat_manager().get_stream(self.chat_id) - logger.debug(f"已更新聊天 {self.chat_id} 的ChatStream兴趣度,分数: {score:.3f}") - - # 更新情绪状态和ChatStream兴趣度数据 - if latest_message and score > 0: - chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id) - await chat_mood.update_mood_by_message(latest_message, score) - logger.debug(f"已更新聊天 {self.chat_id} 的情绪状态,兴趣度: {score:.3f}") - - # 为所有未读消息记录兴趣度信息 + # 为每条消息计算兴趣度 for message in unread_messages: - # 查找对应的兴趣度评分 - message_score = next((s for s in interest_scores if s.message_id == message.message_id), None) - if message_score: - message.interest_degree = message_score.total_score - message.should_reply = self.interest_scoring.should_reply(message_score, message)[0] - logger.debug(f"已记录消息 {message.message_id} - 兴趣度: {message_score.total_score:.3f}, 应回复: {message.should_reply}") + try: + # 使用新的兴趣度管理器计算 + message_interest = interest_manager.calculate_message_interest( + message=message.__dict__, + context=calc_context + ) + + # 更新消息的兴趣度 + message.interest_value = message_interest + + # 简单的回复决策逻辑:兴趣度超过阈值则回复 + message.should_reply = message_interest > global_config.affinity_flow.non_reply_action_interest_threshold + + logger.debug(f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}") # 更新StreamContext中的消息信息并刷新focus_energy if context: @@ -151,25 +140,35 @@ class ChatterActionPlanner: message_manager.update_message_and_refresh_energy( stream_id=self.chat_id, message_id=message.message_id, - interest_degree=message_score.total_score, + interest_value=message_interest, should_reply=message.should_reply ) - else: - # 如果没有找到评分,设置默认值 - message.interest_degree = 0.0 + + # 更新数据库中的消息记录 + try: + from src.chat.message_receive.storage import MessageStorage + MessageStorage.update_message_interest_value(message.message_id, message_interest) + logger.debug(f"已更新数据库中消息 {message.message_id} 的兴趣度为: {message_interest:.3f}") + except Exception as e: + logger.warning(f"更新数据库消息兴趣度失败: {e}") + + # 更新话题兴趣度 + interest_manager.update_topic_interest(message.__dict__, message_interest) + + # 记录最高分 + if message_interest > score: + score = message_interest + if message.should_reply: + should_reply = True + else: + reply_not_available = True + + except Exception as e: + logger.warning(f"计算消息 {message.message_id} 兴趣度失败: {e}") + # 设置默认值 + message.interest_value = 0.0 message.should_reply = False - # 更新StreamContext中的消息信息并刷新focus_energy - if context: - from src.chat.message_manager.message_manager import message_manager - message_manager.update_message_and_refresh_energy( - stream_id=self.chat_id, - message_id=message.message_id, - interest_degree=0.0, - should_reply=False - ) - - # base_threshold = self.interest_scoring.reply_threshold # 检查兴趣度是否达到非回复动作阈值 non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold if score < non_reply_action_interest_threshold: @@ -191,26 +190,16 @@ class ChatterActionPlanner: plan_filter = ChatterPlanFilter(self.chat_id, available_actions) filtered_plan = await plan_filter.filter(reply_not_available, initial_plan) - # 检查filtered_plan是否有reply动作,以便记录reply action - has_reply_action = False - for decision in filtered_plan.decided_actions: - if decision.action_type == "reply": - has_reply_action = True - self.interest_scoring.record_reply_action(has_reply_action) + # 检查filtered_plan是否有reply动作,用于统计 + has_reply_action = any(decision.action_type == "reply" for decision in filtered_plan.decided_actions) # 5. 使用 PlanExecutor 执行 Plan execution_result = await self.executor.execute(filtered_plan) - # 6. 动作记录现在由ChatterActionManager统一处理 - # 动作记录逻辑已移至ChatterActionManager.execute_action方法中 - - # 7. 根据执行结果更新统计信息 + # 6. 根据执行结果更新统计信息 self._update_stats_from_execution_result(execution_result) - # 8. 检查关系更新 - await self.relationship_tracker.check_and_update_relationships() - - # 8. 返回结果 + # 7. 返回结果 return self._build_return_result(filtered_plan) except Exception as e: @@ -259,37 +248,10 @@ class ChatterActionPlanner: return final_actions_dict, final_target_message_dict - def get_user_relationship(self, user_id: str) -> float: - """获取用户关系分""" - return self.interest_scoring.get_user_relationship(user_id) - - def update_interest_keywords(self, new_keywords: Dict[str, List[str]]): - """更新兴趣关键词(已弃用,仅保留用于兼容性)""" - logger.info("传统关键词匹配已移除,此方法仅保留用于兼容性") - # 此方法已弃用,因为现在完全使用embedding匹配 - def get_planner_stats(self) -> Dict[str, any]: """获取规划器统计""" return self.planner_stats.copy() - def get_interest_scoring_stats(self) -> Dict[str, any]: - """获取兴趣度评分统计""" - return { - "no_reply_count": self.interest_scoring.no_reply_count, - "max_no_reply_count": self.interest_scoring.max_no_reply_count, - "reply_threshold": self.interest_scoring.reply_threshold, - "mention_threshold": self.interest_scoring.mention_threshold, - "user_relationships": len(self.interest_scoring.user_relationships), - } - - def get_relationship_stats(self) -> Dict[str, any]: - """获取用户关系统计""" - return { - "tracking_users": len(self.relationship_tracker.tracking_users), - "relationship_history": len(self.relationship_tracker.relationship_history), - "max_tracking_users": self.relationship_tracker.max_tracking_users, - } - def get_current_mood_state(self) -> str: """获取当前聊天的情绪状态""" chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id)