style: 格式化代码
This commit is contained in:
committed by
Windpicker-owo
parent
e7aaafde2f
commit
00ba07e0e1
@@ -12,7 +12,7 @@ from .energy_manager import (
|
||||
ActivityEnergyCalculator,
|
||||
RecencyEnergyCalculator,
|
||||
RelationshipEnergyCalculator,
|
||||
energy_manager
|
||||
energy_manager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -24,5 +24,5 @@ __all__ = [
|
||||
"ActivityEnergyCalculator",
|
||||
"RecencyEnergyCalculator",
|
||||
"RelationshipEnergyCalculator",
|
||||
"energy_manager"
|
||||
]
|
||||
"energy_manager",
|
||||
]
|
||||
|
||||
@@ -17,16 +17,18 @@ logger = get_logger("energy_system")
|
||||
|
||||
class EnergyLevel(Enum):
|
||||
"""能量等级"""
|
||||
|
||||
VERY_LOW = 0.1 # 非常低
|
||||
LOW = 0.3 # 低
|
||||
NORMAL = 0.5 # 正常
|
||||
HIGH = 0.7 # 高
|
||||
VERY_HIGH = 0.9 # 非常高
|
||||
LOW = 0.3 # 低
|
||||
NORMAL = 0.5 # 正常
|
||||
HIGH = 0.7 # 高
|
||||
VERY_HIGH = 0.9 # 非常高
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnergyComponent:
|
||||
"""能量组件"""
|
||||
|
||||
name: str
|
||||
value: float
|
||||
weight: float = 1.0
|
||||
@@ -47,6 +49,7 @@ class EnergyComponent:
|
||||
|
||||
class EnergyContext(TypedDict):
|
||||
"""能量计算上下文"""
|
||||
|
||||
stream_id: str
|
||||
messages: List[Any]
|
||||
user_id: Optional[str]
|
||||
@@ -54,6 +57,7 @@ class EnergyContext(TypedDict):
|
||||
|
||||
class EnergyResult(TypedDict):
|
||||
"""能量计算结果"""
|
||||
|
||||
energy: float
|
||||
level: EnergyLevel
|
||||
distribution_interval: float
|
||||
@@ -114,12 +118,7 @@ class ActivityEnergyCalculator(EnergyCalculator):
|
||||
"""活跃度能量计算器"""
|
||||
|
||||
def __init__(self):
|
||||
self.action_weights = {
|
||||
"reply": 0.4,
|
||||
"react": 0.3,
|
||||
"mention": 0.2,
|
||||
"other": 0.1
|
||||
}
|
||||
self.action_weights = {"reply": 0.4, "react": 0.3, "mention": 0.2, "other": 0.1}
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""基于活跃度计算能量"""
|
||||
@@ -188,7 +187,7 @@ class RecencyEnergyCalculator(EnergyCalculator):
|
||||
else:
|
||||
recency_score = 0.1
|
||||
|
||||
logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age/3600:.1f}小时)")
|
||||
logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age / 3600:.1f}小时)")
|
||||
return recency_score
|
||||
|
||||
def get_weight(self) -> float:
|
||||
@@ -236,11 +235,7 @@ class EnergyManager:
|
||||
self.cache_ttl: int = 60 # 1分钟缓存
|
||||
|
||||
# AFC阈值配置
|
||||
self.thresholds: Dict[str, float] = {
|
||||
"high_match": 0.8,
|
||||
"reply": 0.4,
|
||||
"non_reply": 0.2
|
||||
}
|
||||
self.thresholds: Dict[str, float] = {"high_match": 0.8, "reply": 0.4, "non_reply": 0.2}
|
||||
|
||||
# 统计信息
|
||||
self.stats: Dict[str, Union[int, float, str]] = {
|
||||
@@ -260,9 +255,13 @@ class EnergyManager:
|
||||
"""从配置加载AFC阈值"""
|
||||
try:
|
||||
if hasattr(global_config, "affinity_flow") and global_config.affinity_flow is not None:
|
||||
self.thresholds["high_match"] = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8)
|
||||
self.thresholds["high_match"] = getattr(
|
||||
global_config.affinity_flow, "high_match_interest_threshold", 0.8
|
||||
)
|
||||
self.thresholds["reply"] = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
|
||||
self.thresholds["non_reply"] = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2)
|
||||
self.thresholds["non_reply"] = getattr(
|
||||
global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2
|
||||
)
|
||||
|
||||
# 确保阈值关系合理
|
||||
self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1)
|
||||
@@ -306,6 +305,7 @@ class EnergyManager:
|
||||
# 支持同步和异步计算器
|
||||
if callable(calculator.calculate):
|
||||
import inspect
|
||||
|
||||
if inspect.iscoroutinefunction(calculator.calculate):
|
||||
score = await calculator.calculate(context)
|
||||
else:
|
||||
@@ -347,11 +347,12 @@ class EnergyManager:
|
||||
calculation_time = time.time() - start_time
|
||||
total_calculations = self.stats["total_calculations"]
|
||||
self.stats["average_calculation_time"] = (
|
||||
(self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time)
|
||||
/ total_calculations
|
||||
)
|
||||
self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time
|
||||
) / total_calculations
|
||||
|
||||
logger.debug(f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)")
|
||||
logger.debug(
|
||||
f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)"
|
||||
)
|
||||
return final_energy
|
||||
|
||||
def _apply_threshold_adjustment(self, energy: float) -> float:
|
||||
@@ -405,6 +406,7 @@ class EnergyManager:
|
||||
|
||||
# 添加随机扰动避免同步
|
||||
import random
|
||||
|
||||
jitter = random.uniform(0.8, 1.2)
|
||||
final_interval = base_interval * jitter
|
||||
|
||||
@@ -424,7 +426,8 @@ class EnergyManager:
|
||||
"""清理过期缓存"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
stream_id for stream_id, (_, timestamp) in self.energy_cache.items()
|
||||
stream_id
|
||||
for stream_id, (_, timestamp) in self.energy_cache.items()
|
||||
if current_time - timestamp > self.cache_ttl
|
||||
]
|
||||
|
||||
@@ -479,4 +482,4 @@ class EnergyManager:
|
||||
|
||||
|
||||
# 全局能量管理器实例
|
||||
energy_manager = EnergyManager()
|
||||
energy_manager = EnergyManager()
|
||||
|
||||
Reference in New Issue
Block a user