Files
Mofox-Core/src/chat/energy_system/energy_manager.py
Windpicker-owo 1381e7f700 refactor(scoring): 重构评分系统使用统一API替代插件内部实现
将原有的兴趣度评分系统重构为使用统一的评分API,移除了插件内部实现并更新了所有相关引用。主要变更包括:
- 替换 chatter_interest_scoring_system 为 scoring_api
- 移除 interest_scoring.py 文件,统一使用 src/plugin_system/apis/scoring_api
- 更新关系追踪器以使用统一API,保持向后兼容性
- 简化了多个模块中的关系分获取逻辑
2025-10-07 14:41:49 +08:00

489 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
重构后的 focus_energy 管理系统
提供稳定、高效的聊天流能量计算和管理功能
"""
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, TypedDict
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("energy_system")
class EnergyLevel(Enum):
"""能量等级"""
VERY_LOW = 0.1 # 非常低
LOW = 0.3 # 低
NORMAL = 0.5 # 正常
HIGH = 0.7 # 高
VERY_HIGH = 0.9 # 非常高
@dataclass
class EnergyComponent:
"""能量组件"""
name: str
value: float
weight: float = 1.0
decay_rate: float = 0.05 # 衰减率
last_updated: float = field(default_factory=time.time)
def get_current_value(self) -> float:
"""获取当前值(考虑时间衰减)"""
age = time.time() - self.last_updated
decay_factor = max(0.1, 1.0 - (age * self.decay_rate / (24 * 3600))) # 按天衰减
return self.value * decay_factor
def update_value(self, new_value: float) -> None:
"""更新值"""
self.value = max(0.0, min(1.0, new_value))
self.last_updated = time.time()
class EnergyContext(TypedDict):
"""能量计算上下文"""
stream_id: str
messages: list[Any]
user_id: str | None
class EnergyResult(TypedDict):
"""能量计算结果"""
energy: float
level: EnergyLevel
distribution_interval: float
component_scores: dict[str, float]
cached: bool
class EnergyCalculator(ABC):
"""能量计算器抽象基类"""
@abstractmethod
def calculate(self, context: dict[str, Any]) -> float:
"""计算能量值"""
pass
@abstractmethod
def get_weight(self) -> float:
"""获取权重"""
pass
class InterestEnergyCalculator(EnergyCalculator):
"""兴趣度能量计算器"""
def calculate(self, context: dict[str, Any]) -> float:
"""基于消息兴趣度计算能量"""
messages = context.get("messages", [])
if not messages:
return 0.3
# 计算平均兴趣度
total_interest = 0.0
valid_messages = 0
for msg in messages:
interest_value = getattr(msg, "interest_value", None)
if isinstance(interest_value, int | float):
if 0.0 <= interest_value <= 1.0:
total_interest += interest_value
valid_messages += 1
if valid_messages > 0:
avg_interest = total_interest / valid_messages
logger.debug(f"平均消息兴趣度: {avg_interest:.3f} (基于 {valid_messages} 条消息)")
return avg_interest
else:
return 0.3
def get_weight(self) -> float:
return 0.5
class ActivityEnergyCalculator(EnergyCalculator):
"""活跃度能量计算器"""
def __init__(self):
self.action_weights = {"reply": 0.4, "react": 0.3, "mention": 0.2, "other": 0.1}
def calculate(self, context: dict[str, Any]) -> float:
"""基于活跃度计算能量"""
messages = context.get("messages", [])
if not messages:
return 0.2
total_score = 0.0
max_possible_score = len(messages) * 0.4 # 最高可能分数
for msg in messages:
actions = getattr(msg, "actions", [])
if isinstance(actions, list) and actions:
for action in actions:
weight = self.action_weights.get(action, self.action_weights["other"])
total_score += weight
if max_possible_score > 0:
activity_score = min(1.0, total_score / max_possible_score)
logger.debug(f"活跃度分数: {activity_score:.3f}")
return activity_score
else:
return 0.2
def get_weight(self) -> float:
return 0.3
class RecencyEnergyCalculator(EnergyCalculator):
"""最近性能量计算器"""
def calculate(self, context: dict[str, Any]) -> float:
"""基于最近性计算能量"""
messages = context.get("messages", [])
if not messages:
return 0.1
# 获取最新消息时间
latest_time = 0.0
for msg in messages:
msg_time = getattr(msg, "time", None)
if msg_time and msg_time > latest_time:
latest_time = msg_time
if latest_time == 0.0:
return 0.1
# 计算时间衰减
current_time = time.time()
age = current_time - latest_time
# 时间衰减策略:
# 1小时内1.0
# 1-6小时0.8
# 6-24小时0.5
# 1-7天0.3
# 7天以上0.1
if age < 3600: # 1小时内
recency_score = 1.0
elif age < 6 * 3600: # 6小时内
recency_score = 0.8
elif age < 24 * 3600: # 24小时内
recency_score = 0.5
elif age < 7 * 24 * 3600: # 7天内
recency_score = 0.3
else:
recency_score = 0.1
logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age / 3600:.1f}小时)")
return recency_score
def get_weight(self) -> float:
return 0.2
class RelationshipEnergyCalculator(EnergyCalculator):
"""关系能量计算器"""
async def calculate(self, context: dict[str, Any]) -> float:
"""基于关系计算能量"""
user_id = context.get("user_id")
if not user_id:
return 0.3
# 使用统一的评分API获取关系分
try:
from src.plugin_system.apis.scoring_api import scoring_api
relationship_score = await scoring_api.get_user_relationship_score(user_id)
logger.debug(f"使用统一评分API计算关系分: {relationship_score:.3f}")
return relationship_score
except Exception as e:
logger.warning(f"关系分计算失败,使用默认值: {e}")
return 0.3 # 默认基础分
def get_weight(self) -> float:
return 0.1
class EnergyManager:
"""能量管理器 - 统一管理所有能量计算"""
def __init__(self) -> None:
self.calculators: list[EnergyCalculator] = [
InterestEnergyCalculator(),
ActivityEnergyCalculator(),
RecencyEnergyCalculator(),
RelationshipEnergyCalculator(),
]
# 能量缓存
self.energy_cache: dict[str, tuple[float, float]] = {} # stream_id -> (energy, timestamp)
self.cache_ttl: int = 60 # 1分钟缓存
# AFC阈值配置
self.thresholds: dict[str, float] = {"high_match": 0.8, "reply": 0.4, "non_reply": 0.2}
# 统计信息
self.stats: dict[str, int | float | str] = {
"total_calculations": 0,
"cache_hits": 0,
"cache_misses": 0,
"average_calculation_time": 0.0,
"last_threshold_update": time.time(),
}
# 从配置加载阈值
self._load_thresholds_from_config()
logger.info("能量管理器初始化完成")
def _load_thresholds_from_config(self) -> None:
"""从配置加载AFC阈值"""
try:
if hasattr(global_config, "affinity_flow") and global_config.affinity_flow is not None:
self.thresholds["high_match"] = getattr(
global_config.affinity_flow, "high_match_interest_threshold", 0.8
)
self.thresholds["reply"] = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
self.thresholds["non_reply"] = getattr(
global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2
)
# 确保阈值关系合理
self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1)
self.thresholds["reply"] = max(self.thresholds["reply"], self.thresholds["non_reply"] + 0.1)
self.stats["last_threshold_update"] = time.time()
logger.info(f"加载AFC阈值: {self.thresholds}")
except Exception as e:
logger.warning(f"加载AFC阈值失败使用默认值: {e}")
async def calculate_focus_energy(self, stream_id: str, messages: list[Any], user_id: str | None = None) -> float:
"""计算聊天流的focus_energy"""
start_time = time.time()
# 更新统计
self.stats["total_calculations"] += 1
# 检查缓存
if stream_id in self.energy_cache:
cached_energy, cached_time = self.energy_cache[stream_id]
if time.time() - cached_time < self.cache_ttl:
self.stats["cache_hits"] += 1
logger.debug(f"使用缓存能量: {stream_id} = {cached_energy:.3f}")
return cached_energy
else:
self.stats["cache_misses"] += 1
# 构建计算上下文
context: EnergyContext = {
"stream_id": stream_id,
"messages": messages,
"user_id": user_id,
}
# 计算各组件能量
component_scores: dict[str, float] = {}
total_weight = 0.0
for calculator in self.calculators:
try:
# 支持同步和异步计算器
if callable(calculator.calculate):
import inspect
if inspect.iscoroutinefunction(calculator.calculate):
score = await calculator.calculate(context)
else:
score = calculator.calculate(context)
else:
score = calculator.calculate(context)
weight = calculator.get_weight()
# 确保 score 是 float 类型
if not isinstance(score, int | float):
logger.warning(
f"计算器 {calculator.__class__.__name__} 返回了非数值类型: {type(score)},跳过此组件"
)
continue
component_scores[calculator.__class__.__name__] = float(score)
total_weight += weight
logger.debug(f"{calculator.__class__.__name__} 能量: {score:.3f} (权重: {weight:.3f})")
except Exception as e:
logger.warning(f"计算 {calculator.__class__.__name__} 能量失败: {e}")
# 加权计算总能量
if total_weight > 0:
total_energy = 0.0
for calculator in self.calculators:
if calculator.__class__.__name__ in component_scores:
score = component_scores[calculator.__class__.__name__]
weight = calculator.get_weight()
total_energy += score * (weight / total_weight)
else:
total_energy = 0.5
# 应用阈值调整和变换
final_energy = self._apply_threshold_adjustment(total_energy)
# 缓存结果
self.energy_cache[stream_id] = (final_energy, time.time())
# 清理过期缓存
self._cleanup_cache()
# 更新平均计算时间
calculation_time = time.time() - start_time
total_calculations = self.stats["total_calculations"]
self.stats["average_calculation_time"] = (
self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time
) / total_calculations
logger.debug(
f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)"
)
return final_energy
def _apply_threshold_adjustment(self, energy: float) -> float:
"""应用阈值调整和变换"""
# 获取参考阈值
high_threshold = self.thresholds["high_match"]
reply_threshold = self.thresholds["reply"]
# 计算与阈值的相对位置
if energy >= high_threshold:
# 高能量区域:指数增强
adjusted = 0.7 + max(0, energy - 0.7) ** 0.8
elif energy >= reply_threshold:
# 中等能量区域:线性保持
adjusted = energy
else:
# 低能量区域:对数压缩
adjusted = 0.4 * (energy / 0.4) ** 1.2
# 确保在合理范围内
return max(0.1, min(1.0, adjusted))
def get_energy_level(self, energy: float) -> EnergyLevel:
"""获取能量等级"""
if energy >= EnergyLevel.VERY_HIGH.value:
return EnergyLevel.VERY_HIGH
elif energy >= EnergyLevel.HIGH.value:
return EnergyLevel.HIGH
elif energy >= EnergyLevel.NORMAL.value:
return EnergyLevel.NORMAL
elif energy >= EnergyLevel.LOW.value:
return EnergyLevel.LOW
else:
return EnergyLevel.VERY_LOW
def get_distribution_interval(self, energy: float) -> float:
"""基于能量等级获取分发周期"""
energy_level = self.get_energy_level(energy)
# 根据能量等级确定基础分发周期
if energy_level == EnergyLevel.VERY_HIGH:
base_interval = 1.0 # 1秒
elif energy_level == EnergyLevel.HIGH:
base_interval = 3.0 # 3秒
elif energy_level == EnergyLevel.NORMAL:
base_interval = 8.0 # 8秒
elif energy_level == EnergyLevel.LOW:
base_interval = 15.0 # 15秒
else:
base_interval = 30.0 # 30秒
# 添加随机扰动避免同步
import random
jitter = random.uniform(0.8, 1.2)
final_interval = base_interval * jitter
# 确保在配置范围内
min_interval = getattr(global_config.chat, "dynamic_distribution_min_interval", 1.0)
max_interval = getattr(global_config.chat, "dynamic_distribution_max_interval", 60.0)
return max(min_interval, min(max_interval, final_interval))
def invalidate_cache(self, stream_id: str) -> None:
"""失效指定流的缓存"""
if stream_id in self.energy_cache:
del self.energy_cache[stream_id]
logger.debug(f"已清除聊天流 {stream_id} 的能量缓存")
def _cleanup_cache(self) -> None:
"""清理过期缓存"""
current_time = time.time()
expired_keys = [
stream_id
for stream_id, (_, timestamp) in self.energy_cache.items()
if current_time - timestamp > self.cache_ttl
]
for key in expired_keys:
del self.energy_cache[key]
if expired_keys:
logger.debug(f"清理了 {len(expired_keys)} 个过期能量缓存")
def get_statistics(self) -> dict[str, Any]:
"""获取统计信息"""
return {
"cache_size": len(self.energy_cache),
"calculators": [calc.__class__.__name__ for calc in self.calculators],
"thresholds": self.thresholds,
"performance_stats": self.stats.copy(),
}
def update_thresholds(self, new_thresholds: dict[str, float]) -> None:
"""更新阈值"""
self.thresholds.update(new_thresholds)
# 确保阈值关系合理
self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1)
self.thresholds["reply"] = max(self.thresholds["reply"], self.thresholds["non_reply"] + 0.1)
self.stats["last_threshold_update"] = time.time()
logger.info(f"更新AFC阈值: {self.thresholds}")
def add_calculator(self, calculator: EnergyCalculator) -> None:
"""添加计算器"""
self.calculators.append(calculator)
logger.info(f"添加能量计算器: {calculator.__class__.__name__}")
def remove_calculator(self, calculator: EnergyCalculator) -> None:
"""移除计算器"""
if calculator in self.calculators:
self.calculators.remove(calculator)
logger.info(f"移除能量计算器: {calculator.__class__.__name__}")
def clear_cache(self) -> None:
"""清空缓存"""
self.energy_cache.clear()
logger.info("清空能量缓存")
def get_cache_hit_rate(self) -> float:
"""获取缓存命中率"""
total_requests = self.stats.get("cache_hits", 0) + self.stats.get("cache_misses", 0)
if total_requests == 0:
return 0.0
return self.stats["cache_hits"] / total_requests
# 全局能量管理器实例
energy_manager = EnergyManager()