将原有的兴趣度评分系统重构为使用统一的评分API,移除了插件内部实现并更新了所有相关引用。主要变更包括: - 替换 chatter_interest_scoring_system 为 scoring_api - 移除 interest_scoring.py 文件,统一使用 src/plugin_system/apis/scoring_api - 更新关系追踪器以使用统一API,保持向后兼容性 - 简化了多个模块中的关系分获取逻辑
489 lines
16 KiB
Python
489 lines
16 KiB
Python
"""
|
||
重构后的 focus_energy 管理系统
|
||
提供稳定、高效的聊天流能量计算和管理功能
|
||
"""
|
||
|
||
import time
|
||
from abc import ABC, abstractmethod
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
from typing import Any, TypedDict
|
||
|
||
from src.common.logger import get_logger
|
||
from src.config.config import global_config
|
||
|
||
logger = get_logger("energy_system")
|
||
|
||
|
||
class EnergyLevel(Enum):
|
||
"""能量等级"""
|
||
|
||
VERY_LOW = 0.1 # 非常低
|
||
LOW = 0.3 # 低
|
||
NORMAL = 0.5 # 正常
|
||
HIGH = 0.7 # 高
|
||
VERY_HIGH = 0.9 # 非常高
|
||
|
||
|
||
@dataclass
|
||
class EnergyComponent:
|
||
"""能量组件"""
|
||
|
||
name: str
|
||
value: float
|
||
weight: float = 1.0
|
||
decay_rate: float = 0.05 # 衰减率
|
||
last_updated: float = field(default_factory=time.time)
|
||
|
||
def get_current_value(self) -> float:
|
||
"""获取当前值(考虑时间衰减)"""
|
||
age = time.time() - self.last_updated
|
||
decay_factor = max(0.1, 1.0 - (age * self.decay_rate / (24 * 3600))) # 按天衰减
|
||
return self.value * decay_factor
|
||
|
||
def update_value(self, new_value: float) -> None:
|
||
"""更新值"""
|
||
self.value = max(0.0, min(1.0, new_value))
|
||
self.last_updated = time.time()
|
||
|
||
|
||
class EnergyContext(TypedDict):
|
||
"""能量计算上下文"""
|
||
|
||
stream_id: str
|
||
messages: list[Any]
|
||
user_id: str | None
|
||
|
||
|
||
class EnergyResult(TypedDict):
|
||
"""能量计算结果"""
|
||
|
||
energy: float
|
||
level: EnergyLevel
|
||
distribution_interval: float
|
||
component_scores: dict[str, float]
|
||
cached: bool
|
||
|
||
|
||
class EnergyCalculator(ABC):
|
||
"""能量计算器抽象基类"""
|
||
|
||
@abstractmethod
|
||
def calculate(self, context: dict[str, Any]) -> float:
|
||
"""计算能量值"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def get_weight(self) -> float:
|
||
"""获取权重"""
|
||
pass
|
||
|
||
|
||
class InterestEnergyCalculator(EnergyCalculator):
|
||
"""兴趣度能量计算器"""
|
||
|
||
def calculate(self, context: dict[str, Any]) -> float:
|
||
"""基于消息兴趣度计算能量"""
|
||
messages = context.get("messages", [])
|
||
if not messages:
|
||
return 0.3
|
||
|
||
# 计算平均兴趣度
|
||
total_interest = 0.0
|
||
valid_messages = 0
|
||
|
||
for msg in messages:
|
||
interest_value = getattr(msg, "interest_value", None)
|
||
if isinstance(interest_value, int | float):
|
||
if 0.0 <= interest_value <= 1.0:
|
||
total_interest += interest_value
|
||
valid_messages += 1
|
||
|
||
if valid_messages > 0:
|
||
avg_interest = total_interest / valid_messages
|
||
logger.debug(f"平均消息兴趣度: {avg_interest:.3f} (基于 {valid_messages} 条消息)")
|
||
return avg_interest
|
||
else:
|
||
return 0.3
|
||
|
||
def get_weight(self) -> float:
|
||
return 0.5
|
||
|
||
|
||
class ActivityEnergyCalculator(EnergyCalculator):
|
||
"""活跃度能量计算器"""
|
||
|
||
def __init__(self):
|
||
self.action_weights = {"reply": 0.4, "react": 0.3, "mention": 0.2, "other": 0.1}
|
||
|
||
def calculate(self, context: dict[str, Any]) -> float:
|
||
"""基于活跃度计算能量"""
|
||
messages = context.get("messages", [])
|
||
if not messages:
|
||
return 0.2
|
||
|
||
total_score = 0.0
|
||
max_possible_score = len(messages) * 0.4 # 最高可能分数
|
||
|
||
for msg in messages:
|
||
actions = getattr(msg, "actions", [])
|
||
if isinstance(actions, list) and actions:
|
||
for action in actions:
|
||
weight = self.action_weights.get(action, self.action_weights["other"])
|
||
total_score += weight
|
||
|
||
if max_possible_score > 0:
|
||
activity_score = min(1.0, total_score / max_possible_score)
|
||
logger.debug(f"活跃度分数: {activity_score:.3f}")
|
||
return activity_score
|
||
else:
|
||
return 0.2
|
||
|
||
def get_weight(self) -> float:
|
||
return 0.3
|
||
|
||
|
||
class RecencyEnergyCalculator(EnergyCalculator):
|
||
"""最近性能量计算器"""
|
||
|
||
def calculate(self, context: dict[str, Any]) -> float:
|
||
"""基于最近性计算能量"""
|
||
messages = context.get("messages", [])
|
||
if not messages:
|
||
return 0.1
|
||
|
||
# 获取最新消息时间
|
||
latest_time = 0.0
|
||
for msg in messages:
|
||
msg_time = getattr(msg, "time", None)
|
||
if msg_time and msg_time > latest_time:
|
||
latest_time = msg_time
|
||
|
||
if latest_time == 0.0:
|
||
return 0.1
|
||
|
||
# 计算时间衰减
|
||
current_time = time.time()
|
||
age = current_time - latest_time
|
||
|
||
# 时间衰减策略:
|
||
# 1小时内:1.0
|
||
# 1-6小时:0.8
|
||
# 6-24小时:0.5
|
||
# 1-7天:0.3
|
||
# 7天以上:0.1
|
||
if age < 3600: # 1小时内
|
||
recency_score = 1.0
|
||
elif age < 6 * 3600: # 6小时内
|
||
recency_score = 0.8
|
||
elif age < 24 * 3600: # 24小时内
|
||
recency_score = 0.5
|
||
elif age < 7 * 24 * 3600: # 7天内
|
||
recency_score = 0.3
|
||
else:
|
||
recency_score = 0.1
|
||
|
||
logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age / 3600:.1f}小时)")
|
||
return recency_score
|
||
|
||
def get_weight(self) -> float:
|
||
return 0.2
|
||
|
||
|
||
class RelationshipEnergyCalculator(EnergyCalculator):
|
||
"""关系能量计算器"""
|
||
|
||
async def calculate(self, context: dict[str, Any]) -> float:
|
||
"""基于关系计算能量"""
|
||
user_id = context.get("user_id")
|
||
if not user_id:
|
||
return 0.3
|
||
|
||
# 使用统一的评分API获取关系分
|
||
try:
|
||
from src.plugin_system.apis.scoring_api import scoring_api
|
||
|
||
relationship_score = await scoring_api.get_user_relationship_score(user_id)
|
||
logger.debug(f"使用统一评分API计算关系分: {relationship_score:.3f}")
|
||
return relationship_score
|
||
|
||
except Exception as e:
|
||
logger.warning(f"关系分计算失败,使用默认值: {e}")
|
||
return 0.3 # 默认基础分
|
||
|
||
def get_weight(self) -> float:
|
||
return 0.1
|
||
|
||
|
||
class EnergyManager:
|
||
"""能量管理器 - 统一管理所有能量计算"""
|
||
|
||
def __init__(self) -> None:
|
||
self.calculators: list[EnergyCalculator] = [
|
||
InterestEnergyCalculator(),
|
||
ActivityEnergyCalculator(),
|
||
RecencyEnergyCalculator(),
|
||
RelationshipEnergyCalculator(),
|
||
]
|
||
|
||
# 能量缓存
|
||
self.energy_cache: dict[str, tuple[float, float]] = {} # stream_id -> (energy, timestamp)
|
||
self.cache_ttl: int = 60 # 1分钟缓存
|
||
|
||
# AFC阈值配置
|
||
self.thresholds: dict[str, float] = {"high_match": 0.8, "reply": 0.4, "non_reply": 0.2}
|
||
|
||
# 统计信息
|
||
self.stats: dict[str, int | float | str] = {
|
||
"total_calculations": 0,
|
||
"cache_hits": 0,
|
||
"cache_misses": 0,
|
||
"average_calculation_time": 0.0,
|
||
"last_threshold_update": time.time(),
|
||
}
|
||
|
||
# 从配置加载阈值
|
||
self._load_thresholds_from_config()
|
||
|
||
logger.info("能量管理器初始化完成")
|
||
|
||
def _load_thresholds_from_config(self) -> None:
|
||
"""从配置加载AFC阈值"""
|
||
try:
|
||
if hasattr(global_config, "affinity_flow") and global_config.affinity_flow is not None:
|
||
self.thresholds["high_match"] = getattr(
|
||
global_config.affinity_flow, "high_match_interest_threshold", 0.8
|
||
)
|
||
self.thresholds["reply"] = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
|
||
self.thresholds["non_reply"] = getattr(
|
||
global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2
|
||
)
|
||
|
||
# 确保阈值关系合理
|
||
self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1)
|
||
self.thresholds["reply"] = max(self.thresholds["reply"], self.thresholds["non_reply"] + 0.1)
|
||
|
||
self.stats["last_threshold_update"] = time.time()
|
||
logger.info(f"加载AFC阈值: {self.thresholds}")
|
||
except Exception as e:
|
||
logger.warning(f"加载AFC阈值失败,使用默认值: {e}")
|
||
|
||
async def calculate_focus_energy(self, stream_id: str, messages: list[Any], user_id: str | None = None) -> float:
|
||
"""计算聊天流的focus_energy"""
|
||
start_time = time.time()
|
||
|
||
# 更新统计
|
||
self.stats["total_calculations"] += 1
|
||
|
||
# 检查缓存
|
||
if stream_id in self.energy_cache:
|
||
cached_energy, cached_time = self.energy_cache[stream_id]
|
||
if time.time() - cached_time < self.cache_ttl:
|
||
self.stats["cache_hits"] += 1
|
||
logger.debug(f"使用缓存能量: {stream_id} = {cached_energy:.3f}")
|
||
return cached_energy
|
||
else:
|
||
self.stats["cache_misses"] += 1
|
||
|
||
# 构建计算上下文
|
||
context: EnergyContext = {
|
||
"stream_id": stream_id,
|
||
"messages": messages,
|
||
"user_id": user_id,
|
||
}
|
||
|
||
# 计算各组件能量
|
||
component_scores: dict[str, float] = {}
|
||
total_weight = 0.0
|
||
|
||
for calculator in self.calculators:
|
||
try:
|
||
# 支持同步和异步计算器
|
||
if callable(calculator.calculate):
|
||
import inspect
|
||
|
||
if inspect.iscoroutinefunction(calculator.calculate):
|
||
score = await calculator.calculate(context)
|
||
else:
|
||
score = calculator.calculate(context)
|
||
else:
|
||
score = calculator.calculate(context)
|
||
|
||
weight = calculator.get_weight()
|
||
|
||
# 确保 score 是 float 类型
|
||
if not isinstance(score, int | float):
|
||
logger.warning(
|
||
f"计算器 {calculator.__class__.__name__} 返回了非数值类型: {type(score)},跳过此组件"
|
||
)
|
||
continue
|
||
|
||
component_scores[calculator.__class__.__name__] = float(score)
|
||
total_weight += weight
|
||
|
||
logger.debug(f"{calculator.__class__.__name__} 能量: {score:.3f} (权重: {weight:.3f})")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"计算 {calculator.__class__.__name__} 能量失败: {e}")
|
||
|
||
# 加权计算总能量
|
||
if total_weight > 0:
|
||
total_energy = 0.0
|
||
for calculator in self.calculators:
|
||
if calculator.__class__.__name__ in component_scores:
|
||
score = component_scores[calculator.__class__.__name__]
|
||
weight = calculator.get_weight()
|
||
total_energy += score * (weight / total_weight)
|
||
else:
|
||
total_energy = 0.5
|
||
|
||
# 应用阈值调整和变换
|
||
final_energy = self._apply_threshold_adjustment(total_energy)
|
||
|
||
# 缓存结果
|
||
self.energy_cache[stream_id] = (final_energy, time.time())
|
||
|
||
# 清理过期缓存
|
||
self._cleanup_cache()
|
||
|
||
# 更新平均计算时间
|
||
calculation_time = time.time() - start_time
|
||
total_calculations = self.stats["total_calculations"]
|
||
self.stats["average_calculation_time"] = (
|
||
self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time
|
||
) / total_calculations
|
||
|
||
logger.debug(
|
||
f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)"
|
||
)
|
||
return final_energy
|
||
|
||
def _apply_threshold_adjustment(self, energy: float) -> float:
|
||
"""应用阈值调整和变换"""
|
||
# 获取参考阈值
|
||
high_threshold = self.thresholds["high_match"]
|
||
reply_threshold = self.thresholds["reply"]
|
||
|
||
# 计算与阈值的相对位置
|
||
if energy >= high_threshold:
|
||
# 高能量区域:指数增强
|
||
adjusted = 0.7 + max(0, energy - 0.7) ** 0.8
|
||
elif energy >= reply_threshold:
|
||
# 中等能量区域:线性保持
|
||
adjusted = energy
|
||
else:
|
||
# 低能量区域:对数压缩
|
||
adjusted = 0.4 * (energy / 0.4) ** 1.2
|
||
|
||
# 确保在合理范围内
|
||
return max(0.1, min(1.0, adjusted))
|
||
|
||
def get_energy_level(self, energy: float) -> EnergyLevel:
|
||
"""获取能量等级"""
|
||
if energy >= EnergyLevel.VERY_HIGH.value:
|
||
return EnergyLevel.VERY_HIGH
|
||
elif energy >= EnergyLevel.HIGH.value:
|
||
return EnergyLevel.HIGH
|
||
elif energy >= EnergyLevel.NORMAL.value:
|
||
return EnergyLevel.NORMAL
|
||
elif energy >= EnergyLevel.LOW.value:
|
||
return EnergyLevel.LOW
|
||
else:
|
||
return EnergyLevel.VERY_LOW
|
||
|
||
def get_distribution_interval(self, energy: float) -> float:
|
||
"""基于能量等级获取分发周期"""
|
||
energy_level = self.get_energy_level(energy)
|
||
|
||
# 根据能量等级确定基础分发周期
|
||
if energy_level == EnergyLevel.VERY_HIGH:
|
||
base_interval = 1.0 # 1秒
|
||
elif energy_level == EnergyLevel.HIGH:
|
||
base_interval = 3.0 # 3秒
|
||
elif energy_level == EnergyLevel.NORMAL:
|
||
base_interval = 8.0 # 8秒
|
||
elif energy_level == EnergyLevel.LOW:
|
||
base_interval = 15.0 # 15秒
|
||
else:
|
||
base_interval = 30.0 # 30秒
|
||
|
||
# 添加随机扰动避免同步
|
||
import random
|
||
|
||
jitter = random.uniform(0.8, 1.2)
|
||
final_interval = base_interval * jitter
|
||
|
||
# 确保在配置范围内
|
||
min_interval = getattr(global_config.chat, "dynamic_distribution_min_interval", 1.0)
|
||
max_interval = getattr(global_config.chat, "dynamic_distribution_max_interval", 60.0)
|
||
|
||
return max(min_interval, min(max_interval, final_interval))
|
||
|
||
def invalidate_cache(self, stream_id: str) -> None:
|
||
"""失效指定流的缓存"""
|
||
if stream_id in self.energy_cache:
|
||
del self.energy_cache[stream_id]
|
||
logger.debug(f"已清除聊天流 {stream_id} 的能量缓存")
|
||
|
||
def _cleanup_cache(self) -> None:
|
||
"""清理过期缓存"""
|
||
current_time = time.time()
|
||
expired_keys = [
|
||
stream_id
|
||
for stream_id, (_, timestamp) in self.energy_cache.items()
|
||
if current_time - timestamp > self.cache_ttl
|
||
]
|
||
|
||
for key in expired_keys:
|
||
del self.energy_cache[key]
|
||
|
||
if expired_keys:
|
||
logger.debug(f"清理了 {len(expired_keys)} 个过期能量缓存")
|
||
|
||
def get_statistics(self) -> dict[str, Any]:
|
||
"""获取统计信息"""
|
||
return {
|
||
"cache_size": len(self.energy_cache),
|
||
"calculators": [calc.__class__.__name__ for calc in self.calculators],
|
||
"thresholds": self.thresholds,
|
||
"performance_stats": self.stats.copy(),
|
||
}
|
||
|
||
def update_thresholds(self, new_thresholds: dict[str, float]) -> None:
|
||
"""更新阈值"""
|
||
self.thresholds.update(new_thresholds)
|
||
|
||
# 确保阈值关系合理
|
||
self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1)
|
||
self.thresholds["reply"] = max(self.thresholds["reply"], self.thresholds["non_reply"] + 0.1)
|
||
|
||
self.stats["last_threshold_update"] = time.time()
|
||
logger.info(f"更新AFC阈值: {self.thresholds}")
|
||
|
||
def add_calculator(self, calculator: EnergyCalculator) -> None:
|
||
"""添加计算器"""
|
||
self.calculators.append(calculator)
|
||
logger.info(f"添加能量计算器: {calculator.__class__.__name__}")
|
||
|
||
def remove_calculator(self, calculator: EnergyCalculator) -> None:
|
||
"""移除计算器"""
|
||
if calculator in self.calculators:
|
||
self.calculators.remove(calculator)
|
||
logger.info(f"移除能量计算器: {calculator.__class__.__name__}")
|
||
|
||
def clear_cache(self) -> None:
|
||
"""清空缓存"""
|
||
self.energy_cache.clear()
|
||
logger.info("清空能量缓存")
|
||
|
||
def get_cache_hit_rate(self) -> float:
|
||
"""获取缓存命中率"""
|
||
total_requests = self.stats.get("cache_hits", 0) + self.stats.get("cache_misses", 0)
|
||
if total_requests == 0:
|
||
return 0.0
|
||
return self.stats["cache_hits"] / total_requests
|
||
|
||
|
||
# 全局能量管理器实例
|
||
energy_manager = EnergyManager()
|