refactor(chat): 重构消息管理器以使用集中式上下文管理和能量系统

- 将流上下文管理从MessageManager迁移到专门的ContextManager
- 使用统一的能量系统计算focus_energy和分发间隔
- 重构ChatStream的消息数据转换逻辑,支持更完整的数据字段
- 更新数据库模型,移除interest_degree字段,统一使用interest_value
- 集成新的兴趣度管理系统替代原有的评分系统
- 添加消息存储的interest_value修复功能
This commit is contained in:
Windpicker-owo
2025-09-27 14:23:48 +08:00
parent 5962b44294
commit 1177b81285
15 changed files with 3520 additions and 506 deletions

View File

@@ -0,0 +1,28 @@
"""
能量系统模块
提供稳定、高效的聊天流能量计算和管理功能
"""
from .energy_manager import (
EnergyManager,
EnergyLevel,
EnergyComponent,
EnergyCalculator,
InterestEnergyCalculator,
ActivityEnergyCalculator,
RecencyEnergyCalculator,
RelationshipEnergyCalculator,
energy_manager
)
__all__ = [
"EnergyManager",
"EnergyLevel",
"EnergyComponent",
"EnergyCalculator",
"InterestEnergyCalculator",
"ActivityEnergyCalculator",
"RecencyEnergyCalculator",
"RelationshipEnergyCalculator",
"energy_manager"
]

View File

@@ -0,0 +1,480 @@
"""
重构后的 focus_energy 管理系统
提供稳定、高效的聊天流能量计算和管理功能
"""
import time
from typing import Dict, List, Optional, Tuple, Any, Union, TypedDict
from dataclasses import dataclass, field
from enum import Enum
from abc import ABC, abstractmethod
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("energy_system")
class EnergyLevel(Enum):
"""能量等级"""
VERY_LOW = 0.1 # 非常低
LOW = 0.3 # 低
NORMAL = 0.5 # 正常
HIGH = 0.7 # 高
VERY_HIGH = 0.9 # 非常高
@dataclass
class EnergyComponent:
"""能量组件"""
name: str
value: float
weight: float = 1.0
decay_rate: float = 0.05 # 衰减率
last_updated: float = field(default_factory=time.time)
def get_current_value(self) -> float:
"""获取当前值(考虑时间衰减)"""
age = time.time() - self.last_updated
decay_factor = max(0.1, 1.0 - (age * self.decay_rate / (24 * 3600))) # 按天衰减
return self.value * decay_factor
def update_value(self, new_value: float) -> None:
"""更新值"""
self.value = max(0.0, min(1.0, new_value))
self.last_updated = time.time()
class EnergyContext(TypedDict):
"""能量计算上下文"""
stream_id: str
messages: List[Any]
user_id: Optional[str]
class EnergyResult(TypedDict):
"""能量计算结果"""
energy: float
level: EnergyLevel
distribution_interval: float
component_scores: Dict[str, float]
cached: bool
class EnergyCalculator(ABC):
"""能量计算器抽象基类"""
@abstractmethod
def calculate(self, context: Dict[str, Any]) -> float:
"""计算能量值"""
pass
@abstractmethod
def get_weight(self) -> float:
"""获取权重"""
pass
class InterestEnergyCalculator(EnergyCalculator):
"""兴趣度能量计算器"""
def calculate(self, context: Dict[str, Any]) -> float:
"""基于消息兴趣度计算能量"""
messages = context.get("messages", [])
if not messages:
return 0.3
# 计算平均兴趣度
total_interest = 0.0
valid_messages = 0
for msg in messages:
interest_value = getattr(msg, "interest_value", None)
if interest_value is not None:
try:
interest_float = float(interest_value)
if 0.0 <= interest_float <= 1.0:
total_interest += interest_float
valid_messages += 1
except (ValueError, TypeError):
continue
if valid_messages > 0:
avg_interest = total_interest / valid_messages
logger.debug(f"平均消息兴趣度: {avg_interest:.3f} (基于 {valid_messages} 条消息)")
return avg_interest
else:
return 0.3
def get_weight(self) -> float:
return 0.5
class ActivityEnergyCalculator(EnergyCalculator):
"""活跃度能量计算器"""
def __init__(self):
self.action_weights = {
"reply": 0.4,
"react": 0.3,
"mention": 0.2,
"other": 0.1
}
def calculate(self, context: Dict[str, Any]) -> float:
"""基于活跃度计算能量"""
messages = context.get("messages", [])
if not messages:
return 0.2
total_score = 0.0
max_possible_score = len(messages) * 0.4 # 最高可能分数
for msg in messages:
actions = getattr(msg, "actions", [])
if isinstance(actions, list) and actions:
for action in actions:
weight = self.action_weights.get(action, self.action_weights["other"])
total_score += weight
if max_possible_score > 0:
activity_score = min(1.0, total_score / max_possible_score)
logger.debug(f"活跃度分数: {activity_score:.3f}")
return activity_score
else:
return 0.2
def get_weight(self) -> float:
return 0.3
class RecencyEnergyCalculator(EnergyCalculator):
"""最近性能量计算器"""
def calculate(self, context: Dict[str, Any]) -> float:
"""基于最近性计算能量"""
messages = context.get("messages", [])
if not messages:
return 0.1
# 获取最新消息时间
latest_time = 0.0
for msg in messages:
msg_time = getattr(msg, "time", None)
if msg_time and msg_time > latest_time:
latest_time = msg_time
if latest_time == 0.0:
return 0.1
# 计算时间衰减
current_time = time.time()
age = current_time - latest_time
# 时间衰减策略:
# 1小时内1.0
# 1-6小时0.8
# 6-24小时0.5
# 1-7天0.3
# 7天以上0.1
if age < 3600: # 1小时内
recency_score = 1.0
elif age < 6 * 3600: # 6小时内
recency_score = 0.8
elif age < 24 * 3600: # 24小时内
recency_score = 0.5
elif age < 7 * 24 * 3600: # 7天内
recency_score = 0.3
else:
recency_score = 0.1
logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age/3600:.1f}小时)")
return recency_score
def get_weight(self) -> float:
return 0.2
class RelationshipEnergyCalculator(EnergyCalculator):
"""关系能量计算器"""
def calculate(self, context: Dict[str, Any]) -> float:
"""基于关系计算能量"""
user_id = context.get("user_id")
if not user_id:
return 0.3
try:
# 使用新的兴趣度管理系统获取用户关系分
from src.chat.interest_system import interest_manager
# 获取用户交互历史作为关系分的基础
interaction_calc = interest_manager.calculators.get(
interest_manager.InterestSourceType.USER_INTERACTION
)
if interaction_calc:
relationship_score = interaction_calc.calculate({"user_id": user_id})
logger.debug(f"用户关系分数: {relationship_score:.3f}")
return max(0.0, min(1.0, relationship_score))
else:
# 默认基础分
return 0.3
except Exception:
# 默认基础分
return 0.3
def get_weight(self) -> float:
return 0.1
class EnergyManager:
"""能量管理器 - 统一管理所有能量计算"""
def __init__(self) -> None:
self.calculators: List[EnergyCalculator] = [
InterestEnergyCalculator(),
ActivityEnergyCalculator(),
RecencyEnergyCalculator(),
RelationshipEnergyCalculator(),
]
# 能量缓存
self.energy_cache: Dict[str, Tuple[float, float]] = {} # stream_id -> (energy, timestamp)
self.cache_ttl: int = 60 # 1分钟缓存
# AFC阈值配置
self.thresholds: Dict[str, float] = {
"high_match": 0.8,
"reply": 0.4,
"non_reply": 0.2
}
# 统计信息
self.stats: Dict[str, Union[int, float, str]] = {
"total_calculations": 0,
"cache_hits": 0,
"cache_misses": 0,
"average_calculation_time": 0.0,
"last_threshold_update": time.time(),
}
# 从配置加载阈值
self._load_thresholds_from_config()
logger.info("能量管理器初始化完成")
def _load_thresholds_from_config(self) -> None:
"""从配置加载AFC阈值"""
try:
if hasattr(global_config, "affinity_flow") and global_config.affinity_flow is not None:
self.thresholds["high_match"] = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8)
self.thresholds["reply"] = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
self.thresholds["non_reply"] = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2)
# 确保阈值关系合理
self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1)
self.thresholds["reply"] = max(self.thresholds["reply"], self.thresholds["non_reply"] + 0.1)
self.stats["last_threshold_update"] = time.time()
logger.info(f"加载AFC阈值: {self.thresholds}")
except Exception as e:
logger.warning(f"加载AFC阈值失败使用默认值: {e}")
def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float:
"""计算聊天流的focus_energy"""
start_time = time.time()
# 更新统计
self.stats["total_calculations"] += 1
# 检查缓存
if stream_id in self.energy_cache:
cached_energy, cached_time = self.energy_cache[stream_id]
if time.time() - cached_time < self.cache_ttl:
self.stats["cache_hits"] += 1
logger.debug(f"使用缓存能量: {stream_id} = {cached_energy:.3f}")
return cached_energy
else:
self.stats["cache_misses"] += 1
# 构建计算上下文
context: EnergyContext = {
"stream_id": stream_id,
"messages": messages,
"user_id": user_id,
}
# 计算各组件能量
component_scores: Dict[str, float] = {}
total_weight = 0.0
for calculator in self.calculators:
try:
score = calculator.calculate(context)
weight = calculator.get_weight()
component_scores[calculator.__class__.__name__] = score
total_weight += weight
logger.debug(f"{calculator.__class__.__name__} 能量: {score:.3f} (权重: {weight:.3f})")
except Exception as e:
logger.warning(f"计算 {calculator.__class__.__name__} 能量失败: {e}")
# 加权计算总能量
if total_weight > 0:
total_energy = 0.0
for calculator in self.calculators:
if calculator.__class__.__name__ in component_scores:
score = component_scores[calculator.__class__.__name__]
weight = calculator.get_weight()
total_energy += score * (weight / total_weight)
else:
total_energy = 0.5
# 应用阈值调整和变换
final_energy = self._apply_threshold_adjustment(total_energy)
# 缓存结果
self.energy_cache[stream_id] = (final_energy, time.time())
# 清理过期缓存
self._cleanup_cache()
# 更新平均计算时间
calculation_time = time.time() - start_time
total_calculations = self.stats["total_calculations"]
self.stats["average_calculation_time"] = (
(self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time)
/ total_calculations
)
logger.info(f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)")
return final_energy
def _apply_threshold_adjustment(self, energy: float) -> float:
"""应用阈值调整和变换"""
# 获取参考阈值
high_threshold = self.thresholds["high_match"]
reply_threshold = self.thresholds["reply"]
# 计算与阈值的相对位置
if energy >= high_threshold:
# 高能量区域:指数增强
adjusted = 0.7 + (energy - 0.7) ** 0.8
elif energy >= reply_threshold:
# 中等能量区域:线性保持
adjusted = energy
else:
# 低能量区域:对数压缩
adjusted = 0.4 * (energy / 0.4) ** 1.2
# 确保在合理范围内
return max(0.1, min(1.0, adjusted))
def get_energy_level(self, energy: float) -> EnergyLevel:
"""获取能量等级"""
if energy >= EnergyLevel.VERY_HIGH.value:
return EnergyLevel.VERY_HIGH
elif energy >= EnergyLevel.HIGH.value:
return EnergyLevel.HIGH
elif energy >= EnergyLevel.NORMAL.value:
return EnergyLevel.NORMAL
elif energy >= EnergyLevel.LOW.value:
return EnergyLevel.LOW
else:
return EnergyLevel.VERY_LOW
def get_distribution_interval(self, energy: float) -> float:
"""基于能量等级获取分发周期"""
energy_level = self.get_energy_level(energy)
# 根据能量等级确定基础分发周期
if energy_level == EnergyLevel.VERY_HIGH:
base_interval = 1.0 # 1秒
elif energy_level == EnergyLevel.HIGH:
base_interval = 3.0 # 3秒
elif energy_level == EnergyLevel.NORMAL:
base_interval = 8.0 # 8秒
elif energy_level == EnergyLevel.LOW:
base_interval = 15.0 # 15秒
else:
base_interval = 30.0 # 30秒
# 添加随机扰动避免同步
import random
jitter = random.uniform(0.8, 1.2)
final_interval = base_interval * jitter
# 确保在配置范围内
min_interval = getattr(global_config.chat, "dynamic_distribution_min_interval", 1.0)
max_interval = getattr(global_config.chat, "dynamic_distribution_max_interval", 60.0)
return max(min_interval, min(max_interval, final_interval))
def invalidate_cache(self, stream_id: str) -> None:
"""失效指定流的缓存"""
if stream_id in self.energy_cache:
del self.energy_cache[stream_id]
logger.debug(f"已清除聊天流 {stream_id} 的能量缓存")
def _cleanup_cache(self) -> None:
"""清理过期缓存"""
current_time = time.time()
expired_keys = [
stream_id for stream_id, (_, timestamp) in self.energy_cache.items()
if current_time - timestamp > self.cache_ttl
]
for key in expired_keys:
del self.energy_cache[key]
if expired_keys:
logger.debug(f"清理了 {len(expired_keys)} 个过期能量缓存")
def get_statistics(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
"cache_size": len(self.energy_cache),
"calculators": [calc.__class__.__name__ for calc in self.calculators],
"thresholds": self.thresholds,
"performance_stats": self.stats.copy(),
}
def update_thresholds(self, new_thresholds: Dict[str, float]) -> None:
"""更新阈值"""
self.thresholds.update(new_thresholds)
# 确保阈值关系合理
self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1)
self.thresholds["reply"] = max(self.thresholds["reply"], self.thresholds["non_reply"] + 0.1)
self.stats["last_threshold_update"] = time.time()
logger.info(f"更新AFC阈值: {self.thresholds}")
def add_calculator(self, calculator: EnergyCalculator) -> None:
"""添加计算器"""
self.calculators.append(calculator)
logger.info(f"添加能量计算器: {calculator.__class__.__name__}")
def remove_calculator(self, calculator: EnergyCalculator) -> None:
"""移除计算器"""
if calculator in self.calculators:
self.calculators.remove(calculator)
logger.info(f"移除能量计算器: {calculator.__class__.__name__}")
def clear_cache(self) -> None:
"""清空缓存"""
self.energy_cache.clear()
logger.info("清空能量缓存")
def get_cache_hit_rate(self) -> float:
"""获取缓存命中率"""
total_requests = self.stats.get("cache_hits", 0) + self.stats.get("cache_misses", 0)
if total_requests == 0:
return 0.0
return self.stats["cache_hits"] / total_requests
# 全局能量管理器实例
energy_manager = EnergyManager()