refactor(chat): 重构消息管理器以使用集中式上下文管理和能量系统

- 将流上下文管理从MessageManager迁移到专门的ContextManager
- 使用统一的能量系统计算focus_energy和分发间隔
- 重构ChatStream的消息数据转换逻辑,支持更完整的数据字段
- 更新数据库模型,移除interest_degree字段,统一使用interest_value
- 集成新的兴趣度管理系统替代原有的评分系统
- 添加消息存储的interest_value修复功能
This commit is contained in:
Windpicker-owo
2025-09-27 14:23:48 +08:00
parent 0478be7d2a
commit c49b3f3ac4
15 changed files with 3518 additions and 495 deletions

View File

@@ -0,0 +1,28 @@
"""
能量系统模块
提供稳定、高效的聊天流能量计算和管理功能
"""
from .energy_manager import (
EnergyManager,
EnergyLevel,
EnergyComponent,
EnergyCalculator,
InterestEnergyCalculator,
ActivityEnergyCalculator,
RecencyEnergyCalculator,
RelationshipEnergyCalculator,
energy_manager
)
__all__ = [
"EnergyManager",
"EnergyLevel",
"EnergyComponent",
"EnergyCalculator",
"InterestEnergyCalculator",
"ActivityEnergyCalculator",
"RecencyEnergyCalculator",
"RelationshipEnergyCalculator",
"energy_manager"
]

View File

@@ -0,0 +1,480 @@
"""
重构后的 focus_energy 管理系统
提供稳定、高效的聊天流能量计算和管理功能
"""
import time
from typing import Dict, List, Optional, Tuple, Any, Union, TypedDict
from dataclasses import dataclass, field
from enum import Enum
from abc import ABC, abstractmethod
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("energy_system")
class EnergyLevel(Enum):
"""能量等级"""
VERY_LOW = 0.1 # 非常低
LOW = 0.3 # 低
NORMAL = 0.5 # 正常
HIGH = 0.7 # 高
VERY_HIGH = 0.9 # 非常高
@dataclass
class EnergyComponent:
"""能量组件"""
name: str
value: float
weight: float = 1.0
decay_rate: float = 0.05 # 衰减率
last_updated: float = field(default_factory=time.time)
def get_current_value(self) -> float:
"""获取当前值(考虑时间衰减)"""
age = time.time() - self.last_updated
decay_factor = max(0.1, 1.0 - (age * self.decay_rate / (24 * 3600))) # 按天衰减
return self.value * decay_factor
def update_value(self, new_value: float) -> None:
"""更新值"""
self.value = max(0.0, min(1.0, new_value))
self.last_updated = time.time()
class EnergyContext(TypedDict):
"""能量计算上下文"""
stream_id: str
messages: List[Any]
user_id: Optional[str]
class EnergyResult(TypedDict):
"""能量计算结果"""
energy: float
level: EnergyLevel
distribution_interval: float
component_scores: Dict[str, float]
cached: bool
class EnergyCalculator(ABC):
"""能量计算器抽象基类"""
@abstractmethod
def calculate(self, context: Dict[str, Any]) -> float:
"""计算能量值"""
pass
@abstractmethod
def get_weight(self) -> float:
"""获取权重"""
pass
class InterestEnergyCalculator(EnergyCalculator):
"""兴趣度能量计算器"""
def calculate(self, context: Dict[str, Any]) -> float:
"""基于消息兴趣度计算能量"""
messages = context.get("messages", [])
if not messages:
return 0.3
# 计算平均兴趣度
total_interest = 0.0
valid_messages = 0
for msg in messages:
interest_value = getattr(msg, "interest_value", None)
if interest_value is not None:
try:
interest_float = float(interest_value)
if 0.0 <= interest_float <= 1.0:
total_interest += interest_float
valid_messages += 1
except (ValueError, TypeError):
continue
if valid_messages > 0:
avg_interest = total_interest / valid_messages
logger.debug(f"平均消息兴趣度: {avg_interest:.3f} (基于 {valid_messages} 条消息)")
return avg_interest
else:
return 0.3
def get_weight(self) -> float:
return 0.5
class ActivityEnergyCalculator(EnergyCalculator):
"""活跃度能量计算器"""
def __init__(self):
self.action_weights = {
"reply": 0.4,
"react": 0.3,
"mention": 0.2,
"other": 0.1
}
def calculate(self, context: Dict[str, Any]) -> float:
"""基于活跃度计算能量"""
messages = context.get("messages", [])
if not messages:
return 0.2
total_score = 0.0
max_possible_score = len(messages) * 0.4 # 最高可能分数
for msg in messages:
actions = getattr(msg, "actions", [])
if isinstance(actions, list) and actions:
for action in actions:
weight = self.action_weights.get(action, self.action_weights["other"])
total_score += weight
if max_possible_score > 0:
activity_score = min(1.0, total_score / max_possible_score)
logger.debug(f"活跃度分数: {activity_score:.3f}")
return activity_score
else:
return 0.2
def get_weight(self) -> float:
return 0.3
class RecencyEnergyCalculator(EnergyCalculator):
"""最近性能量计算器"""
def calculate(self, context: Dict[str, Any]) -> float:
"""基于最近性计算能量"""
messages = context.get("messages", [])
if not messages:
return 0.1
# 获取最新消息时间
latest_time = 0.0
for msg in messages:
msg_time = getattr(msg, "time", None)
if msg_time and msg_time > latest_time:
latest_time = msg_time
if latest_time == 0.0:
return 0.1
# 计算时间衰减
current_time = time.time()
age = current_time - latest_time
# 时间衰减策略:
# 1小时内1.0
# 1-6小时0.8
# 6-24小时0.5
# 1-7天0.3
# 7天以上0.1
if age < 3600: # 1小时内
recency_score = 1.0
elif age < 6 * 3600: # 6小时内
recency_score = 0.8
elif age < 24 * 3600: # 24小时内
recency_score = 0.5
elif age < 7 * 24 * 3600: # 7天内
recency_score = 0.3
else:
recency_score = 0.1
logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age/3600:.1f}小时)")
return recency_score
def get_weight(self) -> float:
return 0.2
class RelationshipEnergyCalculator(EnergyCalculator):
"""关系能量计算器"""
def calculate(self, context: Dict[str, Any]) -> float:
"""基于关系计算能量"""
user_id = context.get("user_id")
if not user_id:
return 0.3
try:
# 使用新的兴趣度管理系统获取用户关系分
from src.chat.interest_system import interest_manager
# 获取用户交互历史作为关系分的基础
interaction_calc = interest_manager.calculators.get(
interest_manager.InterestSourceType.USER_INTERACTION
)
if interaction_calc:
relationship_score = interaction_calc.calculate({"user_id": user_id})
logger.debug(f"用户关系分数: {relationship_score:.3f}")
return max(0.0, min(1.0, relationship_score))
else:
# 默认基础分
return 0.3
except Exception:
# 默认基础分
return 0.3
def get_weight(self) -> float:
return 0.1
class EnergyManager:
"""能量管理器 - 统一管理所有能量计算"""
def __init__(self) -> None:
self.calculators: List[EnergyCalculator] = [
InterestEnergyCalculator(),
ActivityEnergyCalculator(),
RecencyEnergyCalculator(),
RelationshipEnergyCalculator(),
]
# 能量缓存
self.energy_cache: Dict[str, Tuple[float, float]] = {} # stream_id -> (energy, timestamp)
self.cache_ttl: int = 60 # 1分钟缓存
# AFC阈值配置
self.thresholds: Dict[str, float] = {
"high_match": 0.8,
"reply": 0.4,
"non_reply": 0.2
}
# 统计信息
self.stats: Dict[str, Union[int, float, str]] = {
"total_calculations": 0,
"cache_hits": 0,
"cache_misses": 0,
"average_calculation_time": 0.0,
"last_threshold_update": time.time(),
}
# 从配置加载阈值
self._load_thresholds_from_config()
logger.info("能量管理器初始化完成")
def _load_thresholds_from_config(self) -> None:
"""从配置加载AFC阈值"""
try:
if hasattr(global_config, "affinity_flow") and global_config.affinity_flow is not None:
self.thresholds["high_match"] = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8)
self.thresholds["reply"] = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
self.thresholds["non_reply"] = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2)
# 确保阈值关系合理
self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1)
self.thresholds["reply"] = max(self.thresholds["reply"], self.thresholds["non_reply"] + 0.1)
self.stats["last_threshold_update"] = time.time()
logger.info(f"加载AFC阈值: {self.thresholds}")
except Exception as e:
logger.warning(f"加载AFC阈值失败使用默认值: {e}")
def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float:
"""计算聊天流的focus_energy"""
start_time = time.time()
# 更新统计
self.stats["total_calculations"] += 1
# 检查缓存
if stream_id in self.energy_cache:
cached_energy, cached_time = self.energy_cache[stream_id]
if time.time() - cached_time < self.cache_ttl:
self.stats["cache_hits"] += 1
logger.debug(f"使用缓存能量: {stream_id} = {cached_energy:.3f}")
return cached_energy
else:
self.stats["cache_misses"] += 1
# 构建计算上下文
context: EnergyContext = {
"stream_id": stream_id,
"messages": messages,
"user_id": user_id,
}
# 计算各组件能量
component_scores: Dict[str, float] = {}
total_weight = 0.0
for calculator in self.calculators:
try:
score = calculator.calculate(context)
weight = calculator.get_weight()
component_scores[calculator.__class__.__name__] = score
total_weight += weight
logger.debug(f"{calculator.__class__.__name__} 能量: {score:.3f} (权重: {weight:.3f})")
except Exception as e:
logger.warning(f"计算 {calculator.__class__.__name__} 能量失败: {e}")
# 加权计算总能量
if total_weight > 0:
total_energy = 0.0
for calculator in self.calculators:
if calculator.__class__.__name__ in component_scores:
score = component_scores[calculator.__class__.__name__]
weight = calculator.get_weight()
total_energy += score * (weight / total_weight)
else:
total_energy = 0.5
# 应用阈值调整和变换
final_energy = self._apply_threshold_adjustment(total_energy)
# 缓存结果
self.energy_cache[stream_id] = (final_energy, time.time())
# 清理过期缓存
self._cleanup_cache()
# 更新平均计算时间
calculation_time = time.time() - start_time
total_calculations = self.stats["total_calculations"]
self.stats["average_calculation_time"] = (
(self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time)
/ total_calculations
)
logger.info(f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)")
return final_energy
def _apply_threshold_adjustment(self, energy: float) -> float:
"""应用阈值调整和变换"""
# 获取参考阈值
high_threshold = self.thresholds["high_match"]
reply_threshold = self.thresholds["reply"]
# 计算与阈值的相对位置
if energy >= high_threshold:
# 高能量区域:指数增强
adjusted = 0.7 + (energy - 0.7) ** 0.8
elif energy >= reply_threshold:
# 中等能量区域:线性保持
adjusted = energy
else:
# 低能量区域:对数压缩
adjusted = 0.4 * (energy / 0.4) ** 1.2
# 确保在合理范围内
return max(0.1, min(1.0, adjusted))
def get_energy_level(self, energy: float) -> EnergyLevel:
"""获取能量等级"""
if energy >= EnergyLevel.VERY_HIGH.value:
return EnergyLevel.VERY_HIGH
elif energy >= EnergyLevel.HIGH.value:
return EnergyLevel.HIGH
elif energy >= EnergyLevel.NORMAL.value:
return EnergyLevel.NORMAL
elif energy >= EnergyLevel.LOW.value:
return EnergyLevel.LOW
else:
return EnergyLevel.VERY_LOW
def get_distribution_interval(self, energy: float) -> float:
"""基于能量等级获取分发周期"""
energy_level = self.get_energy_level(energy)
# 根据能量等级确定基础分发周期
if energy_level == EnergyLevel.VERY_HIGH:
base_interval = 1.0 # 1秒
elif energy_level == EnergyLevel.HIGH:
base_interval = 3.0 # 3秒
elif energy_level == EnergyLevel.NORMAL:
base_interval = 8.0 # 8秒
elif energy_level == EnergyLevel.LOW:
base_interval = 15.0 # 15秒
else:
base_interval = 30.0 # 30秒
# 添加随机扰动避免同步
import random
jitter = random.uniform(0.8, 1.2)
final_interval = base_interval * jitter
# 确保在配置范围内
min_interval = getattr(global_config.chat, "dynamic_distribution_min_interval", 1.0)
max_interval = getattr(global_config.chat, "dynamic_distribution_max_interval", 60.0)
return max(min_interval, min(max_interval, final_interval))
def invalidate_cache(self, stream_id: str) -> None:
"""失效指定流的缓存"""
if stream_id in self.energy_cache:
del self.energy_cache[stream_id]
logger.debug(f"已清除聊天流 {stream_id} 的能量缓存")
def _cleanup_cache(self) -> None:
"""清理过期缓存"""
current_time = time.time()
expired_keys = [
stream_id for stream_id, (_, timestamp) in self.energy_cache.items()
if current_time - timestamp > self.cache_ttl
]
for key in expired_keys:
del self.energy_cache[key]
if expired_keys:
logger.debug(f"清理了 {len(expired_keys)} 个过期能量缓存")
def get_statistics(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
"cache_size": len(self.energy_cache),
"calculators": [calc.__class__.__name__ for calc in self.calculators],
"thresholds": self.thresholds,
"performance_stats": self.stats.copy(),
}
def update_thresholds(self, new_thresholds: Dict[str, float]) -> None:
"""更新阈值"""
self.thresholds.update(new_thresholds)
# 确保阈值关系合理
self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1)
self.thresholds["reply"] = max(self.thresholds["reply"], self.thresholds["non_reply"] + 0.1)
self.stats["last_threshold_update"] = time.time()
logger.info(f"更新AFC阈值: {self.thresholds}")
def add_calculator(self, calculator: EnergyCalculator) -> None:
"""添加计算器"""
self.calculators.append(calculator)
logger.info(f"添加能量计算器: {calculator.__class__.__name__}")
def remove_calculator(self, calculator: EnergyCalculator) -> None:
"""移除计算器"""
if calculator in self.calculators:
self.calculators.remove(calculator)
logger.info(f"移除能量计算器: {calculator.__class__.__name__}")
def clear_cache(self) -> None:
"""清空缓存"""
self.energy_cache.clear()
logger.info("清空能量缓存")
def get_cache_hit_rate(self) -> float:
"""获取缓存命中率"""
total_requests = self.stats.get("cache_hits", 0) + self.stats.get("cache_misses", 0)
if total_requests == 0:
return 0.0
return self.stats["cache_hits"] / total_requests
# 全局能量管理器实例
energy_manager = EnergyManager()

View File

@@ -1,12 +1,30 @@
"""
机器人兴趣标签系统
基于人设生成兴趣标签使用embedding计算匹配度
兴趣度系统模块
提供统一、稳定的消息兴趣度计算和管理功能
"""
from .interest_manager import (
InterestManager,
InterestSourceType,
InterestFactor,
InterestCalculator,
MessageContentInterestCalculator,
TopicInterestCalculator,
UserInteractionInterestCalculator,
interest_manager
)
from .bot_interest_manager import BotInterestManager, bot_interest_manager
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
__all__ = [
"InterestManager",
"InterestSourceType",
"InterestFactor",
"InterestCalculator",
"MessageContentInterestCalculator",
"TopicInterestCalculator",
"UserInteractionInterestCalculator",
"interest_manager",
"BotInterestManager",
"bot_interest_manager",
"BotInterestTag",

View File

@@ -0,0 +1,430 @@
"""
重构后的消息兴趣值计算系统
提供稳定、可靠的消息兴趣度计算和管理功能
"""
import time
from typing import Dict, List, Optional, Tuple, Any, Union, TypedDict
from dataclasses import dataclass, field
from enum import Enum
from abc import ABC, abstractmethod
from src.common.logger import get_logger
logger = get_logger("interest_system")
class InterestSourceType(Enum):
"""兴趣度来源类型"""
MESSAGE_CONTENT = "message_content" # 消息内容
USER_INTERACTION = "user_interaction" # 用户交互
TOPIC_RELEVANCE = "topic_relevance" # 话题相关性
RELATIONSHIP_SCORE = "relationship_score" # 关系分数
HISTORICAL_PATTERN = "historical_pattern" # 历史模式
@dataclass
class InterestFactor:
"""兴趣度因子"""
source_type: InterestSourceType
value: float
weight: float = 1.0
decay_rate: float = 0.1 # 衰减率
last_updated: float = field(default_factory=time.time)
def get_current_value(self) -> float:
"""获取当前值(考虑时间衰减)"""
age = time.time() - self.last_updated
decay_factor = max(0.1, 1.0 - (age * self.decay_rate / (24 * 3600))) # 按天衰减
return self.value * decay_factor
def update_value(self, new_value: float) -> None:
"""更新值"""
self.value = max(0.0, min(1.0, new_value))
self.last_updated = time.time()
class InterestCalculator(ABC):
"""兴趣度计算器抽象基类"""
@abstractmethod
def calculate(self, context: Dict[str, Any]) -> float:
"""计算兴趣度"""
pass
@abstractmethod
def get_confidence(self) -> float:
"""获取计算置信度"""
pass
class MessageData(TypedDict):
"""消息数据类型定义"""
message_id: str
processed_plain_text: str
is_emoji: bool
is_picid: bool
is_mentioned: bool
is_command: bool
key_words: str
user_id: str
time: float
class InterestContext(TypedDict):
"""兴趣度计算上下文"""
stream_id: str
user_id: Optional[str]
message: MessageData
class InterestResult(TypedDict):
"""兴趣度计算结果"""
value: float
confidence: float
source_scores: Dict[InterestSourceType, float]
cached: bool
class MessageContentInterestCalculator(InterestCalculator):
"""消息内容兴趣度计算器"""
def calculate(self, context: Dict[str, Any]) -> float:
"""基于消息内容计算兴趣度"""
message = context.get("message", {})
if not message:
return 0.3 # 默认值
# 提取消息特征
text_length = len(message.get("processed_plain_text", ""))
has_emoji = message.get("is_emoji", False)
has_image = message.get("is_picid", False)
is_mentioned = message.get("is_mentioned", False)
is_command = message.get("is_command", False)
# 基础分数
base_score = 0.3
# 文本长度加权
if text_length > 0:
text_score = min(0.3, text_length / 200) # 200字符为满分
base_score += text_score * 0.3
# 多媒体内容加权
if has_emoji:
base_score += 0.1
if has_image:
base_score += 0.2
# 交互特征加权
if is_mentioned:
base_score += 0.2
if is_command:
base_score += 0.1
return min(1.0, base_score)
def get_confidence(self) -> float:
return 0.8
class TopicInterestCalculator(InterestCalculator):
"""话题兴趣度计算器"""
def __init__(self):
self.topic_interests: Dict[str, float] = {}
self.topic_decay_rate = 0.05 # 话题兴趣度衰减率
def update_topic_interest(self, topic: str, interest_value: float):
"""更新话题兴趣度"""
current_interest = self.topic_interests.get(topic, 0.3)
# 平滑更新
new_interest = current_interest * 0.7 + interest_value * 0.3
self.topic_interests[topic] = max(0.0, min(1.0, new_interest))
logger.debug(f"更新话题 '{topic}' 兴趣度: {current_interest:.3f} -> {new_interest:.3f}")
def calculate(self, context: Dict[str, Any]) -> float:
"""基于话题相关性计算兴趣度"""
message = context.get("message", {})
keywords = message.get("key_words", "[]")
try:
import json
keyword_list = json.loads(keywords) if keywords else []
except (json.JSONDecodeError, TypeError):
keyword_list = []
if not keyword_list:
return 0.4 # 无关键词时的默认值
# 计算相关话题的平均兴趣度
total_interest = 0.0
relevant_topics = 0
for keyword in keyword_list[:5]: # 最多取前5个关键词
# 查找相关话题
for topic, interest in self.topic_interests.items():
if keyword.lower() in topic.lower() or topic.lower() in keyword.lower():
total_interest += interest
relevant_topics += 1
break
if relevant_topics > 0:
return min(1.0, total_interest / relevant_topics)
else:
# 新话题,给予基础兴趣度
for keyword in keyword_list[:3]:
self.topic_interests[keyword] = 0.5
return 0.5
def get_confidence(self) -> float:
return 0.7
class UserInteractionInterestCalculator(InterestCalculator):
"""用户交互兴趣度计算器"""
def __init__(self):
self.interaction_history: List[Dict] = []
self.max_history_size = 100
def add_interaction(self, user_id: str, interaction_type: str, value: float):
"""添加交互记录"""
self.interaction_history.append({
"user_id": user_id,
"type": interaction_type,
"value": value,
"timestamp": time.time()
})
# 保持历史记录大小
if len(self.interaction_history) > self.max_history_size:
self.interaction_history = self.interaction_history[-self.max_history_size:]
def calculate(self, context: Dict[str, Any]) -> float:
"""基于用户交互历史计算兴趣度"""
user_id = context.get("user_id")
if not user_id:
return 0.3
# 获取该用户的最近交互记录
user_interactions = [
interaction for interaction in self.interaction_history
if interaction["user_id"] == user_id
]
if not user_interactions:
return 0.3
# 计算加权平均(最近的交互权重更高)
total_weight = 0.0
weighted_sum = 0.0
for interaction in user_interactions[-20:]: # 最近20次交互
age = time.time() - interaction["timestamp"]
weight = max(0.1, 1.0 - age / (7 * 24 * 3600)) # 7天内衰减
weighted_sum += interaction["value"] * weight
total_weight += weight
if total_weight > 0:
return min(1.0, weighted_sum / total_weight)
else:
return 0.3
def get_confidence(self) -> float:
return 0.6
class InterestManager:
"""兴趣度管理器 - 统一管理所有兴趣度计算"""
def __init__(self) -> None:
self.calculators: Dict[InterestSourceType, InterestCalculator] = {
InterestSourceType.MESSAGE_CONTENT: MessageContentInterestCalculator(),
InterestSourceType.TOPIC_RELEVANCE: TopicInterestCalculator(),
InterestSourceType.USER_INTERACTION: UserInteractionInterestCalculator(),
}
# 权重配置
self.source_weights: Dict[InterestSourceType, float] = {
InterestSourceType.MESSAGE_CONTENT: 0.4,
InterestSourceType.TOPIC_RELEVANCE: 0.3,
InterestSourceType.USER_INTERACTION: 0.3,
}
# 兴趣度缓存
self.interest_cache: Dict[str, Tuple[float, float]] = {} # message_id -> (value, timestamp)
self.cache_ttl: int = 300 # 5分钟缓存
# 统计信息
self.stats: Dict[str, Union[int, float, List[str]]] = {
"total_calculations": 0,
"cache_hits": 0,
"cache_misses": 0,
"average_calculation_time": 0.0,
"calculator_usage": {calc_type.value: 0 for calc_type in InterestSourceType}
}
logger.info("兴趣度管理器初始化完成")
def calculate_message_interest(self, message: Dict[str, Any], context: Dict[str, Any]) -> float:
"""计算消息兴趣度"""
start_time = time.time()
message_id = message.get("message_id", "")
# 更新统计
self.stats["total_calculations"] += 1
# 检查缓存
if message_id in self.interest_cache:
cached_value, cached_time = self.interest_cache[message_id]
if time.time() - cached_time < self.cache_ttl:
self.stats["cache_hits"] += 1
logger.debug(f"使用缓存兴趣度: {message_id} = {cached_value:.3f}")
return cached_value
else:
self.stats["cache_misses"] += 1
# 构建计算上下文
calc_context: Dict[str, Any] = {
"message": message,
"user_id": message.get("user_id"),
**context
}
# 计算各来源的兴趣度
source_scores: Dict[InterestSourceType, float] = {}
total_confidence = 0.0
for source_type, calculator in self.calculators.items():
try:
score = calculator.calculate(calc_context)
confidence = calculator.get_confidence()
source_scores[source_type] = score
total_confidence += confidence
# 更新计算器使用统计
self.stats["calculator_usage"][source_type.value] += 1
logger.debug(f"{source_type.value} 兴趣度: {score:.3f} (置信度: {confidence:.3f})")
except Exception as e:
logger.warning(f"计算 {source_type.value} 兴趣度失败: {e}")
source_scores[source_type] = 0.3
# 加权计算最终兴趣度
final_interest = 0.0
total_weight = 0.0
for source_type, score in source_scores.items():
weight = self.source_weights.get(source_type, 0.0)
final_interest += score * weight
total_weight += weight
if total_weight > 0:
final_interest /= total_weight
# 确保在合理范围内
final_interest = max(0.0, min(1.0, final_interest))
# 缓存结果
self.interest_cache[message_id] = (final_interest, time.time())
# 清理过期缓存
self._cleanup_cache()
# 更新平均计算时间
calculation_time = time.time() - start_time
total_calculations = self.stats["total_calculations"]
self.stats["average_calculation_time"] = (
(self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time)
/ total_calculations
)
logger.info(f"消息 {message_id} 最终兴趣度: {final_interest:.3f} (耗时: {calculation_time:.3f}s)")
return final_interest
def update_topic_interest(self, message: Dict[str, Any], interest_value: float) -> None:
"""更新话题兴趣度"""
topic_calc = self.calculators.get(InterestSourceType.TOPIC_RELEVANCE)
if isinstance(topic_calc, TopicInterestCalculator):
# 提取关键词作为话题
keywords = message.get("key_words", "[]")
try:
import json
keyword_list: List[str] = json.loads(keywords) if keywords else []
for keyword in keyword_list[:3]: # 更新前3个关键词
topic_calc.update_topic_interest(keyword, interest_value)
except (json.JSONDecodeError, TypeError):
pass
def add_user_interaction(self, user_id: str, interaction_type: str, value: float) -> None:
"""添加用户交互记录"""
interaction_calc = self.calculators.get(InterestSourceType.USER_INTERACTION)
if isinstance(interaction_calc, UserInteractionInterestCalculator):
interaction_calc.add_interaction(user_id, interaction_type, value)
def get_topic_interests(self) -> Dict[str, float]:
"""获取所有话题兴趣度"""
topic_calc = self.calculators.get(InterestSourceType.TOPIC_RELEVANCE)
if isinstance(topic_calc, TopicInterestCalculator):
return topic_calc.topic_interests.copy()
return {}
def _cleanup_cache(self) -> None:
"""清理过期缓存"""
current_time = time.time()
expired_keys = [
message_id for message_id, (_, timestamp) in self.interest_cache.items()
if current_time - timestamp > self.cache_ttl
]
for key in expired_keys:
del self.interest_cache[key]
if expired_keys:
logger.debug(f"清理了 {len(expired_keys)} 个过期兴趣度缓存")
def get_statistics(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
"cache_size": len(self.interest_cache),
"topic_count": len(self.get_topic_interests()),
"calculators": list(self.calculators.keys()),
"performance_stats": self.stats.copy(),
}
def add_calculator(self, source_type: InterestSourceType, calculator: InterestCalculator) -> None:
"""添加自定义计算器"""
self.calculators[source_type] = calculator
logger.info(f"添加计算器: {source_type.value}")
def remove_calculator(self, source_type: InterestSourceType) -> None:
"""移除计算器"""
if source_type in self.calculators:
del self.calculators[source_type]
logger.info(f"移除计算器: {source_type.value}")
def set_source_weight(self, source_type: InterestSourceType, weight: float) -> None:
"""设置来源权重"""
self.source_weights[source_type] = max(0.0, min(1.0, weight))
logger.info(f"设置 {source_type.value} 权重: {weight}")
def clear_cache(self) -> None:
"""清空缓存"""
self.interest_cache.clear()
logger.info("清空兴趣度缓存")
def get_cache_hit_rate(self) -> float:
"""获取缓存命中率"""
total_requests = self.stats.get("cache_hits", 0) + self.stats.get("cache_misses", 0)
if total_requests == 0:
return 0.0
return self.stats["cache_hits"] / total_requests
# 全局兴趣度管理器实例
interest_manager = InterestManager()

View File

@@ -1,14 +1,26 @@
"""
消息管理模块
管理每个聊天流的上下文信息,包含历史记录和未读消息,定期检查并处理新消息
消息管理模块
提供统一的消息管理、上下文管理和分发调度功能
"""
from .message_manager import MessageManager, message_manager
from src.common.data_models.message_manager_data_model import (
StreamContext,
MessageStatus,
MessageManagerStats,
StreamStats,
from .context_manager import StreamContextManager, context_manager
from .distribution_manager import (
DistributionManager,
DistributionPriority,
DistributionTask,
StreamDistributionState,
distribution_manager
)
__all__ = ["MessageManager", "message_manager", "StreamContext", "MessageStatus", "MessageManagerStats", "StreamStats"]
__all__ = [
"MessageManager",
"message_manager",
"StreamContextManager",
"context_manager",
"DistributionManager",
"DistributionPriority",
"DistributionTask",
"StreamDistributionState",
"distribution_manager"
]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -18,6 +18,7 @@ from src.plugin_system.base.component_types import ChatMode
from .sleep_manager.sleep_manager import SleepManager
from .sleep_manager.wakeup_manager import WakeUpManager
from src.config.config import global_config
from . import context_manager
if TYPE_CHECKING:
from src.common.data_models.message_manager_data_model import StreamContext
@@ -29,7 +30,6 @@ class MessageManager:
"""消息管理器"""
def __init__(self, check_interval: float = 5.0):
self.stream_contexts: Dict[str, StreamContext] = {}
self.check_interval = check_interval # 检查间隔(秒)
self.is_running = False
self.manager_task: Optional[asyncio.Task] = None
@@ -45,6 +45,9 @@ class MessageManager:
self.sleep_manager = SleepManager()
self.wakeup_manager = WakeUpManager(self.sleep_manager)
# 初始化上下文管理器
self.context_manager = context_manager.context_manager
async def start(self):
"""启动消息管理器"""
if self.is_running:
@@ -54,6 +57,7 @@ class MessageManager:
self.is_running = True
self.manager_task = asyncio.create_task(self._manager_loop())
await self.wakeup_manager.start()
await self.context_manager.start()
logger.info("消息管理器已启动")
async def stop(self):
@@ -64,48 +68,44 @@ class MessageManager:
self.is_running = False
# 停止所有流处理任务
for context in self.stream_contexts.values():
if context.processing_task and not context.processing_task.done():
context.processing_task.cancel()
# 停止管理器任务
# 注意context_manager 会自己清理任务
if self.manager_task and not self.manager_task.done():
self.manager_task.cancel()
await self.wakeup_manager.stop()
await self.context_manager.stop()
logger.info("消息管理器已停止")
def add_message(self, stream_id: str, message: DatabaseMessages):
"""添加消息到指定聊天流"""
# 获取或创建流上下文
if stream_id not in self.stream_contexts:
self.stream_contexts[stream_id] = StreamContext(stream_id=stream_id)
self.stats.total_streams += 1
context = self.stream_contexts[stream_id]
context.set_chat_mode(ChatMode.FOCUS)
context.add_message(message)
# 使用 context_manager 添加消息
success = self.context_manager.add_message_to_context(stream_id, message)
if success:
logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}")
else:
logger.warning(f"添加消息到聊天流 {stream_id} 失败")
def update_message_and_refresh_energy(
self,
stream_id: str,
message_id: str,
interest_degree: float = None,
interest_value: float = None,
actions: list = None,
should_reply: bool = None,
):
"""更新消息信息"""
if stream_id in self.stream_contexts:
context = self.stream_contexts[stream_id]
context.update_message_info(message_id, interest_degree, actions, should_reply)
# 使用 context_manager 更新消息信息
context = self.context_manager.get_stream_context(stream_id)
if context:
context.update_message_info(message_id, interest_value, actions, should_reply)
def add_action_and_refresh_energy(self, stream_id: str, message_id: str, action: str):
"""添加动作到消息"""
if stream_id in self.stream_contexts:
context = self.stream_contexts[stream_id]
# 使用 context_manager 添加动作到消息
context = self.context_manager.get_stream_context(stream_id)
if context:
context.add_action_to_message(message_id, action)
async def _manager_loop(self):
@@ -136,19 +136,23 @@ class MessageManager:
active_streams = 0
total_unread = 0
for stream_id, context in self.stream_contexts.items():
if not context.is_active:
# 使用 context_manager 获取活跃的流
active_stream_ids = self.context_manager.get_active_streams()
for stream_id in active_stream_ids:
context = self.context_manager.get_stream_context(stream_id)
if not context:
continue
active_streams += 1
# 检查是否有未读消息
unread_messages = context.get_unread_messages()
unread_messages = self.context_manager.get_unread_messages(stream_id)
if unread_messages:
total_unread += len(unread_messages)
# 如果没有处理任务,创建一个
if not context.processing_task or context.processing_task.done():
if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done():
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
# 更新统计
@@ -157,14 +161,13 @@ class MessageManager:
async def _process_stream_messages(self, stream_id: str):
"""处理指定聊天流的消息"""
if stream_id not in self.stream_contexts:
context = self.context_manager.get_stream_context(stream_id)
if not context:
return
context = self.stream_contexts[stream_id]
try:
# 获取未读消息
unread_messages = context.get_unread_messages()
unread_messages = self.context_manager.get_unread_messages(stream_id)
if not unread_messages:
return
@@ -205,7 +208,7 @@ class MessageManager:
# 处理结果,标记消息为已读
if results.get("success", False):
self._clear_all_unread_messages(context)
self._clear_all_unread_messages(stream_id)
logger.debug(f"聊天流 {stream_id} 处理成功,清除了 {len(unread_messages)} 条未读消息")
else:
logger.warning(f"聊天流 {stream_id} 处理失败: {results.get('error_message', '未知错误')}")
@@ -213,7 +216,7 @@ class MessageManager:
except Exception as e:
logger.error(f"处理聊天流 {stream_id} 时发生异常,将清除所有未读消息: {e}")
# 出现异常时也清除未读消息,避免重复处理
self._clear_all_unread_messages(context)
self._clear_all_unread_messages(stream_id)
raise
logger.debug(f"聊天流 {stream_id} 消息处理完成")
@@ -226,35 +229,36 @@ class MessageManager:
def deactivate_stream(self, stream_id: str):
"""停用聊天流"""
if stream_id in self.stream_contexts:
context = self.stream_contexts[stream_id]
context = self.context_manager.get_stream_context(stream_id)
if context:
context.is_active = False
# 取消处理任务
if context.processing_task and not context.processing_task.done():
if hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done():
context.processing_task.cancel()
logger.info(f"停用聊天流: {stream_id}")
def activate_stream(self, stream_id: str):
"""激活聊天流"""
if stream_id in self.stream_contexts:
self.stream_contexts[stream_id].is_active = True
context = self.context_manager.get_stream_context(stream_id)
if context:
context.is_active = True
logger.info(f"激活聊天流: {stream_id}")
def get_stream_stats(self, stream_id: str) -> Optional[StreamStats]:
"""获取聊天流统计"""
if stream_id not in self.stream_contexts:
context = self.context_manager.get_stream_context(stream_id)
if not context:
return None
context = self.stream_contexts[stream_id]
return StreamStats(
stream_id=stream_id,
is_active=context.is_active,
unread_count=len(context.get_unread_messages()),
unread_count=len(self.context_manager.get_unread_messages(stream_id)),
history_count=len(context.history_messages),
last_check_time=context.last_check_time,
has_active_task=bool(context.processing_task and not context.processing_task.done()),
has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()),
)
def get_manager_stats(self) -> Dict[str, Any]:
@@ -270,18 +274,9 @@ class MessageManager:
def cleanup_inactive_streams(self, max_inactive_hours: int = 24):
"""清理不活跃的聊天流"""
current_time = time.time()
max_inactive_seconds = max_inactive_hours * 3600
inactive_streams = []
for stream_id, context in self.stream_contexts.items():
if current_time - context.last_check_time > max_inactive_seconds and not context.get_unread_messages():
inactive_streams.append(stream_id)
for stream_id in inactive_streams:
self.deactivate_stream(stream_id)
del self.stream_contexts[stream_id]
logger.info(f"清理不活跃聊天流: {stream_id}")
# 使用 context_manager 的自动清理功能
self.context_manager.cleanup_inactive_contexts(max_inactive_hours * 3600)
logger.info("已启动不活跃聊天流清理")
async def _check_and_handle_interruption(self, context: StreamContext, stream_id: str):
"""检查并处理消息打断"""
@@ -330,90 +325,29 @@ class MessageManager:
logger.debug(f"聊天流 {stream_id} 未触发打断,打断概率: {interruption_probability:.2f}")
def _calculate_stream_distribution_interval(self, context: StreamContext) -> float:
"""计算单个聊天流的分发周期 - 基于阈值感知的focus_energy"""
"""计算单个聊天流的分发周期 - 使用重构后的能量管理器"""
if not global_config.chat.dynamic_distribution_enabled:
return self.check_interval # 使用固定间隔
try:
from src.chat.energy_system import energy_manager
from src.plugin_system.apis.chat_api import get_chat_manager
# 获取聊天流和能量
chat_stream = get_chat_manager().get_stream(context.stream_id)
# 获取该流的focus_energy新的阈值感知版本
focus_energy = 0.5 # 默认值
avg_message_interest = 0.5 # 默认平均兴趣度
if chat_stream:
focus_energy = chat_stream.focus_energy
# 获取平均消息兴趣度用于更精确的计算 - 从StreamContext获取
history_messages = context.get_history_messages(limit=100)
unread_messages = context.get_unread_messages()
all_messages = history_messages + unread_messages
if all_messages:
message_interests = [msg.interest_degree for msg in all_messages if hasattr(msg, "interest_degree")]
avg_message_interest = sum(message_interests) / len(message_interests) if message_interests else 0.5
# 获取AFC阈值用于参考添加None值检查
reply_threshold = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
non_reply_threshold = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2)
high_match_threshold = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8)
# 使用配置参数
base_interval = global_config.chat.dynamic_distribution_base_interval
min_interval = global_config.chat.dynamic_distribution_min_interval
max_interval = global_config.chat.dynamic_distribution_max_interval
jitter_factor = global_config.chat.dynamic_distribution_jitter_factor
# 基于阈值感知的智能分发周期计算
if avg_message_interest >= high_match_threshold:
# 超高兴趣度:极快响应 (1-2秒)
interval_multiplier = 0.3 + (focus_energy - 0.7) * 2.0
elif avg_message_interest >= reply_threshold:
# 高兴趣度:快速响应 (2-6秒)
gap_from_reply = (avg_message_interest - reply_threshold) / (high_match_threshold - reply_threshold)
interval_multiplier = 0.6 + gap_from_reply * 0.4
elif avg_message_interest >= non_reply_threshold:
# 中等兴趣度:正常响应 (6-15秒)
gap_from_non_reply = (avg_message_interest - non_reply_threshold) / (reply_threshold - non_reply_threshold)
interval_multiplier = 1.2 + gap_from_non_reply * 1.8
# 使用能量管理器获取分发周期
interval = energy_manager.get_distribution_interval(focus_energy)
logger.debug(f"{context.stream_id} 分发周期: {interval:.2f}s (能量: {focus_energy:.3f})")
return interval
else:
# 低兴趣度:缓慢响应 (15-30秒)
gap_ratio = max(0, avg_message_interest / non_reply_threshold)
interval_multiplier = 3.0 + (1.0 - gap_ratio) * 3.0
# 默认间隔
return self.check_interval
# 应用focus_energy微调
energy_adjustment = 1.0 + (focus_energy - 0.5) * 0.5
interval = base_interval * interval_multiplier * energy_adjustment
# 添加随机扰动避免同步
import random
jitter = random.uniform(1.0 - jitter_factor, 1.0 + jitter_factor)
final_interval = interval * jitter
# 限制在合理范围内
final_interval = max(min_interval, min(max_interval, final_interval))
# 根据兴趣度级别调整日志级别
if avg_message_interest >= high_match_threshold:
log_level = "info"
elif avg_message_interest >= reply_threshold:
log_level = "info"
else:
log_level = "debug"
log_msg = (
f"{context.stream_id} 分发周期: {final_interval:.2f}s | "
f"focus_energy: {focus_energy:.3f} | "
f"avg_interest: {avg_message_interest:.3f} | "
f"阈值参考: {non_reply_threshold:.2f}/{reply_threshold:.2f}/{high_match_threshold:.2f}"
)
if log_level == "info":
logger.info(log_msg)
else:
logger.debug(log_msg)
return final_interval
except Exception as e:
logger.error(f"计算分发周期失败: {e}")
return self.check_interval
def _calculate_next_manager_delay(self) -> float:
"""计算管理器下次检查的延迟时间"""
@@ -421,8 +355,10 @@ class MessageManager:
min_delay = float("inf")
# 找到最近需要检查的流
for context in self.stream_contexts.values():
if not context.is_active:
active_stream_ids = self.context_manager.get_active_streams()
for stream_id in active_stream_ids:
context = self.context_manager.get_stream_context(stream_id)
if not context or not context.is_active:
continue
time_until_check = context.next_check_time - current_time
@@ -444,8 +380,12 @@ class MessageManager:
current_time = time.time()
processed_streams = 0
for stream_id, context in self.stream_contexts.items():
if not context.is_active:
# 使用 context_manager 获取活跃的流
active_stream_ids = self.context_manager.get_active_streams()
for stream_id in active_stream_ids:
context = self.context_manager.get_stream_context(stream_id)
if not context or not context.is_active:
continue
# 检查是否达到检查时间
@@ -463,7 +403,7 @@ class MessageManager:
context.next_check_time = current_time + context.distribution_interval
# 检查未读消息
unread_messages = context.get_unread_messages()
unread_messages = self.context_manager.get_unread_messages(stream_id)
if unread_messages:
processed_streams += 1
self.stats.total_unread_messages = len(unread_messages)
@@ -493,7 +433,7 @@ class MessageManager:
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
# 更新活跃流计数
active_count = sum(1 for ctx in self.stream_contexts.values() if ctx.is_active)
active_count = len(self.context_manager.get_active_streams())
self.stats.active_streams = active_count
if processed_streams > 0:
@@ -501,13 +441,16 @@ class MessageManager:
async def _check_all_streams_with_priority(self):
"""按优先级检查所有聊天流高focus_energy的流优先处理"""
if not self.stream_contexts:
if not self.context_manager.get_active_streams():
return
# 获取活跃的聊天流并按focus_energy排序
active_streams = []
for stream_id, context in self.stream_contexts.items():
if not context.is_active:
active_stream_ids = self.context_manager.get_active_streams()
for stream_id in active_stream_ids:
context = self.context_manager.get_stream_context(stream_id)
if not context or not context.is_active:
continue
# 获取focus_energy如果不存在则使用默认值
@@ -533,12 +476,12 @@ class MessageManager:
active_stream_count += 1
# 检查是否有未读消息
unread_messages = context.get_unread_messages()
unread_messages = self.context_manager.get_unread_messages(stream_id)
if unread_messages:
total_unread += len(unread_messages)
# 如果没有处理任务,创建一个
if not context.processing_task or context.processing_task.done():
if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done():
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
# 高优先级流的额外日志
@@ -554,56 +497,33 @@ class MessageManager:
self.stats.total_unread_messages = total_unread
def _calculate_stream_priority(self, context: StreamContext, focus_energy: float) -> float:
"""计算聊天流的优先级分数"""
from src.plugin_system.apis.chat_api import get_chat_manager
chat_stream = get_chat_manager().get_stream(context.stream_id)
# 基础优先级focus_energy
"""计算聊天流的优先级分数 - 简化版本主要使用focus_energy"""
# 使用重构后的能量管理器主要依赖focus_energy
base_priority = focus_energy
# 未读消息数量加权
# 简单的未读消息加权
unread_count = len(context.get_unread_messages())
message_count_bonus = min(unread_count * 0.1, 0.3) # 最多30%加成
message_bonus = min(unread_count * 0.05, 0.2) # 最多20%加成
# 时间加权:最近活跃的流优先级更高
# 简单的时间加权
current_time = time.time()
time_since_active = current_time - context.last_check_time
time_penalty = max(0, 1.0 - time_since_active / 3600.0) # 1小时内无惩罚
# 连续无回复惩罚 - 从StreamContext历史消息计算
if chat_stream:
# 计算连续无回复次数
consecutive_no_reply = 0
all_messages = context.get_history_messages(limit=50) + context.get_unread_messages()
for msg in reversed(all_messages):
if hasattr(msg, "should_reply") and msg.should_reply:
if not (hasattr(msg, "actions") and "reply" in (msg.actions or [])):
consecutive_no_reply += 1
else:
break
no_reply_penalty = max(0, 1.0 - consecutive_no_reply * 0.05) # 每次无回复降低5%
else:
no_reply_penalty = 1.0
# 综合优先级计算
final_priority = (
base_priority * 0.6 # 基础兴趣度权重60%
+ message_count_bonus * 0.2 # 消息数量权重20%
+ time_penalty * 0.1 # 时间权重10%
+ no_reply_penalty * 0.1 # 回复状态权重10%
)
time_bonus = max(0, 1.0 - time_since_active / 7200.0) * 0.1 # 2小时内衰减
final_priority = base_priority + message_bonus + time_bonus
return max(0.0, min(1.0, final_priority))
def _clear_all_unread_messages(self, context: StreamContext):
def _clear_all_unread_messages(self, stream_id: str):
"""清除指定上下文中的所有未读消息,防止意外情况导致消息一直未读"""
unread_messages = context.get_unread_messages()
unread_messages = self.context_manager.get_unread_messages(stream_id)
if not unread_messages:
return
logger.warning(f"正在清除 {len(unread_messages)} 条未读消息")
# 将所有未读消息标记为已读并移动到历史记录
# 将所有未读消息标记为已读
context = self.context_manager.get_stream_context(stream_id)
if context:
for msg in unread_messages[:]: # 使用切片复制避免迭代时修改列表
try:
context.mark_message_as_read(msg.message_id)

View File

@@ -120,186 +120,209 @@ class ChatStream:
"""设置聊天消息上下文"""
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
from src.common.data_models.database_data_model import DatabaseMessages
import json
# 简化转换,实际可能需要更完整的转换逻辑
# 安全获取message_info中的数据
message_info = getattr(message, "message_info", {})
user_info = getattr(message_info, "user_info", {})
group_info = getattr(message_info, "group_info", {})
# 提取reply_to信息从message_segment中查找reply类型的段
reply_to = None
if hasattr(message, "message_segment") and message.message_segment:
reply_to = self._extract_reply_from_segment(message.message_segment)
# 完整的数据转移逻辑
db_message = DatabaseMessages(
# 基础消息信息
message_id=getattr(message, "message_id", ""),
time=getattr(message, "time", time.time()),
chat_id=getattr(message, "chat_id", ""),
user_id=str(getattr(message.message_info, "user_info", {}).user_id)
if hasattr(message, "message_info") and hasattr(message.message_info, "user_info")
else "",
user_nickname=getattr(message.message_info, "user_info", {}).user_nickname
if hasattr(message, "message_info") and hasattr(message.message_info, "user_info")
else "",
user_platform=getattr(message.message_info, "user_info", {}).platform
if hasattr(message, "message_info") and hasattr(message.message_info, "user_info")
else "",
priority_mode=getattr(message, "priority_mode", None),
priority_info=str(getattr(message, "priority_info", None))
if hasattr(message, "priority_info") and message.priority_info
chat_id=self._generate_chat_id(message_info),
reply_to=reply_to,
# 兴趣度相关
interest_value=getattr(message, "interest_value", 0.0),
# 关键词
key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False)
if getattr(message, "key_words", None)
else None,
additional_config=getattr(getattr(message, "message_info", {}), "additional_config", None),
key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False)
if getattr(message, "key_words_lite", None)
else None,
# 消息状态标记
is_mentioned=getattr(message, "is_mentioned", None),
is_at=getattr(message, "is_at", False),
is_emoji=getattr(message, "is_emoji", False),
is_picid=getattr(message, "is_picid", False),
is_voice=getattr(message, "is_voice", False),
is_video=getattr(message, "is_video", False),
is_command=getattr(message, "is_command", False),
is_notify=getattr(message, "is_notify", False),
# 消息内容
processed_plain_text=getattr(message, "processed_plain_text", ""),
display_message=getattr(message, "processed_plain_text", ""), # 默认使用processed_plain_text
# 优先级信息
priority_mode=getattr(message, "priority_mode", None),
priority_info=json.dumps(getattr(message, "priority_info", None))
if getattr(message, "priority_info", None)
else None,
# 额外配置
additional_config=getattr(message_info, "additional_config", None),
# 用户信息
user_id=str(getattr(user_info, "user_id", "")),
user_nickname=getattr(user_info, "user_nickname", ""),
user_cardname=getattr(user_info, "user_cardname", None),
user_platform=getattr(user_info, "platform", ""),
# 群组信息
chat_info_group_id=getattr(group_info, "group_id", None),
chat_info_group_name=getattr(group_info, "group_name", None),
chat_info_group_platform=getattr(group_info, "platform", None),
# 聊天流信息
chat_info_user_id=str(getattr(user_info, "user_id", "")),
chat_info_user_nickname=getattr(user_info, "user_nickname", ""),
chat_info_user_cardname=getattr(user_info, "user_cardname", None),
chat_info_user_platform=getattr(user_info, "platform", ""),
chat_info_stream_id=self.stream_id,
chat_info_platform=self.platform,
chat_info_create_time=self.create_time,
chat_info_last_active_time=self.last_active_time,
# 新增兴趣度系统字段 - 添加安全处理
actions=self._safe_get_actions(message),
should_reply=getattr(message, "should_reply", False),
)
self.stream_context.set_current_message(db_message)
self.stream_context.priority_mode = getattr(message, "priority_mode", None)
self.stream_context.priority_info = getattr(message, "priority_info", None)
# 调试日志:记录数据转移情况
logger.debug(f"消息数据转移完成 - message_id: {db_message.message_id}, "
f"chat_id: {db_message.chat_id}, "
f"is_mentioned: {db_message.is_mentioned}, "
f"is_emoji: {db_message.is_emoji}, "
f"is_picid: {db_message.is_picid}, "
f"interest_value: {db_message.interest_value}")
def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]:
"""安全获取消息的actions字段"""
try:
actions = getattr(message, "actions", None)
if actions is None:
return None
# 如果是字符串尝试解析为JSON
if isinstance(actions, str):
try:
import json
actions = json.loads(actions)
except json.JSONDecodeError:
logger.warning(f"无法解析actions JSON字符串: {actions}")
return None
# 确保返回列表类型
if isinstance(actions, list):
# 过滤掉空值和非字符串元素
filtered_actions = [action for action in actions if action is not None and isinstance(action, str)]
return filtered_actions if filtered_actions else None
else:
logger.warning(f"actions字段类型不支持: {type(actions)}")
return None
except Exception as e:
logger.warning(f"获取actions字段失败: {e}")
return None
def _extract_reply_from_segment(self, segment) -> Optional[str]:
"""从消息段中提取reply_to信息"""
try:
if hasattr(segment, "type") and segment.type == "seglist":
# 递归搜索seglist中的reply段
if hasattr(segment, "data") and segment.data:
for seg in segment.data:
reply_id = self._extract_reply_from_segment(seg)
if reply_id:
return reply_id
elif hasattr(segment, "type") and segment.type == "reply":
# 找到reply段返回message_id
return str(segment.data) if segment.data else None
except Exception as e:
logger.warning(f"提取reply_to信息失败: {e}")
return None
def _generate_chat_id(self, message_info) -> str:
"""生成chat_id基于群组或用户信息"""
try:
group_info = getattr(message_info, "group_info", None)
user_info = getattr(message_info, "user_info", None)
if group_info and hasattr(group_info, "group_id") and group_info.group_id:
# 群聊使用群组ID
return f"{self.platform}_{group_info.group_id}"
elif user_info and hasattr(user_info, "user_id") and user_info.user_id:
# 私聊使用用户ID
return f"{self.platform}_{user_info.user_id}_private"
else:
# 默认使用stream_id
return self.stream_id
except Exception as e:
logger.warning(f"生成chat_id失败: {e}")
return self.stream_id
@property
def focus_energy(self) -> float:
"""动态计算的聊天流总体兴趣度,访问时自动更新"""
self._focus_energy = self._calculate_dynamic_focus_energy()
"""使用重构后的能量管理器计算focus_energy"""
try:
from src.chat.energy_system import energy_manager
# 获取所有消息
history_messages = self.stream_context.get_history_messages(limit=global_config.chat.max_context_size)
unread_messages = self.stream_context.get_unread_messages()
all_messages = history_messages + unread_messages
# 获取用户ID
user_id = None
if self.user_info and hasattr(self.user_info, "user_id"):
user_id = str(self.user_info.user_id)
# 使用能量管理器计算
energy = energy_manager.calculate_focus_energy(
stream_id=self.stream_id,
messages=all_messages,
user_id=user_id
)
# 更新内部存储
self._focus_energy = energy
logger.debug(f"聊天流 {self.stream_id} 能量: {energy:.3f}")
return energy
except Exception as e:
logger.error(f"获取focus_energy失败: {e}", exc_info=True)
# 返回缓存的值或默认值
if hasattr(self, '_focus_energy'):
return self._focus_energy
else:
return 0.5
@focus_energy.setter
def focus_energy(self, value: float):
"""设置focus_energy值主要用于初始化或特殊场景"""
self._focus_energy = max(0.0, min(1.0, value))
def _calculate_dynamic_focus_energy(self) -> float:
"""动态计算聊天流的总体兴趣度使用StreamContext历史消息"""
try:
# 从StreamContext获取历史消息计算统计数据
history_messages = self.stream_context.get_history_messages(limit=global_config.chat.max_context_size)
unread_messages = self.stream_context.get_unread_messages()
all_messages = history_messages + unread_messages
# 计算基于历史消息的统计数据
if all_messages:
# 基础分:平均消息兴趣度
message_interests = [msg.interest_degree for msg in all_messages if hasattr(msg, "interest_degree")]
avg_message_interest = sum(message_interests) / len(message_interests) if message_interests else 0.3
# 动作参与度:有动作的消息比例
messages_with_actions = [msg for msg in all_messages if hasattr(msg, "actions") and msg.actions]
action_rate = len(messages_with_actions) / len(all_messages)
# 回复活跃度:应该回复且已回复的消息比例
should_reply_messages = [
msg for msg in all_messages if hasattr(msg, "should_reply") and msg.should_reply
]
replied_messages = [
msg for msg in should_reply_messages if hasattr(msg, "actions") and "reply" in (msg.actions or [])
]
reply_rate = len(replied_messages) / len(should_reply_messages) if should_reply_messages else 0.0
# 获取最后交互时间
if all_messages:
self.last_interaction_time = max(msg.time for msg in all_messages)
# 连续无回复计算:从最近的未回复消息计数
consecutive_no_reply = 0
for msg in reversed(all_messages):
if hasattr(msg, "should_reply") and msg.should_reply:
if not (hasattr(msg, "actions") and "reply" in (msg.actions or [])):
consecutive_no_reply += 1
else:
break
else:
# 没有历史消息时的默认值
avg_message_interest = 0.3
action_rate = 0.0
reply_rate = 0.0
consecutive_no_reply = 0
self.last_interaction_time = time.time()
# 获取用户关系分(对于私聊,群聊无效)
relationship_factor = self._get_user_relationship_score()
# 时间衰减因子:最近活跃度
current_time = time.time()
if not hasattr(self, "last_interaction_time") or not self.last_interaction_time:
self.last_interaction_time = current_time
time_since_interaction = current_time - self.last_interaction_time
time_decay = max(0.3, 1.0 - min(time_since_interaction / (7 * 24 * 3600), 0.7)) # 7天衰减
# 连续无回复惩罚
no_reply_penalty = max(0.1, 1.0 - consecutive_no_reply * 0.1)
# 获取AFC系统阈值添加None值检查
reply_threshold = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
non_reply_threshold = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2)
high_match_threshold = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8)
# 计算与不同阈值的差距比例
reply_gap_ratio = max(0, (avg_message_interest - reply_threshold) / max(0.1, (1.0 - reply_threshold)))
non_reply_gap_ratio = max(
0, (avg_message_interest - non_reply_threshold) / max(0.1, (1.0 - non_reply_threshold))
)
high_match_gap_ratio = max(
0, (avg_message_interest - high_match_threshold) / max(0.1, (1.0 - high_match_threshold))
)
# 基于阈值差距比例的基础分计算
threshold_based_score = (
reply_gap_ratio * 0.6 # 回复阈值差距权重60%
+ non_reply_gap_ratio * 0.2 # 非回复阈值差距权重20%
+ high_match_gap_ratio * 0.2 # 高匹配阈值差距权重20%
)
# 动态权重调整:根据平均兴趣度水平调整权重分配
if avg_message_interest >= high_match_threshold:
# 高兴趣度:更注重阈值差距
threshold_weight = 0.7
activity_weight = 0.2
relationship_weight = 0.1
elif avg_message_interest >= reply_threshold:
# 中等兴趣度:平衡权重
threshold_weight = 0.5
activity_weight = 0.3
relationship_weight = 0.2
else:
# 低兴趣度:更注重活跃度提升
threshold_weight = 0.3
activity_weight = 0.5
relationship_weight = 0.2
# 计算活跃度得分
activity_score = action_rate * 0.6 + reply_rate * 0.4
# 综合计算:基于阈值的动态加权
focus_energy = (
(
threshold_based_score * threshold_weight # 阈值差距基础分
+ activity_score * activity_weight # 活跃度得分
+ relationship_factor * relationship_weight # 关系得分
+ self.base_interest_energy * 0.05 # 基础兴趣微调
)
* time_decay
* no_reply_penalty
)
# 确保在合理范围内
focus_energy = max(0.1, min(1.0, focus_energy))
# 应用非线性变换增强区分度
if focus_energy >= 0.7:
# 高兴趣度区域:指数增强,更敏感
focus_energy = 0.7 + (focus_energy - 0.7) ** 0.8
elif focus_energy >= 0.4:
# 中等兴趣度区域:线性保持
pass
else:
# 低兴趣度区域:对数压缩,减少区分度
focus_energy = 0.4 * (focus_energy / 0.4) ** 1.2
return max(0.1, min(1.0, focus_energy))
except Exception as e:
logger.error(f"计算动态focus_energy失败: {e}")
return self.base_interest_energy
def _get_user_relationship_score(self) -> float:
"""外部系统获取用户关系分"""
"""新的兴趣度管理系统获取用户关系分"""
try:
# 尝试从兴趣评分系统获取用户关系分
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
chatter_interest_scoring_system,
)
# 使用新的兴趣度管理系统
from src.chat.interest_system import interest_manager
if self.user_info and hasattr(self.user_info, "user_id"):
return chatter_interest_scoring_system.get_user_relationship(str(self.user_info.user_id))
user_id = str(self.user_info.user_id)
# 获取用户交互历史作为关系分的基础
interaction_calc = interest_manager.calculators.get(
interest_manager.InterestSourceType.USER_INTERACTION
)
if interaction_calc:
return interaction_calc.calculate({"user_id": user_id})
except Exception:
pass
@@ -378,12 +401,13 @@ class ChatStream:
chat_info_platform=db_msg.chat_info_platform,
chat_info_create_time=db_msg.chat_info_create_time,
chat_info_last_active_time=db_msg.chat_info_last_active_time,
# 新增的兴趣度系统字段
interest_degree=getattr(db_msg, "interest_degree", 0.0) or 0.0,
actions=actions,
should_reply=getattr(db_msg, "should_reply", False) or False,
)
# 添加调试日志检查从数据库加载的interest_value
logger.debug(f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}")
# 标记为已读并添加到历史消息
db_message.is_read = True
self.stream_context.history_messages.append(db_message)

View File

@@ -218,4 +218,93 @@ class MessageStorage:
except Exception:
return match.group(0)
return re.sub(r"\[图片:([^\]]+)\]", replace_match, text)
@staticmethod
def update_message_interest_value(message_id: str, interest_value: float) -> None:
"""
更新数据库中消息的interest_value字段
Args:
message_id: 消息ID
interest_value: 兴趣度值
"""
try:
with get_db_session() as session:
# 更新消息的interest_value字段
stmt = update(Messages).where(Messages.message_id == message_id).values(interest_value=interest_value)
result = session.execute(stmt)
session.commit()
if result.rowcount > 0:
logger.debug(f"成功更新消息 {message_id} 的interest_value为 {interest_value}")
else:
logger.warning(f"未找到消息 {message_id}无法更新interest_value")
except Exception as e:
logger.error(f"更新消息 {message_id} 的interest_value失败: {e}")
raise
@staticmethod
def fix_zero_interest_values(chat_id: str, since_time: float) -> int:
"""
修复指定聊天中interest_value为0或null的历史消息记录
Args:
chat_id: 聊天ID
since_time: 从指定时间开始修复(时间戳)
Returns:
修复的记录数量
"""
try:
with get_db_session() as session:
from sqlalchemy import select, update
from src.common.database.sqlalchemy_models import Messages
# 查找需要修复的记录interest_value为0、null或很小的值
query = select(Messages).where(
(Messages.chat_id == chat_id) &
(Messages.time >= since_time) &
(
(Messages.interest_value == 0) |
(Messages.interest_value.is_(None)) |
(Messages.interest_value < 0.1)
)
).limit(50) # 限制每次修复的数量,避免性能问题
messages_to_fix = session.execute(query).scalars().all()
fixed_count = 0
for msg in messages_to_fix:
# 为这些消息设置一个合理的默认兴趣度
# 可以基于消息长度、内容或其他因素计算
default_interest = 0.3 # 默认中等兴趣度
# 如果消息内容较长,可能是重要消息,兴趣度稍高
if hasattr(msg, 'processed_plain_text') and msg.processed_plain_text:
text_length = len(msg.processed_plain_text)
if text_length > 50: # 长消息
default_interest = 0.4
elif text_length > 20: # 中等长度消息
default_interest = 0.35
# 如果是被@的消息,兴趣度更高
if getattr(msg, 'is_mentioned', False):
default_interest = min(default_interest + 0.2, 0.8)
# 执行更新
update_stmt = update(Messages).where(
Messages.message_id == msg.message_id
).values(interest_value=default_interest)
result = session.execute(update_stmt)
if result.rowcount > 0:
fixed_count += 1
logger.debug(f"修复消息 {msg.message_id} 的interest_value为 {default_interest}")
session.commit()
logger.info(f"共修复了 {fixed_count} 条历史消息的interest_value值")
return fixed_count
except Exception as e:
logger.error(f"修复历史消息interest_value失败: {e}")
return 0

View File

@@ -96,7 +96,6 @@ class DatabaseMessages(BaseDataModel):
chat_info_create_time: float = 0.0,
chat_info_last_active_time: float = 0.0,
# 新增字段
interest_degree: float = 0.0,
actions: Optional[list] = None,
should_reply: bool = False,
**kwargs: Any,
@@ -108,7 +107,6 @@ class DatabaseMessages(BaseDataModel):
self.interest_value = interest_value
# 新增字段
self.interest_degree = interest_degree
self.actions = actions
self.should_reply = should_reply
@@ -201,7 +199,6 @@ class DatabaseMessages(BaseDataModel):
"selected_expressions": self.selected_expressions,
"is_read": self.is_read,
# 新增字段
"interest_degree": self.interest_degree,
"actions": self.actions,
"should_reply": self.should_reply,
"user_id": self.user_info.user_id,
@@ -221,17 +218,17 @@ class DatabaseMessages(BaseDataModel):
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
}
def update_message_info(self, interest_degree: float = None, actions: list = None, should_reply: bool = None):
def update_message_info(self, interest_value: float = None, actions: list = None, should_reply: bool = None):
"""
更新消息信息
Args:
interest_degree: 兴趣度值
interest_value: 兴趣度值
actions: 执行的动作列表
should_reply: 是否应该回复
"""
if interest_degree is not None:
self.interest_degree = interest_degree
if interest_value is not None:
self.interest_value = interest_value
if actions is not None:
self.actions = actions
if should_reply is not None:
@@ -268,7 +265,7 @@ class DatabaseMessages(BaseDataModel):
return {
"message_id": self.message_id,
"time": self.time,
"interest_degree": self.interest_degree,
"interest_value": self.interest_value,
"actions": self.actions,
"should_reply": self.should_reply,
"user_nickname": self.user_info.user_nickname,

View File

@@ -61,27 +61,27 @@ class StreamContext(BaseDataModel):
self._detect_chat_type(message)
def update_message_info(
self, message_id: str, interest_degree: float = None, actions: list = None, should_reply: bool = None
self, message_id: str, interest_value: float = None, actions: list = None, should_reply: bool = None
):
"""
更新消息信息
Args:
message_id: 消息ID
interest_degree: 兴趣度值
interest_value: 兴趣度值
actions: 执行的动作列表
should_reply: 是否应该回复
"""
# 在未读消息中查找并更新
for message in self.unread_messages:
if message.message_id == message_id:
message.update_message_info(interest_degree, actions, should_reply)
message.update_message_info(interest_value, actions, should_reply)
break
# 在历史消息中查找并更新
for message in self.history_messages:
if message.message_id == message_id:
message.update_message_info(interest_degree, actions, should_reply)
message.update_message_info(interest_value, actions, should_reply)
break
def add_action_to_message(self, message_id: str, action: str):

View File

@@ -174,7 +174,6 @@ class Messages(Base):
is_notify = Column(Boolean, nullable=False, default=False)
# 兴趣度系统字段
interest_degree = Column(Float, nullable=True, default=0.0)
actions = Column(Text, nullable=True) # JSON格式存储动作列表
should_reply = Column(Boolean, nullable=True, default=False)
@@ -183,7 +182,6 @@ class Messages(Base):
Index("idx_messages_chat_id", "chat_id"),
Index("idx_messages_time", "time"),
Index("idx_messages_user_id", "user_id"),
Index("idx_messages_interest_degree", "interest_degree"),
Index("idx_messages_should_reply", "should_reply"),
)

View File

@@ -368,40 +368,29 @@ class ChatterPlanFilter:
interest_scores = {}
try:
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
chatter_interest_scoring_system as interest_scoring_system,
)
from src.common.data_models.database_data_model import DatabaseMessages
from src.chat.interest_system import interest_manager
# 转换消息格式
db_messages = []
# 使用新的兴趣度管理系统计算评分
for msg_dict in messages:
try:
db_msg = DatabaseMessages(
message_id=msg_dict.get("message_id", ""),
time=msg_dict.get("time", time.time()),
chat_id=msg_dict.get("chat_id", ""),
processed_plain_text=msg_dict.get("processed_plain_text", ""),
user_id=msg_dict.get("user_id", ""),
user_nickname=msg_dict.get("user_nickname", ""),
user_platform=msg_dict.get("platform", "qq"),
chat_info_group_id=msg_dict.get("group_id", ""),
chat_info_group_name=msg_dict.get("group_name", ""),
chat_info_group_platform=msg_dict.get("platform", "qq"),
)
db_messages.append(db_msg)
except Exception as e:
logger.warning(f"转换消息格式失败: {e}")
continue
# 构建计算上下文
calc_context = {
"stream_id": msg_dict.get("chat_id", ""),
"user_id": msg_dict.get("user_id"),
}
# 计算兴趣度评分
if db_messages:
bot_nickname = global_config.bot.nickname or "麦麦"
scores = await interest_scoring_system.calculate_interest_scores(db_messages, bot_nickname)
# 计算消息兴趣度
interest_score = interest_manager.calculate_message_interest(
message=msg_dict,
context=calc_context
)
# 构建兴趣度字典
for score in scores:
interest_scores[score.message_id] = score.total_score
interest_scores[msg_dict.get("message_id", "")] = interest_score
except Exception as e:
logger.warning(f"计算消息兴趣度失败: {e}")
continue
except Exception as e:
logger.warning(f"获取兴趣度评分失败: {e}")

View File

@@ -4,13 +4,13 @@
"""
from dataclasses import asdict
import time
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import ChatterInterestScoringSystem
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
from src.chat.interest_system import interest_manager
from src.mood.mood_manager import mood_manager
@@ -52,14 +52,7 @@ class ChatterActionPlanner:
self.generator = ChatterPlanGenerator(chat_id)
self.executor = ChatterPlanExecutor(action_manager)
# 初始化兴趣度评分系统
self.interest_scoring = ChatterInterestScoringSystem()
# 创建新的关系追踪器
self.relationship_tracker = ChatterRelationshipTracker(self.interest_scoring)
# 设置执行器的关系追踪器
self.executor.set_relationship_tracker(self.relationship_tracker)
# 使用新的统一兴趣度管理系统
# 规划器统计
self.planner_stats = {
@@ -107,43 +100,39 @@ class ChatterActionPlanner:
initial_plan.available_actions = self.action_manager.get_using_actions()
unread_messages = context.get_unread_messages() if context else []
# 2. 兴趣度评分 - 只对未读消息进行评分
if unread_messages:
bot_nickname = global_config.bot.nickname
interest_scores = await self.interest_scoring.calculate_interest_scores(unread_messages, bot_nickname)
# 3. 根据兴趣度调整可用动作
if interest_scores:
latest_score = max(interest_scores, key=lambda s: s.total_score)
latest_message = next(
(msg for msg in unread_messages if msg.message_id == latest_score.message_id), None
)
should_reply, score = self.interest_scoring.should_reply(latest_score, latest_message)
# 2. 使用新的兴趣度管理系统进行评分
score = 0.0
should_reply = False
reply_not_available = False
if not should_reply and "reply" in initial_plan.available_actions:
logger.info(f"兴趣度不足 ({latest_score.total_score:.2f}),移除回复")
reply_not_available = True
# 更新ChatStream的兴趣度数据
from src.plugin_system.apis.chat_api import get_chat_manager
chat_stream = get_chat_manager().get_stream(self.chat_id)
logger.debug(f"已更新聊天 {self.chat_id} 的ChatStream兴趣度分数: {score:.3f}")
if unread_messages:
# 获取用户ID
user_id = None
if unread_messages[0].user_id:
user_id = unread_messages[0].user_id
# 更新情绪状态和ChatStream兴趣度数据
if latest_message and score > 0:
chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id)
await chat_mood.update_mood_by_message(latest_message, score)
logger.debug(f"已更新聊天 {self.chat_id} 的情绪状态,兴趣度: {score:.3f}")
# 构建计算上下文
calc_context = {
"stream_id": self.chat_id,
"user_id": user_id,
}
# 为所有未读消息记录兴趣度信息
# 为每条消息计算兴趣度
for message in unread_messages:
# 查找对应的兴趣度评分
message_score = next((s for s in interest_scores if s.message_id == message.message_id), None)
if message_score:
message.interest_degree = message_score.total_score
message.should_reply = self.interest_scoring.should_reply(message_score, message)[0]
logger.debug(f"已记录消息 {message.message_id} - 兴趣度: {message_score.total_score:.3f}, 应回复: {message.should_reply}")
try:
# 使用新的兴趣度管理器计算
message_interest = interest_manager.calculate_message_interest(
message=message.__dict__,
context=calc_context
)
# 更新消息的兴趣度
message.interest_value = message_interest
# 简单的回复决策逻辑:兴趣度超过阈值则回复
message.should_reply = message_interest > global_config.affinity_flow.non_reply_action_interest_threshold
logger.debug(f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}")
# 更新StreamContext中的消息信息并刷新focus_energy
if context:
@@ -151,25 +140,35 @@ class ChatterActionPlanner:
message_manager.update_message_and_refresh_energy(
stream_id=self.chat_id,
message_id=message.message_id,
interest_degree=message_score.total_score,
interest_value=message_interest,
should_reply=message.should_reply
)
# 更新数据库中的消息记录
try:
from src.chat.message_receive.storage import MessageStorage
MessageStorage.update_message_interest_value(message.message_id, message_interest)
logger.debug(f"已更新数据库中消息 {message.message_id} 的兴趣度为: {message_interest:.3f}")
except Exception as e:
logger.warning(f"更新数据库消息兴趣度失败: {e}")
# 更新话题兴趣度
interest_manager.update_topic_interest(message.__dict__, message_interest)
# 记录最高分
if message_interest > score:
score = message_interest
if message.should_reply:
should_reply = True
else:
# 如果没有找到评分,设置默认值
message.interest_degree = 0.0
reply_not_available = True
except Exception as e:
logger.warning(f"计算消息 {message.message_id} 兴趣度失败: {e}")
# 设置默认值
message.interest_value = 0.0
message.should_reply = False
# 更新StreamContext中的消息信息并刷新focus_energy
if context:
from src.chat.message_manager.message_manager import message_manager
message_manager.update_message_and_refresh_energy(
stream_id=self.chat_id,
message_id=message.message_id,
interest_degree=0.0,
should_reply=False
)
# base_threshold = self.interest_scoring.reply_threshold
# 检查兴趣度是否达到非回复动作阈值
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
if score < non_reply_action_interest_threshold:
@@ -191,26 +190,16 @@ class ChatterActionPlanner:
plan_filter = ChatterPlanFilter(self.chat_id, available_actions)
filtered_plan = await plan_filter.filter(reply_not_available, initial_plan)
# 检查filtered_plan是否有reply动作以便记录reply action
has_reply_action = False
for decision in filtered_plan.decided_actions:
if decision.action_type == "reply":
has_reply_action = True
self.interest_scoring.record_reply_action(has_reply_action)
# 检查filtered_plan是否有reply动作用于统计
has_reply_action = any(decision.action_type == "reply" for decision in filtered_plan.decided_actions)
# 5. 使用 PlanExecutor 执行 Plan
execution_result = await self.executor.execute(filtered_plan)
# 6. 动作记录现在由ChatterActionManager统一处理
# 动作记录逻辑已移至ChatterActionManager.execute_action方法中
# 7. 根据执行结果更新统计信息
# 6. 根据执行结果更新统计信息
self._update_stats_from_execution_result(execution_result)
# 8. 检查关系更新
await self.relationship_tracker.check_and_update_relationships()
# 8. 返回结果
# 7. 返回结果
return self._build_return_result(filtered_plan)
except Exception as e:
@@ -259,37 +248,10 @@ class ChatterActionPlanner:
return final_actions_dict, final_target_message_dict
def get_user_relationship(self, user_id: str) -> float:
"""获取用户关系分"""
return self.interest_scoring.get_user_relationship(user_id)
def update_interest_keywords(self, new_keywords: Dict[str, List[str]]):
"""更新兴趣关键词(已弃用,仅保留用于兼容性)"""
logger.info("传统关键词匹配已移除,此方法仅保留用于兼容性")
# 此方法已弃用因为现在完全使用embedding匹配
def get_planner_stats(self) -> Dict[str, any]:
"""获取规划器统计"""
return self.planner_stats.copy()
def get_interest_scoring_stats(self) -> Dict[str, any]:
"""获取兴趣度评分统计"""
return {
"no_reply_count": self.interest_scoring.no_reply_count,
"max_no_reply_count": self.interest_scoring.max_no_reply_count,
"reply_threshold": self.interest_scoring.reply_threshold,
"mention_threshold": self.interest_scoring.mention_threshold,
"user_relationships": len(self.interest_scoring.user_relationships),
}
def get_relationship_stats(self) -> Dict[str, any]:
"""获取用户关系统计"""
return {
"tracking_users": len(self.relationship_tracker.tracking_users),
"relationship_history": len(self.relationship_tracker.relationship_history),
"max_tracking_users": self.relationship_tracker.max_tracking_users,
}
def get_current_mood_state(self) -> str:
"""获取当前聊天的情绪状态"""
chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id)