refactor(chat): 重构消息管理器以使用集中式上下文管理和能量系统
- 将流上下文管理从MessageManager迁移到专门的ContextManager - 使用统一的能量系统计算focus_energy和分发间隔 - 重构ChatStream的消息数据转换逻辑,支持更完整的数据字段 - 更新数据库模型,移除interest_degree字段,统一使用interest_value - 集成新的兴趣度管理系统替代原有的评分系统 - 添加消息存储的interest_value修复功能
This commit is contained in:
28
src/chat/energy_system/__init__.py
Normal file
28
src/chat/energy_system/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
能量系统模块
|
||||
提供稳定、高效的聊天流能量计算和管理功能
|
||||
"""
|
||||
|
||||
from .energy_manager import (
|
||||
EnergyManager,
|
||||
EnergyLevel,
|
||||
EnergyComponent,
|
||||
EnergyCalculator,
|
||||
InterestEnergyCalculator,
|
||||
ActivityEnergyCalculator,
|
||||
RecencyEnergyCalculator,
|
||||
RelationshipEnergyCalculator,
|
||||
energy_manager
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EnergyManager",
|
||||
"EnergyLevel",
|
||||
"EnergyComponent",
|
||||
"EnergyCalculator",
|
||||
"InterestEnergyCalculator",
|
||||
"ActivityEnergyCalculator",
|
||||
"RecencyEnergyCalculator",
|
||||
"RelationshipEnergyCalculator",
|
||||
"energy_manager"
|
||||
]
|
||||
480
src/chat/energy_system/energy_manager.py
Normal file
480
src/chat/energy_system/energy_manager.py
Normal file
@@ -0,0 +1,480 @@
|
||||
"""
|
||||
重构后的 focus_energy 管理系统
|
||||
提供稳定、高效的聊天流能量计算和管理功能
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union, TypedDict
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("energy_system")
|
||||
|
||||
|
||||
class EnergyLevel(Enum):
|
||||
"""能量等级"""
|
||||
VERY_LOW = 0.1 # 非常低
|
||||
LOW = 0.3 # 低
|
||||
NORMAL = 0.5 # 正常
|
||||
HIGH = 0.7 # 高
|
||||
VERY_HIGH = 0.9 # 非常高
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnergyComponent:
|
||||
"""能量组件"""
|
||||
name: str
|
||||
value: float
|
||||
weight: float = 1.0
|
||||
decay_rate: float = 0.05 # 衰减率
|
||||
last_updated: float = field(default_factory=time.time)
|
||||
|
||||
def get_current_value(self) -> float:
|
||||
"""获取当前值(考虑时间衰减)"""
|
||||
age = time.time() - self.last_updated
|
||||
decay_factor = max(0.1, 1.0 - (age * self.decay_rate / (24 * 3600))) # 按天衰减
|
||||
return self.value * decay_factor
|
||||
|
||||
def update_value(self, new_value: float) -> None:
|
||||
"""更新值"""
|
||||
self.value = max(0.0, min(1.0, new_value))
|
||||
self.last_updated = time.time()
|
||||
|
||||
|
||||
class EnergyContext(TypedDict):
|
||||
"""能量计算上下文"""
|
||||
stream_id: str
|
||||
messages: List[Any]
|
||||
user_id: Optional[str]
|
||||
|
||||
|
||||
class EnergyResult(TypedDict):
|
||||
"""能量计算结果"""
|
||||
energy: float
|
||||
level: EnergyLevel
|
||||
distribution_interval: float
|
||||
component_scores: Dict[str, float]
|
||||
cached: bool
|
||||
|
||||
|
||||
class EnergyCalculator(ABC):
|
||||
"""能量计算器抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""计算能量值"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_weight(self) -> float:
|
||||
"""获取权重"""
|
||||
pass
|
||||
|
||||
|
||||
class InterestEnergyCalculator(EnergyCalculator):
|
||||
"""兴趣度能量计算器"""
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""基于消息兴趣度计算能量"""
|
||||
messages = context.get("messages", [])
|
||||
if not messages:
|
||||
return 0.3
|
||||
|
||||
# 计算平均兴趣度
|
||||
total_interest = 0.0
|
||||
valid_messages = 0
|
||||
|
||||
for msg in messages:
|
||||
interest_value = getattr(msg, "interest_value", None)
|
||||
if interest_value is not None:
|
||||
try:
|
||||
interest_float = float(interest_value)
|
||||
if 0.0 <= interest_float <= 1.0:
|
||||
total_interest += interest_float
|
||||
valid_messages += 1
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
if valid_messages > 0:
|
||||
avg_interest = total_interest / valid_messages
|
||||
logger.debug(f"平均消息兴趣度: {avg_interest:.3f} (基于 {valid_messages} 条消息)")
|
||||
return avg_interest
|
||||
else:
|
||||
return 0.3
|
||||
|
||||
def get_weight(self) -> float:
|
||||
return 0.5
|
||||
|
||||
|
||||
class ActivityEnergyCalculator(EnergyCalculator):
|
||||
"""活跃度能量计算器"""
|
||||
|
||||
def __init__(self):
|
||||
self.action_weights = {
|
||||
"reply": 0.4,
|
||||
"react": 0.3,
|
||||
"mention": 0.2,
|
||||
"other": 0.1
|
||||
}
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""基于活跃度计算能量"""
|
||||
messages = context.get("messages", [])
|
||||
if not messages:
|
||||
return 0.2
|
||||
|
||||
total_score = 0.0
|
||||
max_possible_score = len(messages) * 0.4 # 最高可能分数
|
||||
|
||||
for msg in messages:
|
||||
actions = getattr(msg, "actions", [])
|
||||
if isinstance(actions, list) and actions:
|
||||
for action in actions:
|
||||
weight = self.action_weights.get(action, self.action_weights["other"])
|
||||
total_score += weight
|
||||
|
||||
if max_possible_score > 0:
|
||||
activity_score = min(1.0, total_score / max_possible_score)
|
||||
logger.debug(f"活跃度分数: {activity_score:.3f}")
|
||||
return activity_score
|
||||
else:
|
||||
return 0.2
|
||||
|
||||
def get_weight(self) -> float:
|
||||
return 0.3
|
||||
|
||||
|
||||
class RecencyEnergyCalculator(EnergyCalculator):
|
||||
"""最近性能量计算器"""
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""基于最近性计算能量"""
|
||||
messages = context.get("messages", [])
|
||||
if not messages:
|
||||
return 0.1
|
||||
|
||||
# 获取最新消息时间
|
||||
latest_time = 0.0
|
||||
for msg in messages:
|
||||
msg_time = getattr(msg, "time", None)
|
||||
if msg_time and msg_time > latest_time:
|
||||
latest_time = msg_time
|
||||
|
||||
if latest_time == 0.0:
|
||||
return 0.1
|
||||
|
||||
# 计算时间衰减
|
||||
current_time = time.time()
|
||||
age = current_time - latest_time
|
||||
|
||||
# 时间衰减策略:
|
||||
# 1小时内:1.0
|
||||
# 1-6小时:0.8
|
||||
# 6-24小时:0.5
|
||||
# 1-7天:0.3
|
||||
# 7天以上:0.1
|
||||
if age < 3600: # 1小时内
|
||||
recency_score = 1.0
|
||||
elif age < 6 * 3600: # 6小时内
|
||||
recency_score = 0.8
|
||||
elif age < 24 * 3600: # 24小时内
|
||||
recency_score = 0.5
|
||||
elif age < 7 * 24 * 3600: # 7天内
|
||||
recency_score = 0.3
|
||||
else:
|
||||
recency_score = 0.1
|
||||
|
||||
logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age/3600:.1f}小时)")
|
||||
return recency_score
|
||||
|
||||
def get_weight(self) -> float:
|
||||
return 0.2
|
||||
|
||||
|
||||
class RelationshipEnergyCalculator(EnergyCalculator):
|
||||
"""关系能量计算器"""
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""基于关系计算能量"""
|
||||
user_id = context.get("user_id")
|
||||
if not user_id:
|
||||
return 0.3
|
||||
|
||||
try:
|
||||
# 使用新的兴趣度管理系统获取用户关系分
|
||||
from src.chat.interest_system import interest_manager
|
||||
|
||||
# 获取用户交互历史作为关系分的基础
|
||||
interaction_calc = interest_manager.calculators.get(
|
||||
interest_manager.InterestSourceType.USER_INTERACTION
|
||||
)
|
||||
if interaction_calc:
|
||||
relationship_score = interaction_calc.calculate({"user_id": user_id})
|
||||
logger.debug(f"用户关系分数: {relationship_score:.3f}")
|
||||
return max(0.0, min(1.0, relationship_score))
|
||||
else:
|
||||
# 默认基础分
|
||||
return 0.3
|
||||
except Exception:
|
||||
# 默认基础分
|
||||
return 0.3
|
||||
|
||||
def get_weight(self) -> float:
|
||||
return 0.1
|
||||
|
||||
|
||||
class EnergyManager:
|
||||
"""能量管理器 - 统一管理所有能量计算"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calculators: List[EnergyCalculator] = [
|
||||
InterestEnergyCalculator(),
|
||||
ActivityEnergyCalculator(),
|
||||
RecencyEnergyCalculator(),
|
||||
RelationshipEnergyCalculator(),
|
||||
]
|
||||
|
||||
# 能量缓存
|
||||
self.energy_cache: Dict[str, Tuple[float, float]] = {} # stream_id -> (energy, timestamp)
|
||||
self.cache_ttl: int = 60 # 1分钟缓存
|
||||
|
||||
# AFC阈值配置
|
||||
self.thresholds: Dict[str, float] = {
|
||||
"high_match": 0.8,
|
||||
"reply": 0.4,
|
||||
"non_reply": 0.2
|
||||
}
|
||||
|
||||
# 统计信息
|
||||
self.stats: Dict[str, Union[int, float, str]] = {
|
||||
"total_calculations": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
"average_calculation_time": 0.0,
|
||||
"last_threshold_update": time.time(),
|
||||
}
|
||||
|
||||
# 从配置加载阈值
|
||||
self._load_thresholds_from_config()
|
||||
|
||||
logger.info("能量管理器初始化完成")
|
||||
|
||||
def _load_thresholds_from_config(self) -> None:
|
||||
"""从配置加载AFC阈值"""
|
||||
try:
|
||||
if hasattr(global_config, "affinity_flow") and global_config.affinity_flow is not None:
|
||||
self.thresholds["high_match"] = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8)
|
||||
self.thresholds["reply"] = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
|
||||
self.thresholds["non_reply"] = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2)
|
||||
|
||||
# 确保阈值关系合理
|
||||
self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1)
|
||||
self.thresholds["reply"] = max(self.thresholds["reply"], self.thresholds["non_reply"] + 0.1)
|
||||
|
||||
self.stats["last_threshold_update"] = time.time()
|
||||
logger.info(f"加载AFC阈值: {self.thresholds}")
|
||||
except Exception as e:
|
||||
logger.warning(f"加载AFC阈值失败,使用默认值: {e}")
|
||||
|
||||
def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float:
|
||||
"""计算聊天流的focus_energy"""
|
||||
start_time = time.time()
|
||||
|
||||
# 更新统计
|
||||
self.stats["total_calculations"] += 1
|
||||
|
||||
# 检查缓存
|
||||
if stream_id in self.energy_cache:
|
||||
cached_energy, cached_time = self.energy_cache[stream_id]
|
||||
if time.time() - cached_time < self.cache_ttl:
|
||||
self.stats["cache_hits"] += 1
|
||||
logger.debug(f"使用缓存能量: {stream_id} = {cached_energy:.3f}")
|
||||
return cached_energy
|
||||
else:
|
||||
self.stats["cache_misses"] += 1
|
||||
|
||||
# 构建计算上下文
|
||||
context: EnergyContext = {
|
||||
"stream_id": stream_id,
|
||||
"messages": messages,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
# 计算各组件能量
|
||||
component_scores: Dict[str, float] = {}
|
||||
total_weight = 0.0
|
||||
|
||||
for calculator in self.calculators:
|
||||
try:
|
||||
score = calculator.calculate(context)
|
||||
weight = calculator.get_weight()
|
||||
|
||||
component_scores[calculator.__class__.__name__] = score
|
||||
total_weight += weight
|
||||
|
||||
logger.debug(f"{calculator.__class__.__name__} 能量: {score:.3f} (权重: {weight:.3f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算 {calculator.__class__.__name__} 能量失败: {e}")
|
||||
|
||||
# 加权计算总能量
|
||||
if total_weight > 0:
|
||||
total_energy = 0.0
|
||||
for calculator in self.calculators:
|
||||
if calculator.__class__.__name__ in component_scores:
|
||||
score = component_scores[calculator.__class__.__name__]
|
||||
weight = calculator.get_weight()
|
||||
total_energy += score * (weight / total_weight)
|
||||
else:
|
||||
total_energy = 0.5
|
||||
|
||||
# 应用阈值调整和变换
|
||||
final_energy = self._apply_threshold_adjustment(total_energy)
|
||||
|
||||
# 缓存结果
|
||||
self.energy_cache[stream_id] = (final_energy, time.time())
|
||||
|
||||
# 清理过期缓存
|
||||
self._cleanup_cache()
|
||||
|
||||
# 更新平均计算时间
|
||||
calculation_time = time.time() - start_time
|
||||
total_calculations = self.stats["total_calculations"]
|
||||
self.stats["average_calculation_time"] = (
|
||||
(self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time)
|
||||
/ total_calculations
|
||||
)
|
||||
|
||||
logger.info(f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)")
|
||||
return final_energy
|
||||
|
||||
def _apply_threshold_adjustment(self, energy: float) -> float:
|
||||
"""应用阈值调整和变换"""
|
||||
# 获取参考阈值
|
||||
high_threshold = self.thresholds["high_match"]
|
||||
reply_threshold = self.thresholds["reply"]
|
||||
|
||||
# 计算与阈值的相对位置
|
||||
if energy >= high_threshold:
|
||||
# 高能量区域:指数增强
|
||||
adjusted = 0.7 + (energy - 0.7) ** 0.8
|
||||
elif energy >= reply_threshold:
|
||||
# 中等能量区域:线性保持
|
||||
adjusted = energy
|
||||
else:
|
||||
# 低能量区域:对数压缩
|
||||
adjusted = 0.4 * (energy / 0.4) ** 1.2
|
||||
|
||||
# 确保在合理范围内
|
||||
return max(0.1, min(1.0, adjusted))
|
||||
|
||||
def get_energy_level(self, energy: float) -> EnergyLevel:
|
||||
"""获取能量等级"""
|
||||
if energy >= EnergyLevel.VERY_HIGH.value:
|
||||
return EnergyLevel.VERY_HIGH
|
||||
elif energy >= EnergyLevel.HIGH.value:
|
||||
return EnergyLevel.HIGH
|
||||
elif energy >= EnergyLevel.NORMAL.value:
|
||||
return EnergyLevel.NORMAL
|
||||
elif energy >= EnergyLevel.LOW.value:
|
||||
return EnergyLevel.LOW
|
||||
else:
|
||||
return EnergyLevel.VERY_LOW
|
||||
|
||||
def get_distribution_interval(self, energy: float) -> float:
|
||||
"""基于能量等级获取分发周期"""
|
||||
energy_level = self.get_energy_level(energy)
|
||||
|
||||
# 根据能量等级确定基础分发周期
|
||||
if energy_level == EnergyLevel.VERY_HIGH:
|
||||
base_interval = 1.0 # 1秒
|
||||
elif energy_level == EnergyLevel.HIGH:
|
||||
base_interval = 3.0 # 3秒
|
||||
elif energy_level == EnergyLevel.NORMAL:
|
||||
base_interval = 8.0 # 8秒
|
||||
elif energy_level == EnergyLevel.LOW:
|
||||
base_interval = 15.0 # 15秒
|
||||
else:
|
||||
base_interval = 30.0 # 30秒
|
||||
|
||||
# 添加随机扰动避免同步
|
||||
import random
|
||||
jitter = random.uniform(0.8, 1.2)
|
||||
final_interval = base_interval * jitter
|
||||
|
||||
# 确保在配置范围内
|
||||
min_interval = getattr(global_config.chat, "dynamic_distribution_min_interval", 1.0)
|
||||
max_interval = getattr(global_config.chat, "dynamic_distribution_max_interval", 60.0)
|
||||
|
||||
return max(min_interval, min(max_interval, final_interval))
|
||||
|
||||
def invalidate_cache(self, stream_id: str) -> None:
|
||||
"""失效指定流的缓存"""
|
||||
if stream_id in self.energy_cache:
|
||||
del self.energy_cache[stream_id]
|
||||
logger.debug(f"已清除聊天流 {stream_id} 的能量缓存")
|
||||
|
||||
def _cleanup_cache(self) -> None:
|
||||
"""清理过期缓存"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
stream_id for stream_id, (_, timestamp) in self.energy_cache.items()
|
||||
if current_time - timestamp > self.cache_ttl
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
del self.energy_cache[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.debug(f"清理了 {len(expired_keys)} 个过期能量缓存")
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
return {
|
||||
"cache_size": len(self.energy_cache),
|
||||
"calculators": [calc.__class__.__name__ for calc in self.calculators],
|
||||
"thresholds": self.thresholds,
|
||||
"performance_stats": self.stats.copy(),
|
||||
}
|
||||
|
||||
def update_thresholds(self, new_thresholds: Dict[str, float]) -> None:
|
||||
"""更新阈值"""
|
||||
self.thresholds.update(new_thresholds)
|
||||
|
||||
# 确保阈值关系合理
|
||||
self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1)
|
||||
self.thresholds["reply"] = max(self.thresholds["reply"], self.thresholds["non_reply"] + 0.1)
|
||||
|
||||
self.stats["last_threshold_update"] = time.time()
|
||||
logger.info(f"更新AFC阈值: {self.thresholds}")
|
||||
|
||||
def add_calculator(self, calculator: EnergyCalculator) -> None:
|
||||
"""添加计算器"""
|
||||
self.calculators.append(calculator)
|
||||
logger.info(f"添加能量计算器: {calculator.__class__.__name__}")
|
||||
|
||||
def remove_calculator(self, calculator: EnergyCalculator) -> None:
|
||||
"""移除计算器"""
|
||||
if calculator in self.calculators:
|
||||
self.calculators.remove(calculator)
|
||||
logger.info(f"移除能量计算器: {calculator.__class__.__name__}")
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清空缓存"""
|
||||
self.energy_cache.clear()
|
||||
logger.info("清空能量缓存")
|
||||
|
||||
def get_cache_hit_rate(self) -> float:
|
||||
"""获取缓存命中率"""
|
||||
total_requests = self.stats.get("cache_hits", 0) + self.stats.get("cache_misses", 0)
|
||||
if total_requests == 0:
|
||||
return 0.0
|
||||
return self.stats["cache_hits"] / total_requests
|
||||
|
||||
|
||||
# 全局能量管理器实例
|
||||
energy_manager = EnergyManager()
|
||||
@@ -1,12 +1,30 @@
|
||||
"""
|
||||
机器人兴趣标签系统
|
||||
基于人设生成兴趣标签,使用embedding计算匹配度
|
||||
兴趣度系统模块
|
||||
提供统一、稳定的消息兴趣度计算和管理功能
|
||||
"""
|
||||
|
||||
from .interest_manager import (
|
||||
InterestManager,
|
||||
InterestSourceType,
|
||||
InterestFactor,
|
||||
InterestCalculator,
|
||||
MessageContentInterestCalculator,
|
||||
TopicInterestCalculator,
|
||||
UserInteractionInterestCalculator,
|
||||
interest_manager
|
||||
)
|
||||
from .bot_interest_manager import BotInterestManager, bot_interest_manager
|
||||
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
|
||||
|
||||
__all__ = [
|
||||
"InterestManager",
|
||||
"InterestSourceType",
|
||||
"InterestFactor",
|
||||
"InterestCalculator",
|
||||
"MessageContentInterestCalculator",
|
||||
"TopicInterestCalculator",
|
||||
"UserInteractionInterestCalculator",
|
||||
"interest_manager",
|
||||
"BotInterestManager",
|
||||
"bot_interest_manager",
|
||||
"BotInterestTag",
|
||||
|
||||
430
src/chat/interest_system/interest_manager.py
Normal file
430
src/chat/interest_system/interest_manager.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""
|
||||
重构后的消息兴趣值计算系统
|
||||
提供稳定、可靠的消息兴趣度计算和管理功能
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union, TypedDict
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("interest_system")
|
||||
|
||||
|
||||
class InterestSourceType(Enum):
|
||||
"""兴趣度来源类型"""
|
||||
MESSAGE_CONTENT = "message_content" # 消息内容
|
||||
USER_INTERACTION = "user_interaction" # 用户交互
|
||||
TOPIC_RELEVANCE = "topic_relevance" # 话题相关性
|
||||
RELATIONSHIP_SCORE = "relationship_score" # 关系分数
|
||||
HISTORICAL_PATTERN = "historical_pattern" # 历史模式
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterestFactor:
|
||||
"""兴趣度因子"""
|
||||
source_type: InterestSourceType
|
||||
value: float
|
||||
weight: float = 1.0
|
||||
decay_rate: float = 0.1 # 衰减率
|
||||
last_updated: float = field(default_factory=time.time)
|
||||
|
||||
def get_current_value(self) -> float:
|
||||
"""获取当前值(考虑时间衰减)"""
|
||||
age = time.time() - self.last_updated
|
||||
decay_factor = max(0.1, 1.0 - (age * self.decay_rate / (24 * 3600))) # 按天衰减
|
||||
return self.value * decay_factor
|
||||
|
||||
def update_value(self, new_value: float) -> None:
|
||||
"""更新值"""
|
||||
self.value = max(0.0, min(1.0, new_value))
|
||||
self.last_updated = time.time()
|
||||
|
||||
|
||||
class InterestCalculator(ABC):
|
||||
"""兴趣度计算器抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""计算兴趣度"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_confidence(self) -> float:
|
||||
"""获取计算置信度"""
|
||||
pass
|
||||
|
||||
|
||||
class MessageData(TypedDict):
|
||||
"""消息数据类型定义"""
|
||||
message_id: str
|
||||
processed_plain_text: str
|
||||
is_emoji: bool
|
||||
is_picid: bool
|
||||
is_mentioned: bool
|
||||
is_command: bool
|
||||
key_words: str
|
||||
user_id: str
|
||||
time: float
|
||||
|
||||
|
||||
class InterestContext(TypedDict):
|
||||
"""兴趣度计算上下文"""
|
||||
stream_id: str
|
||||
user_id: Optional[str]
|
||||
message: MessageData
|
||||
|
||||
|
||||
class InterestResult(TypedDict):
|
||||
"""兴趣度计算结果"""
|
||||
value: float
|
||||
confidence: float
|
||||
source_scores: Dict[InterestSourceType, float]
|
||||
cached: bool
|
||||
|
||||
|
||||
class MessageContentInterestCalculator(InterestCalculator):
|
||||
"""消息内容兴趣度计算器"""
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""基于消息内容计算兴趣度"""
|
||||
message = context.get("message", {})
|
||||
if not message:
|
||||
return 0.3 # 默认值
|
||||
|
||||
# 提取消息特征
|
||||
text_length = len(message.get("processed_plain_text", ""))
|
||||
has_emoji = message.get("is_emoji", False)
|
||||
has_image = message.get("is_picid", False)
|
||||
is_mentioned = message.get("is_mentioned", False)
|
||||
is_command = message.get("is_command", False)
|
||||
|
||||
# 基础分数
|
||||
base_score = 0.3
|
||||
|
||||
# 文本长度加权
|
||||
if text_length > 0:
|
||||
text_score = min(0.3, text_length / 200) # 200字符为满分
|
||||
base_score += text_score * 0.3
|
||||
|
||||
# 多媒体内容加权
|
||||
if has_emoji:
|
||||
base_score += 0.1
|
||||
if has_image:
|
||||
base_score += 0.2
|
||||
|
||||
# 交互特征加权
|
||||
if is_mentioned:
|
||||
base_score += 0.2
|
||||
if is_command:
|
||||
base_score += 0.1
|
||||
|
||||
return min(1.0, base_score)
|
||||
|
||||
def get_confidence(self) -> float:
|
||||
return 0.8
|
||||
|
||||
|
||||
class TopicInterestCalculator(InterestCalculator):
|
||||
"""话题兴趣度计算器"""
|
||||
|
||||
def __init__(self):
|
||||
self.topic_interests: Dict[str, float] = {}
|
||||
self.topic_decay_rate = 0.05 # 话题兴趣度衰减率
|
||||
|
||||
def update_topic_interest(self, topic: str, interest_value: float):
|
||||
"""更新话题兴趣度"""
|
||||
current_interest = self.topic_interests.get(topic, 0.3)
|
||||
# 平滑更新
|
||||
new_interest = current_interest * 0.7 + interest_value * 0.3
|
||||
self.topic_interests[topic] = max(0.0, min(1.0, new_interest))
|
||||
|
||||
logger.debug(f"更新话题 '{topic}' 兴趣度: {current_interest:.3f} -> {new_interest:.3f}")
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""基于话题相关性计算兴趣度"""
|
||||
message = context.get("message", {})
|
||||
keywords = message.get("key_words", "[]")
|
||||
|
||||
try:
|
||||
import json
|
||||
keyword_list = json.loads(keywords) if keywords else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
keyword_list = []
|
||||
|
||||
if not keyword_list:
|
||||
return 0.4 # 无关键词时的默认值
|
||||
|
||||
# 计算相关话题的平均兴趣度
|
||||
total_interest = 0.0
|
||||
relevant_topics = 0
|
||||
|
||||
for keyword in keyword_list[:5]: # 最多取前5个关键词
|
||||
# 查找相关话题
|
||||
for topic, interest in self.topic_interests.items():
|
||||
if keyword.lower() in topic.lower() or topic.lower() in keyword.lower():
|
||||
total_interest += interest
|
||||
relevant_topics += 1
|
||||
break
|
||||
|
||||
if relevant_topics > 0:
|
||||
return min(1.0, total_interest / relevant_topics)
|
||||
else:
|
||||
# 新话题,给予基础兴趣度
|
||||
for keyword in keyword_list[:3]:
|
||||
self.topic_interests[keyword] = 0.5
|
||||
return 0.5
|
||||
|
||||
def get_confidence(self) -> float:
|
||||
return 0.7
|
||||
|
||||
|
||||
class UserInteractionInterestCalculator(InterestCalculator):
|
||||
"""用户交互兴趣度计算器"""
|
||||
|
||||
def __init__(self):
|
||||
self.interaction_history: List[Dict] = []
|
||||
self.max_history_size = 100
|
||||
|
||||
def add_interaction(self, user_id: str, interaction_type: str, value: float):
|
||||
"""添加交互记录"""
|
||||
self.interaction_history.append({
|
||||
"user_id": user_id,
|
||||
"type": interaction_type,
|
||||
"value": value,
|
||||
"timestamp": time.time()
|
||||
})
|
||||
|
||||
# 保持历史记录大小
|
||||
if len(self.interaction_history) > self.max_history_size:
|
||||
self.interaction_history = self.interaction_history[-self.max_history_size:]
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""基于用户交互历史计算兴趣度"""
|
||||
user_id = context.get("user_id")
|
||||
if not user_id:
|
||||
return 0.3
|
||||
|
||||
# 获取该用户的最近交互记录
|
||||
user_interactions = [
|
||||
interaction for interaction in self.interaction_history
|
||||
if interaction["user_id"] == user_id
|
||||
]
|
||||
|
||||
if not user_interactions:
|
||||
return 0.3
|
||||
|
||||
# 计算加权平均(最近的交互权重更高)
|
||||
total_weight = 0.0
|
||||
weighted_sum = 0.0
|
||||
|
||||
for interaction in user_interactions[-20:]: # 最近20次交互
|
||||
age = time.time() - interaction["timestamp"]
|
||||
weight = max(0.1, 1.0 - age / (7 * 24 * 3600)) # 7天内衰减
|
||||
|
||||
weighted_sum += interaction["value"] * weight
|
||||
total_weight += weight
|
||||
|
||||
if total_weight > 0:
|
||||
return min(1.0, weighted_sum / total_weight)
|
||||
else:
|
||||
return 0.3
|
||||
|
||||
def get_confidence(self) -> float:
|
||||
return 0.6
|
||||
|
||||
|
||||
class InterestManager:
|
||||
"""兴趣度管理器 - 统一管理所有兴趣度计算"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calculators: Dict[InterestSourceType, InterestCalculator] = {
|
||||
InterestSourceType.MESSAGE_CONTENT: MessageContentInterestCalculator(),
|
||||
InterestSourceType.TOPIC_RELEVANCE: TopicInterestCalculator(),
|
||||
InterestSourceType.USER_INTERACTION: UserInteractionInterestCalculator(),
|
||||
}
|
||||
|
||||
# 权重配置
|
||||
self.source_weights: Dict[InterestSourceType, float] = {
|
||||
InterestSourceType.MESSAGE_CONTENT: 0.4,
|
||||
InterestSourceType.TOPIC_RELEVANCE: 0.3,
|
||||
InterestSourceType.USER_INTERACTION: 0.3,
|
||||
}
|
||||
|
||||
# 兴趣度缓存
|
||||
self.interest_cache: Dict[str, Tuple[float, float]] = {} # message_id -> (value, timestamp)
|
||||
self.cache_ttl: int = 300 # 5分钟缓存
|
||||
|
||||
# 统计信息
|
||||
self.stats: Dict[str, Union[int, float, List[str]]] = {
|
||||
"total_calculations": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
"average_calculation_time": 0.0,
|
||||
"calculator_usage": {calc_type.value: 0 for calc_type in InterestSourceType}
|
||||
}
|
||||
|
||||
logger.info("兴趣度管理器初始化完成")
|
||||
|
||||
def calculate_message_interest(self, message: Dict[str, Any], context: Dict[str, Any]) -> float:
|
||||
"""计算消息兴趣度"""
|
||||
start_time = time.time()
|
||||
message_id = message.get("message_id", "")
|
||||
|
||||
# 更新统计
|
||||
self.stats["total_calculations"] += 1
|
||||
|
||||
# 检查缓存
|
||||
if message_id in self.interest_cache:
|
||||
cached_value, cached_time = self.interest_cache[message_id]
|
||||
if time.time() - cached_time < self.cache_ttl:
|
||||
self.stats["cache_hits"] += 1
|
||||
logger.debug(f"使用缓存兴趣度: {message_id} = {cached_value:.3f}")
|
||||
return cached_value
|
||||
else:
|
||||
self.stats["cache_misses"] += 1
|
||||
|
||||
# 构建计算上下文
|
||||
calc_context: Dict[str, Any] = {
|
||||
"message": message,
|
||||
"user_id": message.get("user_id"),
|
||||
**context
|
||||
}
|
||||
|
||||
# 计算各来源的兴趣度
|
||||
source_scores: Dict[InterestSourceType, float] = {}
|
||||
total_confidence = 0.0
|
||||
|
||||
for source_type, calculator in self.calculators.items():
|
||||
try:
|
||||
score = calculator.calculate(calc_context)
|
||||
confidence = calculator.get_confidence()
|
||||
|
||||
source_scores[source_type] = score
|
||||
total_confidence += confidence
|
||||
|
||||
# 更新计算器使用统计
|
||||
self.stats["calculator_usage"][source_type.value] += 1
|
||||
|
||||
logger.debug(f"{source_type.value} 兴趣度: {score:.3f} (置信度: {confidence:.3f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算 {source_type.value} 兴趣度失败: {e}")
|
||||
source_scores[source_type] = 0.3
|
||||
|
||||
# 加权计算最终兴趣度
|
||||
final_interest = 0.0
|
||||
total_weight = 0.0
|
||||
|
||||
for source_type, score in source_scores.items():
|
||||
weight = self.source_weights.get(source_type, 0.0)
|
||||
final_interest += score * weight
|
||||
total_weight += weight
|
||||
|
||||
if total_weight > 0:
|
||||
final_interest /= total_weight
|
||||
|
||||
# 确保在合理范围内
|
||||
final_interest = max(0.0, min(1.0, final_interest))
|
||||
|
||||
# 缓存结果
|
||||
self.interest_cache[message_id] = (final_interest, time.time())
|
||||
|
||||
# 清理过期缓存
|
||||
self._cleanup_cache()
|
||||
|
||||
# 更新平均计算时间
|
||||
calculation_time = time.time() - start_time
|
||||
total_calculations = self.stats["total_calculations"]
|
||||
self.stats["average_calculation_time"] = (
|
||||
(self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time)
|
||||
/ total_calculations
|
||||
)
|
||||
|
||||
logger.info(f"消息 {message_id} 最终兴趣度: {final_interest:.3f} (耗时: {calculation_time:.3f}s)")
|
||||
return final_interest
|
||||
|
||||
def update_topic_interest(self, message: Dict[str, Any], interest_value: float) -> None:
|
||||
"""更新话题兴趣度"""
|
||||
topic_calc = self.calculators.get(InterestSourceType.TOPIC_RELEVANCE)
|
||||
if isinstance(topic_calc, TopicInterestCalculator):
|
||||
# 提取关键词作为话题
|
||||
keywords = message.get("key_words", "[]")
|
||||
try:
|
||||
import json
|
||||
keyword_list: List[str] = json.loads(keywords) if keywords else []
|
||||
for keyword in keyword_list[:3]: # 更新前3个关键词
|
||||
topic_calc.update_topic_interest(keyword, interest_value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
def add_user_interaction(self, user_id: str, interaction_type: str, value: float) -> None:
|
||||
"""添加用户交互记录"""
|
||||
interaction_calc = self.calculators.get(InterestSourceType.USER_INTERACTION)
|
||||
if isinstance(interaction_calc, UserInteractionInterestCalculator):
|
||||
interaction_calc.add_interaction(user_id, interaction_type, value)
|
||||
|
||||
def get_topic_interests(self) -> Dict[str, float]:
|
||||
"""获取所有话题兴趣度"""
|
||||
topic_calc = self.calculators.get(InterestSourceType.TOPIC_RELEVANCE)
|
||||
if isinstance(topic_calc, TopicInterestCalculator):
|
||||
return topic_calc.topic_interests.copy()
|
||||
return {}
|
||||
|
||||
def _cleanup_cache(self) -> None:
|
||||
"""清理过期缓存"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
message_id for message_id, (_, timestamp) in self.interest_cache.items()
|
||||
if current_time - timestamp > self.cache_ttl
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
del self.interest_cache[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.debug(f"清理了 {len(expired_keys)} 个过期兴趣度缓存")
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
return {
|
||||
"cache_size": len(self.interest_cache),
|
||||
"topic_count": len(self.get_topic_interests()),
|
||||
"calculators": list(self.calculators.keys()),
|
||||
"performance_stats": self.stats.copy(),
|
||||
}
|
||||
|
||||
def add_calculator(self, source_type: InterestSourceType, calculator: InterestCalculator) -> None:
|
||||
"""添加自定义计算器"""
|
||||
self.calculators[source_type] = calculator
|
||||
logger.info(f"添加计算器: {source_type.value}")
|
||||
|
||||
def remove_calculator(self, source_type: InterestSourceType) -> None:
|
||||
"""移除计算器"""
|
||||
if source_type in self.calculators:
|
||||
del self.calculators[source_type]
|
||||
logger.info(f"移除计算器: {source_type.value}")
|
||||
|
||||
def set_source_weight(self, source_type: InterestSourceType, weight: float) -> None:
|
||||
"""设置来源权重"""
|
||||
self.source_weights[source_type] = max(0.0, min(1.0, weight))
|
||||
logger.info(f"设置 {source_type.value} 权重: {weight}")
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清空缓存"""
|
||||
self.interest_cache.clear()
|
||||
logger.info("清空兴趣度缓存")
|
||||
|
||||
def get_cache_hit_rate(self) -> float:
|
||||
"""获取缓存命中率"""
|
||||
total_requests = self.stats.get("cache_hits", 0) + self.stats.get("cache_misses", 0)
|
||||
if total_requests == 0:
|
||||
return 0.0
|
||||
return self.stats["cache_hits"] / total_requests
|
||||
|
||||
|
||||
# 全局兴趣度管理器实例
|
||||
interest_manager = InterestManager()
|
||||
@@ -1,14 +1,26 @@
|
||||
"""
|
||||
消息管理模块
|
||||
管理每个聊天流的上下文信息,包含历史记录和未读消息,定期检查并处理新消息
|
||||
消息管理器模块
|
||||
提供统一的消息管理、上下文管理和分发调度功能
|
||||
"""
|
||||
|
||||
from .message_manager import MessageManager, message_manager
|
||||
from src.common.data_models.message_manager_data_model import (
|
||||
StreamContext,
|
||||
MessageStatus,
|
||||
MessageManagerStats,
|
||||
StreamStats,
|
||||
from .context_manager import StreamContextManager, context_manager
|
||||
from .distribution_manager import (
|
||||
DistributionManager,
|
||||
DistributionPriority,
|
||||
DistributionTask,
|
||||
StreamDistributionState,
|
||||
distribution_manager
|
||||
)
|
||||
|
||||
__all__ = ["MessageManager", "message_manager", "StreamContext", "MessageStatus", "MessageManagerStats", "StreamStats"]
|
||||
__all__ = [
|
||||
"MessageManager",
|
||||
"message_manager",
|
||||
"StreamContextManager",
|
||||
"context_manager",
|
||||
"DistributionManager",
|
||||
"DistributionPriority",
|
||||
"DistributionTask",
|
||||
"StreamDistributionState",
|
||||
"distribution_manager"
|
||||
]
|
||||
1072
src/chat/message_manager/context_manager.py
Normal file
1072
src/chat/message_manager/context_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
1004
src/chat/message_manager/distribution_manager.py
Normal file
1004
src/chat/message_manager/distribution_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -18,6 +18,7 @@ from src.plugin_system.base.component_types import ChatMode
|
||||
from .sleep_manager.sleep_manager import SleepManager
|
||||
from .sleep_manager.wakeup_manager import WakeUpManager
|
||||
from src.config.config import global_config
|
||||
from . import context_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
@@ -29,7 +30,6 @@ class MessageManager:
|
||||
"""消息管理器"""
|
||||
|
||||
def __init__(self, check_interval: float = 5.0):
|
||||
self.stream_contexts: Dict[str, StreamContext] = {}
|
||||
self.check_interval = check_interval # 检查间隔(秒)
|
||||
self.is_running = False
|
||||
self.manager_task: Optional[asyncio.Task] = None
|
||||
@@ -45,6 +45,9 @@ class MessageManager:
|
||||
self.sleep_manager = SleepManager()
|
||||
self.wakeup_manager = WakeUpManager(self.sleep_manager)
|
||||
|
||||
# 初始化上下文管理器
|
||||
self.context_manager = context_manager.context_manager
|
||||
|
||||
async def start(self):
|
||||
"""启动消息管理器"""
|
||||
if self.is_running:
|
||||
@@ -54,6 +57,7 @@ class MessageManager:
|
||||
self.is_running = True
|
||||
self.manager_task = asyncio.create_task(self._manager_loop())
|
||||
await self.wakeup_manager.start()
|
||||
await self.context_manager.start()
|
||||
logger.info("消息管理器已启动")
|
||||
|
||||
async def stop(self):
|
||||
@@ -64,48 +68,44 @@ class MessageManager:
|
||||
self.is_running = False
|
||||
|
||||
# 停止所有流处理任务
|
||||
for context in self.stream_contexts.values():
|
||||
if context.processing_task and not context.processing_task.done():
|
||||
context.processing_task.cancel()
|
||||
|
||||
# 停止管理器任务
|
||||
# 注意:context_manager 会自己清理任务
|
||||
if self.manager_task and not self.manager_task.done():
|
||||
self.manager_task.cancel()
|
||||
|
||||
await self.wakeup_manager.stop()
|
||||
await self.context_manager.stop()
|
||||
|
||||
logger.info("消息管理器已停止")
|
||||
|
||||
def add_message(self, stream_id: str, message: DatabaseMessages):
|
||||
"""添加消息到指定聊天流"""
|
||||
# 获取或创建流上下文
|
||||
if stream_id not in self.stream_contexts:
|
||||
self.stream_contexts[stream_id] = StreamContext(stream_id=stream_id)
|
||||
self.stats.total_streams += 1
|
||||
# 使用 context_manager 添加消息
|
||||
success = self.context_manager.add_message_to_context(stream_id, message)
|
||||
|
||||
context = self.stream_contexts[stream_id]
|
||||
context.set_chat_mode(ChatMode.FOCUS)
|
||||
context.add_message(message)
|
||||
|
||||
logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}")
|
||||
if success:
|
||||
logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}")
|
||||
else:
|
||||
logger.warning(f"添加消息到聊天流 {stream_id} 失败")
|
||||
|
||||
def update_message_and_refresh_energy(
|
||||
self,
|
||||
stream_id: str,
|
||||
message_id: str,
|
||||
interest_degree: float = None,
|
||||
interest_value: float = None,
|
||||
actions: list = None,
|
||||
should_reply: bool = None,
|
||||
):
|
||||
"""更新消息信息"""
|
||||
if stream_id in self.stream_contexts:
|
||||
context = self.stream_contexts[stream_id]
|
||||
context.update_message_info(message_id, interest_degree, actions, should_reply)
|
||||
# 使用 context_manager 更新消息信息
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if context:
|
||||
context.update_message_info(message_id, interest_value, actions, should_reply)
|
||||
|
||||
def add_action_and_refresh_energy(self, stream_id: str, message_id: str, action: str):
|
||||
"""添加动作到消息"""
|
||||
if stream_id in self.stream_contexts:
|
||||
context = self.stream_contexts[stream_id]
|
||||
# 使用 context_manager 添加动作到消息
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if context:
|
||||
context.add_action_to_message(message_id, action)
|
||||
|
||||
async def _manager_loop(self):
|
||||
@@ -136,19 +136,23 @@ class MessageManager:
|
||||
active_streams = 0
|
||||
total_unread = 0
|
||||
|
||||
for stream_id, context in self.stream_contexts.items():
|
||||
if not context.is_active:
|
||||
# 使用 context_manager 获取活跃的流
|
||||
active_stream_ids = self.context_manager.get_active_streams()
|
||||
|
||||
for stream_id in active_stream_ids:
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context:
|
||||
continue
|
||||
|
||||
active_streams += 1
|
||||
|
||||
# 检查是否有未读消息
|
||||
unread_messages = context.get_unread_messages()
|
||||
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||
if unread_messages:
|
||||
total_unread += len(unread_messages)
|
||||
|
||||
# 如果没有处理任务,创建一个
|
||||
if not context.processing_task or context.processing_task.done():
|
||||
if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done():
|
||||
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
|
||||
|
||||
# 更新统计
|
||||
@@ -157,14 +161,13 @@ class MessageManager:
|
||||
|
||||
async def _process_stream_messages(self, stream_id: str):
|
||||
"""处理指定聊天流的消息"""
|
||||
if stream_id not in self.stream_contexts:
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context:
|
||||
return
|
||||
|
||||
context = self.stream_contexts[stream_id]
|
||||
|
||||
try:
|
||||
# 获取未读消息
|
||||
unread_messages = context.get_unread_messages()
|
||||
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||
if not unread_messages:
|
||||
return
|
||||
|
||||
@@ -205,7 +208,7 @@ class MessageManager:
|
||||
|
||||
# 处理结果,标记消息为已读
|
||||
if results.get("success", False):
|
||||
self._clear_all_unread_messages(context)
|
||||
self._clear_all_unread_messages(stream_id)
|
||||
logger.debug(f"聊天流 {stream_id} 处理成功,清除了 {len(unread_messages)} 条未读消息")
|
||||
else:
|
||||
logger.warning(f"聊天流 {stream_id} 处理失败: {results.get('error_message', '未知错误')}")
|
||||
@@ -213,7 +216,7 @@ class MessageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"处理聊天流 {stream_id} 时发生异常,将清除所有未读消息: {e}")
|
||||
# 出现异常时也清除未读消息,避免重复处理
|
||||
self._clear_all_unread_messages(context)
|
||||
self._clear_all_unread_messages(stream_id)
|
||||
raise
|
||||
|
||||
logger.debug(f"聊天流 {stream_id} 消息处理完成")
|
||||
@@ -226,35 +229,36 @@ class MessageManager:
|
||||
|
||||
def deactivate_stream(self, stream_id: str):
|
||||
"""停用聊天流"""
|
||||
if stream_id in self.stream_contexts:
|
||||
context = self.stream_contexts[stream_id]
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if context:
|
||||
context.is_active = False
|
||||
|
||||
# 取消处理任务
|
||||
if context.processing_task and not context.processing_task.done():
|
||||
if hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done():
|
||||
context.processing_task.cancel()
|
||||
|
||||
logger.info(f"停用聊天流: {stream_id}")
|
||||
|
||||
def activate_stream(self, stream_id: str):
|
||||
"""激活聊天流"""
|
||||
if stream_id in self.stream_contexts:
|
||||
self.stream_contexts[stream_id].is_active = True
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if context:
|
||||
context.is_active = True
|
||||
logger.info(f"激活聊天流: {stream_id}")
|
||||
|
||||
def get_stream_stats(self, stream_id: str) -> Optional[StreamStats]:
|
||||
"""获取聊天流统计"""
|
||||
if stream_id not in self.stream_contexts:
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context:
|
||||
return None
|
||||
|
||||
context = self.stream_contexts[stream_id]
|
||||
return StreamStats(
|
||||
stream_id=stream_id,
|
||||
is_active=context.is_active,
|
||||
unread_count=len(context.get_unread_messages()),
|
||||
unread_count=len(self.context_manager.get_unread_messages(stream_id)),
|
||||
history_count=len(context.history_messages),
|
||||
last_check_time=context.last_check_time,
|
||||
has_active_task=bool(context.processing_task and not context.processing_task.done()),
|
||||
has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()),
|
||||
)
|
||||
|
||||
def get_manager_stats(self) -> Dict[str, Any]:
|
||||
@@ -270,18 +274,9 @@ class MessageManager:
|
||||
|
||||
def cleanup_inactive_streams(self, max_inactive_hours: int = 24):
|
||||
"""清理不活跃的聊天流"""
|
||||
current_time = time.time()
|
||||
max_inactive_seconds = max_inactive_hours * 3600
|
||||
|
||||
inactive_streams = []
|
||||
for stream_id, context in self.stream_contexts.items():
|
||||
if current_time - context.last_check_time > max_inactive_seconds and not context.get_unread_messages():
|
||||
inactive_streams.append(stream_id)
|
||||
|
||||
for stream_id in inactive_streams:
|
||||
self.deactivate_stream(stream_id)
|
||||
del self.stream_contexts[stream_id]
|
||||
logger.info(f"清理不活跃聊天流: {stream_id}")
|
||||
# 使用 context_manager 的自动清理功能
|
||||
self.context_manager.cleanup_inactive_contexts(max_inactive_hours * 3600)
|
||||
logger.info("已启动不活跃聊天流清理")
|
||||
|
||||
async def _check_and_handle_interruption(self, context: StreamContext, stream_id: str):
|
||||
"""检查并处理消息打断"""
|
||||
@@ -330,90 +325,29 @@ class MessageManager:
|
||||
logger.debug(f"聊天流 {stream_id} 未触发打断,打断概率: {interruption_probability:.2f}")
|
||||
|
||||
def _calculate_stream_distribution_interval(self, context: StreamContext) -> float:
|
||||
"""计算单个聊天流的分发周期 - 基于阈值感知的focus_energy"""
|
||||
"""计算单个聊天流的分发周期 - 使用重构后的能量管理器"""
|
||||
if not global_config.chat.dynamic_distribution_enabled:
|
||||
return self.check_interval # 使用固定间隔
|
||||
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
try:
|
||||
from src.chat.energy_system import energy_manager
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||
# 获取该流的focus_energy(新的阈值感知版本)
|
||||
focus_energy = 0.5 # 默认值
|
||||
avg_message_interest = 0.5 # 默认平均兴趣度
|
||||
# 获取聊天流和能量
|
||||
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||
if chat_stream:
|
||||
focus_energy = chat_stream.focus_energy
|
||||
# 使用能量管理器获取分发周期
|
||||
interval = energy_manager.get_distribution_interval(focus_energy)
|
||||
logger.debug(f"流 {context.stream_id} 分发周期: {interval:.2f}s (能量: {focus_energy:.3f})")
|
||||
return interval
|
||||
else:
|
||||
# 默认间隔
|
||||
return self.check_interval
|
||||
|
||||
if chat_stream:
|
||||
focus_energy = chat_stream.focus_energy
|
||||
# 获取平均消息兴趣度用于更精确的计算 - 从StreamContext获取
|
||||
history_messages = context.get_history_messages(limit=100)
|
||||
unread_messages = context.get_unread_messages()
|
||||
all_messages = history_messages + unread_messages
|
||||
|
||||
if all_messages:
|
||||
message_interests = [msg.interest_degree for msg in all_messages if hasattr(msg, "interest_degree")]
|
||||
avg_message_interest = sum(message_interests) / len(message_interests) if message_interests else 0.5
|
||||
|
||||
# 获取AFC阈值用于参考,添加None值检查
|
||||
reply_threshold = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
|
||||
non_reply_threshold = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2)
|
||||
high_match_threshold = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8)
|
||||
|
||||
# 使用配置参数
|
||||
base_interval = global_config.chat.dynamic_distribution_base_interval
|
||||
min_interval = global_config.chat.dynamic_distribution_min_interval
|
||||
max_interval = global_config.chat.dynamic_distribution_max_interval
|
||||
jitter_factor = global_config.chat.dynamic_distribution_jitter_factor
|
||||
|
||||
# 基于阈值感知的智能分发周期计算
|
||||
if avg_message_interest >= high_match_threshold:
|
||||
# 超高兴趣度:极快响应 (1-2秒)
|
||||
interval_multiplier = 0.3 + (focus_energy - 0.7) * 2.0
|
||||
elif avg_message_interest >= reply_threshold:
|
||||
# 高兴趣度:快速响应 (2-6秒)
|
||||
gap_from_reply = (avg_message_interest - reply_threshold) / (high_match_threshold - reply_threshold)
|
||||
interval_multiplier = 0.6 + gap_from_reply * 0.4
|
||||
elif avg_message_interest >= non_reply_threshold:
|
||||
# 中等兴趣度:正常响应 (6-15秒)
|
||||
gap_from_non_reply = (avg_message_interest - non_reply_threshold) / (reply_threshold - non_reply_threshold)
|
||||
interval_multiplier = 1.2 + gap_from_non_reply * 1.8
|
||||
else:
|
||||
# 低兴趣度:缓慢响应 (15-30秒)
|
||||
gap_ratio = max(0, avg_message_interest / non_reply_threshold)
|
||||
interval_multiplier = 3.0 + (1.0 - gap_ratio) * 3.0
|
||||
|
||||
# 应用focus_energy微调
|
||||
energy_adjustment = 1.0 + (focus_energy - 0.5) * 0.5
|
||||
interval = base_interval * interval_multiplier * energy_adjustment
|
||||
|
||||
# 添加随机扰动避免同步
|
||||
import random
|
||||
|
||||
jitter = random.uniform(1.0 - jitter_factor, 1.0 + jitter_factor)
|
||||
final_interval = interval * jitter
|
||||
|
||||
# 限制在合理范围内
|
||||
final_interval = max(min_interval, min(max_interval, final_interval))
|
||||
|
||||
# 根据兴趣度级别调整日志级别
|
||||
if avg_message_interest >= high_match_threshold:
|
||||
log_level = "info"
|
||||
elif avg_message_interest >= reply_threshold:
|
||||
log_level = "info"
|
||||
else:
|
||||
log_level = "debug"
|
||||
|
||||
log_msg = (
|
||||
f"流 {context.stream_id} 分发周期: {final_interval:.2f}s | "
|
||||
f"focus_energy: {focus_energy:.3f} | "
|
||||
f"avg_interest: {avg_message_interest:.3f} | "
|
||||
f"阈值参考: {non_reply_threshold:.2f}/{reply_threshold:.2f}/{high_match_threshold:.2f}"
|
||||
)
|
||||
|
||||
if log_level == "info":
|
||||
logger.info(log_msg)
|
||||
else:
|
||||
logger.debug(log_msg)
|
||||
|
||||
return final_interval
|
||||
except Exception as e:
|
||||
logger.error(f"计算分发周期失败: {e}")
|
||||
return self.check_interval
|
||||
|
||||
def _calculate_next_manager_delay(self) -> float:
|
||||
"""计算管理器下次检查的延迟时间"""
|
||||
@@ -421,8 +355,10 @@ class MessageManager:
|
||||
min_delay = float("inf")
|
||||
|
||||
# 找到最近需要检查的流
|
||||
for context in self.stream_contexts.values():
|
||||
if not context.is_active:
|
||||
active_stream_ids = self.context_manager.get_active_streams()
|
||||
for stream_id in active_stream_ids:
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context or not context.is_active:
|
||||
continue
|
||||
|
||||
time_until_check = context.next_check_time - current_time
|
||||
@@ -444,8 +380,12 @@ class MessageManager:
|
||||
current_time = time.time()
|
||||
processed_streams = 0
|
||||
|
||||
for stream_id, context in self.stream_contexts.items():
|
||||
if not context.is_active:
|
||||
# 使用 context_manager 获取活跃的流
|
||||
active_stream_ids = self.context_manager.get_active_streams()
|
||||
|
||||
for stream_id in active_stream_ids:
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context or not context.is_active:
|
||||
continue
|
||||
|
||||
# 检查是否达到检查时间
|
||||
@@ -463,7 +403,7 @@ class MessageManager:
|
||||
context.next_check_time = current_time + context.distribution_interval
|
||||
|
||||
# 检查未读消息
|
||||
unread_messages = context.get_unread_messages()
|
||||
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||
if unread_messages:
|
||||
processed_streams += 1
|
||||
self.stats.total_unread_messages = len(unread_messages)
|
||||
@@ -493,7 +433,7 @@ class MessageManager:
|
||||
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
|
||||
|
||||
# 更新活跃流计数
|
||||
active_count = sum(1 for ctx in self.stream_contexts.values() if ctx.is_active)
|
||||
active_count = len(self.context_manager.get_active_streams())
|
||||
self.stats.active_streams = active_count
|
||||
|
||||
if processed_streams > 0:
|
||||
@@ -501,13 +441,16 @@ class MessageManager:
|
||||
|
||||
async def _check_all_streams_with_priority(self):
|
||||
"""按优先级检查所有聊天流,高focus_energy的流优先处理"""
|
||||
if not self.stream_contexts:
|
||||
if not self.context_manager.get_active_streams():
|
||||
return
|
||||
|
||||
# 获取活跃的聊天流并按focus_energy排序
|
||||
active_streams = []
|
||||
for stream_id, context in self.stream_contexts.items():
|
||||
if not context.is_active:
|
||||
active_stream_ids = self.context_manager.get_active_streams()
|
||||
|
||||
for stream_id in active_stream_ids:
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if not context or not context.is_active:
|
||||
continue
|
||||
|
||||
# 获取focus_energy,如果不存在则使用默认值
|
||||
@@ -533,12 +476,12 @@ class MessageManager:
|
||||
active_stream_count += 1
|
||||
|
||||
# 检查是否有未读消息
|
||||
unread_messages = context.get_unread_messages()
|
||||
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||
if unread_messages:
|
||||
total_unread += len(unread_messages)
|
||||
|
||||
# 如果没有处理任务,创建一个
|
||||
if not context.processing_task or context.processing_task.done():
|
||||
if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done():
|
||||
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
|
||||
|
||||
# 高优先级流的额外日志
|
||||
@@ -554,63 +497,40 @@ class MessageManager:
|
||||
self.stats.total_unread_messages = total_unread
|
||||
|
||||
def _calculate_stream_priority(self, context: StreamContext, focus_energy: float) -> float:
|
||||
"""计算聊天流的优先级分数"""
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||
# 基础优先级:focus_energy
|
||||
"""计算聊天流的优先级分数 - 简化版本,主要使用focus_energy"""
|
||||
# 使用重构后的能量管理器,主要依赖focus_energy
|
||||
base_priority = focus_energy
|
||||
|
||||
# 未读消息数量加权
|
||||
# 简单的未读消息加权
|
||||
unread_count = len(context.get_unread_messages())
|
||||
message_count_bonus = min(unread_count * 0.1, 0.3) # 最多30%加成
|
||||
message_bonus = min(unread_count * 0.05, 0.2) # 最多20%加成
|
||||
|
||||
# 时间加权:最近活跃的流优先级更高
|
||||
# 简单的时间加权
|
||||
current_time = time.time()
|
||||
time_since_active = current_time - context.last_check_time
|
||||
time_penalty = max(0, 1.0 - time_since_active / 3600.0) # 1小时内无惩罚
|
||||
|
||||
# 连续无回复惩罚 - 从StreamContext历史消息计算
|
||||
if chat_stream:
|
||||
# 计算连续无回复次数
|
||||
consecutive_no_reply = 0
|
||||
all_messages = context.get_history_messages(limit=50) + context.get_unread_messages()
|
||||
for msg in reversed(all_messages):
|
||||
if hasattr(msg, "should_reply") and msg.should_reply:
|
||||
if not (hasattr(msg, "actions") and "reply" in (msg.actions or [])):
|
||||
consecutive_no_reply += 1
|
||||
else:
|
||||
break
|
||||
no_reply_penalty = max(0, 1.0 - consecutive_no_reply * 0.05) # 每次无回复降低5%
|
||||
else:
|
||||
no_reply_penalty = 1.0
|
||||
|
||||
# 综合优先级计算
|
||||
final_priority = (
|
||||
base_priority * 0.6 # 基础兴趣度权重60%
|
||||
+ message_count_bonus * 0.2 # 消息数量权重20%
|
||||
+ time_penalty * 0.1 # 时间权重10%
|
||||
+ no_reply_penalty * 0.1 # 回复状态权重10%
|
||||
)
|
||||
time_bonus = max(0, 1.0 - time_since_active / 7200.0) * 0.1 # 2小时内衰减
|
||||
|
||||
final_priority = base_priority + message_bonus + time_bonus
|
||||
return max(0.0, min(1.0, final_priority))
|
||||
|
||||
def _clear_all_unread_messages(self, context: StreamContext):
|
||||
def _clear_all_unread_messages(self, stream_id: str):
|
||||
"""清除指定上下文中的所有未读消息,防止意外情况导致消息一直未读"""
|
||||
unread_messages = context.get_unread_messages()
|
||||
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||
if not unread_messages:
|
||||
return
|
||||
|
||||
logger.warning(f"正在清除 {len(unread_messages)} 条未读消息")
|
||||
|
||||
# 将所有未读消息标记为已读并移动到历史记录
|
||||
for msg in unread_messages[:]: # 使用切片复制避免迭代时修改列表
|
||||
try:
|
||||
context.mark_message_as_read(msg.message_id)
|
||||
self.stats.total_processed_messages += 1
|
||||
logger.debug(f"强制清除消息 {msg.message_id},标记为已读")
|
||||
except Exception as e:
|
||||
logger.error(f"清除消息 {msg.message_id} 时出错: {e}")
|
||||
# 将所有未读消息标记为已读
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
if context:
|
||||
for msg in unread_messages[:]: # 使用切片复制避免迭代时修改列表
|
||||
try:
|
||||
context.mark_message_as_read(msg.message_id)
|
||||
self.stats.total_processed_messages += 1
|
||||
logger.debug(f"强制清除消息 {msg.message_id},标记为已读")
|
||||
except Exception as e:
|
||||
logger.error(f"清除消息 {msg.message_id} 时出错: {e}")
|
||||
|
||||
|
||||
# 创建全局消息管理器实例
|
||||
|
||||
@@ -120,186 +120,209 @@ class ChatStream:
|
||||
"""设置聊天消息上下文"""
|
||||
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
import json
|
||||
|
||||
# 简化转换,实际可能需要更完整的转换逻辑
|
||||
# 安全获取message_info中的数据
|
||||
message_info = getattr(message, "message_info", {})
|
||||
user_info = getattr(message_info, "user_info", {})
|
||||
group_info = getattr(message_info, "group_info", {})
|
||||
|
||||
# 提取reply_to信息(从message_segment中查找reply类型的段)
|
||||
reply_to = None
|
||||
if hasattr(message, "message_segment") and message.message_segment:
|
||||
reply_to = self._extract_reply_from_segment(message.message_segment)
|
||||
|
||||
# 完整的数据转移逻辑
|
||||
db_message = DatabaseMessages(
|
||||
# 基础消息信息
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
time=getattr(message, "time", time.time()),
|
||||
chat_id=getattr(message, "chat_id", ""),
|
||||
user_id=str(getattr(message.message_info, "user_info", {}).user_id)
|
||||
if hasattr(message, "message_info") and hasattr(message.message_info, "user_info")
|
||||
else "",
|
||||
user_nickname=getattr(message.message_info, "user_info", {}).user_nickname
|
||||
if hasattr(message, "message_info") and hasattr(message.message_info, "user_info")
|
||||
else "",
|
||||
user_platform=getattr(message.message_info, "user_info", {}).platform
|
||||
if hasattr(message, "message_info") and hasattr(message.message_info, "user_info")
|
||||
else "",
|
||||
priority_mode=getattr(message, "priority_mode", None),
|
||||
priority_info=str(getattr(message, "priority_info", None))
|
||||
if hasattr(message, "priority_info") and message.priority_info
|
||||
chat_id=self._generate_chat_id(message_info),
|
||||
reply_to=reply_to,
|
||||
# 兴趣度相关
|
||||
interest_value=getattr(message, "interest_value", 0.0),
|
||||
# 关键词
|
||||
key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False)
|
||||
if getattr(message, "key_words", None)
|
||||
else None,
|
||||
additional_config=getattr(getattr(message, "message_info", {}), "additional_config", None),
|
||||
key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False)
|
||||
if getattr(message, "key_words_lite", None)
|
||||
else None,
|
||||
# 消息状态标记
|
||||
is_mentioned=getattr(message, "is_mentioned", None),
|
||||
is_at=getattr(message, "is_at", False),
|
||||
is_emoji=getattr(message, "is_emoji", False),
|
||||
is_picid=getattr(message, "is_picid", False),
|
||||
is_voice=getattr(message, "is_voice", False),
|
||||
is_video=getattr(message, "is_video", False),
|
||||
is_command=getattr(message, "is_command", False),
|
||||
is_notify=getattr(message, "is_notify", False),
|
||||
# 消息内容
|
||||
processed_plain_text=getattr(message, "processed_plain_text", ""),
|
||||
display_message=getattr(message, "processed_plain_text", ""), # 默认使用processed_plain_text
|
||||
# 优先级信息
|
||||
priority_mode=getattr(message, "priority_mode", None),
|
||||
priority_info=json.dumps(getattr(message, "priority_info", None))
|
||||
if getattr(message, "priority_info", None)
|
||||
else None,
|
||||
# 额外配置
|
||||
additional_config=getattr(message_info, "additional_config", None),
|
||||
# 用户信息
|
||||
user_id=str(getattr(user_info, "user_id", "")),
|
||||
user_nickname=getattr(user_info, "user_nickname", ""),
|
||||
user_cardname=getattr(user_info, "user_cardname", None),
|
||||
user_platform=getattr(user_info, "platform", ""),
|
||||
# 群组信息
|
||||
chat_info_group_id=getattr(group_info, "group_id", None),
|
||||
chat_info_group_name=getattr(group_info, "group_name", None),
|
||||
chat_info_group_platform=getattr(group_info, "platform", None),
|
||||
# 聊天流信息
|
||||
chat_info_user_id=str(getattr(user_info, "user_id", "")),
|
||||
chat_info_user_nickname=getattr(user_info, "user_nickname", ""),
|
||||
chat_info_user_cardname=getattr(user_info, "user_cardname", None),
|
||||
chat_info_user_platform=getattr(user_info, "platform", ""),
|
||||
chat_info_stream_id=self.stream_id,
|
||||
chat_info_platform=self.platform,
|
||||
chat_info_create_time=self.create_time,
|
||||
chat_info_last_active_time=self.last_active_time,
|
||||
# 新增兴趣度系统字段 - 添加安全处理
|
||||
actions=self._safe_get_actions(message),
|
||||
should_reply=getattr(message, "should_reply", False),
|
||||
)
|
||||
|
||||
self.stream_context.set_current_message(db_message)
|
||||
self.stream_context.priority_mode = getattr(message, "priority_mode", None)
|
||||
self.stream_context.priority_info = getattr(message, "priority_info", None)
|
||||
|
||||
# 调试日志:记录数据转移情况
|
||||
logger.debug(f"消息数据转移完成 - message_id: {db_message.message_id}, "
|
||||
f"chat_id: {db_message.chat_id}, "
|
||||
f"is_mentioned: {db_message.is_mentioned}, "
|
||||
f"is_emoji: {db_message.is_emoji}, "
|
||||
f"is_picid: {db_message.is_picid}, "
|
||||
f"interest_value: {db_message.interest_value}")
|
||||
|
||||
def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]:
|
||||
"""安全获取消息的actions字段"""
|
||||
try:
|
||||
actions = getattr(message, "actions", None)
|
||||
if actions is None:
|
||||
return None
|
||||
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
if isinstance(actions, str):
|
||||
try:
|
||||
import json
|
||||
actions = json.loads(actions)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"无法解析actions JSON字符串: {actions}")
|
||||
return None
|
||||
|
||||
# 确保返回列表类型
|
||||
if isinstance(actions, list):
|
||||
# 过滤掉空值和非字符串元素
|
||||
filtered_actions = [action for action in actions if action is not None and isinstance(action, str)]
|
||||
return filtered_actions if filtered_actions else None
|
||||
else:
|
||||
logger.warning(f"actions字段类型不支持: {type(actions)}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取actions字段失败: {e}")
|
||||
return None
|
||||
|
||||
def _extract_reply_from_segment(self, segment) -> Optional[str]:
|
||||
"""从消息段中提取reply_to信息"""
|
||||
try:
|
||||
if hasattr(segment, "type") and segment.type == "seglist":
|
||||
# 递归搜索seglist中的reply段
|
||||
if hasattr(segment, "data") and segment.data:
|
||||
for seg in segment.data:
|
||||
reply_id = self._extract_reply_from_segment(seg)
|
||||
if reply_id:
|
||||
return reply_id
|
||||
elif hasattr(segment, "type") and segment.type == "reply":
|
||||
# 找到reply段,返回message_id
|
||||
return str(segment.data) if segment.data else None
|
||||
except Exception as e:
|
||||
logger.warning(f"提取reply_to信息失败: {e}")
|
||||
return None
|
||||
|
||||
def _generate_chat_id(self, message_info) -> str:
|
||||
"""生成chat_id,基于群组或用户信息"""
|
||||
try:
|
||||
group_info = getattr(message_info, "group_info", None)
|
||||
user_info = getattr(message_info, "user_info", None)
|
||||
|
||||
if group_info and hasattr(group_info, "group_id") and group_info.group_id:
|
||||
# 群聊:使用群组ID
|
||||
return f"{self.platform}_{group_info.group_id}"
|
||||
elif user_info and hasattr(user_info, "user_id") and user_info.user_id:
|
||||
# 私聊:使用用户ID
|
||||
return f"{self.platform}_{user_info.user_id}_private"
|
||||
else:
|
||||
# 默认:使用stream_id
|
||||
return self.stream_id
|
||||
except Exception as e:
|
||||
logger.warning(f"生成chat_id失败: {e}")
|
||||
return self.stream_id
|
||||
|
||||
@property
|
||||
def focus_energy(self) -> float:
|
||||
"""动态计算的聊天流总体兴趣度,访问时自动更新"""
|
||||
self._focus_energy = self._calculate_dynamic_focus_energy()
|
||||
return self._focus_energy
|
||||
"""使用重构后的能量管理器计算focus_energy"""
|
||||
try:
|
||||
from src.chat.energy_system import energy_manager
|
||||
|
||||
# 获取所有消息
|
||||
history_messages = self.stream_context.get_history_messages(limit=global_config.chat.max_context_size)
|
||||
unread_messages = self.stream_context.get_unread_messages()
|
||||
all_messages = history_messages + unread_messages
|
||||
|
||||
# 获取用户ID
|
||||
user_id = None
|
||||
if self.user_info and hasattr(self.user_info, "user_id"):
|
||||
user_id = str(self.user_info.user_id)
|
||||
|
||||
# 使用能量管理器计算
|
||||
energy = energy_manager.calculate_focus_energy(
|
||||
stream_id=self.stream_id,
|
||||
messages=all_messages,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# 更新内部存储
|
||||
self._focus_energy = energy
|
||||
|
||||
logger.debug(f"聊天流 {self.stream_id} 能量: {energy:.3f}")
|
||||
return energy
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取focus_energy失败: {e}", exc_info=True)
|
||||
# 返回缓存的值或默认值
|
||||
if hasattr(self, '_focus_energy'):
|
||||
return self._focus_energy
|
||||
else:
|
||||
return 0.5
|
||||
|
||||
@focus_energy.setter
|
||||
def focus_energy(self, value: float):
|
||||
"""设置focus_energy值(主要用于初始化或特殊场景)"""
|
||||
self._focus_energy = max(0.0, min(1.0, value))
|
||||
|
||||
def _calculate_dynamic_focus_energy(self) -> float:
|
||||
"""动态计算聊天流的总体兴趣度,使用StreamContext历史消息"""
|
||||
try:
|
||||
# 从StreamContext获取历史消息计算统计数据
|
||||
history_messages = self.stream_context.get_history_messages(limit=global_config.chat.max_context_size)
|
||||
unread_messages = self.stream_context.get_unread_messages()
|
||||
all_messages = history_messages + unread_messages
|
||||
|
||||
# 计算基于历史消息的统计数据
|
||||
if all_messages:
|
||||
# 基础分:平均消息兴趣度
|
||||
message_interests = [msg.interest_degree for msg in all_messages if hasattr(msg, "interest_degree")]
|
||||
avg_message_interest = sum(message_interests) / len(message_interests) if message_interests else 0.3
|
||||
|
||||
# 动作参与度:有动作的消息比例
|
||||
messages_with_actions = [msg for msg in all_messages if hasattr(msg, "actions") and msg.actions]
|
||||
action_rate = len(messages_with_actions) / len(all_messages)
|
||||
|
||||
# 回复活跃度:应该回复且已回复的消息比例
|
||||
should_reply_messages = [
|
||||
msg for msg in all_messages if hasattr(msg, "should_reply") and msg.should_reply
|
||||
]
|
||||
replied_messages = [
|
||||
msg for msg in should_reply_messages if hasattr(msg, "actions") and "reply" in (msg.actions or [])
|
||||
]
|
||||
reply_rate = len(replied_messages) / len(should_reply_messages) if should_reply_messages else 0.0
|
||||
|
||||
# 获取最后交互时间
|
||||
if all_messages:
|
||||
self.last_interaction_time = max(msg.time for msg in all_messages)
|
||||
|
||||
# 连续无回复计算:从最近的未回复消息计数
|
||||
consecutive_no_reply = 0
|
||||
for msg in reversed(all_messages):
|
||||
if hasattr(msg, "should_reply") and msg.should_reply:
|
||||
if not (hasattr(msg, "actions") and "reply" in (msg.actions or [])):
|
||||
consecutive_no_reply += 1
|
||||
else:
|
||||
break
|
||||
else:
|
||||
# 没有历史消息时的默认值
|
||||
avg_message_interest = 0.3
|
||||
action_rate = 0.0
|
||||
reply_rate = 0.0
|
||||
consecutive_no_reply = 0
|
||||
self.last_interaction_time = time.time()
|
||||
|
||||
# 获取用户关系分(对于私聊,群聊无效)
|
||||
relationship_factor = self._get_user_relationship_score()
|
||||
|
||||
# 时间衰减因子:最近活跃度
|
||||
current_time = time.time()
|
||||
if not hasattr(self, "last_interaction_time") or not self.last_interaction_time:
|
||||
self.last_interaction_time = current_time
|
||||
time_since_interaction = current_time - self.last_interaction_time
|
||||
time_decay = max(0.3, 1.0 - min(time_since_interaction / (7 * 24 * 3600), 0.7)) # 7天衰减
|
||||
|
||||
# 连续无回复惩罚
|
||||
no_reply_penalty = max(0.1, 1.0 - consecutive_no_reply * 0.1)
|
||||
|
||||
# 获取AFC系统阈值,添加None值检查
|
||||
reply_threshold = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
|
||||
non_reply_threshold = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2)
|
||||
high_match_threshold = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8)
|
||||
|
||||
# 计算与不同阈值的差距比例
|
||||
reply_gap_ratio = max(0, (avg_message_interest - reply_threshold) / max(0.1, (1.0 - reply_threshold)))
|
||||
non_reply_gap_ratio = max(
|
||||
0, (avg_message_interest - non_reply_threshold) / max(0.1, (1.0 - non_reply_threshold))
|
||||
)
|
||||
high_match_gap_ratio = max(
|
||||
0, (avg_message_interest - high_match_threshold) / max(0.1, (1.0 - high_match_threshold))
|
||||
)
|
||||
|
||||
# 基于阈值差距比例的基础分计算
|
||||
threshold_based_score = (
|
||||
reply_gap_ratio * 0.6 # 回复阈值差距权重60%
|
||||
+ non_reply_gap_ratio * 0.2 # 非回复阈值差距权重20%
|
||||
+ high_match_gap_ratio * 0.2 # 高匹配阈值差距权重20%
|
||||
)
|
||||
|
||||
# 动态权重调整:根据平均兴趣度水平调整权重分配
|
||||
if avg_message_interest >= high_match_threshold:
|
||||
# 高兴趣度:更注重阈值差距
|
||||
threshold_weight = 0.7
|
||||
activity_weight = 0.2
|
||||
relationship_weight = 0.1
|
||||
elif avg_message_interest >= reply_threshold:
|
||||
# 中等兴趣度:平衡权重
|
||||
threshold_weight = 0.5
|
||||
activity_weight = 0.3
|
||||
relationship_weight = 0.2
|
||||
else:
|
||||
# 低兴趣度:更注重活跃度提升
|
||||
threshold_weight = 0.3
|
||||
activity_weight = 0.5
|
||||
relationship_weight = 0.2
|
||||
|
||||
# 计算活跃度得分
|
||||
activity_score = action_rate * 0.6 + reply_rate * 0.4
|
||||
|
||||
# 综合计算:基于阈值的动态加权
|
||||
focus_energy = (
|
||||
(
|
||||
threshold_based_score * threshold_weight # 阈值差距基础分
|
||||
+ activity_score * activity_weight # 活跃度得分
|
||||
+ relationship_factor * relationship_weight # 关系得分
|
||||
+ self.base_interest_energy * 0.05 # 基础兴趣微调
|
||||
)
|
||||
* time_decay
|
||||
* no_reply_penalty
|
||||
)
|
||||
|
||||
# 确保在合理范围内
|
||||
focus_energy = max(0.1, min(1.0, focus_energy))
|
||||
|
||||
# 应用非线性变换增强区分度
|
||||
if focus_energy >= 0.7:
|
||||
# 高兴趣度区域:指数增强,更敏感
|
||||
focus_energy = 0.7 + (focus_energy - 0.7) ** 0.8
|
||||
elif focus_energy >= 0.4:
|
||||
# 中等兴趣度区域:线性保持
|
||||
pass
|
||||
else:
|
||||
# 低兴趣度区域:对数压缩,减少区分度
|
||||
focus_energy = 0.4 * (focus_energy / 0.4) ** 1.2
|
||||
|
||||
return max(0.1, min(1.0, focus_energy))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算动态focus_energy失败: {e}")
|
||||
return self.base_interest_energy
|
||||
|
||||
def _get_user_relationship_score(self) -> float:
|
||||
"""从外部系统获取用户关系分"""
|
||||
"""从新的兴趣度管理系统获取用户关系分"""
|
||||
try:
|
||||
# 尝试从兴趣评分系统获取用户关系分
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
|
||||
chatter_interest_scoring_system,
|
||||
)
|
||||
# 使用新的兴趣度管理系统
|
||||
from src.chat.interest_system import interest_manager
|
||||
|
||||
if self.user_info and hasattr(self.user_info, "user_id"):
|
||||
return chatter_interest_scoring_system.get_user_relationship(str(self.user_info.user_id))
|
||||
user_id = str(self.user_info.user_id)
|
||||
# 获取用户交互历史作为关系分的基础
|
||||
interaction_calc = interest_manager.calculators.get(
|
||||
interest_manager.InterestSourceType.USER_INTERACTION
|
||||
)
|
||||
if interaction_calc:
|
||||
return interaction_calc.calculate({"user_id": user_id})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -378,12 +401,13 @@ class ChatStream:
|
||||
chat_info_platform=db_msg.chat_info_platform,
|
||||
chat_info_create_time=db_msg.chat_info_create_time,
|
||||
chat_info_last_active_time=db_msg.chat_info_last_active_time,
|
||||
# 新增的兴趣度系统字段
|
||||
interest_degree=getattr(db_msg, "interest_degree", 0.0) or 0.0,
|
||||
actions=actions,
|
||||
should_reply=getattr(db_msg, "should_reply", False) or False,
|
||||
)
|
||||
|
||||
# 添加调试日志:检查从数据库加载的interest_value
|
||||
logger.debug(f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}")
|
||||
|
||||
# 标记为已读并添加到历史消息
|
||||
db_message.is_read = True
|
||||
self.stream_context.history_messages.append(db_message)
|
||||
|
||||
@@ -218,4 +218,93 @@ class MessageStorage:
|
||||
except Exception:
|
||||
return match.group(0)
|
||||
|
||||
return re.sub(r"\[图片:([^\]]+)\]", replace_match, text)
|
||||
@staticmethod
|
||||
def update_message_interest_value(message_id: str, interest_value: float) -> None:
|
||||
"""
|
||||
更新数据库中消息的interest_value字段
|
||||
|
||||
Args:
|
||||
message_id: 消息ID
|
||||
interest_value: 兴趣度值
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
# 更新消息的interest_value字段
|
||||
stmt = update(Messages).where(Messages.message_id == message_id).values(interest_value=interest_value)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
logger.debug(f"成功更新消息 {message_id} 的interest_value为 {interest_value}")
|
||||
else:
|
||||
logger.warning(f"未找到消息 {message_id},无法更新interest_value")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息 {message_id} 的interest_value失败: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def fix_zero_interest_values(chat_id: str, since_time: float) -> int:
|
||||
"""
|
||||
修复指定聊天中interest_value为0或null的历史消息记录
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
since_time: 从指定时间开始修复(时间戳)
|
||||
|
||||
Returns:
|
||||
修复的记录数量
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
from sqlalchemy import select, update
|
||||
from src.common.database.sqlalchemy_models import Messages
|
||||
|
||||
# 查找需要修复的记录:interest_value为0、null或很小的值
|
||||
query = select(Messages).where(
|
||||
(Messages.chat_id == chat_id) &
|
||||
(Messages.time >= since_time) &
|
||||
(
|
||||
(Messages.interest_value == 0) |
|
||||
(Messages.interest_value.is_(None)) |
|
||||
(Messages.interest_value < 0.1)
|
||||
)
|
||||
).limit(50) # 限制每次修复的数量,避免性能问题
|
||||
|
||||
messages_to_fix = session.execute(query).scalars().all()
|
||||
fixed_count = 0
|
||||
|
||||
for msg in messages_to_fix:
|
||||
# 为这些消息设置一个合理的默认兴趣度
|
||||
# 可以基于消息长度、内容或其他因素计算
|
||||
default_interest = 0.3 # 默认中等兴趣度
|
||||
|
||||
# 如果消息内容较长,可能是重要消息,兴趣度稍高
|
||||
if hasattr(msg, 'processed_plain_text') and msg.processed_plain_text:
|
||||
text_length = len(msg.processed_plain_text)
|
||||
if text_length > 50: # 长消息
|
||||
default_interest = 0.4
|
||||
elif text_length > 20: # 中等长度消息
|
||||
default_interest = 0.35
|
||||
|
||||
# 如果是被@的消息,兴趣度更高
|
||||
if getattr(msg, 'is_mentioned', False):
|
||||
default_interest = min(default_interest + 0.2, 0.8)
|
||||
|
||||
# 执行更新
|
||||
update_stmt = update(Messages).where(
|
||||
Messages.message_id == msg.message_id
|
||||
).values(interest_value=default_interest)
|
||||
|
||||
result = session.execute(update_stmt)
|
||||
if result.rowcount > 0:
|
||||
fixed_count += 1
|
||||
logger.debug(f"修复消息 {msg.message_id} 的interest_value为 {default_interest}")
|
||||
|
||||
session.commit()
|
||||
logger.info(f"共修复了 {fixed_count} 条历史消息的interest_value值")
|
||||
return fixed_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"修复历史消息interest_value失败: {e}")
|
||||
return 0
|
||||
|
||||
@@ -96,7 +96,6 @@ class DatabaseMessages(BaseDataModel):
|
||||
chat_info_create_time: float = 0.0,
|
||||
chat_info_last_active_time: float = 0.0,
|
||||
# 新增字段
|
||||
interest_degree: float = 0.0,
|
||||
actions: Optional[list] = None,
|
||||
should_reply: bool = False,
|
||||
**kwargs: Any,
|
||||
@@ -108,7 +107,6 @@ class DatabaseMessages(BaseDataModel):
|
||||
self.interest_value = interest_value
|
||||
|
||||
# 新增字段
|
||||
self.interest_degree = interest_degree
|
||||
self.actions = actions
|
||||
self.should_reply = should_reply
|
||||
|
||||
@@ -201,7 +199,6 @@ class DatabaseMessages(BaseDataModel):
|
||||
"selected_expressions": self.selected_expressions,
|
||||
"is_read": self.is_read,
|
||||
# 新增字段
|
||||
"interest_degree": self.interest_degree,
|
||||
"actions": self.actions,
|
||||
"should_reply": self.should_reply,
|
||||
"user_id": self.user_info.user_id,
|
||||
@@ -221,17 +218,17 @@ class DatabaseMessages(BaseDataModel):
|
||||
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
|
||||
}
|
||||
|
||||
def update_message_info(self, interest_degree: float = None, actions: list = None, should_reply: bool = None):
|
||||
def update_message_info(self, interest_value: float = None, actions: list = None, should_reply: bool = None):
|
||||
"""
|
||||
更新消息信息
|
||||
|
||||
Args:
|
||||
interest_degree: 兴趣度值
|
||||
interest_value: 兴趣度值
|
||||
actions: 执行的动作列表
|
||||
should_reply: 是否应该回复
|
||||
"""
|
||||
if interest_degree is not None:
|
||||
self.interest_degree = interest_degree
|
||||
if interest_value is not None:
|
||||
self.interest_value = interest_value
|
||||
if actions is not None:
|
||||
self.actions = actions
|
||||
if should_reply is not None:
|
||||
@@ -268,7 +265,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
return {
|
||||
"message_id": self.message_id,
|
||||
"time": self.time,
|
||||
"interest_degree": self.interest_degree,
|
||||
"interest_value": self.interest_value,
|
||||
"actions": self.actions,
|
||||
"should_reply": self.should_reply,
|
||||
"user_nickname": self.user_info.user_nickname,
|
||||
|
||||
@@ -61,27 +61,27 @@ class StreamContext(BaseDataModel):
|
||||
self._detect_chat_type(message)
|
||||
|
||||
def update_message_info(
|
||||
self, message_id: str, interest_degree: float = None, actions: list = None, should_reply: bool = None
|
||||
self, message_id: str, interest_value: float = None, actions: list = None, should_reply: bool = None
|
||||
):
|
||||
"""
|
||||
更新消息信息
|
||||
|
||||
Args:
|
||||
message_id: 消息ID
|
||||
interest_degree: 兴趣度值
|
||||
interest_value: 兴趣度值
|
||||
actions: 执行的动作列表
|
||||
should_reply: 是否应该回复
|
||||
"""
|
||||
# 在未读消息中查找并更新
|
||||
for message in self.unread_messages:
|
||||
if message.message_id == message_id:
|
||||
message.update_message_info(interest_degree, actions, should_reply)
|
||||
message.update_message_info(interest_value, actions, should_reply)
|
||||
break
|
||||
|
||||
# 在历史消息中查找并更新
|
||||
for message in self.history_messages:
|
||||
if message.message_id == message_id:
|
||||
message.update_message_info(interest_degree, actions, should_reply)
|
||||
message.update_message_info(interest_value, actions, should_reply)
|
||||
break
|
||||
|
||||
def add_action_to_message(self, message_id: str, action: str):
|
||||
|
||||
@@ -174,7 +174,6 @@ class Messages(Base):
|
||||
is_notify = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
# 兴趣度系统字段
|
||||
interest_degree = Column(Float, nullable=True, default=0.0)
|
||||
actions = Column(Text, nullable=True) # JSON格式存储动作列表
|
||||
should_reply = Column(Boolean, nullable=True, default=False)
|
||||
|
||||
@@ -183,7 +182,6 @@ class Messages(Base):
|
||||
Index("idx_messages_chat_id", "chat_id"),
|
||||
Index("idx_messages_time", "time"),
|
||||
Index("idx_messages_user_id", "user_id"),
|
||||
Index("idx_messages_interest_degree", "interest_degree"),
|
||||
Index("idx_messages_should_reply", "should_reply"),
|
||||
)
|
||||
|
||||
|
||||
@@ -368,41 +368,30 @@ class ChatterPlanFilter:
|
||||
interest_scores = {}
|
||||
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
|
||||
chatter_interest_scoring_system as interest_scoring_system,
|
||||
)
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.chat.interest_system import interest_manager
|
||||
|
||||
# 转换消息格式
|
||||
db_messages = []
|
||||
# 使用新的兴趣度管理系统计算评分
|
||||
for msg_dict in messages:
|
||||
try:
|
||||
db_msg = DatabaseMessages(
|
||||
message_id=msg_dict.get("message_id", ""),
|
||||
time=msg_dict.get("time", time.time()),
|
||||
chat_id=msg_dict.get("chat_id", ""),
|
||||
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||
user_id=msg_dict.get("user_id", ""),
|
||||
user_nickname=msg_dict.get("user_nickname", ""),
|
||||
user_platform=msg_dict.get("platform", "qq"),
|
||||
chat_info_group_id=msg_dict.get("group_id", ""),
|
||||
chat_info_group_name=msg_dict.get("group_name", ""),
|
||||
chat_info_group_platform=msg_dict.get("platform", "qq"),
|
||||
# 构建计算上下文
|
||||
calc_context = {
|
||||
"stream_id": msg_dict.get("chat_id", ""),
|
||||
"user_id": msg_dict.get("user_id"),
|
||||
}
|
||||
|
||||
# 计算消息兴趣度
|
||||
interest_score = interest_manager.calculate_message_interest(
|
||||
message=msg_dict,
|
||||
context=calc_context
|
||||
)
|
||||
db_messages.append(db_msg)
|
||||
|
||||
# 构建兴趣度字典
|
||||
interest_scores[msg_dict.get("message_id", "")] = interest_score
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"转换消息格式失败: {e}")
|
||||
logger.warning(f"计算消息兴趣度失败: {e}")
|
||||
continue
|
||||
|
||||
# 计算兴趣度评分
|
||||
if db_messages:
|
||||
bot_nickname = global_config.bot.nickname or "麦麦"
|
||||
scores = await interest_scoring_system.calculate_interest_scores(db_messages, bot_nickname)
|
||||
|
||||
# 构建兴趣度字典
|
||||
for score in scores:
|
||||
interest_scores[score.message_id] = score.total_score
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取兴趣度评分失败: {e}")
|
||||
|
||||
|
||||
@@ -4,13 +4,13 @@
|
||||
"""
|
||||
|
||||
from dataclasses import asdict
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import ChatterInterestScoringSystem
|
||||
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
|
||||
from src.chat.interest_system import interest_manager
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
|
||||
@@ -52,14 +52,7 @@ class ChatterActionPlanner:
|
||||
self.generator = ChatterPlanGenerator(chat_id)
|
||||
self.executor = ChatterPlanExecutor(action_manager)
|
||||
|
||||
# 初始化兴趣度评分系统
|
||||
self.interest_scoring = ChatterInterestScoringSystem()
|
||||
|
||||
# 创建新的关系追踪器
|
||||
self.relationship_tracker = ChatterRelationshipTracker(self.interest_scoring)
|
||||
|
||||
# 设置执行器的关系追踪器
|
||||
self.executor.set_relationship_tracker(self.relationship_tracker)
|
||||
# 使用新的统一兴趣度管理系统
|
||||
|
||||
# 规划器统计
|
||||
self.planner_stats = {
|
||||
@@ -107,43 +100,39 @@ class ChatterActionPlanner:
|
||||
initial_plan.available_actions = self.action_manager.get_using_actions()
|
||||
|
||||
unread_messages = context.get_unread_messages() if context else []
|
||||
# 2. 兴趣度评分 - 只对未读消息进行评分
|
||||
# 2. 使用新的兴趣度管理系统进行评分
|
||||
score = 0.0
|
||||
should_reply = False
|
||||
reply_not_available = False
|
||||
|
||||
if unread_messages:
|
||||
bot_nickname = global_config.bot.nickname
|
||||
interest_scores = await self.interest_scoring.calculate_interest_scores(unread_messages, bot_nickname)
|
||||
# 获取用户ID
|
||||
user_id = None
|
||||
if unread_messages[0].user_id:
|
||||
user_id = unread_messages[0].user_id
|
||||
|
||||
# 3. 根据兴趣度调整可用动作
|
||||
if interest_scores:
|
||||
latest_score = max(interest_scores, key=lambda s: s.total_score)
|
||||
latest_message = next(
|
||||
(msg for msg in unread_messages if msg.message_id == latest_score.message_id), None
|
||||
)
|
||||
should_reply, score = self.interest_scoring.should_reply(latest_score, latest_message)
|
||||
# 构建计算上下文
|
||||
calc_context = {
|
||||
"stream_id": self.chat_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
reply_not_available = False
|
||||
if not should_reply and "reply" in initial_plan.available_actions:
|
||||
logger.info(f"兴趣度不足 ({latest_score.total_score:.2f}),移除回复")
|
||||
reply_not_available = True
|
||||
|
||||
# 更新ChatStream的兴趣度数据
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
chat_stream = get_chat_manager().get_stream(self.chat_id)
|
||||
logger.debug(f"已更新聊天 {self.chat_id} 的ChatStream兴趣度,分数: {score:.3f}")
|
||||
|
||||
# 更新情绪状态和ChatStream兴趣度数据
|
||||
if latest_message and score > 0:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id)
|
||||
await chat_mood.update_mood_by_message(latest_message, score)
|
||||
logger.debug(f"已更新聊天 {self.chat_id} 的情绪状态,兴趣度: {score:.3f}")
|
||||
|
||||
# 为所有未读消息记录兴趣度信息
|
||||
# 为每条消息计算兴趣度
|
||||
for message in unread_messages:
|
||||
# 查找对应的兴趣度评分
|
||||
message_score = next((s for s in interest_scores if s.message_id == message.message_id), None)
|
||||
if message_score:
|
||||
message.interest_degree = message_score.total_score
|
||||
message.should_reply = self.interest_scoring.should_reply(message_score, message)[0]
|
||||
logger.debug(f"已记录消息 {message.message_id} - 兴趣度: {message_score.total_score:.3f}, 应回复: {message.should_reply}")
|
||||
try:
|
||||
# 使用新的兴趣度管理器计算
|
||||
message_interest = interest_manager.calculate_message_interest(
|
||||
message=message.__dict__,
|
||||
context=calc_context
|
||||
)
|
||||
|
||||
# 更新消息的兴趣度
|
||||
message.interest_value = message_interest
|
||||
|
||||
# 简单的回复决策逻辑:兴趣度超过阈值则回复
|
||||
message.should_reply = message_interest > global_config.affinity_flow.non_reply_action_interest_threshold
|
||||
|
||||
logger.debug(f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}")
|
||||
|
||||
# 更新StreamContext中的消息信息并刷新focus_energy
|
||||
if context:
|
||||
@@ -151,25 +140,35 @@ class ChatterActionPlanner:
|
||||
message_manager.update_message_and_refresh_energy(
|
||||
stream_id=self.chat_id,
|
||||
message_id=message.message_id,
|
||||
interest_degree=message_score.total_score,
|
||||
interest_value=message_interest,
|
||||
should_reply=message.should_reply
|
||||
)
|
||||
else:
|
||||
# 如果没有找到评分,设置默认值
|
||||
message.interest_degree = 0.0
|
||||
|
||||
# 更新数据库中的消息记录
|
||||
try:
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
MessageStorage.update_message_interest_value(message.message_id, message_interest)
|
||||
logger.debug(f"已更新数据库中消息 {message.message_id} 的兴趣度为: {message_interest:.3f}")
|
||||
except Exception as e:
|
||||
logger.warning(f"更新数据库消息兴趣度失败: {e}")
|
||||
|
||||
# 更新话题兴趣度
|
||||
interest_manager.update_topic_interest(message.__dict__, message_interest)
|
||||
|
||||
# 记录最高分
|
||||
if message_interest > score:
|
||||
score = message_interest
|
||||
if message.should_reply:
|
||||
should_reply = True
|
||||
else:
|
||||
reply_not_available = True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算消息 {message.message_id} 兴趣度失败: {e}")
|
||||
# 设置默认值
|
||||
message.interest_value = 0.0
|
||||
message.should_reply = False
|
||||
|
||||
# 更新StreamContext中的消息信息并刷新focus_energy
|
||||
if context:
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
message_manager.update_message_and_refresh_energy(
|
||||
stream_id=self.chat_id,
|
||||
message_id=message.message_id,
|
||||
interest_degree=0.0,
|
||||
should_reply=False
|
||||
)
|
||||
|
||||
# base_threshold = self.interest_scoring.reply_threshold
|
||||
# 检查兴趣度是否达到非回复动作阈值
|
||||
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
|
||||
if score < non_reply_action_interest_threshold:
|
||||
@@ -191,26 +190,16 @@ class ChatterActionPlanner:
|
||||
plan_filter = ChatterPlanFilter(self.chat_id, available_actions)
|
||||
filtered_plan = await plan_filter.filter(reply_not_available, initial_plan)
|
||||
|
||||
# 检查filtered_plan是否有reply动作,以便记录reply action
|
||||
has_reply_action = False
|
||||
for decision in filtered_plan.decided_actions:
|
||||
if decision.action_type == "reply":
|
||||
has_reply_action = True
|
||||
self.interest_scoring.record_reply_action(has_reply_action)
|
||||
# 检查filtered_plan是否有reply动作,用于统计
|
||||
has_reply_action = any(decision.action_type == "reply" for decision in filtered_plan.decided_actions)
|
||||
|
||||
# 5. 使用 PlanExecutor 执行 Plan
|
||||
execution_result = await self.executor.execute(filtered_plan)
|
||||
|
||||
# 6. 动作记录现在由ChatterActionManager统一处理
|
||||
# 动作记录逻辑已移至ChatterActionManager.execute_action方法中
|
||||
|
||||
# 7. 根据执行结果更新统计信息
|
||||
# 6. 根据执行结果更新统计信息
|
||||
self._update_stats_from_execution_result(execution_result)
|
||||
|
||||
# 8. 检查关系更新
|
||||
await self.relationship_tracker.check_and_update_relationships()
|
||||
|
||||
# 8. 返回结果
|
||||
# 7. 返回结果
|
||||
return self._build_return_result(filtered_plan)
|
||||
|
||||
except Exception as e:
|
||||
@@ -259,37 +248,10 @@ class ChatterActionPlanner:
|
||||
|
||||
return final_actions_dict, final_target_message_dict
|
||||
|
||||
def get_user_relationship(self, user_id: str) -> float:
|
||||
"""获取用户关系分"""
|
||||
return self.interest_scoring.get_user_relationship(user_id)
|
||||
|
||||
def update_interest_keywords(self, new_keywords: Dict[str, List[str]]):
|
||||
"""更新兴趣关键词(已弃用,仅保留用于兼容性)"""
|
||||
logger.info("传统关键词匹配已移除,此方法仅保留用于兼容性")
|
||||
# 此方法已弃用,因为现在完全使用embedding匹配
|
||||
|
||||
def get_planner_stats(self) -> Dict[str, any]:
|
||||
"""获取规划器统计"""
|
||||
return self.planner_stats.copy()
|
||||
|
||||
def get_interest_scoring_stats(self) -> Dict[str, any]:
|
||||
"""获取兴趣度评分统计"""
|
||||
return {
|
||||
"no_reply_count": self.interest_scoring.no_reply_count,
|
||||
"max_no_reply_count": self.interest_scoring.max_no_reply_count,
|
||||
"reply_threshold": self.interest_scoring.reply_threshold,
|
||||
"mention_threshold": self.interest_scoring.mention_threshold,
|
||||
"user_relationships": len(self.interest_scoring.user_relationships),
|
||||
}
|
||||
|
||||
def get_relationship_stats(self) -> Dict[str, any]:
|
||||
"""获取用户关系统计"""
|
||||
return {
|
||||
"tracking_users": len(self.relationship_tracker.tracking_users),
|
||||
"relationship_history": len(self.relationship_tracker.relationship_history),
|
||||
"max_tracking_users": self.relationship_tracker.max_tracking_users,
|
||||
}
|
||||
|
||||
def get_current_mood_state(self) -> str:
|
||||
"""获取当前聊天的情绪状态"""
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id)
|
||||
|
||||
Reference in New Issue
Block a user