refactor(interest-system): 移除旧兴趣度管理系统,迁移到插件内部实现
移除旧的集中式兴趣度管理系统(interest_manager.py),将兴趣度计算功能迁移到affinity_flow_chatter插件内部实现。主要包括: - 删除interest_manager.py及其相关导入引用 - 修改RelationshipEnergyCalculator使用插件内部的关系分计算 - 重构StreamContextManager使用插件内部的兴趣度评分系统 - 更新ChatStream、PlanFilter、Planner等组件使用新的插件接口 - 简化上下文管理器,移除事件系统和验证器相关代码 此次重构提高了模块独立性,减少了核心代码对插件功能的直接依赖,符合"高内聚低耦合"的设计原则。
This commit is contained in:
@@ -204,24 +204,17 @@ class RelationshipEnergyCalculator(EnergyCalculator):
|
|||||||
if not user_id:
|
if not user_id:
|
||||||
return 0.3
|
return 0.3
|
||||||
|
|
||||||
|
# 使用插件内部的兴趣度评分系统获取关系分
|
||||||
try:
|
try:
|
||||||
# 使用新的兴趣度管理系统获取用户关系分
|
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||||
from src.chat.interest_system import interest_manager
|
|
||||||
|
|
||||||
# 获取用户交互历史作为关系分的基础
|
relationship_score = chatter_interest_scoring_system._calculate_relationship_score(user_id)
|
||||||
interaction_calc = interest_manager.calculators.get(
|
logger.debug(f"使用插件内部系统计算关系分: {relationship_score:.3f}")
|
||||||
interest_manager.InterestSourceType.USER_INTERACTION
|
|
||||||
)
|
|
||||||
if interaction_calc:
|
|
||||||
relationship_score = interaction_calc.calculate({"user_id": user_id})
|
|
||||||
logger.debug(f"用户关系分数: {relationship_score:.3f}")
|
|
||||||
return max(0.0, min(1.0, relationship_score))
|
return max(0.0, min(1.0, relationship_score))
|
||||||
else:
|
|
||||||
# 默认基础分
|
except Exception as e:
|
||||||
return 0.3
|
logger.warning(f"插件内部关系分计算失败,使用默认值: {e}")
|
||||||
except Exception:
|
return 0.3 # 默认基础分
|
||||||
# 默认基础分
|
|
||||||
return 0.3
|
|
||||||
|
|
||||||
def get_weight(self) -> float:
|
def get_weight(self) -> float:
|
||||||
return 0.1
|
return 0.1
|
||||||
|
|||||||
@@ -1,30 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
兴趣度系统模块
|
兴趣度系统模块
|
||||||
提供统一、稳定的消息兴趣度计算和管理功能
|
提供机器人兴趣标签和智能匹配功能
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .interest_manager import (
|
|
||||||
InterestManager,
|
|
||||||
InterestSourceType,
|
|
||||||
InterestFactor,
|
|
||||||
InterestCalculator,
|
|
||||||
MessageContentInterestCalculator,
|
|
||||||
TopicInterestCalculator,
|
|
||||||
UserInteractionInterestCalculator,
|
|
||||||
interest_manager
|
|
||||||
)
|
|
||||||
from .bot_interest_manager import BotInterestManager, bot_interest_manager
|
from .bot_interest_manager import BotInterestManager, bot_interest_manager
|
||||||
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
|
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"InterestManager",
|
|
||||||
"InterestSourceType",
|
|
||||||
"InterestFactor",
|
|
||||||
"InterestCalculator",
|
|
||||||
"MessageContentInterestCalculator",
|
|
||||||
"TopicInterestCalculator",
|
|
||||||
"UserInteractionInterestCalculator",
|
|
||||||
"interest_manager",
|
|
||||||
"BotInterestManager",
|
"BotInterestManager",
|
||||||
"bot_interest_manager",
|
"bot_interest_manager",
|
||||||
"BotInterestTag",
|
"BotInterestTag",
|
||||||
|
|||||||
@@ -1,430 +0,0 @@
|
|||||||
"""
|
|
||||||
重构后的消息兴趣值计算系统
|
|
||||||
提供稳定、可靠的消息兴趣度计算和管理功能
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
from typing import Dict, List, Optional, Tuple, Any, Union, TypedDict
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from enum import Enum
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("interest_system")
|
|
||||||
|
|
||||||
|
|
||||||
class InterestSourceType(Enum):
|
|
||||||
"""兴趣度来源类型"""
|
|
||||||
MESSAGE_CONTENT = "message_content" # 消息内容
|
|
||||||
USER_INTERACTION = "user_interaction" # 用户交互
|
|
||||||
TOPIC_RELEVANCE = "topic_relevance" # 话题相关性
|
|
||||||
RELATIONSHIP_SCORE = "relationship_score" # 关系分数
|
|
||||||
HISTORICAL_PATTERN = "historical_pattern" # 历史模式
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InterestFactor:
|
|
||||||
"""兴趣度因子"""
|
|
||||||
source_type: InterestSourceType
|
|
||||||
value: float
|
|
||||||
weight: float = 1.0
|
|
||||||
decay_rate: float = 0.1 # 衰减率
|
|
||||||
last_updated: float = field(default_factory=time.time)
|
|
||||||
|
|
||||||
def get_current_value(self) -> float:
|
|
||||||
"""获取当前值(考虑时间衰减)"""
|
|
||||||
age = time.time() - self.last_updated
|
|
||||||
decay_factor = max(0.1, 1.0 - (age * self.decay_rate / (24 * 3600))) # 按天衰减
|
|
||||||
return self.value * decay_factor
|
|
||||||
|
|
||||||
def update_value(self, new_value: float) -> None:
|
|
||||||
"""更新值"""
|
|
||||||
self.value = max(0.0, min(1.0, new_value))
|
|
||||||
self.last_updated = time.time()
|
|
||||||
|
|
||||||
|
|
||||||
class InterestCalculator(ABC):
|
|
||||||
"""兴趣度计算器抽象基类"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def calculate(self, context: Dict[str, Any]) -> float:
|
|
||||||
"""计算兴趣度"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_confidence(self) -> float:
|
|
||||||
"""获取计算置信度"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MessageData(TypedDict):
|
|
||||||
"""消息数据类型定义"""
|
|
||||||
message_id: str
|
|
||||||
processed_plain_text: str
|
|
||||||
is_emoji: bool
|
|
||||||
is_picid: bool
|
|
||||||
is_mentioned: bool
|
|
||||||
is_command: bool
|
|
||||||
key_words: str
|
|
||||||
user_id: str
|
|
||||||
time: float
|
|
||||||
|
|
||||||
|
|
||||||
class InterestContext(TypedDict):
|
|
||||||
"""兴趣度计算上下文"""
|
|
||||||
stream_id: str
|
|
||||||
user_id: Optional[str]
|
|
||||||
message: MessageData
|
|
||||||
|
|
||||||
|
|
||||||
class InterestResult(TypedDict):
|
|
||||||
"""兴趣度计算结果"""
|
|
||||||
value: float
|
|
||||||
confidence: float
|
|
||||||
source_scores: Dict[InterestSourceType, float]
|
|
||||||
cached: bool
|
|
||||||
|
|
||||||
|
|
||||||
class MessageContentInterestCalculator(InterestCalculator):
|
|
||||||
"""消息内容兴趣度计算器"""
|
|
||||||
|
|
||||||
def calculate(self, context: Dict[str, Any]) -> float:
|
|
||||||
"""基于消息内容计算兴趣度"""
|
|
||||||
message = context.get("message", {})
|
|
||||||
if not message:
|
|
||||||
return 0.3 # 默认值
|
|
||||||
|
|
||||||
# 提取消息特征
|
|
||||||
text_length = len(message.get("processed_plain_text", ""))
|
|
||||||
has_emoji = message.get("is_emoji", False)
|
|
||||||
has_image = message.get("is_picid", False)
|
|
||||||
is_mentioned = message.get("is_mentioned", False)
|
|
||||||
is_command = message.get("is_command", False)
|
|
||||||
|
|
||||||
# 基础分数
|
|
||||||
base_score = 0.3
|
|
||||||
|
|
||||||
# 文本长度加权
|
|
||||||
if text_length > 0:
|
|
||||||
text_score = min(0.3, text_length / 200) # 200字符为满分
|
|
||||||
base_score += text_score * 0.3
|
|
||||||
|
|
||||||
# 多媒体内容加权
|
|
||||||
if has_emoji:
|
|
||||||
base_score += 0.1
|
|
||||||
if has_image:
|
|
||||||
base_score += 0.2
|
|
||||||
|
|
||||||
# 交互特征加权
|
|
||||||
if is_mentioned:
|
|
||||||
base_score += 0.2
|
|
||||||
if is_command:
|
|
||||||
base_score += 0.1
|
|
||||||
|
|
||||||
return min(1.0, base_score)
|
|
||||||
|
|
||||||
def get_confidence(self) -> float:
|
|
||||||
return 0.8
|
|
||||||
|
|
||||||
|
|
||||||
class TopicInterestCalculator(InterestCalculator):
|
|
||||||
"""话题兴趣度计算器"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.topic_interests: Dict[str, float] = {}
|
|
||||||
self.topic_decay_rate = 0.05 # 话题兴趣度衰减率
|
|
||||||
|
|
||||||
def update_topic_interest(self, topic: str, interest_value: float):
|
|
||||||
"""更新话题兴趣度"""
|
|
||||||
current_interest = self.topic_interests.get(topic, 0.3)
|
|
||||||
# 平滑更新
|
|
||||||
new_interest = current_interest * 0.7 + interest_value * 0.3
|
|
||||||
self.topic_interests[topic] = max(0.0, min(1.0, new_interest))
|
|
||||||
|
|
||||||
logger.debug(f"更新话题 '{topic}' 兴趣度: {current_interest:.3f} -> {new_interest:.3f}")
|
|
||||||
|
|
||||||
def calculate(self, context: Dict[str, Any]) -> float:
|
|
||||||
"""基于话题相关性计算兴趣度"""
|
|
||||||
message = context.get("message", {})
|
|
||||||
keywords = message.get("key_words", "[]")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import json
|
|
||||||
keyword_list = json.loads(keywords) if keywords else []
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
keyword_list = []
|
|
||||||
|
|
||||||
if not keyword_list:
|
|
||||||
return 0.4 # 无关键词时的默认值
|
|
||||||
|
|
||||||
# 计算相关话题的平均兴趣度
|
|
||||||
total_interest = 0.0
|
|
||||||
relevant_topics = 0
|
|
||||||
|
|
||||||
for keyword in keyword_list[:5]: # 最多取前5个关键词
|
|
||||||
# 查找相关话题
|
|
||||||
for topic, interest in self.topic_interests.items():
|
|
||||||
if keyword.lower() in topic.lower() or topic.lower() in keyword.lower():
|
|
||||||
total_interest += interest
|
|
||||||
relevant_topics += 1
|
|
||||||
break
|
|
||||||
|
|
||||||
if relevant_topics > 0:
|
|
||||||
return min(1.0, total_interest / relevant_topics)
|
|
||||||
else:
|
|
||||||
# 新话题,给予基础兴趣度
|
|
||||||
for keyword in keyword_list[:3]:
|
|
||||||
self.topic_interests[keyword] = 0.5
|
|
||||||
return 0.5
|
|
||||||
|
|
||||||
def get_confidence(self) -> float:
|
|
||||||
return 0.7
|
|
||||||
|
|
||||||
|
|
||||||
class UserInteractionInterestCalculator(InterestCalculator):
|
|
||||||
"""用户交互兴趣度计算器"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.interaction_history: List[Dict] = []
|
|
||||||
self.max_history_size = 100
|
|
||||||
|
|
||||||
def add_interaction(self, user_id: str, interaction_type: str, value: float):
|
|
||||||
"""添加交互记录"""
|
|
||||||
self.interaction_history.append({
|
|
||||||
"user_id": user_id,
|
|
||||||
"type": interaction_type,
|
|
||||||
"value": value,
|
|
||||||
"timestamp": time.time()
|
|
||||||
})
|
|
||||||
|
|
||||||
# 保持历史记录大小
|
|
||||||
if len(self.interaction_history) > self.max_history_size:
|
|
||||||
self.interaction_history = self.interaction_history[-self.max_history_size:]
|
|
||||||
|
|
||||||
def calculate(self, context: Dict[str, Any]) -> float:
|
|
||||||
"""基于用户交互历史计算兴趣度"""
|
|
||||||
user_id = context.get("user_id")
|
|
||||||
if not user_id:
|
|
||||||
return 0.3
|
|
||||||
|
|
||||||
# 获取该用户的最近交互记录
|
|
||||||
user_interactions = [
|
|
||||||
interaction for interaction in self.interaction_history
|
|
||||||
if interaction["user_id"] == user_id
|
|
||||||
]
|
|
||||||
|
|
||||||
if not user_interactions:
|
|
||||||
return 0.3
|
|
||||||
|
|
||||||
# 计算加权平均(最近的交互权重更高)
|
|
||||||
total_weight = 0.0
|
|
||||||
weighted_sum = 0.0
|
|
||||||
|
|
||||||
for interaction in user_interactions[-20:]: # 最近20次交互
|
|
||||||
age = time.time() - interaction["timestamp"]
|
|
||||||
weight = max(0.1, 1.0 - age / (7 * 24 * 3600)) # 7天内衰减
|
|
||||||
|
|
||||||
weighted_sum += interaction["value"] * weight
|
|
||||||
total_weight += weight
|
|
||||||
|
|
||||||
if total_weight > 0:
|
|
||||||
return min(1.0, weighted_sum / total_weight)
|
|
||||||
else:
|
|
||||||
return 0.3
|
|
||||||
|
|
||||||
def get_confidence(self) -> float:
|
|
||||||
return 0.6
|
|
||||||
|
|
||||||
|
|
||||||
class InterestManager:
|
|
||||||
"""兴趣度管理器 - 统一管理所有兴趣度计算"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.calculators: Dict[InterestSourceType, InterestCalculator] = {
|
|
||||||
InterestSourceType.MESSAGE_CONTENT: MessageContentInterestCalculator(),
|
|
||||||
InterestSourceType.TOPIC_RELEVANCE: TopicInterestCalculator(),
|
|
||||||
InterestSourceType.USER_INTERACTION: UserInteractionInterestCalculator(),
|
|
||||||
}
|
|
||||||
|
|
||||||
# 权重配置
|
|
||||||
self.source_weights: Dict[InterestSourceType, float] = {
|
|
||||||
InterestSourceType.MESSAGE_CONTENT: 0.4,
|
|
||||||
InterestSourceType.TOPIC_RELEVANCE: 0.3,
|
|
||||||
InterestSourceType.USER_INTERACTION: 0.3,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 兴趣度缓存
|
|
||||||
self.interest_cache: Dict[str, Tuple[float, float]] = {} # message_id -> (value, timestamp)
|
|
||||||
self.cache_ttl: int = 300 # 5分钟缓存
|
|
||||||
|
|
||||||
# 统计信息
|
|
||||||
self.stats: Dict[str, Union[int, float, List[str]]] = {
|
|
||||||
"total_calculations": 0,
|
|
||||||
"cache_hits": 0,
|
|
||||||
"cache_misses": 0,
|
|
||||||
"average_calculation_time": 0.0,
|
|
||||||
"calculator_usage": {calc_type.value: 0 for calc_type in InterestSourceType}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info("兴趣度管理器初始化完成")
|
|
||||||
|
|
||||||
def calculate_message_interest(self, message: Dict[str, Any], context: Dict[str, Any]) -> float:
|
|
||||||
"""计算消息兴趣度"""
|
|
||||||
start_time = time.time()
|
|
||||||
message_id = message.get("message_id", "")
|
|
||||||
|
|
||||||
# 更新统计
|
|
||||||
self.stats["total_calculations"] += 1
|
|
||||||
|
|
||||||
# 检查缓存
|
|
||||||
if message_id in self.interest_cache:
|
|
||||||
cached_value, cached_time = self.interest_cache[message_id]
|
|
||||||
if time.time() - cached_time < self.cache_ttl:
|
|
||||||
self.stats["cache_hits"] += 1
|
|
||||||
logger.debug(f"使用缓存兴趣度: {message_id} = {cached_value:.3f}")
|
|
||||||
return cached_value
|
|
||||||
else:
|
|
||||||
self.stats["cache_misses"] += 1
|
|
||||||
|
|
||||||
# 构建计算上下文
|
|
||||||
calc_context: Dict[str, Any] = {
|
|
||||||
"message": message,
|
|
||||||
"user_id": message.get("user_id"),
|
|
||||||
**context
|
|
||||||
}
|
|
||||||
|
|
||||||
# 计算各来源的兴趣度
|
|
||||||
source_scores: Dict[InterestSourceType, float] = {}
|
|
||||||
total_confidence = 0.0
|
|
||||||
|
|
||||||
for source_type, calculator in self.calculators.items():
|
|
||||||
try:
|
|
||||||
score = calculator.calculate(calc_context)
|
|
||||||
confidence = calculator.get_confidence()
|
|
||||||
|
|
||||||
source_scores[source_type] = score
|
|
||||||
total_confidence += confidence
|
|
||||||
|
|
||||||
# 更新计算器使用统计
|
|
||||||
self.stats["calculator_usage"][source_type.value] += 1
|
|
||||||
|
|
||||||
logger.debug(f"{source_type.value} 兴趣度: {score:.3f} (置信度: {confidence:.3f})")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"计算 {source_type.value} 兴趣度失败: {e}")
|
|
||||||
source_scores[source_type] = 0.3
|
|
||||||
|
|
||||||
# 加权计算最终兴趣度
|
|
||||||
final_interest = 0.0
|
|
||||||
total_weight = 0.0
|
|
||||||
|
|
||||||
for source_type, score in source_scores.items():
|
|
||||||
weight = self.source_weights.get(source_type, 0.0)
|
|
||||||
final_interest += score * weight
|
|
||||||
total_weight += weight
|
|
||||||
|
|
||||||
if total_weight > 0:
|
|
||||||
final_interest /= total_weight
|
|
||||||
|
|
||||||
# 确保在合理范围内
|
|
||||||
final_interest = max(0.0, min(1.0, final_interest))
|
|
||||||
|
|
||||||
# 缓存结果
|
|
||||||
self.interest_cache[message_id] = (final_interest, time.time())
|
|
||||||
|
|
||||||
# 清理过期缓存
|
|
||||||
self._cleanup_cache()
|
|
||||||
|
|
||||||
# 更新平均计算时间
|
|
||||||
calculation_time = time.time() - start_time
|
|
||||||
total_calculations = self.stats["total_calculations"]
|
|
||||||
self.stats["average_calculation_time"] = (
|
|
||||||
(self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time)
|
|
||||||
/ total_calculations
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"消息 {message_id} 最终兴趣度: {final_interest:.3f} (耗时: {calculation_time:.3f}s)")
|
|
||||||
return final_interest
|
|
||||||
|
|
||||||
def update_topic_interest(self, message: Dict[str, Any], interest_value: float) -> None:
|
|
||||||
"""更新话题兴趣度"""
|
|
||||||
topic_calc = self.calculators.get(InterestSourceType.TOPIC_RELEVANCE)
|
|
||||||
if isinstance(topic_calc, TopicInterestCalculator):
|
|
||||||
# 提取关键词作为话题
|
|
||||||
keywords = message.get("key_words", "[]")
|
|
||||||
try:
|
|
||||||
import json
|
|
||||||
keyword_list: List[str] = json.loads(keywords) if keywords else []
|
|
||||||
for keyword in keyword_list[:3]: # 更新前3个关键词
|
|
||||||
topic_calc.update_topic_interest(keyword, interest_value)
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def add_user_interaction(self, user_id: str, interaction_type: str, value: float) -> None:
|
|
||||||
"""添加用户交互记录"""
|
|
||||||
interaction_calc = self.calculators.get(InterestSourceType.USER_INTERACTION)
|
|
||||||
if isinstance(interaction_calc, UserInteractionInterestCalculator):
|
|
||||||
interaction_calc.add_interaction(user_id, interaction_type, value)
|
|
||||||
|
|
||||||
def get_topic_interests(self) -> Dict[str, float]:
|
|
||||||
"""获取所有话题兴趣度"""
|
|
||||||
topic_calc = self.calculators.get(InterestSourceType.TOPIC_RELEVANCE)
|
|
||||||
if isinstance(topic_calc, TopicInterestCalculator):
|
|
||||||
return topic_calc.topic_interests.copy()
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def _cleanup_cache(self) -> None:
|
|
||||||
"""清理过期缓存"""
|
|
||||||
current_time = time.time()
|
|
||||||
expired_keys = [
|
|
||||||
message_id for message_id, (_, timestamp) in self.interest_cache.items()
|
|
||||||
if current_time - timestamp > self.cache_ttl
|
|
||||||
]
|
|
||||||
|
|
||||||
for key in expired_keys:
|
|
||||||
del self.interest_cache[key]
|
|
||||||
|
|
||||||
if expired_keys:
|
|
||||||
logger.debug(f"清理了 {len(expired_keys)} 个过期兴趣度缓存")
|
|
||||||
|
|
||||||
def get_statistics(self) -> Dict[str, Any]:
|
|
||||||
"""获取统计信息"""
|
|
||||||
return {
|
|
||||||
"cache_size": len(self.interest_cache),
|
|
||||||
"topic_count": len(self.get_topic_interests()),
|
|
||||||
"calculators": list(self.calculators.keys()),
|
|
||||||
"performance_stats": self.stats.copy(),
|
|
||||||
}
|
|
||||||
|
|
||||||
def add_calculator(self, source_type: InterestSourceType, calculator: InterestCalculator) -> None:
|
|
||||||
"""添加自定义计算器"""
|
|
||||||
self.calculators[source_type] = calculator
|
|
||||||
logger.info(f"添加计算器: {source_type.value}")
|
|
||||||
|
|
||||||
def remove_calculator(self, source_type: InterestSourceType) -> None:
|
|
||||||
"""移除计算器"""
|
|
||||||
if source_type in self.calculators:
|
|
||||||
del self.calculators[source_type]
|
|
||||||
logger.info(f"移除计算器: {source_type.value}")
|
|
||||||
|
|
||||||
def set_source_weight(self, source_type: InterestSourceType, weight: float) -> None:
|
|
||||||
"""设置来源权重"""
|
|
||||||
self.source_weights[source_type] = max(0.0, min(1.0, weight))
|
|
||||||
logger.info(f"设置 {source_type.value} 权重: {weight}")
|
|
||||||
|
|
||||||
def clear_cache(self) -> None:
|
|
||||||
"""清空缓存"""
|
|
||||||
self.interest_cache.clear()
|
|
||||||
logger.info("清空兴趣度缓存")
|
|
||||||
|
|
||||||
def get_cache_hit_rate(self) -> float:
|
|
||||||
"""获取缓存命中率"""
|
|
||||||
total_requests = self.stats.get("cache_hits", 0) + self.stats.get("cache_misses", 0)
|
|
||||||
if total_requests == 0:
|
|
||||||
return 0.0
|
|
||||||
return self.stats["cache_hits"] / total_requests
|
|
||||||
|
|
||||||
|
|
||||||
# 全局兴趣度管理器实例
|
|
||||||
interest_manager = InterestManager()
|
|
||||||
@@ -5,103 +5,18 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Any, Callable, Union, Tuple
|
from typing import Dict, List, Optional, Any, Union, Tuple
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from enum import Enum
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.interest_system import interest_manager
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.chat.energy_system import energy_manager
|
from src.chat.energy_system import energy_manager
|
||||||
from .distribution_manager import distribution_manager
|
from .distribution_manager import distribution_manager
|
||||||
|
|
||||||
logger = get_logger("context_manager")
|
logger = get_logger("context_manager")
|
||||||
|
|
||||||
|
|
||||||
class ContextEventType(Enum):
|
|
||||||
"""上下文事件类型"""
|
|
||||||
MESSAGE_ADDED = "message_added"
|
|
||||||
MESSAGE_UPDATED = "message_updated"
|
|
||||||
ENERGY_CHANGED = "energy_changed"
|
|
||||||
STREAM_ACTIVATED = "stream_activated"
|
|
||||||
STREAM_DEACTIVATED = "stream_deactivated"
|
|
||||||
CONTEXT_CLEARED = "context_cleared"
|
|
||||||
VALIDATION_FAILED = "validation_failed"
|
|
||||||
CLEANUP_COMPLETED = "cleanup_completed"
|
|
||||||
INTEGRITY_CHECK = "integrity_check"
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"ContextEventType.{self.name}"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ContextEvent:
|
|
||||||
"""上下文事件"""
|
|
||||||
event_type: ContextEventType
|
|
||||||
stream_id: str
|
|
||||||
data: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
timestamp: float = field(default_factory=time.time)
|
|
||||||
event_id: str = field(default_factory=lambda: f"event_{time.time()}_{id(object())}")
|
|
||||||
priority: int = 0 # 事件优先级,数字越大优先级越高
|
|
||||||
source: str = "system" # 事件来源
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return f"ContextEvent({self.event_type}, {self.stream_id}, ts={self.timestamp:.3f})"
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"ContextEvent(event_type={self.event_type}, stream_id={self.stream_id}, timestamp={self.timestamp}, event_id={self.event_id})"
|
|
||||||
|
|
||||||
def get_age(self) -> float:
|
|
||||||
"""获取事件年龄(秒)"""
|
|
||||||
return time.time() - self.timestamp
|
|
||||||
|
|
||||||
def is_expired(self, max_age: float = 3600.0) -> bool:
|
|
||||||
"""检查事件是否已过期
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_age: 最大年龄(秒)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否已过期
|
|
||||||
"""
|
|
||||||
return self.get_age() > max_age
|
|
||||||
|
|
||||||
|
|
||||||
class ContextValidator(ABC):
|
|
||||||
"""上下文验证器抽象基类"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def validate_context(self, stream_id: str, context: Any) -> Tuple[bool, Optional[str]]:
|
|
||||||
"""验证上下文
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stream_id: 流ID
|
|
||||||
context: 上下文对象
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[bool, Optional[str]]: (是否有效, 错误信息)
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultContextValidator(ContextValidator):
|
|
||||||
"""默认上下文验证器"""
|
|
||||||
|
|
||||||
def validate_context(self, stream_id: str, context: Any) -> Tuple[bool, Optional[str]]:
|
|
||||||
"""验证上下文基本完整性"""
|
|
||||||
if not hasattr(context, 'stream_id'):
|
|
||||||
return False, "缺少 stream_id 属性"
|
|
||||||
if not hasattr(context, 'unread_messages'):
|
|
||||||
return False, "缺少 unread_messages 属性"
|
|
||||||
if not hasattr(context, 'history_messages'):
|
|
||||||
return False, "缺少 history_messages 属性"
|
|
||||||
return True, None
|
|
||||||
|
|
||||||
|
|
||||||
class StreamContextManager:
|
class StreamContextManager:
|
||||||
"""流上下文管理器 - 统一管理所有聊天流上下文"""
|
"""流上下文管理器 - 统一管理所有聊天流上下文"""
|
||||||
|
|
||||||
@@ -110,14 +25,6 @@ class StreamContextManager:
|
|||||||
self.stream_contexts: Dict[str, Any] = {}
|
self.stream_contexts: Dict[str, Any] = {}
|
||||||
self.context_metadata: Dict[str, Dict[str, Any]] = {}
|
self.context_metadata: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
# 事件监听器
|
|
||||||
self.event_listeners: Dict[ContextEventType, List[Callable]] = {}
|
|
||||||
self.event_history: List[ContextEvent] = []
|
|
||||||
self.max_event_history = 1000
|
|
||||||
|
|
||||||
# 验证器
|
|
||||||
self.validators: List[ContextValidator] = [DefaultContextValidator()]
|
|
||||||
|
|
||||||
# 统计信息
|
# 统计信息
|
||||||
self.stats: Dict[str, Union[int, float, str, Dict]] = {
|
self.stats: Dict[str, Union[int, float, str, Dict]] = {
|
||||||
"total_messages": 0,
|
"total_messages": 0,
|
||||||
@@ -126,16 +33,6 @@ class StreamContextManager:
|
|||||||
"inactive_streams": 0,
|
"inactive_streams": 0,
|
||||||
"last_activity": time.time(),
|
"last_activity": time.time(),
|
||||||
"creation_time": time.time(),
|
"creation_time": time.time(),
|
||||||
"validation_stats": {
|
|
||||||
"total_validations": 0,
|
|
||||||
"validation_failures": 0,
|
|
||||||
"last_validation_time": 0.0,
|
|
||||||
},
|
|
||||||
"event_stats": {
|
|
||||||
"total_events": 0,
|
|
||||||
"events_by_type": {},
|
|
||||||
"last_event_time": 0.0,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# 配置参数
|
# 配置参数
|
||||||
@@ -166,17 +63,6 @@ class StreamContextManager:
|
|||||||
logger.warning(f"流上下文已存在: {stream_id}")
|
logger.warning(f"流上下文已存在: {stream_id}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 验证上下文
|
|
||||||
if self.enable_validation:
|
|
||||||
is_valid, error_msg = self._validate_context(stream_id, context)
|
|
||||||
if not is_valid:
|
|
||||||
logger.error(f"上下文验证失败: {stream_id} - {error_msg}")
|
|
||||||
self._emit_event(ContextEventType.VALIDATION_FAILED, stream_id, {
|
|
||||||
"error": error_msg,
|
|
||||||
"context_type": type(context).__name__
|
|
||||||
})
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 添加上下文
|
# 添加上下文
|
||||||
self.stream_contexts[stream_id] = context
|
self.stream_contexts[stream_id] = context
|
||||||
|
|
||||||
@@ -185,7 +71,6 @@ class StreamContextManager:
|
|||||||
"created_time": time.time(),
|
"created_time": time.time(),
|
||||||
"last_access_time": time.time(),
|
"last_access_time": time.time(),
|
||||||
"access_count": 0,
|
"access_count": 0,
|
||||||
"validation_errors": 0,
|
|
||||||
"last_validation_time": 0.0,
|
"last_validation_time": 0.0,
|
||||||
"custom_metadata": metadata or {},
|
"custom_metadata": metadata or {},
|
||||||
}
|
}
|
||||||
@@ -195,13 +80,6 @@ class StreamContextManager:
|
|||||||
self.stats["active_streams"] += 1
|
self.stats["active_streams"] += 1
|
||||||
self.stats["last_activity"] = time.time()
|
self.stats["last_activity"] = time.time()
|
||||||
|
|
||||||
# 触发事件
|
|
||||||
self._emit_event(ContextEventType.STREAM_ACTIVATED, stream_id, {
|
|
||||||
"context": context,
|
|
||||||
"context_type": type(context).__name__,
|
|
||||||
"metadata": metadata
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.debug(f"添加流上下文: {stream_id} (类型: {type(context).__name__})")
|
logger.debug(f"添加流上下文: {stream_id} (类型: {type(context).__name__})")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -226,19 +104,11 @@ class StreamContextManager:
|
|||||||
self.stats["inactive_streams"] += 1
|
self.stats["inactive_streams"] += 1
|
||||||
self.stats["last_activity"] = time.time()
|
self.stats["last_activity"] = time.time()
|
||||||
|
|
||||||
# 触发事件
|
|
||||||
self._emit_event(ContextEventType.STREAM_DEACTIVATED, stream_id, {
|
|
||||||
"context": context,
|
|
||||||
"context_type": type(context).__name__,
|
|
||||||
"metadata": metadata,
|
|
||||||
"uptime": time.time() - metadata.get("created_time", time.time())
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.debug(f"移除流上下文: {stream_id} (类型: {type(context).__name__})")
|
logger.debug(f"移除流上下文: {stream_id} (类型: {type(context).__name__})")
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_stream_context(self, stream_id: str, update_access: bool = True) -> Optional[Any]:
|
def get_stream_context(self, stream_id: str, update_access: bool = True) -> Optional[StreamContext]:
|
||||||
"""获取流上下文
|
"""获取流上下文
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -284,7 +154,7 @@ class StreamContextManager:
|
|||||||
self.context_metadata[stream_id].update(updates)
|
self.context_metadata[stream_id].update(updates)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def add_message_to_context(self, stream_id: str, message: Any, skip_energy_update: bool = False) -> bool:
|
def add_message_to_context(self, stream_id: str, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
|
||||||
"""添加消息到上下文
|
"""添加消息到上下文
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -302,30 +172,16 @@ class StreamContextManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 添加消息到上下文
|
# 添加消息到上下文
|
||||||
if hasattr(context, 'add_message'):
|
|
||||||
context.add_message(message)
|
context.add_message(message)
|
||||||
else:
|
|
||||||
logger.error(f"上下文对象缺少 add_message 方法: {stream_id}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 计算消息兴趣度
|
# 计算消息兴趣度
|
||||||
interest_value = self._calculate_message_interest(message)
|
interest_value = self._calculate_message_interest(message)
|
||||||
if hasattr(message, 'interest_value'):
|
|
||||||
message.interest_value = interest_value
|
message.interest_value = interest_value
|
||||||
|
|
||||||
# 更新统计
|
# 更新统计
|
||||||
self.stats["total_messages"] += 1
|
self.stats["total_messages"] += 1
|
||||||
self.stats["last_activity"] = time.time()
|
self.stats["last_activity"] = time.time()
|
||||||
|
|
||||||
# 触发事件
|
|
||||||
event_data = {
|
|
||||||
"message": message,
|
|
||||||
"interest_value": interest_value,
|
|
||||||
"message_type": type(message).__name__,
|
|
||||||
"message_id": getattr(message, "message_id", None),
|
|
||||||
}
|
|
||||||
self._emit_event(ContextEventType.MESSAGE_ADDED, stream_id, event_data)
|
|
||||||
|
|
||||||
# 更新能量和分发
|
# 更新能量和分发
|
||||||
if not skip_energy_update:
|
if not skip_energy_update:
|
||||||
self._update_stream_energy(stream_id)
|
self._update_stream_energy(stream_id)
|
||||||
@@ -356,18 +212,7 @@ class StreamContextManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 更新消息信息
|
# 更新消息信息
|
||||||
if hasattr(context, 'update_message_info'):
|
|
||||||
context.update_message_info(message_id, **updates)
|
context.update_message_info(message_id, **updates)
|
||||||
else:
|
|
||||||
logger.error(f"上下文对象缺少 update_message_info 方法: {stream_id}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 触发事件
|
|
||||||
self._emit_event(ContextEventType.MESSAGE_UPDATED, stream_id, {
|
|
||||||
"message_id": message_id,
|
|
||||||
"updates": updates,
|
|
||||||
"update_time": time.time(),
|
|
||||||
})
|
|
||||||
|
|
||||||
# 如果更新了兴趣度,重新计算能量
|
# 如果更新了兴趣度,重新计算能量
|
||||||
if "interest_value" in updates:
|
if "interest_value" in updates:
|
||||||
@@ -380,7 +225,7 @@ class StreamContextManager:
|
|||||||
logger.error(f"更新上下文消息失败 {stream_id}/{message_id}: {e}", exc_info=True)
|
logger.error(f"更新上下文消息失败 {stream_id}/{message_id}: {e}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_context_messages(self, stream_id: str, limit: Optional[int] = None, include_unread: bool = True) -> List[Any]:
|
def get_context_messages(self, stream_id: str, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]:
|
||||||
"""获取上下文消息
|
"""获取上下文消息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -397,10 +242,9 @@ class StreamContextManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
messages = []
|
messages = []
|
||||||
if include_unread and hasattr(context, 'get_unread_messages'):
|
if include_unread:
|
||||||
messages.extend(context.get_unread_messages())
|
messages.extend(context.get_unread_messages())
|
||||||
|
|
||||||
if hasattr(context, 'get_history_messages'):
|
|
||||||
if limit:
|
if limit:
|
||||||
messages.extend(context.get_history_messages(limit=limit))
|
messages.extend(context.get_history_messages(limit=limit))
|
||||||
else:
|
else:
|
||||||
@@ -419,7 +263,7 @@ class StreamContextManager:
|
|||||||
logger.error(f"获取上下文消息失败 {stream_id}: {e}", exc_info=True)
|
logger.error(f"获取上下文消息失败 {stream_id}: {e}", exc_info=True)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_unread_messages(self, stream_id: str) -> List[Any]:
|
def get_unread_messages(self, stream_id: str) -> List[DatabaseMessages]:
|
||||||
"""获取未读消息
|
"""获取未读消息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -433,11 +277,7 @@ class StreamContextManager:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(context, 'get_unread_messages'):
|
|
||||||
return context.get_unread_messages()
|
return context.get_unread_messages()
|
||||||
else:
|
|
||||||
logger.warning(f"上下文对象缺少 get_unread_messages 方法: {stream_id}")
|
|
||||||
return []
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取未读消息失败 {stream_id}: {e}", exc_info=True)
|
logger.error(f"获取未读消息失败 {stream_id}: {e}", exc_info=True)
|
||||||
return []
|
return []
|
||||||
@@ -507,12 +347,6 @@ class StreamContextManager:
|
|||||||
else:
|
else:
|
||||||
setattr(context, attr, time.time())
|
setattr(context, attr, time.time())
|
||||||
|
|
||||||
# 触发事件
|
|
||||||
self._emit_event(ContextEventType.CONTEXT_CLEARED, stream_id, {
|
|
||||||
"clear_time": time.time(),
|
|
||||||
"reset_attributes": reset_attrs,
|
|
||||||
})
|
|
||||||
|
|
||||||
# 重新计算能量
|
# 重新计算能量
|
||||||
self._update_stream_energy(stream_id)
|
self._update_stream_energy(stream_id)
|
||||||
|
|
||||||
@@ -523,22 +357,33 @@ class StreamContextManager:
|
|||||||
logger.error(f"清空上下文失败 {stream_id}: {e}", exc_info=True)
|
logger.error(f"清空上下文失败 {stream_id}: {e}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _calculate_message_interest(self, message: Any) -> float:
|
def _calculate_message_interest(self, message: DatabaseMessages) -> float:
|
||||||
"""计算消息兴趣度"""
|
"""计算消息兴趣度"""
|
||||||
try:
|
try:
|
||||||
# 将消息转换为字典格式
|
# 使用插件内部的兴趣度评分系统
|
||||||
message_dict = self._message_to_dict(message)
|
try:
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||||
|
|
||||||
# 使用兴趣度管理器计算
|
# 使用插件内部的兴趣度评分系统计算(同步方式)
|
||||||
context = {
|
try:
|
||||||
"stream_id": getattr(message, 'chat_info_stream_id', ''),
|
loop = asyncio.get_event_loop()
|
||||||
"user_id": getattr(message, 'user_id', ''),
|
except RuntimeError:
|
||||||
}
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
interest_value = interest_manager.calculate_message_interest(message_dict, context)
|
interest_score = loop.run_until_complete(
|
||||||
|
chatter_interest_scoring_system._calculate_single_message_score(
|
||||||
|
message=message,
|
||||||
|
bot_nickname=global_config.bot.nickname
|
||||||
|
)
|
||||||
|
)
|
||||||
|
interest_value = interest_score.total_score
|
||||||
|
|
||||||
# 更新话题兴趣度
|
logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}")
|
||||||
interest_manager.update_topic_interest(message_dict, interest_value)
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"插件内部兴趣度计算失败,使用默认值: {e}")
|
||||||
|
interest_value = 0.5 # 默认中等兴趣度
|
||||||
|
|
||||||
return interest_value
|
return interest_value
|
||||||
|
|
||||||
@@ -546,31 +391,6 @@ class StreamContextManager:
|
|||||||
logger.error(f"计算消息兴趣度失败: {e}")
|
logger.error(f"计算消息兴趣度失败: {e}")
|
||||||
return 0.5
|
return 0.5
|
||||||
|
|
||||||
def _message_to_dict(self, message: Any) -> Dict[str, Any]:
|
|
||||||
"""将消息对象转换为字典"""
|
|
||||||
try:
|
|
||||||
# 获取user_id,优先从user_info.user_id获取,其次从user_id属性获取
|
|
||||||
user_id = ""
|
|
||||||
if hasattr(message, 'user_info') and hasattr(message.user_info, 'user_id'):
|
|
||||||
user_id = getattr(message.user_info, 'user_id', "")
|
|
||||||
else:
|
|
||||||
user_id = getattr(message, 'user_id', "")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"message_id": getattr(message, "message_id", ""),
|
|
||||||
"processed_plain_text": getattr(message, "processed_plain_text", ""),
|
|
||||||
"is_emoji": getattr(message, "is_emoji", False),
|
|
||||||
"is_picid": getattr(message, "is_picid", False),
|
|
||||||
"is_mentioned": getattr(message, "is_mentioned", False),
|
|
||||||
"is_command": getattr(message, "is_command", False),
|
|
||||||
"key_words": getattr(message, "key_words", "[]"),
|
|
||||||
"user_id": user_id,
|
|
||||||
"time": getattr(message, "time", time.time()),
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"转换消息为字典失败: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def _update_stream_energy(self, stream_id: str):
|
def _update_stream_energy(self, stream_id: str):
|
||||||
"""更新流能量"""
|
"""更新流能量"""
|
||||||
try:
|
try:
|
||||||
@@ -583,7 +403,7 @@ class StreamContextManager:
|
|||||||
user_id = None
|
user_id = None
|
||||||
if combined_messages:
|
if combined_messages:
|
||||||
last_message = combined_messages[-1]
|
last_message = combined_messages[-1]
|
||||||
user_id = getattr(last_message, "user_id", None)
|
user_id = last_message.user_info.user_id
|
||||||
|
|
||||||
# 计算能量
|
# 计算能量
|
||||||
energy = energy_manager.calculate_focus_energy(
|
energy = energy_manager.calculate_focus_energy(
|
||||||
@@ -595,91 +415,9 @@ class StreamContextManager:
|
|||||||
# 更新分发管理器
|
# 更新分发管理器
|
||||||
distribution_manager.update_stream_energy(stream_id, energy)
|
distribution_manager.update_stream_energy(stream_id, energy)
|
||||||
|
|
||||||
# 触发事件
|
|
||||||
self._emit_event(ContextEventType.ENERGY_CHANGED, stream_id, {
|
|
||||||
"energy": energy,
|
|
||||||
"message_count": len(combined_messages),
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新流能量失败 {stream_id}: {e}")
|
logger.error(f"更新流能量失败 {stream_id}: {e}")
|
||||||
|
|
||||||
def add_event_listener(self, event_type: ContextEventType, listener: Callable[[ContextEvent], None]) -> bool:
|
|
||||||
"""添加事件监听器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_type: 事件类型
|
|
||||||
listener: 监听器函数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否成功添加
|
|
||||||
"""
|
|
||||||
if not callable(listener):
|
|
||||||
logger.error(f"监听器必须是可调用对象: {type(listener)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if event_type not in self.event_listeners:
|
|
||||||
self.event_listeners[event_type] = []
|
|
||||||
|
|
||||||
if listener not in self.event_listeners[event_type]:
|
|
||||||
self.event_listeners[event_type].append(listener)
|
|
||||||
logger.debug(f"添加事件监听器: {event_type} -> {getattr(listener, '__name__', 'anonymous')}")
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def remove_event_listener(self, event_type: ContextEventType, listener: Callable[[ContextEvent], None]) -> bool:
|
|
||||||
"""移除事件监听器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_type: 事件类型
|
|
||||||
listener: 监听器函数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否成功移除
|
|
||||||
"""
|
|
||||||
if event_type in self.event_listeners:
|
|
||||||
try:
|
|
||||||
self.event_listeners[event_type].remove(listener)
|
|
||||||
logger.debug(f"移除事件监听器: {event_type}")
|
|
||||||
return True
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _emit_event(self, event_type: ContextEventType, stream_id: str, data: Optional[Dict] = None, priority: int = 0) -> None:
|
|
||||||
"""触发事件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_type: 事件类型
|
|
||||||
stream_id: 流ID
|
|
||||||
data: 事件数据
|
|
||||||
priority: 事件优先级
|
|
||||||
"""
|
|
||||||
if data is None:
|
|
||||||
data = {}
|
|
||||||
|
|
||||||
event = ContextEvent(event_type, stream_id, data, priority=priority)
|
|
||||||
|
|
||||||
# 添加到事件历史
|
|
||||||
self.event_history.append(event)
|
|
||||||
if len(self.event_history) > self.max_event_history:
|
|
||||||
self.event_history = self.event_history[-self.max_event_history:]
|
|
||||||
|
|
||||||
# 更新事件统计
|
|
||||||
event_stats = self.stats["event_stats"]
|
|
||||||
event_stats["total_events"] += 1
|
|
||||||
event_stats["last_event_time"] = time.time()
|
|
||||||
event_type_str = str(event_type)
|
|
||||||
event_stats["events_by_type"][event_type_str] = event_stats["events_by_type"].get(event_type_str, 0) + 1
|
|
||||||
|
|
||||||
# 通知监听器
|
|
||||||
if event_type in self.event_listeners:
|
|
||||||
for listener in self.event_listeners[event_type]:
|
|
||||||
try:
|
|
||||||
listener(event)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"事件监听器执行失败: {e}", exc_info=True)
|
|
||||||
|
|
||||||
def get_stream_statistics(self, stream_id: str) -> Optional[Dict[str, Any]]:
|
def get_stream_statistics(self, stream_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""获取流统计信息
|
"""获取流统计信息
|
||||||
|
|
||||||
@@ -718,7 +456,6 @@ class StreamContextManager:
|
|||||||
"access_count": access_count,
|
"access_count": access_count,
|
||||||
"uptime_seconds": current_time - created_time,
|
"uptime_seconds": current_time - created_time,
|
||||||
"idle_seconds": current_time - last_access_time,
|
"idle_seconds": current_time - last_access_time,
|
||||||
"validation_errors": metadata.get("validation_errors", 0),
|
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取流统计失败 {stream_id}: {e}", exc_info=True)
|
logger.error(f"获取流统计失败 {stream_id}: {e}", exc_info=True)
|
||||||
@@ -733,31 +470,11 @@ class StreamContextManager:
|
|||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
uptime = current_time - self.stats.get("creation_time", current_time)
|
uptime = current_time - self.stats.get("creation_time", current_time)
|
||||||
|
|
||||||
# 计算验证统计
|
|
||||||
validation_stats = self.stats["validation_stats"]
|
|
||||||
validation_success_rate = (
|
|
||||||
(validation_stats.get("total_validations", 0) - validation_stats.get("validation_failures", 0)) /
|
|
||||||
max(1, validation_stats.get("total_validations", 1))
|
|
||||||
)
|
|
||||||
|
|
||||||
# 计算事件统计
|
|
||||||
event_stats = self.stats["event_stats"]
|
|
||||||
events_by_type = event_stats.get("events_by_type", {})
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
**self.stats,
|
**self.stats,
|
||||||
"uptime_hours": uptime / 3600,
|
"uptime_hours": uptime / 3600,
|
||||||
"stream_count": len(self.stream_contexts),
|
"stream_count": len(self.stream_contexts),
|
||||||
"metadata_count": len(self.context_metadata),
|
"metadata_count": len(self.context_metadata),
|
||||||
"event_history_size": len(self.event_history),
|
|
||||||
"validators_count": len(self.validators),
|
|
||||||
"event_listeners": {
|
|
||||||
str(event_type): len(listeners)
|
|
||||||
for event_type, listeners in self.event_listeners.items()
|
|
||||||
},
|
|
||||||
"validation_success_rate": validation_success_rate,
|
|
||||||
"event_distribution": events_by_type,
|
|
||||||
"max_event_history": self.max_event_history,
|
|
||||||
"auto_cleanup_enabled": self.auto_cleanup,
|
"auto_cleanup_enabled": self.auto_cleanup,
|
||||||
"cleanup_interval": self.cleanup_interval,
|
"cleanup_interval": self.cleanup_interval,
|
||||||
}
|
}
|
||||||
@@ -840,31 +557,6 @@ class StreamContextManager:
|
|||||||
logger.error(f"验证上下文完整性失败 {stream_id}: {e}")
|
logger.error(f"验证上下文完整性失败 {stream_id}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _validate_context(self, stream_id: str, context: Any) -> Tuple[bool, Optional[str]]:
|
|
||||||
"""验证上下文完整性
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stream_id: 流ID
|
|
||||||
context: 上下文对象
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[bool, Optional[str]]: (是否有效, 错误信息)
|
|
||||||
"""
|
|
||||||
validation_stats = self.stats["validation_stats"]
|
|
||||||
validation_stats["total_validations"] += 1
|
|
||||||
validation_stats["last_validation_time"] = time.time()
|
|
||||||
|
|
||||||
for validator in self.validators:
|
|
||||||
try:
|
|
||||||
is_valid, error_msg = validator.validate_context(stream_id, context)
|
|
||||||
if not is_valid:
|
|
||||||
validation_stats["validation_failures"] += 1
|
|
||||||
return False, error_msg
|
|
||||||
except Exception as e:
|
|
||||||
validation_stats["validation_failures"] += 1
|
|
||||||
return False, f"验证器执行失败: {e}"
|
|
||||||
return True, None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""启动上下文管理器"""
|
"""启动上下文管理器"""
|
||||||
if self.is_running:
|
if self.is_running:
|
||||||
@@ -924,7 +616,6 @@ class StreamContextManager:
|
|||||||
try:
|
try:
|
||||||
await asyncio.sleep(interval)
|
await asyncio.sleep(interval)
|
||||||
self.cleanup_inactive_contexts()
|
self.cleanup_inactive_contexts()
|
||||||
self._cleanup_event_history()
|
|
||||||
self._cleanup_expired_contexts()
|
self._cleanup_expired_contexts()
|
||||||
logger.debug("自动清理完成")
|
logger.debug("自动清理完成")
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
@@ -933,20 +624,6 @@ class StreamContextManager:
|
|||||||
logger.error(f"清理循环出错: {e}", exc_info=True)
|
logger.error(f"清理循环出错: {e}", exc_info=True)
|
||||||
await asyncio.sleep(interval)
|
await asyncio.sleep(interval)
|
||||||
|
|
||||||
def _cleanup_event_history(self) -> None:
|
|
||||||
"""清理事件历史"""
|
|
||||||
max_age = 24 * 3600 # 24小时
|
|
||||||
|
|
||||||
# 清理过期事件
|
|
||||||
self.event_history = [
|
|
||||||
event for event in self.event_history
|
|
||||||
if not event.is_expired(max_age)
|
|
||||||
]
|
|
||||||
|
|
||||||
# 保持历史大小限制
|
|
||||||
if len(self.event_history) > self.max_event_history:
|
|
||||||
self.event_history = self.event_history[-self.max_event_history:]
|
|
||||||
|
|
||||||
def _cleanup_expired_contexts(self) -> None:
|
def _cleanup_expired_contexts(self) -> None:
|
||||||
"""清理过期上下文"""
|
"""清理过期上下文"""
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
@@ -963,21 +640,6 @@ class StreamContextManager:
|
|||||||
if expired_contexts:
|
if expired_contexts:
|
||||||
logger.info(f"清理了 {len(expired_contexts)} 个过期上下文")
|
logger.info(f"清理了 {len(expired_contexts)} 个过期上下文")
|
||||||
|
|
||||||
def get_event_history(self, limit: int = 100, event_type: Optional[ContextEventType] = None) -> List[ContextEvent]:
|
|
||||||
"""获取事件历史
|
|
||||||
|
|
||||||
Args:
|
|
||||||
limit: 返回数量限制
|
|
||||||
event_type: 过滤事件类型
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[ContextEvent]: 事件列表
|
|
||||||
"""
|
|
||||||
events = self.event_history
|
|
||||||
if event_type:
|
|
||||||
events = [event for event in events if event.event_type == event_type]
|
|
||||||
return events[-limit:]
|
|
||||||
|
|
||||||
def get_active_streams(self) -> List[str]:
|
def get_active_streams(self) -> List[str]:
|
||||||
"""获取活跃流列表
|
"""获取活跃流列表
|
||||||
|
|
||||||
@@ -986,111 +648,6 @@ class StreamContextManager:
|
|||||||
"""
|
"""
|
||||||
return list(self.stream_contexts.keys())
|
return list(self.stream_contexts.keys())
|
||||||
|
|
||||||
def get_context_summary(self) -> Dict[str, Any]:
|
|
||||||
"""获取上下文摘要
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any]: 上下文摘要信息
|
|
||||||
"""
|
|
||||||
current_time = time.time()
|
|
||||||
uptime = current_time - self.stats.get("creation_time", current_time)
|
|
||||||
|
|
||||||
# 计算平均访问次数
|
|
||||||
total_access = sum(meta.get("access_count", 0) for meta in self.context_metadata.values())
|
|
||||||
avg_access = total_access / max(1, len(self.context_metadata))
|
|
||||||
|
|
||||||
# 计算验证成功率
|
|
||||||
validation_stats = self.stats["validation_stats"]
|
|
||||||
total_validations = validation_stats.get("total_validations", 0)
|
|
||||||
validation_success_rate = (
|
|
||||||
(total_validations - validation_stats.get("validation_failures", 0)) /
|
|
||||||
max(1, total_validations)
|
|
||||||
) if total_validations > 0 else 1.0
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_streams": len(self.stream_contexts),
|
|
||||||
"active_streams": len(self.stream_contexts),
|
|
||||||
"total_messages": self.stats.get("total_messages", 0),
|
|
||||||
"uptime_hours": uptime / 3600,
|
|
||||||
"average_access_count": avg_access,
|
|
||||||
"validation_success_rate": validation_success_rate,
|
|
||||||
"event_history_size": len(self.event_history),
|
|
||||||
"validators_count": len(self.validators),
|
|
||||||
"auto_cleanup_enabled": self.auto_cleanup,
|
|
||||||
"cleanup_interval": self.cleanup_interval,
|
|
||||||
"last_activity": self.stats.get("last_activity", 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
def force_validation(self, stream_id: str) -> Tuple[bool, Optional[str]]:
|
|
||||||
"""强制验证上下文
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stream_id: 流ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[bool, Optional[str]]: (是否有效, 错误信息)
|
|
||||||
"""
|
|
||||||
context = self.get_stream_context(stream_id)
|
|
||||||
if not context:
|
|
||||||
return False, "上下文不存在"
|
|
||||||
|
|
||||||
return self._validate_context(stream_id, context)
|
|
||||||
|
|
||||||
def reset_statistics(self) -> None:
|
|
||||||
"""重置统计信息"""
|
|
||||||
# 重置基本统计
|
|
||||||
self.stats.update({
|
|
||||||
"total_messages": 0,
|
|
||||||
"total_streams": len(self.stream_contexts),
|
|
||||||
"active_streams": len(self.stream_contexts),
|
|
||||||
"inactive_streams": 0,
|
|
||||||
"last_activity": time.time(),
|
|
||||||
"creation_time": time.time(),
|
|
||||||
})
|
|
||||||
|
|
||||||
# 重置验证统计
|
|
||||||
self.stats["validation_stats"].update({
|
|
||||||
"total_validations": 0,
|
|
||||||
"validation_failures": 0,
|
|
||||||
"last_validation_time": 0.0,
|
|
||||||
})
|
|
||||||
|
|
||||||
# 重置事件统计
|
|
||||||
self.stats["event_stats"].update({
|
|
||||||
"total_events": 0,
|
|
||||||
"events_by_type": {},
|
|
||||||
"last_event_time": 0.0,
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.info("上下文管理器统计信息已重置")
|
|
||||||
|
|
||||||
def export_context_data(self, stream_id: str) -> Optional[Dict[str, Any]]:
|
|
||||||
"""导出上下文数据
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stream_id: 流ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Dict[str, Any]]: 导出的数据
|
|
||||||
"""
|
|
||||||
context = self.get_stream_context(stream_id, update_access=False)
|
|
||||||
if not context:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
return {
|
|
||||||
"stream_id": stream_id,
|
|
||||||
"context_type": type(context).__name__,
|
|
||||||
"metadata": self.context_metadata.get(stream_id, {}),
|
|
||||||
"statistics": self.get_stream_statistics(stream_id),
|
|
||||||
"export_time": time.time(),
|
|
||||||
"unread_message_count": len(getattr(context, "unread_messages", [])),
|
|
||||||
"history_message_count": len(getattr(context, "history_messages", [])),
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"导出上下文数据失败 {stream_id}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# 全局上下文管理器实例
|
# 全局上下文管理器实例
|
||||||
context_manager = StreamContextManager()
|
context_manager = StreamContextManager()
|
||||||
@@ -17,7 +17,7 @@ from src.chat.planner_actions.action_manager import ChatterActionManager
|
|||||||
from .sleep_manager.sleep_manager import SleepManager
|
from .sleep_manager.sleep_manager import SleepManager
|
||||||
from .sleep_manager.wakeup_manager import WakeUpManager
|
from .sleep_manager.wakeup_manager import WakeUpManager
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from . import context_manager
|
from .context_manager import context_manager
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.common.data_models.message_manager_data_model import StreamContext
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
@@ -45,7 +45,7 @@ class MessageManager:
|
|||||||
self.wakeup_manager = WakeUpManager(self.sleep_manager)
|
self.wakeup_manager = WakeUpManager(self.sleep_manager)
|
||||||
|
|
||||||
# 初始化上下文管理器
|
# 初始化上下文管理器
|
||||||
self.context_manager = context_manager.context_manager
|
self.context_manager = context_manager
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
"""启动消息管理器"""
|
"""启动消息管理器"""
|
||||||
@@ -83,11 +83,9 @@ class MessageManager:
|
|||||||
if not context:
|
if not context:
|
||||||
# 创建新的流上下文
|
# 创建新的流上下文
|
||||||
from src.common.data_models.message_manager_data_model import StreamContext
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
new_context = StreamContext(stream_id=stream_id)
|
context = StreamContext(stream_id=stream_id)
|
||||||
success = self.context_manager.add_stream_context(stream_id, new_context)
|
# 将创建的上下文添加到管理器
|
||||||
if not success:
|
self.context_manager.add_stream_context(stream_id, context)
|
||||||
logger.error(f"无法为流 {stream_id} 创建上下文")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 使用 context_manager 添加消息
|
# 使用 context_manager 添加消息
|
||||||
success = self.context_manager.add_message_to_context(stream_id, message)
|
success = self.context_manager.add_message_to_context(stream_id, message)
|
||||||
@@ -97,7 +95,7 @@ class MessageManager:
|
|||||||
else:
|
else:
|
||||||
logger.warning(f"添加消息到聊天流 {stream_id} 失败")
|
logger.warning(f"添加消息到聊天流 {stream_id} 失败")
|
||||||
|
|
||||||
def update_message_and_refresh_energy(
|
def update_message(
|
||||||
self,
|
self,
|
||||||
stream_id: str,
|
stream_id: str,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
@@ -111,7 +109,7 @@ class MessageManager:
|
|||||||
if context:
|
if context:
|
||||||
context.update_message_info(message_id, interest_value, actions, should_reply)
|
context.update_message_info(message_id, interest_value, actions, should_reply)
|
||||||
|
|
||||||
def add_action_and_refresh_energy(self, stream_id: str, message_id: str, action: str):
|
def add_action(self, stream_id: str, message_id: str, action: str):
|
||||||
"""添加动作到消息"""
|
"""添加动作到消息"""
|
||||||
# 使用 context_manager 添加动作到消息
|
# 使用 context_manager 添加动作到消息
|
||||||
context = self.context_manager.get_stream_context(stream_id)
|
context = self.context_manager.get_stream_context(stream_id)
|
||||||
|
|||||||
@@ -310,21 +310,19 @@ class ChatStream:
|
|||||||
self._focus_energy = max(0.0, min(1.0, value))
|
self._focus_energy = max(0.0, min(1.0, value))
|
||||||
|
|
||||||
def _get_user_relationship_score(self) -> float:
|
def _get_user_relationship_score(self) -> float:
|
||||||
"""从新的兴趣度管理系统获取用户关系分"""
|
"""获取用户关系分"""
|
||||||
|
# 使用插件内部的兴趣度评分系统
|
||||||
try:
|
try:
|
||||||
# 使用新的兴趣度管理系统
|
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||||
from src.chat.interest_system import interest_manager
|
|
||||||
|
|
||||||
if self.user_info and hasattr(self.user_info, "user_id"):
|
if self.user_info and hasattr(self.user_info, "user_id"):
|
||||||
user_id = str(self.user_info.user_id)
|
user_id = str(self.user_info.user_id)
|
||||||
# 获取用户交互历史作为关系分的基础
|
relationship_score = chatter_interest_scoring_system._calculate_relationship_score(user_id)
|
||||||
interaction_calc = interest_manager.calculators.get(
|
logger.debug(f"ChatStream {self.stream_id}: 用户关系分 = {relationship_score:.3f}")
|
||||||
interest_manager.InterestSourceType.USER_INTERACTION
|
return max(0.0, min(1.0, relationship_score))
|
||||||
)
|
|
||||||
if interaction_calc:
|
except Exception as e:
|
||||||
return interaction_calc.calculate({"user_id": user_id})
|
logger.warning(f"ChatStream {self.stream_id}: 插件内部关系分计算失败: {e}")
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 默认基础分
|
# 默认基础分
|
||||||
return 0.3
|
return 0.3
|
||||||
|
|||||||
@@ -298,7 +298,7 @@ class ChatterActionManager:
|
|||||||
|
|
||||||
# 通过message_manager更新消息的动作记录并刷新focus_energy
|
# 通过message_manager更新消息的动作记录并刷新focus_energy
|
||||||
if chat_stream.stream_id in message_manager.stream_contexts:
|
if chat_stream.stream_id in message_manager.stream_contexts:
|
||||||
message_manager.add_action_and_refresh_energy(
|
message_manager.add_action(
|
||||||
stream_id=chat_stream.stream_id,
|
stream_id=chat_stream.stream_id,
|
||||||
message_id=target_message_id,
|
message_id=target_message_id,
|
||||||
action=action_name
|
action=action_name
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from src.common.data_models.info_data_model import InterestScore
|
|||||||
from src.chat.interest_system import bot_interest_manager
|
from src.chat.interest_system import bot_interest_manager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
|
||||||
logger = get_logger("chatter_interest_scoring")
|
logger = get_logger("chatter_interest_scoring")
|
||||||
|
|
||||||
# 定义颜色
|
# 定义颜色
|
||||||
|
|||||||
@@ -387,22 +387,27 @@ class ChatterPlanFilter:
|
|||||||
interest_scores = {}
|
interest_scores = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from src.chat.interest_system import interest_manager
|
from .interest_scoring import chatter_interest_scoring_system
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
# 使用新的兴趣度管理系统计算评分
|
# 使用插件内部的兴趣度评分系统计算评分
|
||||||
for msg_dict in messages:
|
for msg_dict in messages:
|
||||||
try:
|
try:
|
||||||
# 构建计算上下文
|
# 将字典转换为DatabaseMessages对象
|
||||||
calc_context = {
|
db_message = DatabaseMessages(
|
||||||
"stream_id": msg_dict.get("chat_id", ""),
|
message_id=msg_dict.get("message_id", ""),
|
||||||
"user_id": msg_dict.get("user_id"),
|
user_info=msg_dict.get("user_info", {}),
|
||||||
}
|
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||||
|
key_words=msg_dict.get("key_words", "[]"),
|
||||||
|
is_mentioned=msg_dict.get("is_mentioned", False)
|
||||||
|
)
|
||||||
|
|
||||||
# 计算消息兴趣度
|
# 计算消息兴趣度
|
||||||
interest_score = interest_manager.calculate_message_interest(
|
interest_score_obj = await chatter_interest_scoring_system._calculate_single_message_score(
|
||||||
message=msg_dict,
|
message=db_message,
|
||||||
context=calc_context
|
bot_nickname=global_config.bot.nickname
|
||||||
)
|
)
|
||||||
|
interest_score = interest_score_obj.total_score
|
||||||
|
|
||||||
# 构建兴趣度字典
|
# 构建兴趣度字典
|
||||||
interest_scores[msg_dict.get("message_id", "")] = interest_score
|
interest_scores[msg_dict.get("message_id", "")] = interest_score
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
|||||||
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
|
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
|
||||||
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
|
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
|
||||||
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
|
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
|
||||||
from src.chat.interest_system import interest_manager
|
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
|
|
||||||
|
|
||||||
@@ -109,10 +109,7 @@ class ChatterActionPlanner:
|
|||||||
# 获取用户ID,优先从user_info.user_id获取,其次从user_id属性获取
|
# 获取用户ID,优先从user_info.user_id获取,其次从user_id属性获取
|
||||||
user_id = None
|
user_id = None
|
||||||
first_message = unread_messages[0]
|
first_message = unread_messages[0]
|
||||||
if hasattr(first_message, 'user_info') and hasattr(first_message.user_info, 'user_id'):
|
user_id = first_message.user_info.user_id
|
||||||
user_id = getattr(first_message.user_info, 'user_id', None)
|
|
||||||
elif hasattr(first_message, 'user_id'):
|
|
||||||
user_id = getattr(first_message, 'user_id', None)
|
|
||||||
|
|
||||||
# 构建计算上下文
|
# 构建计算上下文
|
||||||
calc_context = {
|
calc_context = {
|
||||||
@@ -123,11 +120,12 @@ class ChatterActionPlanner:
|
|||||||
# 为每条消息计算兴趣度
|
# 为每条消息计算兴趣度
|
||||||
for message in unread_messages:
|
for message in unread_messages:
|
||||||
try:
|
try:
|
||||||
# 使用新的兴趣度管理器计算
|
# 使用插件内部的兴趣度评分系统计算
|
||||||
message_interest = interest_manager.calculate_message_interest(
|
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
|
||||||
message=message.__dict__,
|
message=message,
|
||||||
context=calc_context
|
bot_nickname=global_config.bot.nickname
|
||||||
)
|
)
|
||||||
|
message_interest = interest_score.total_score
|
||||||
|
|
||||||
# 更新消息的兴趣度
|
# 更新消息的兴趣度
|
||||||
message.interest_value = message_interest
|
message.interest_value = message_interest
|
||||||
@@ -140,7 +138,7 @@ class ChatterActionPlanner:
|
|||||||
# 更新StreamContext中的消息信息并刷新focus_energy
|
# 更新StreamContext中的消息信息并刷新focus_energy
|
||||||
if context:
|
if context:
|
||||||
from src.chat.message_manager.message_manager import message_manager
|
from src.chat.message_manager.message_manager import message_manager
|
||||||
message_manager.update_message_and_refresh_energy(
|
message_manager.update_message(
|
||||||
stream_id=self.chat_id,
|
stream_id=self.chat_id,
|
||||||
message_id=message.message_id,
|
message_id=message.message_id,
|
||||||
interest_value=message_interest,
|
interest_value=message_interest,
|
||||||
@@ -155,9 +153,6 @@ class ChatterActionPlanner:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"更新数据库消息兴趣度失败: {e}")
|
logger.warning(f"更新数据库消息兴趣度失败: {e}")
|
||||||
|
|
||||||
# 更新话题兴趣度
|
|
||||||
interest_manager.update_topic_interest(message.__dict__, message_interest)
|
|
||||||
|
|
||||||
# 记录最高分
|
# 记录最高分
|
||||||
if message_interest > score:
|
if message_interest > score:
|
||||||
score = message_interest
|
score = message_interest
|
||||||
|
|||||||
@@ -29,9 +29,6 @@ class ChatterRelationshipTracker:
|
|||||||
self.relationship_history: List[Dict] = []
|
self.relationship_history: List[Dict] = []
|
||||||
self.interest_scoring_system = interest_scoring_system
|
self.interest_scoring_system = interest_scoring_system
|
||||||
|
|
||||||
# 数据库访问 - 使用SQLAlchemy
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 用户关系缓存 (user_id -> {"relationship_text": str, "relationship_score": float, "last_tracked": float})
|
# 用户关系缓存 (user_id -> {"relationship_text": str, "relationship_score": float, "last_tracked": float})
|
||||||
self.user_relationship_cache: Dict[str, Dict] = {}
|
self.user_relationship_cache: Dict[str, Dict] = {}
|
||||||
self.cache_expiry_hours = 1 # 缓存过期时间(小时)
|
self.cache_expiry_hours = 1 # 缓存过期时间(小时)
|
||||||
|
|||||||
Reference in New Issue
Block a user