refactor(interest-system): 移除旧兴趣度管理系统,迁移到插件内部实现
移除旧的集中式兴趣度管理系统(interest_manager.py),将兴趣度计算功能迁移到affinity_flow_chatter插件内部实现。主要包括: - 删除interest_manager.py及其相关导入引用 - 修改RelationshipEnergyCalculator使用插件内部的关系分计算 - 重构StreamContextManager使用插件内部的兴趣度评分系统 - 更新ChatStream、PlanFilter、Planner等组件使用新的插件接口 - 简化上下文管理器,移除事件系统和验证器相关代码 此次重构提高了模块独立性,减少了核心代码对插件功能的直接依赖,符合"高内聚低耦合"的设计原则。
This commit is contained in:
@@ -204,24 +204,17 @@ class RelationshipEnergyCalculator(EnergyCalculator):
|
||||
if not user_id:
|
||||
return 0.3
|
||||
|
||||
# 使用插件内部的兴趣度评分系统获取关系分
|
||||
try:
|
||||
# 使用新的兴趣度管理系统获取用户关系分
|
||||
from src.chat.interest_system import interest_manager
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
|
||||
# 获取用户交互历史作为关系分的基础
|
||||
interaction_calc = interest_manager.calculators.get(
|
||||
interest_manager.InterestSourceType.USER_INTERACTION
|
||||
)
|
||||
if interaction_calc:
|
||||
relationship_score = interaction_calc.calculate({"user_id": user_id})
|
||||
logger.debug(f"用户关系分数: {relationship_score:.3f}")
|
||||
return max(0.0, min(1.0, relationship_score))
|
||||
else:
|
||||
# 默认基础分
|
||||
return 0.3
|
||||
except Exception:
|
||||
# 默认基础分
|
||||
return 0.3
|
||||
relationship_score = chatter_interest_scoring_system._calculate_relationship_score(user_id)
|
||||
logger.debug(f"使用插件内部系统计算关系分: {relationship_score:.3f}")
|
||||
return max(0.0, min(1.0, relationship_score))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"插件内部关系分计算失败,使用默认值: {e}")
|
||||
return 0.3 # 默认基础分
|
||||
|
||||
def get_weight(self) -> float:
|
||||
return 0.1
|
||||
|
||||
@@ -1,30 +1,12 @@
|
||||
"""
|
||||
兴趣度系统模块
|
||||
提供统一、稳定的消息兴趣度计算和管理功能
|
||||
提供机器人兴趣标签和智能匹配功能
|
||||
"""
|
||||
|
||||
from .interest_manager import (
|
||||
InterestManager,
|
||||
InterestSourceType,
|
||||
InterestFactor,
|
||||
InterestCalculator,
|
||||
MessageContentInterestCalculator,
|
||||
TopicInterestCalculator,
|
||||
UserInteractionInterestCalculator,
|
||||
interest_manager
|
||||
)
|
||||
from .bot_interest_manager import BotInterestManager, bot_interest_manager
|
||||
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
|
||||
|
||||
__all__ = [
|
||||
"InterestManager",
|
||||
"InterestSourceType",
|
||||
"InterestFactor",
|
||||
"InterestCalculator",
|
||||
"MessageContentInterestCalculator",
|
||||
"TopicInterestCalculator",
|
||||
"UserInteractionInterestCalculator",
|
||||
"interest_manager",
|
||||
"BotInterestManager",
|
||||
"bot_interest_manager",
|
||||
"BotInterestTag",
|
||||
|
||||
@@ -1,430 +0,0 @@
|
||||
"""
|
||||
重构后的消息兴趣值计算系统
|
||||
提供稳定、可靠的消息兴趣度计算和管理功能
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union, TypedDict
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("interest_system")
|
||||
|
||||
|
||||
class InterestSourceType(Enum):
|
||||
"""兴趣度来源类型"""
|
||||
MESSAGE_CONTENT = "message_content" # 消息内容
|
||||
USER_INTERACTION = "user_interaction" # 用户交互
|
||||
TOPIC_RELEVANCE = "topic_relevance" # 话题相关性
|
||||
RELATIONSHIP_SCORE = "relationship_score" # 关系分数
|
||||
HISTORICAL_PATTERN = "historical_pattern" # 历史模式
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterestFactor:
|
||||
"""兴趣度因子"""
|
||||
source_type: InterestSourceType
|
||||
value: float
|
||||
weight: float = 1.0
|
||||
decay_rate: float = 0.1 # 衰减率
|
||||
last_updated: float = field(default_factory=time.time)
|
||||
|
||||
def get_current_value(self) -> float:
|
||||
"""获取当前值(考虑时间衰减)"""
|
||||
age = time.time() - self.last_updated
|
||||
decay_factor = max(0.1, 1.0 - (age * self.decay_rate / (24 * 3600))) # 按天衰减
|
||||
return self.value * decay_factor
|
||||
|
||||
def update_value(self, new_value: float) -> None:
|
||||
"""更新值"""
|
||||
self.value = max(0.0, min(1.0, new_value))
|
||||
self.last_updated = time.time()
|
||||
|
||||
|
||||
class InterestCalculator(ABC):
|
||||
"""兴趣度计算器抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""计算兴趣度"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_confidence(self) -> float:
|
||||
"""获取计算置信度"""
|
||||
pass
|
||||
|
||||
|
||||
class MessageData(TypedDict):
|
||||
"""消息数据类型定义"""
|
||||
message_id: str
|
||||
processed_plain_text: str
|
||||
is_emoji: bool
|
||||
is_picid: bool
|
||||
is_mentioned: bool
|
||||
is_command: bool
|
||||
key_words: str
|
||||
user_id: str
|
||||
time: float
|
||||
|
||||
|
||||
class InterestContext(TypedDict):
|
||||
"""兴趣度计算上下文"""
|
||||
stream_id: str
|
||||
user_id: Optional[str]
|
||||
message: MessageData
|
||||
|
||||
|
||||
class InterestResult(TypedDict):
|
||||
"""兴趣度计算结果"""
|
||||
value: float
|
||||
confidence: float
|
||||
source_scores: Dict[InterestSourceType, float]
|
||||
cached: bool
|
||||
|
||||
|
||||
class MessageContentInterestCalculator(InterestCalculator):
|
||||
"""消息内容兴趣度计算器"""
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""基于消息内容计算兴趣度"""
|
||||
message = context.get("message", {})
|
||||
if not message:
|
||||
return 0.3 # 默认值
|
||||
|
||||
# 提取消息特征
|
||||
text_length = len(message.get("processed_plain_text", ""))
|
||||
has_emoji = message.get("is_emoji", False)
|
||||
has_image = message.get("is_picid", False)
|
||||
is_mentioned = message.get("is_mentioned", False)
|
||||
is_command = message.get("is_command", False)
|
||||
|
||||
# 基础分数
|
||||
base_score = 0.3
|
||||
|
||||
# 文本长度加权
|
||||
if text_length > 0:
|
||||
text_score = min(0.3, text_length / 200) # 200字符为满分
|
||||
base_score += text_score * 0.3
|
||||
|
||||
# 多媒体内容加权
|
||||
if has_emoji:
|
||||
base_score += 0.1
|
||||
if has_image:
|
||||
base_score += 0.2
|
||||
|
||||
# 交互特征加权
|
||||
if is_mentioned:
|
||||
base_score += 0.2
|
||||
if is_command:
|
||||
base_score += 0.1
|
||||
|
||||
return min(1.0, base_score)
|
||||
|
||||
def get_confidence(self) -> float:
|
||||
return 0.8
|
||||
|
||||
|
||||
class TopicInterestCalculator(InterestCalculator):
|
||||
"""话题兴趣度计算器"""
|
||||
|
||||
def __init__(self):
|
||||
self.topic_interests: Dict[str, float] = {}
|
||||
self.topic_decay_rate = 0.05 # 话题兴趣度衰减率
|
||||
|
||||
def update_topic_interest(self, topic: str, interest_value: float):
|
||||
"""更新话题兴趣度"""
|
||||
current_interest = self.topic_interests.get(topic, 0.3)
|
||||
# 平滑更新
|
||||
new_interest = current_interest * 0.7 + interest_value * 0.3
|
||||
self.topic_interests[topic] = max(0.0, min(1.0, new_interest))
|
||||
|
||||
logger.debug(f"更新话题 '{topic}' 兴趣度: {current_interest:.3f} -> {new_interest:.3f}")
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""基于话题相关性计算兴趣度"""
|
||||
message = context.get("message", {})
|
||||
keywords = message.get("key_words", "[]")
|
||||
|
||||
try:
|
||||
import json
|
||||
keyword_list = json.loads(keywords) if keywords else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
keyword_list = []
|
||||
|
||||
if not keyword_list:
|
||||
return 0.4 # 无关键词时的默认值
|
||||
|
||||
# 计算相关话题的平均兴趣度
|
||||
total_interest = 0.0
|
||||
relevant_topics = 0
|
||||
|
||||
for keyword in keyword_list[:5]: # 最多取前5个关键词
|
||||
# 查找相关话题
|
||||
for topic, interest in self.topic_interests.items():
|
||||
if keyword.lower() in topic.lower() or topic.lower() in keyword.lower():
|
||||
total_interest += interest
|
||||
relevant_topics += 1
|
||||
break
|
||||
|
||||
if relevant_topics > 0:
|
||||
return min(1.0, total_interest / relevant_topics)
|
||||
else:
|
||||
# 新话题,给予基础兴趣度
|
||||
for keyword in keyword_list[:3]:
|
||||
self.topic_interests[keyword] = 0.5
|
||||
return 0.5
|
||||
|
||||
def get_confidence(self) -> float:
|
||||
return 0.7
|
||||
|
||||
|
||||
class UserInteractionInterestCalculator(InterestCalculator):
|
||||
"""用户交互兴趣度计算器"""
|
||||
|
||||
def __init__(self):
|
||||
self.interaction_history: List[Dict] = []
|
||||
self.max_history_size = 100
|
||||
|
||||
def add_interaction(self, user_id: str, interaction_type: str, value: float):
|
||||
"""添加交互记录"""
|
||||
self.interaction_history.append({
|
||||
"user_id": user_id,
|
||||
"type": interaction_type,
|
||||
"value": value,
|
||||
"timestamp": time.time()
|
||||
})
|
||||
|
||||
# 保持历史记录大小
|
||||
if len(self.interaction_history) > self.max_history_size:
|
||||
self.interaction_history = self.interaction_history[-self.max_history_size:]
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
"""基于用户交互历史计算兴趣度"""
|
||||
user_id = context.get("user_id")
|
||||
if not user_id:
|
||||
return 0.3
|
||||
|
||||
# 获取该用户的最近交互记录
|
||||
user_interactions = [
|
||||
interaction for interaction in self.interaction_history
|
||||
if interaction["user_id"] == user_id
|
||||
]
|
||||
|
||||
if not user_interactions:
|
||||
return 0.3
|
||||
|
||||
# 计算加权平均(最近的交互权重更高)
|
||||
total_weight = 0.0
|
||||
weighted_sum = 0.0
|
||||
|
||||
for interaction in user_interactions[-20:]: # 最近20次交互
|
||||
age = time.time() - interaction["timestamp"]
|
||||
weight = max(0.1, 1.0 - age / (7 * 24 * 3600)) # 7天内衰减
|
||||
|
||||
weighted_sum += interaction["value"] * weight
|
||||
total_weight += weight
|
||||
|
||||
if total_weight > 0:
|
||||
return min(1.0, weighted_sum / total_weight)
|
||||
else:
|
||||
return 0.3
|
||||
|
||||
def get_confidence(self) -> float:
|
||||
return 0.6
|
||||
|
||||
|
||||
class InterestManager:
|
||||
"""兴趣度管理器 - 统一管理所有兴趣度计算"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calculators: Dict[InterestSourceType, InterestCalculator] = {
|
||||
InterestSourceType.MESSAGE_CONTENT: MessageContentInterestCalculator(),
|
||||
InterestSourceType.TOPIC_RELEVANCE: TopicInterestCalculator(),
|
||||
InterestSourceType.USER_INTERACTION: UserInteractionInterestCalculator(),
|
||||
}
|
||||
|
||||
# 权重配置
|
||||
self.source_weights: Dict[InterestSourceType, float] = {
|
||||
InterestSourceType.MESSAGE_CONTENT: 0.4,
|
||||
InterestSourceType.TOPIC_RELEVANCE: 0.3,
|
||||
InterestSourceType.USER_INTERACTION: 0.3,
|
||||
}
|
||||
|
||||
# 兴趣度缓存
|
||||
self.interest_cache: Dict[str, Tuple[float, float]] = {} # message_id -> (value, timestamp)
|
||||
self.cache_ttl: int = 300 # 5分钟缓存
|
||||
|
||||
# 统计信息
|
||||
self.stats: Dict[str, Union[int, float, List[str]]] = {
|
||||
"total_calculations": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
"average_calculation_time": 0.0,
|
||||
"calculator_usage": {calc_type.value: 0 for calc_type in InterestSourceType}
|
||||
}
|
||||
|
||||
logger.info("兴趣度管理器初始化完成")
|
||||
|
||||
def calculate_message_interest(self, message: Dict[str, Any], context: Dict[str, Any]) -> float:
|
||||
"""计算消息兴趣度"""
|
||||
start_time = time.time()
|
||||
message_id = message.get("message_id", "")
|
||||
|
||||
# 更新统计
|
||||
self.stats["total_calculations"] += 1
|
||||
|
||||
# 检查缓存
|
||||
if message_id in self.interest_cache:
|
||||
cached_value, cached_time = self.interest_cache[message_id]
|
||||
if time.time() - cached_time < self.cache_ttl:
|
||||
self.stats["cache_hits"] += 1
|
||||
logger.debug(f"使用缓存兴趣度: {message_id} = {cached_value:.3f}")
|
||||
return cached_value
|
||||
else:
|
||||
self.stats["cache_misses"] += 1
|
||||
|
||||
# 构建计算上下文
|
||||
calc_context: Dict[str, Any] = {
|
||||
"message": message,
|
||||
"user_id": message.get("user_id"),
|
||||
**context
|
||||
}
|
||||
|
||||
# 计算各来源的兴趣度
|
||||
source_scores: Dict[InterestSourceType, float] = {}
|
||||
total_confidence = 0.0
|
||||
|
||||
for source_type, calculator in self.calculators.items():
|
||||
try:
|
||||
score = calculator.calculate(calc_context)
|
||||
confidence = calculator.get_confidence()
|
||||
|
||||
source_scores[source_type] = score
|
||||
total_confidence += confidence
|
||||
|
||||
# 更新计算器使用统计
|
||||
self.stats["calculator_usage"][source_type.value] += 1
|
||||
|
||||
logger.debug(f"{source_type.value} 兴趣度: {score:.3f} (置信度: {confidence:.3f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算 {source_type.value} 兴趣度失败: {e}")
|
||||
source_scores[source_type] = 0.3
|
||||
|
||||
# 加权计算最终兴趣度
|
||||
final_interest = 0.0
|
||||
total_weight = 0.0
|
||||
|
||||
for source_type, score in source_scores.items():
|
||||
weight = self.source_weights.get(source_type, 0.0)
|
||||
final_interest += score * weight
|
||||
total_weight += weight
|
||||
|
||||
if total_weight > 0:
|
||||
final_interest /= total_weight
|
||||
|
||||
# 确保在合理范围内
|
||||
final_interest = max(0.0, min(1.0, final_interest))
|
||||
|
||||
# 缓存结果
|
||||
self.interest_cache[message_id] = (final_interest, time.time())
|
||||
|
||||
# 清理过期缓存
|
||||
self._cleanup_cache()
|
||||
|
||||
# 更新平均计算时间
|
||||
calculation_time = time.time() - start_time
|
||||
total_calculations = self.stats["total_calculations"]
|
||||
self.stats["average_calculation_time"] = (
|
||||
(self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time)
|
||||
/ total_calculations
|
||||
)
|
||||
|
||||
logger.info(f"消息 {message_id} 最终兴趣度: {final_interest:.3f} (耗时: {calculation_time:.3f}s)")
|
||||
return final_interest
|
||||
|
||||
def update_topic_interest(self, message: Dict[str, Any], interest_value: float) -> None:
|
||||
"""更新话题兴趣度"""
|
||||
topic_calc = self.calculators.get(InterestSourceType.TOPIC_RELEVANCE)
|
||||
if isinstance(topic_calc, TopicInterestCalculator):
|
||||
# 提取关键词作为话题
|
||||
keywords = message.get("key_words", "[]")
|
||||
try:
|
||||
import json
|
||||
keyword_list: List[str] = json.loads(keywords) if keywords else []
|
||||
for keyword in keyword_list[:3]: # 更新前3个关键词
|
||||
topic_calc.update_topic_interest(keyword, interest_value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
def add_user_interaction(self, user_id: str, interaction_type: str, value: float) -> None:
|
||||
"""添加用户交互记录"""
|
||||
interaction_calc = self.calculators.get(InterestSourceType.USER_INTERACTION)
|
||||
if isinstance(interaction_calc, UserInteractionInterestCalculator):
|
||||
interaction_calc.add_interaction(user_id, interaction_type, value)
|
||||
|
||||
def get_topic_interests(self) -> Dict[str, float]:
|
||||
"""获取所有话题兴趣度"""
|
||||
topic_calc = self.calculators.get(InterestSourceType.TOPIC_RELEVANCE)
|
||||
if isinstance(topic_calc, TopicInterestCalculator):
|
||||
return topic_calc.topic_interests.copy()
|
||||
return {}
|
||||
|
||||
def _cleanup_cache(self) -> None:
|
||||
"""清理过期缓存"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
message_id for message_id, (_, timestamp) in self.interest_cache.items()
|
||||
if current_time - timestamp > self.cache_ttl
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
del self.interest_cache[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.debug(f"清理了 {len(expired_keys)} 个过期兴趣度缓存")
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
return {
|
||||
"cache_size": len(self.interest_cache),
|
||||
"topic_count": len(self.get_topic_interests()),
|
||||
"calculators": list(self.calculators.keys()),
|
||||
"performance_stats": self.stats.copy(),
|
||||
}
|
||||
|
||||
def add_calculator(self, source_type: InterestSourceType, calculator: InterestCalculator) -> None:
|
||||
"""添加自定义计算器"""
|
||||
self.calculators[source_type] = calculator
|
||||
logger.info(f"添加计算器: {source_type.value}")
|
||||
|
||||
def remove_calculator(self, source_type: InterestSourceType) -> None:
|
||||
"""移除计算器"""
|
||||
if source_type in self.calculators:
|
||||
del self.calculators[source_type]
|
||||
logger.info(f"移除计算器: {source_type.value}")
|
||||
|
||||
def set_source_weight(self, source_type: InterestSourceType, weight: float) -> None:
|
||||
"""设置来源权重"""
|
||||
self.source_weights[source_type] = max(0.0, min(1.0, weight))
|
||||
logger.info(f"设置 {source_type.value} 权重: {weight}")
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清空缓存"""
|
||||
self.interest_cache.clear()
|
||||
logger.info("清空兴趣度缓存")
|
||||
|
||||
def get_cache_hit_rate(self) -> float:
|
||||
"""获取缓存命中率"""
|
||||
total_requests = self.stats.get("cache_hits", 0) + self.stats.get("cache_misses", 0)
|
||||
if total_requests == 0:
|
||||
return 0.0
|
||||
return self.stats["cache_hits"] / total_requests
|
||||
|
||||
|
||||
# 全局兴趣度管理器实例
|
||||
interest_manager = InterestManager()
|
||||
@@ -5,103 +5,18 @@
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Callable, Union, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Any, Union, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.interest_system import interest_manager
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.chat.energy_system import energy_manager
|
||||
from .distribution_manager import distribution_manager
|
||||
|
||||
logger = get_logger("context_manager")
|
||||
|
||||
|
||||
class ContextEventType(Enum):
|
||||
"""上下文事件类型"""
|
||||
MESSAGE_ADDED = "message_added"
|
||||
MESSAGE_UPDATED = "message_updated"
|
||||
ENERGY_CHANGED = "energy_changed"
|
||||
STREAM_ACTIVATED = "stream_activated"
|
||||
STREAM_DEACTIVATED = "stream_deactivated"
|
||||
CONTEXT_CLEARED = "context_cleared"
|
||||
VALIDATION_FAILED = "validation_failed"
|
||||
CLEANUP_COMPLETED = "cleanup_completed"
|
||||
INTEGRITY_CHECK = "integrity_check"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ContextEventType.{self.name}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextEvent:
|
||||
"""上下文事件"""
|
||||
event_type: ContextEventType
|
||||
stream_id: str
|
||||
data: Dict[str, Any] = field(default_factory=dict)
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
event_id: str = field(default_factory=lambda: f"event_{time.time()}_{id(object())}")
|
||||
priority: int = 0 # 事件优先级,数字越大优先级越高
|
||||
source: str = "system" # 事件来源
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ContextEvent({self.event_type}, {self.stream_id}, ts={self.timestamp:.3f})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ContextEvent(event_type={self.event_type}, stream_id={self.stream_id}, timestamp={self.timestamp}, event_id={self.event_id})"
|
||||
|
||||
def get_age(self) -> float:
|
||||
"""获取事件年龄(秒)"""
|
||||
return time.time() - self.timestamp
|
||||
|
||||
def is_expired(self, max_age: float = 3600.0) -> bool:
|
||||
"""检查事件是否已过期
|
||||
|
||||
Args:
|
||||
max_age: 最大年龄(秒)
|
||||
|
||||
Returns:
|
||||
bool: 是否已过期
|
||||
"""
|
||||
return self.get_age() > max_age
|
||||
|
||||
|
||||
class ContextValidator(ABC):
|
||||
"""上下文验证器抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
def validate_context(self, stream_id: str, context: Any) -> Tuple[bool, Optional[str]]:
|
||||
"""验证上下文
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
context: 上下文对象
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (是否有效, 错误信息)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DefaultContextValidator(ContextValidator):
|
||||
"""默认上下文验证器"""
|
||||
|
||||
def validate_context(self, stream_id: str, context: Any) -> Tuple[bool, Optional[str]]:
|
||||
"""验证上下文基本完整性"""
|
||||
if not hasattr(context, 'stream_id'):
|
||||
return False, "缺少 stream_id 属性"
|
||||
if not hasattr(context, 'unread_messages'):
|
||||
return False, "缺少 unread_messages 属性"
|
||||
if not hasattr(context, 'history_messages'):
|
||||
return False, "缺少 history_messages 属性"
|
||||
return True, None
|
||||
|
||||
|
||||
class StreamContextManager:
|
||||
"""流上下文管理器 - 统一管理所有聊天流上下文"""
|
||||
|
||||
@@ -110,14 +25,6 @@ class StreamContextManager:
|
||||
self.stream_contexts: Dict[str, Any] = {}
|
||||
self.context_metadata: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# 事件监听器
|
||||
self.event_listeners: Dict[ContextEventType, List[Callable]] = {}
|
||||
self.event_history: List[ContextEvent] = []
|
||||
self.max_event_history = 1000
|
||||
|
||||
# 验证器
|
||||
self.validators: List[ContextValidator] = [DefaultContextValidator()]
|
||||
|
||||
# 统计信息
|
||||
self.stats: Dict[str, Union[int, float, str, Dict]] = {
|
||||
"total_messages": 0,
|
||||
@@ -126,16 +33,6 @@ class StreamContextManager:
|
||||
"inactive_streams": 0,
|
||||
"last_activity": time.time(),
|
||||
"creation_time": time.time(),
|
||||
"validation_stats": {
|
||||
"total_validations": 0,
|
||||
"validation_failures": 0,
|
||||
"last_validation_time": 0.0,
|
||||
},
|
||||
"event_stats": {
|
||||
"total_events": 0,
|
||||
"events_by_type": {},
|
||||
"last_event_time": 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
# 配置参数
|
||||
@@ -166,17 +63,6 @@ class StreamContextManager:
|
||||
logger.warning(f"流上下文已存在: {stream_id}")
|
||||
return False
|
||||
|
||||
# 验证上下文
|
||||
if self.enable_validation:
|
||||
is_valid, error_msg = self._validate_context(stream_id, context)
|
||||
if not is_valid:
|
||||
logger.error(f"上下文验证失败: {stream_id} - {error_msg}")
|
||||
self._emit_event(ContextEventType.VALIDATION_FAILED, stream_id, {
|
||||
"error": error_msg,
|
||||
"context_type": type(context).__name__
|
||||
})
|
||||
return False
|
||||
|
||||
# 添加上下文
|
||||
self.stream_contexts[stream_id] = context
|
||||
|
||||
@@ -185,7 +71,6 @@ class StreamContextManager:
|
||||
"created_time": time.time(),
|
||||
"last_access_time": time.time(),
|
||||
"access_count": 0,
|
||||
"validation_errors": 0,
|
||||
"last_validation_time": 0.0,
|
||||
"custom_metadata": metadata or {},
|
||||
}
|
||||
@@ -195,13 +80,6 @@ class StreamContextManager:
|
||||
self.stats["active_streams"] += 1
|
||||
self.stats["last_activity"] = time.time()
|
||||
|
||||
# 触发事件
|
||||
self._emit_event(ContextEventType.STREAM_ACTIVATED, stream_id, {
|
||||
"context": context,
|
||||
"context_type": type(context).__name__,
|
||||
"metadata": metadata
|
||||
})
|
||||
|
||||
logger.debug(f"添加流上下文: {stream_id} (类型: {type(context).__name__})")
|
||||
return True
|
||||
|
||||
@@ -226,19 +104,11 @@ class StreamContextManager:
|
||||
self.stats["inactive_streams"] += 1
|
||||
self.stats["last_activity"] = time.time()
|
||||
|
||||
# 触发事件
|
||||
self._emit_event(ContextEventType.STREAM_DEACTIVATED, stream_id, {
|
||||
"context": context,
|
||||
"context_type": type(context).__name__,
|
||||
"metadata": metadata,
|
||||
"uptime": time.time() - metadata.get("created_time", time.time())
|
||||
})
|
||||
|
||||
logger.debug(f"移除流上下文: {stream_id} (类型: {type(context).__name__})")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_stream_context(self, stream_id: str, update_access: bool = True) -> Optional[Any]:
|
||||
def get_stream_context(self, stream_id: str, update_access: bool = True) -> Optional[StreamContext]:
|
||||
"""获取流上下文
|
||||
|
||||
Args:
|
||||
@@ -284,7 +154,7 @@ class StreamContextManager:
|
||||
self.context_metadata[stream_id].update(updates)
|
||||
return True
|
||||
|
||||
def add_message_to_context(self, stream_id: str, message: Any, skip_energy_update: bool = False) -> bool:
|
||||
def add_message_to_context(self, stream_id: str, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
|
||||
"""添加消息到上下文
|
||||
|
||||
Args:
|
||||
@@ -302,30 +172,16 @@ class StreamContextManager:
|
||||
|
||||
try:
|
||||
# 添加消息到上下文
|
||||
if hasattr(context, 'add_message'):
|
||||
context.add_message(message)
|
||||
else:
|
||||
logger.error(f"上下文对象缺少 add_message 方法: {stream_id}")
|
||||
return False
|
||||
context.add_message(message)
|
||||
|
||||
# 计算消息兴趣度
|
||||
interest_value = self._calculate_message_interest(message)
|
||||
if hasattr(message, 'interest_value'):
|
||||
message.interest_value = interest_value
|
||||
message.interest_value = interest_value
|
||||
|
||||
# 更新统计
|
||||
self.stats["total_messages"] += 1
|
||||
self.stats["last_activity"] = time.time()
|
||||
|
||||
# 触发事件
|
||||
event_data = {
|
||||
"message": message,
|
||||
"interest_value": interest_value,
|
||||
"message_type": type(message).__name__,
|
||||
"message_id": getattr(message, "message_id", None),
|
||||
}
|
||||
self._emit_event(ContextEventType.MESSAGE_ADDED, stream_id, event_data)
|
||||
|
||||
# 更新能量和分发
|
||||
if not skip_energy_update:
|
||||
self._update_stream_energy(stream_id)
|
||||
@@ -356,18 +212,7 @@ class StreamContextManager:
|
||||
|
||||
try:
|
||||
# 更新消息信息
|
||||
if hasattr(context, 'update_message_info'):
|
||||
context.update_message_info(message_id, **updates)
|
||||
else:
|
||||
logger.error(f"上下文对象缺少 update_message_info 方法: {stream_id}")
|
||||
return False
|
||||
|
||||
# 触发事件
|
||||
self._emit_event(ContextEventType.MESSAGE_UPDATED, stream_id, {
|
||||
"message_id": message_id,
|
||||
"updates": updates,
|
||||
"update_time": time.time(),
|
||||
})
|
||||
context.update_message_info(message_id, **updates)
|
||||
|
||||
# 如果更新了兴趣度,重新计算能量
|
||||
if "interest_value" in updates:
|
||||
@@ -380,7 +225,7 @@ class StreamContextManager:
|
||||
logger.error(f"更新上下文消息失败 {stream_id}/{message_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def get_context_messages(self, stream_id: str, limit: Optional[int] = None, include_unread: bool = True) -> List[Any]:
|
||||
def get_context_messages(self, stream_id: str, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]:
|
||||
"""获取上下文消息
|
||||
|
||||
Args:
|
||||
@@ -397,14 +242,13 @@ class StreamContextManager:
|
||||
|
||||
try:
|
||||
messages = []
|
||||
if include_unread and hasattr(context, 'get_unread_messages'):
|
||||
if include_unread:
|
||||
messages.extend(context.get_unread_messages())
|
||||
|
||||
if hasattr(context, 'get_history_messages'):
|
||||
if limit:
|
||||
messages.extend(context.get_history_messages(limit=limit))
|
||||
else:
|
||||
messages.extend(context.get_history_messages())
|
||||
if limit:
|
||||
messages.extend(context.get_history_messages(limit=limit))
|
||||
else:
|
||||
messages.extend(context.get_history_messages())
|
||||
|
||||
# 按时间排序
|
||||
messages.sort(key=lambda msg: getattr(msg, 'time', 0))
|
||||
@@ -419,7 +263,7 @@ class StreamContextManager:
|
||||
logger.error(f"获取上下文消息失败 {stream_id}: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def get_unread_messages(self, stream_id: str) -> List[Any]:
|
||||
def get_unread_messages(self, stream_id: str) -> List[DatabaseMessages]:
|
||||
"""获取未读消息
|
||||
|
||||
Args:
|
||||
@@ -433,11 +277,7 @@ class StreamContextManager:
|
||||
return []
|
||||
|
||||
try:
|
||||
if hasattr(context, 'get_unread_messages'):
|
||||
return context.get_unread_messages()
|
||||
else:
|
||||
logger.warning(f"上下文对象缺少 get_unread_messages 方法: {stream_id}")
|
||||
return []
|
||||
return context.get_unread_messages()
|
||||
except Exception as e:
|
||||
logger.error(f"获取未读消息失败 {stream_id}: {e}", exc_info=True)
|
||||
return []
|
||||
@@ -507,12 +347,6 @@ class StreamContextManager:
|
||||
else:
|
||||
setattr(context, attr, time.time())
|
||||
|
||||
# 触发事件
|
||||
self._emit_event(ContextEventType.CONTEXT_CLEARED, stream_id, {
|
||||
"clear_time": time.time(),
|
||||
"reset_attributes": reset_attrs,
|
||||
})
|
||||
|
||||
# 重新计算能量
|
||||
self._update_stream_energy(stream_id)
|
||||
|
||||
@@ -523,22 +357,33 @@ class StreamContextManager:
|
||||
logger.error(f"清空上下文失败 {stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def _calculate_message_interest(self, message: Any) -> float:
|
||||
def _calculate_message_interest(self, message: DatabaseMessages) -> float:
|
||||
"""计算消息兴趣度"""
|
||||
try:
|
||||
# 将消息转换为字典格式
|
||||
message_dict = self._message_to_dict(message)
|
||||
# 使用插件内部的兴趣度评分系统
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
|
||||
# 使用兴趣度管理器计算
|
||||
context = {
|
||||
"stream_id": getattr(message, 'chat_info_stream_id', ''),
|
||||
"user_id": getattr(message, 'user_id', ''),
|
||||
}
|
||||
# 使用插件内部的兴趣度评分系统计算(同步方式)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
interest_value = interest_manager.calculate_message_interest(message_dict, context)
|
||||
interest_score = loop.run_until_complete(
|
||||
chatter_interest_scoring_system._calculate_single_message_score(
|
||||
message=message,
|
||||
bot_nickname=global_config.bot.nickname
|
||||
)
|
||||
)
|
||||
interest_value = interest_score.total_score
|
||||
|
||||
# 更新话题兴趣度
|
||||
interest_manager.update_topic_interest(message_dict, interest_value)
|
||||
logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"插件内部兴趣度计算失败,使用默认值: {e}")
|
||||
interest_value = 0.5 # 默认中等兴趣度
|
||||
|
||||
return interest_value
|
||||
|
||||
@@ -546,31 +391,6 @@ class StreamContextManager:
|
||||
logger.error(f"计算消息兴趣度失败: {e}")
|
||||
return 0.5
|
||||
|
||||
def _message_to_dict(self, message: Any) -> Dict[str, Any]:
|
||||
"""将消息对象转换为字典"""
|
||||
try:
|
||||
# 获取user_id,优先从user_info.user_id获取,其次从user_id属性获取
|
||||
user_id = ""
|
||||
if hasattr(message, 'user_info') and hasattr(message.user_info, 'user_id'):
|
||||
user_id = getattr(message.user_info, 'user_id', "")
|
||||
else:
|
||||
user_id = getattr(message, 'user_id', "")
|
||||
|
||||
return {
|
||||
"message_id": getattr(message, "message_id", ""),
|
||||
"processed_plain_text": getattr(message, "processed_plain_text", ""),
|
||||
"is_emoji": getattr(message, "is_emoji", False),
|
||||
"is_picid": getattr(message, "is_picid", False),
|
||||
"is_mentioned": getattr(message, "is_mentioned", False),
|
||||
"is_command": getattr(message, "is_command", False),
|
||||
"key_words": getattr(message, "key_words", "[]"),
|
||||
"user_id": user_id,
|
||||
"time": getattr(message, "time", time.time()),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"转换消息为字典失败: {e}")
|
||||
return {}
|
||||
|
||||
def _update_stream_energy(self, stream_id: str):
|
||||
"""更新流能量"""
|
||||
try:
|
||||
@@ -583,7 +403,7 @@ class StreamContextManager:
|
||||
user_id = None
|
||||
if combined_messages:
|
||||
last_message = combined_messages[-1]
|
||||
user_id = getattr(last_message, "user_id", None)
|
||||
user_id = last_message.user_info.user_id
|
||||
|
||||
# 计算能量
|
||||
energy = energy_manager.calculate_focus_energy(
|
||||
@@ -595,91 +415,9 @@ class StreamContextManager:
|
||||
# 更新分发管理器
|
||||
distribution_manager.update_stream_energy(stream_id, energy)
|
||||
|
||||
# 触发事件
|
||||
self._emit_event(ContextEventType.ENERGY_CHANGED, stream_id, {
|
||||
"energy": energy,
|
||||
"message_count": len(combined_messages),
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新流能量失败 {stream_id}: {e}")
|
||||
|
||||
def add_event_listener(self, event_type: ContextEventType, listener: Callable[[ContextEvent], None]) -> bool:
|
||||
"""添加事件监听器
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
listener: 监听器函数
|
||||
|
||||
Returns:
|
||||
bool: 是否成功添加
|
||||
"""
|
||||
if not callable(listener):
|
||||
logger.error(f"监听器必须是可调用对象: {type(listener)}")
|
||||
return False
|
||||
|
||||
if event_type not in self.event_listeners:
|
||||
self.event_listeners[event_type] = []
|
||||
|
||||
if listener not in self.event_listeners[event_type]:
|
||||
self.event_listeners[event_type].append(listener)
|
||||
logger.debug(f"添加事件监听器: {event_type} -> {getattr(listener, '__name__', 'anonymous')}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def remove_event_listener(self, event_type: ContextEventType, listener: Callable[[ContextEvent], None]) -> bool:
|
||||
"""移除事件监听器
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
listener: 监听器函数
|
||||
|
||||
Returns:
|
||||
bool: 是否成功移除
|
||||
"""
|
||||
if event_type in self.event_listeners:
|
||||
try:
|
||||
self.event_listeners[event_type].remove(listener)
|
||||
logger.debug(f"移除事件监听器: {event_type}")
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _emit_event(self, event_type: ContextEventType, stream_id: str, data: Optional[Dict] = None, priority: int = 0) -> None:
|
||||
"""触发事件
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
stream_id: 流ID
|
||||
data: 事件数据
|
||||
priority: 事件优先级
|
||||
"""
|
||||
if data is None:
|
||||
data = {}
|
||||
|
||||
event = ContextEvent(event_type, stream_id, data, priority=priority)
|
||||
|
||||
# 添加到事件历史
|
||||
self.event_history.append(event)
|
||||
if len(self.event_history) > self.max_event_history:
|
||||
self.event_history = self.event_history[-self.max_event_history:]
|
||||
|
||||
# 更新事件统计
|
||||
event_stats = self.stats["event_stats"]
|
||||
event_stats["total_events"] += 1
|
||||
event_stats["last_event_time"] = time.time()
|
||||
event_type_str = str(event_type)
|
||||
event_stats["events_by_type"][event_type_str] = event_stats["events_by_type"].get(event_type_str, 0) + 1
|
||||
|
||||
# 通知监听器
|
||||
if event_type in self.event_listeners:
|
||||
for listener in self.event_listeners[event_type]:
|
||||
try:
|
||||
listener(event)
|
||||
except Exception as e:
|
||||
logger.error(f"事件监听器执行失败: {e}", exc_info=True)
|
||||
|
||||
def get_stream_statistics(self, stream_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取流统计信息
|
||||
|
||||
@@ -718,7 +456,6 @@ class StreamContextManager:
|
||||
"access_count": access_count,
|
||||
"uptime_seconds": current_time - created_time,
|
||||
"idle_seconds": current_time - last_access_time,
|
||||
"validation_errors": metadata.get("validation_errors", 0),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取流统计失败 {stream_id}: {e}", exc_info=True)
|
||||
@@ -733,31 +470,11 @@ class StreamContextManager:
|
||||
current_time = time.time()
|
||||
uptime = current_time - self.stats.get("creation_time", current_time)
|
||||
|
||||
# 计算验证统计
|
||||
validation_stats = self.stats["validation_stats"]
|
||||
validation_success_rate = (
|
||||
(validation_stats.get("total_validations", 0) - validation_stats.get("validation_failures", 0)) /
|
||||
max(1, validation_stats.get("total_validations", 1))
|
||||
)
|
||||
|
||||
# 计算事件统计
|
||||
event_stats = self.stats["event_stats"]
|
||||
events_by_type = event_stats.get("events_by_type", {})
|
||||
|
||||
return {
|
||||
**self.stats,
|
||||
"uptime_hours": uptime / 3600,
|
||||
"stream_count": len(self.stream_contexts),
|
||||
"metadata_count": len(self.context_metadata),
|
||||
"event_history_size": len(self.event_history),
|
||||
"validators_count": len(self.validators),
|
||||
"event_listeners": {
|
||||
str(event_type): len(listeners)
|
||||
for event_type, listeners in self.event_listeners.items()
|
||||
},
|
||||
"validation_success_rate": validation_success_rate,
|
||||
"event_distribution": events_by_type,
|
||||
"max_event_history": self.max_event_history,
|
||||
"auto_cleanup_enabled": self.auto_cleanup,
|
||||
"cleanup_interval": self.cleanup_interval,
|
||||
}
|
||||
@@ -840,31 +557,6 @@ class StreamContextManager:
|
||||
logger.error(f"验证上下文完整性失败 {stream_id}: {e}")
|
||||
return False
|
||||
|
||||
def _validate_context(self, stream_id: str, context: Any) -> Tuple[bool, Optional[str]]:
|
||||
"""验证上下文完整性
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
context: 上下文对象
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (是否有效, 错误信息)
|
||||
"""
|
||||
validation_stats = self.stats["validation_stats"]
|
||||
validation_stats["total_validations"] += 1
|
||||
validation_stats["last_validation_time"] = time.time()
|
||||
|
||||
for validator in self.validators:
|
||||
try:
|
||||
is_valid, error_msg = validator.validate_context(stream_id, context)
|
||||
if not is_valid:
|
||||
validation_stats["validation_failures"] += 1
|
||||
return False, error_msg
|
||||
except Exception as e:
|
||||
validation_stats["validation_failures"] += 1
|
||||
return False, f"验证器执行失败: {e}"
|
||||
return True, None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动上下文管理器"""
|
||||
if self.is_running:
|
||||
@@ -924,7 +616,6 @@ class StreamContextManager:
|
||||
try:
|
||||
await asyncio.sleep(interval)
|
||||
self.cleanup_inactive_contexts()
|
||||
self._cleanup_event_history()
|
||||
self._cleanup_expired_contexts()
|
||||
logger.debug("自动清理完成")
|
||||
except asyncio.CancelledError:
|
||||
@@ -933,20 +624,6 @@ class StreamContextManager:
|
||||
logger.error(f"清理循环出错: {e}", exc_info=True)
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
def _cleanup_event_history(self) -> None:
|
||||
"""清理事件历史"""
|
||||
max_age = 24 * 3600 # 24小时
|
||||
|
||||
# 清理过期事件
|
||||
self.event_history = [
|
||||
event for event in self.event_history
|
||||
if not event.is_expired(max_age)
|
||||
]
|
||||
|
||||
# 保持历史大小限制
|
||||
if len(self.event_history) > self.max_event_history:
|
||||
self.event_history = self.event_history[-self.max_event_history:]
|
||||
|
||||
def _cleanup_expired_contexts(self) -> None:
|
||||
"""清理过期上下文"""
|
||||
current_time = time.time()
|
||||
@@ -963,21 +640,6 @@ class StreamContextManager:
|
||||
if expired_contexts:
|
||||
logger.info(f"清理了 {len(expired_contexts)} 个过期上下文")
|
||||
|
||||
def get_event_history(self, limit: int = 100, event_type: Optional[ContextEventType] = None) -> List[ContextEvent]:
|
||||
"""获取事件历史
|
||||
|
||||
Args:
|
||||
limit: 返回数量限制
|
||||
event_type: 过滤事件类型
|
||||
|
||||
Returns:
|
||||
List[ContextEvent]: 事件列表
|
||||
"""
|
||||
events = self.event_history
|
||||
if event_type:
|
||||
events = [event for event in events if event.event_type == event_type]
|
||||
return events[-limit:]
|
||||
|
||||
def get_active_streams(self) -> List[str]:
|
||||
"""获取活跃流列表
|
||||
|
||||
@@ -986,111 +648,6 @@ class StreamContextManager:
|
||||
"""
|
||||
return list(self.stream_contexts.keys())
|
||||
|
||||
def get_context_summary(self) -> Dict[str, Any]:
|
||||
"""获取上下文摘要
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 上下文摘要信息
|
||||
"""
|
||||
current_time = time.time()
|
||||
uptime = current_time - self.stats.get("creation_time", current_time)
|
||||
|
||||
# 计算平均访问次数
|
||||
total_access = sum(meta.get("access_count", 0) for meta in self.context_metadata.values())
|
||||
avg_access = total_access / max(1, len(self.context_metadata))
|
||||
|
||||
# 计算验证成功率
|
||||
validation_stats = self.stats["validation_stats"]
|
||||
total_validations = validation_stats.get("total_validations", 0)
|
||||
validation_success_rate = (
|
||||
(total_validations - validation_stats.get("validation_failures", 0)) /
|
||||
max(1, total_validations)
|
||||
) if total_validations > 0 else 1.0
|
||||
|
||||
return {
|
||||
"total_streams": len(self.stream_contexts),
|
||||
"active_streams": len(self.stream_contexts),
|
||||
"total_messages": self.stats.get("total_messages", 0),
|
||||
"uptime_hours": uptime / 3600,
|
||||
"average_access_count": avg_access,
|
||||
"validation_success_rate": validation_success_rate,
|
||||
"event_history_size": len(self.event_history),
|
||||
"validators_count": len(self.validators),
|
||||
"auto_cleanup_enabled": self.auto_cleanup,
|
||||
"cleanup_interval": self.cleanup_interval,
|
||||
"last_activity": self.stats.get("last_activity", 0),
|
||||
}
|
||||
|
||||
def force_validation(self, stream_id: str) -> Tuple[bool, Optional[str]]:
|
||||
"""强制验证上下文
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (是否有效, 错误信息)
|
||||
"""
|
||||
context = self.get_stream_context(stream_id)
|
||||
if not context:
|
||||
return False, "上下文不存在"
|
||||
|
||||
return self._validate_context(stream_id, context)
|
||||
|
||||
def reset_statistics(self) -> None:
|
||||
"""重置统计信息"""
|
||||
# 重置基本统计
|
||||
self.stats.update({
|
||||
"total_messages": 0,
|
||||
"total_streams": len(self.stream_contexts),
|
||||
"active_streams": len(self.stream_contexts),
|
||||
"inactive_streams": 0,
|
||||
"last_activity": time.time(),
|
||||
"creation_time": time.time(),
|
||||
})
|
||||
|
||||
# 重置验证统计
|
||||
self.stats["validation_stats"].update({
|
||||
"total_validations": 0,
|
||||
"validation_failures": 0,
|
||||
"last_validation_time": 0.0,
|
||||
})
|
||||
|
||||
# 重置事件统计
|
||||
self.stats["event_stats"].update({
|
||||
"total_events": 0,
|
||||
"events_by_type": {},
|
||||
"last_event_time": 0.0,
|
||||
})
|
||||
|
||||
logger.info("上下文管理器统计信息已重置")
|
||||
|
||||
def export_context_data(self, stream_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""导出上下文数据
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 导出的数据
|
||||
"""
|
||||
context = self.get_stream_context(stream_id, update_access=False)
|
||||
if not context:
|
||||
return None
|
||||
|
||||
try:
|
||||
return {
|
||||
"stream_id": stream_id,
|
||||
"context_type": type(context).__name__,
|
||||
"metadata": self.context_metadata.get(stream_id, {}),
|
||||
"statistics": self.get_stream_statistics(stream_id),
|
||||
"export_time": time.time(),
|
||||
"unread_message_count": len(getattr(context, "unread_messages", [])),
|
||||
"history_message_count": len(getattr(context, "history_messages", [])),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"导出上下文数据失败 {stream_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# 全局上下文管理器实例
|
||||
context_manager = StreamContextManager()
|
||||
@@ -18,7 +18,7 @@ from src.plugin_system.base.component_types import ChatMode
|
||||
from .sleep_manager.sleep_manager import SleepManager
|
||||
from .sleep_manager.wakeup_manager import WakeUpManager
|
||||
from src.config.config import global_config
|
||||
from . import context_manager
|
||||
from .context_manager import context_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
@@ -46,7 +46,7 @@ class MessageManager:
|
||||
self.wakeup_manager = WakeUpManager(self.sleep_manager)
|
||||
|
||||
# 初始化上下文管理器
|
||||
self.context_manager = context_manager.context_manager
|
||||
self.context_manager = context_manager
|
||||
|
||||
async def start(self):
|
||||
"""启动消息管理器"""
|
||||
@@ -84,11 +84,9 @@ class MessageManager:
|
||||
if not context:
|
||||
# 创建新的流上下文
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
new_context = StreamContext(stream_id=stream_id)
|
||||
success = self.context_manager.add_stream_context(stream_id, new_context)
|
||||
if not success:
|
||||
logger.error(f"无法为流 {stream_id} 创建上下文")
|
||||
return
|
||||
context = StreamContext(stream_id=stream_id)
|
||||
# 将创建的上下文添加到管理器
|
||||
self.context_manager.add_stream_context(stream_id, context)
|
||||
|
||||
# 使用 context_manager 添加消息
|
||||
success = self.context_manager.add_message_to_context(stream_id, message)
|
||||
@@ -98,7 +96,7 @@ class MessageManager:
|
||||
else:
|
||||
logger.warning(f"添加消息到聊天流 {stream_id} 失败")
|
||||
|
||||
def update_message_and_refresh_energy(
|
||||
def update_message(
|
||||
self,
|
||||
stream_id: str,
|
||||
message_id: str,
|
||||
@@ -112,7 +110,7 @@ class MessageManager:
|
||||
if context:
|
||||
context.update_message_info(message_id, interest_value, actions, should_reply)
|
||||
|
||||
def add_action_and_refresh_energy(self, stream_id: str, message_id: str, action: str):
|
||||
def add_action(self, stream_id: str, message_id: str, action: str):
|
||||
"""添加动作到消息"""
|
||||
# 使用 context_manager 添加动作到消息
|
||||
context = self.context_manager.get_stream_context(stream_id)
|
||||
|
||||
@@ -310,21 +310,19 @@ class ChatStream:
|
||||
self._focus_energy = max(0.0, min(1.0, value))
|
||||
|
||||
def _get_user_relationship_score(self) -> float:
|
||||
"""从新的兴趣度管理系统获取用户关系分"""
|
||||
"""获取用户关系分"""
|
||||
# 使用插件内部的兴趣度评分系统
|
||||
try:
|
||||
# 使用新的兴趣度管理系统
|
||||
from src.chat.interest_system import interest_manager
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
|
||||
if self.user_info and hasattr(self.user_info, "user_id"):
|
||||
user_id = str(self.user_info.user_id)
|
||||
# 获取用户交互历史作为关系分的基础
|
||||
interaction_calc = interest_manager.calculators.get(
|
||||
interest_manager.InterestSourceType.USER_INTERACTION
|
||||
)
|
||||
if interaction_calc:
|
||||
return interaction_calc.calculate({"user_id": user_id})
|
||||
except Exception:
|
||||
pass
|
||||
relationship_score = chatter_interest_scoring_system._calculate_relationship_score(user_id)
|
||||
logger.debug(f"ChatStream {self.stream_id}: 用户关系分 = {relationship_score:.3f}")
|
||||
return max(0.0, min(1.0, relationship_score))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"ChatStream {self.stream_id}: 插件内部关系分计算失败: {e}")
|
||||
|
||||
# 默认基础分
|
||||
return 0.3
|
||||
|
||||
@@ -298,7 +298,7 @@ class ChatterActionManager:
|
||||
|
||||
# 通过message_manager更新消息的动作记录并刷新focus_energy
|
||||
if chat_stream.stream_id in message_manager.stream_contexts:
|
||||
message_manager.add_action_and_refresh_energy(
|
||||
message_manager.add_action(
|
||||
stream_id=chat_stream.stream_id,
|
||||
message_id=target_message_id,
|
||||
action=action_name
|
||||
|
||||
@@ -12,7 +12,7 @@ from src.common.data_models.info_data_model import InterestScore
|
||||
from src.chat.interest_system import bot_interest_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
|
||||
logger = get_logger("chatter_interest_scoring")
|
||||
|
||||
# 定义颜色
|
||||
@@ -45,7 +45,7 @@ class ChatterInterestScoringSystem:
|
||||
self.probability_boost_per_no_reply = (
|
||||
affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count
|
||||
) # 每次不回复增加的概率
|
||||
|
||||
|
||||
# 用户关系数据
|
||||
self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score
|
||||
|
||||
|
||||
@@ -387,22 +387,27 @@ class ChatterPlanFilter:
|
||||
interest_scores = {}
|
||||
|
||||
try:
|
||||
from src.chat.interest_system import interest_manager
|
||||
from .interest_scoring import chatter_interest_scoring_system
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
# 使用新的兴趣度管理系统计算评分
|
||||
# 使用插件内部的兴趣度评分系统计算评分
|
||||
for msg_dict in messages:
|
||||
try:
|
||||
# 构建计算上下文
|
||||
calc_context = {
|
||||
"stream_id": msg_dict.get("chat_id", ""),
|
||||
"user_id": msg_dict.get("user_id"),
|
||||
}
|
||||
# 将字典转换为DatabaseMessages对象
|
||||
db_message = DatabaseMessages(
|
||||
message_id=msg_dict.get("message_id", ""),
|
||||
user_info=msg_dict.get("user_info", {}),
|
||||
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||
key_words=msg_dict.get("key_words", "[]"),
|
||||
is_mentioned=msg_dict.get("is_mentioned", False)
|
||||
)
|
||||
|
||||
# 计算消息兴趣度
|
||||
interest_score = interest_manager.calculate_message_interest(
|
||||
message=msg_dict,
|
||||
context=calc_context
|
||||
interest_score_obj = await chatter_interest_scoring_system._calculate_single_message_score(
|
||||
message=db_message,
|
||||
bot_nickname=global_config.bot.nickname
|
||||
)
|
||||
interest_score = interest_score_obj.total_score
|
||||
|
||||
# 构建兴趣度字典
|
||||
interest_scores[msg_dict.get("message_id", "")] = interest_score
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
|
||||
from src.chat.interest_system import interest_manager
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
|
||||
@@ -109,10 +109,7 @@ class ChatterActionPlanner:
|
||||
# 获取用户ID,优先从user_info.user_id获取,其次从user_id属性获取
|
||||
user_id = None
|
||||
first_message = unread_messages[0]
|
||||
if hasattr(first_message, 'user_info') and hasattr(first_message.user_info, 'user_id'):
|
||||
user_id = getattr(first_message.user_info, 'user_id', None)
|
||||
elif hasattr(first_message, 'user_id'):
|
||||
user_id = getattr(first_message, 'user_id', None)
|
||||
user_id = first_message.user_info.user_id
|
||||
|
||||
# 构建计算上下文
|
||||
calc_context = {
|
||||
@@ -123,11 +120,12 @@ class ChatterActionPlanner:
|
||||
# 为每条消息计算兴趣度
|
||||
for message in unread_messages:
|
||||
try:
|
||||
# 使用新的兴趣度管理器计算
|
||||
message_interest = interest_manager.calculate_message_interest(
|
||||
message=message.__dict__,
|
||||
context=calc_context
|
||||
# 使用插件内部的兴趣度评分系统计算
|
||||
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
|
||||
message=message,
|
||||
bot_nickname=global_config.bot.nickname
|
||||
)
|
||||
message_interest = interest_score.total_score
|
||||
|
||||
# 更新消息的兴趣度
|
||||
message.interest_value = message_interest
|
||||
@@ -140,7 +138,7 @@ class ChatterActionPlanner:
|
||||
# 更新StreamContext中的消息信息并刷新focus_energy
|
||||
if context:
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
message_manager.update_message_and_refresh_energy(
|
||||
message_manager.update_message(
|
||||
stream_id=self.chat_id,
|
||||
message_id=message.message_id,
|
||||
interest_value=message_interest,
|
||||
@@ -154,10 +152,7 @@ class ChatterActionPlanner:
|
||||
logger.debug(f"已更新数据库中消息 {message.message_id} 的兴趣度为: {message_interest:.3f}")
|
||||
except Exception as e:
|
||||
logger.warning(f"更新数据库消息兴趣度失败: {e}")
|
||||
|
||||
# 更新话题兴趣度
|
||||
interest_manager.update_topic_interest(message.__dict__, message_interest)
|
||||
|
||||
|
||||
# 记录最高分
|
||||
if message_interest > score:
|
||||
score = message_interest
|
||||
|
||||
@@ -29,9 +29,6 @@ class ChatterRelationshipTracker:
|
||||
self.relationship_history: List[Dict] = []
|
||||
self.interest_scoring_system = interest_scoring_system
|
||||
|
||||
# 数据库访问 - 使用SQLAlchemy
|
||||
pass
|
||||
|
||||
# 用户关系缓存 (user_id -> {"relationship_text": str, "relationship_score": float, "last_tracked": float})
|
||||
self.user_relationship_cache: Dict[str, Dict] = {}
|
||||
self.cache_expiry_hours = 1 # 缓存过期时间(小时)
|
||||
|
||||
Reference in New Issue
Block a user