refactor(interest-system): 移除旧兴趣度管理系统,迁移到插件内部实现

移除旧的集中式兴趣度管理系统(interest_manager.py),将兴趣度计算功能迁移到affinity_flow_chatter插件内部实现。主要包括:

- 删除interest_manager.py及其相关导入引用
- 修改RelationshipEnergyCalculator使用插件内部的关系分计算
- 重构StreamContextManager使用插件内部的兴趣度评分系统
- 更新ChatStream、PlanFilter、Planner等组件使用新的插件接口
- 简化上下文管理器,移除事件系统和验证器相关代码

此次重构提高了模块独立性,减少了核心代码对插件功能的直接依赖,符合"高内聚低耦合"的设计原则。
This commit is contained in:
Windpicker-owo
2025-09-27 19:07:24 +08:00
parent 0fe052dd37
commit 80d34f3130
11 changed files with 92 additions and 997 deletions

View File

@@ -5,103 +5,18 @@
import asyncio
import time
from typing import Dict, List, Optional, Any, Callable, Union, Tuple
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Any, Union, Tuple
from abc import ABC, abstractmethod
from src.common.data_models.message_manager_data_model import StreamContext
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.interest_system import interest_manager
from src.common.data_models.database_data_model import DatabaseMessages
from src.chat.energy_system import energy_manager
from .distribution_manager import distribution_manager
logger = get_logger("context_manager")
class ContextEventType(Enum):
"""上下文事件类型"""
MESSAGE_ADDED = "message_added"
MESSAGE_UPDATED = "message_updated"
ENERGY_CHANGED = "energy_changed"
STREAM_ACTIVATED = "stream_activated"
STREAM_DEACTIVATED = "stream_deactivated"
CONTEXT_CLEARED = "context_cleared"
VALIDATION_FAILED = "validation_failed"
CLEANUP_COMPLETED = "cleanup_completed"
INTEGRITY_CHECK = "integrity_check"
def __str__(self) -> str:
return self.value
def __repr__(self) -> str:
return f"ContextEventType.{self.name}"
@dataclass
class ContextEvent:
"""上下文事件"""
event_type: ContextEventType
stream_id: str
data: Dict[str, Any] = field(default_factory=dict)
timestamp: float = field(default_factory=time.time)
event_id: str = field(default_factory=lambda: f"event_{time.time()}_{id(object())}")
priority: int = 0 # 事件优先级,数字越大优先级越高
source: str = "system" # 事件来源
def __str__(self) -> str:
return f"ContextEvent({self.event_type}, {self.stream_id}, ts={self.timestamp:.3f})"
def __repr__(self) -> str:
return f"ContextEvent(event_type={self.event_type}, stream_id={self.stream_id}, timestamp={self.timestamp}, event_id={self.event_id})"
def get_age(self) -> float:
"""获取事件年龄(秒)"""
return time.time() - self.timestamp
def is_expired(self, max_age: float = 3600.0) -> bool:
"""检查事件是否已过期
Args:
max_age: 最大年龄(秒)
Returns:
bool: 是否已过期
"""
return self.get_age() > max_age
class ContextValidator(ABC):
"""上下文验证器抽象基类"""
@abstractmethod
def validate_context(self, stream_id: str, context: Any) -> Tuple[bool, Optional[str]]:
"""验证上下文
Args:
stream_id: 流ID
context: 上下文对象
Returns:
Tuple[bool, Optional[str]]: (是否有效, 错误信息)
"""
pass
class DefaultContextValidator(ContextValidator):
"""默认上下文验证器"""
def validate_context(self, stream_id: str, context: Any) -> Tuple[bool, Optional[str]]:
"""验证上下文基本完整性"""
if not hasattr(context, 'stream_id'):
return False, "缺少 stream_id 属性"
if not hasattr(context, 'unread_messages'):
return False, "缺少 unread_messages 属性"
if not hasattr(context, 'history_messages'):
return False, "缺少 history_messages 属性"
return True, None
class StreamContextManager:
"""流上下文管理器 - 统一管理所有聊天流上下文"""
@@ -110,14 +25,6 @@ class StreamContextManager:
self.stream_contexts: Dict[str, Any] = {}
self.context_metadata: Dict[str, Dict[str, Any]] = {}
# 事件监听器
self.event_listeners: Dict[ContextEventType, List[Callable]] = {}
self.event_history: List[ContextEvent] = []
self.max_event_history = 1000
# 验证器
self.validators: List[ContextValidator] = [DefaultContextValidator()]
# 统计信息
self.stats: Dict[str, Union[int, float, str, Dict]] = {
"total_messages": 0,
@@ -126,16 +33,6 @@ class StreamContextManager:
"inactive_streams": 0,
"last_activity": time.time(),
"creation_time": time.time(),
"validation_stats": {
"total_validations": 0,
"validation_failures": 0,
"last_validation_time": 0.0,
},
"event_stats": {
"total_events": 0,
"events_by_type": {},
"last_event_time": 0.0,
},
}
# 配置参数
@@ -166,17 +63,6 @@ class StreamContextManager:
logger.warning(f"流上下文已存在: {stream_id}")
return False
# 验证上下文
if self.enable_validation:
is_valid, error_msg = self._validate_context(stream_id, context)
if not is_valid:
logger.error(f"上下文验证失败: {stream_id} - {error_msg}")
self._emit_event(ContextEventType.VALIDATION_FAILED, stream_id, {
"error": error_msg,
"context_type": type(context).__name__
})
return False
# 添加上下文
self.stream_contexts[stream_id] = context
@@ -185,7 +71,6 @@ class StreamContextManager:
"created_time": time.time(),
"last_access_time": time.time(),
"access_count": 0,
"validation_errors": 0,
"last_validation_time": 0.0,
"custom_metadata": metadata or {},
}
@@ -195,13 +80,6 @@ class StreamContextManager:
self.stats["active_streams"] += 1
self.stats["last_activity"] = time.time()
# 触发事件
self._emit_event(ContextEventType.STREAM_ACTIVATED, stream_id, {
"context": context,
"context_type": type(context).__name__,
"metadata": metadata
})
logger.debug(f"添加流上下文: {stream_id} (类型: {type(context).__name__})")
return True
@@ -226,19 +104,11 @@ class StreamContextManager:
self.stats["inactive_streams"] += 1
self.stats["last_activity"] = time.time()
# 触发事件
self._emit_event(ContextEventType.STREAM_DEACTIVATED, stream_id, {
"context": context,
"context_type": type(context).__name__,
"metadata": metadata,
"uptime": time.time() - metadata.get("created_time", time.time())
})
logger.debug(f"移除流上下文: {stream_id} (类型: {type(context).__name__})")
return True
return False
def get_stream_context(self, stream_id: str, update_access: bool = True) -> Optional[Any]:
def get_stream_context(self, stream_id: str, update_access: bool = True) -> Optional[StreamContext]:
"""获取流上下文
Args:
@@ -284,7 +154,7 @@ class StreamContextManager:
self.context_metadata[stream_id].update(updates)
return True
def add_message_to_context(self, stream_id: str, message: Any, skip_energy_update: bool = False) -> bool:
def add_message_to_context(self, stream_id: str, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
"""添加消息到上下文
Args:
@@ -302,30 +172,16 @@ class StreamContextManager:
try:
# 添加消息到上下文
if hasattr(context, 'add_message'):
context.add_message(message)
else:
logger.error(f"上下文对象缺少 add_message 方法: {stream_id}")
return False
context.add_message(message)
# 计算消息兴趣度
interest_value = self._calculate_message_interest(message)
if hasattr(message, 'interest_value'):
message.interest_value = interest_value
message.interest_value = interest_value
# 更新统计
self.stats["total_messages"] += 1
self.stats["last_activity"] = time.time()
# 触发事件
event_data = {
"message": message,
"interest_value": interest_value,
"message_type": type(message).__name__,
"message_id": getattr(message, "message_id", None),
}
self._emit_event(ContextEventType.MESSAGE_ADDED, stream_id, event_data)
# 更新能量和分发
if not skip_energy_update:
self._update_stream_energy(stream_id)
@@ -356,18 +212,7 @@ class StreamContextManager:
try:
# 更新消息信息
if hasattr(context, 'update_message_info'):
context.update_message_info(message_id, **updates)
else:
logger.error(f"上下文对象缺少 update_message_info 方法: {stream_id}")
return False
# 触发事件
self._emit_event(ContextEventType.MESSAGE_UPDATED, stream_id, {
"message_id": message_id,
"updates": updates,
"update_time": time.time(),
})
context.update_message_info(message_id, **updates)
# 如果更新了兴趣度,重新计算能量
if "interest_value" in updates:
@@ -380,7 +225,7 @@ class StreamContextManager:
logger.error(f"更新上下文消息失败 {stream_id}/{message_id}: {e}", exc_info=True)
return False
def get_context_messages(self, stream_id: str, limit: Optional[int] = None, include_unread: bool = True) -> List[Any]:
def get_context_messages(self, stream_id: str, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]:
"""获取上下文消息
Args:
@@ -397,14 +242,13 @@ class StreamContextManager:
try:
messages = []
if include_unread and hasattr(context, 'get_unread_messages'):
if include_unread:
messages.extend(context.get_unread_messages())
if hasattr(context, 'get_history_messages'):
if limit:
messages.extend(context.get_history_messages(limit=limit))
else:
messages.extend(context.get_history_messages())
if limit:
messages.extend(context.get_history_messages(limit=limit))
else:
messages.extend(context.get_history_messages())
# 按时间排序
messages.sort(key=lambda msg: getattr(msg, 'time', 0))
@@ -419,7 +263,7 @@ class StreamContextManager:
logger.error(f"获取上下文消息失败 {stream_id}: {e}", exc_info=True)
return []
def get_unread_messages(self, stream_id: str) -> List[Any]:
def get_unread_messages(self, stream_id: str) -> List[DatabaseMessages]:
"""获取未读消息
Args:
@@ -433,11 +277,7 @@ class StreamContextManager:
return []
try:
if hasattr(context, 'get_unread_messages'):
return context.get_unread_messages()
else:
logger.warning(f"上下文对象缺少 get_unread_messages 方法: {stream_id}")
return []
return context.get_unread_messages()
except Exception as e:
logger.error(f"获取未读消息失败 {stream_id}: {e}", exc_info=True)
return []
@@ -507,12 +347,6 @@ class StreamContextManager:
else:
setattr(context, attr, time.time())
# 触发事件
self._emit_event(ContextEventType.CONTEXT_CLEARED, stream_id, {
"clear_time": time.time(),
"reset_attributes": reset_attrs,
})
# 重新计算能量
self._update_stream_energy(stream_id)
@@ -523,22 +357,33 @@ class StreamContextManager:
logger.error(f"清空上下文失败 {stream_id}: {e}", exc_info=True)
return False
def _calculate_message_interest(self, message: Any) -> float:
def _calculate_message_interest(self, message: DatabaseMessages) -> float:
"""计算消息兴趣度"""
try:
# 将消息转换为字典格式
message_dict = self._message_to_dict(message)
# 使用插件内部的兴趣度评分系统
try:
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
# 使用兴趣度管理器计算
context = {
"stream_id": getattr(message, 'chat_info_stream_id', ''),
"user_id": getattr(message, 'user_id', ''),
}
# 使用插件内部的兴趣度评分系统计算(同步方式)
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
interest_value = interest_manager.calculate_message_interest(message_dict, context)
interest_score = loop.run_until_complete(
chatter_interest_scoring_system._calculate_single_message_score(
message=message,
bot_nickname=global_config.bot.nickname
)
)
interest_value = interest_score.total_score
# 更新话题兴趣度
interest_manager.update_topic_interest(message_dict, interest_value)
logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}")
except Exception as e:
logger.warning(f"插件内部兴趣度计算失败,使用默认值: {e}")
interest_value = 0.5 # 默认中等兴趣度
return interest_value
@@ -546,31 +391,6 @@ class StreamContextManager:
logger.error(f"计算消息兴趣度失败: {e}")
return 0.5
def _message_to_dict(self, message: Any) -> Dict[str, Any]:
"""将消息对象转换为字典"""
try:
# 获取user_id优先从user_info.user_id获取其次从user_id属性获取
user_id = ""
if hasattr(message, 'user_info') and hasattr(message.user_info, 'user_id'):
user_id = getattr(message.user_info, 'user_id', "")
else:
user_id = getattr(message, 'user_id', "")
return {
"message_id": getattr(message, "message_id", ""),
"processed_plain_text": getattr(message, "processed_plain_text", ""),
"is_emoji": getattr(message, "is_emoji", False),
"is_picid": getattr(message, "is_picid", False),
"is_mentioned": getattr(message, "is_mentioned", False),
"is_command": getattr(message, "is_command", False),
"key_words": getattr(message, "key_words", "[]"),
"user_id": user_id,
"time": getattr(message, "time", time.time()),
}
except Exception as e:
logger.error(f"转换消息为字典失败: {e}")
return {}
def _update_stream_energy(self, stream_id: str):
"""更新流能量"""
try:
@@ -583,7 +403,7 @@ class StreamContextManager:
user_id = None
if combined_messages:
last_message = combined_messages[-1]
user_id = getattr(last_message, "user_id", None)
user_id = last_message.user_info.user_id
# 计算能量
energy = energy_manager.calculate_focus_energy(
@@ -595,91 +415,9 @@ class StreamContextManager:
# 更新分发管理器
distribution_manager.update_stream_energy(stream_id, energy)
# 触发事件
self._emit_event(ContextEventType.ENERGY_CHANGED, stream_id, {
"energy": energy,
"message_count": len(combined_messages),
})
except Exception as e:
logger.error(f"更新流能量失败 {stream_id}: {e}")
def add_event_listener(self, event_type: ContextEventType, listener: Callable[[ContextEvent], None]) -> bool:
"""添加事件监听器
Args:
event_type: 事件类型
listener: 监听器函数
Returns:
bool: 是否成功添加
"""
if not callable(listener):
logger.error(f"监听器必须是可调用对象: {type(listener)}")
return False
if event_type not in self.event_listeners:
self.event_listeners[event_type] = []
if listener not in self.event_listeners[event_type]:
self.event_listeners[event_type].append(listener)
logger.debug(f"添加事件监听器: {event_type} -> {getattr(listener, '__name__', 'anonymous')}")
return True
return False
def remove_event_listener(self, event_type: ContextEventType, listener: Callable[[ContextEvent], None]) -> bool:
"""移除事件监听器
Args:
event_type: 事件类型
listener: 监听器函数
Returns:
bool: 是否成功移除
"""
if event_type in self.event_listeners:
try:
self.event_listeners[event_type].remove(listener)
logger.debug(f"移除事件监听器: {event_type}")
return True
except ValueError:
pass
return False
def _emit_event(self, event_type: ContextEventType, stream_id: str, data: Optional[Dict] = None, priority: int = 0) -> None:
"""触发事件
Args:
event_type: 事件类型
stream_id: 流ID
data: 事件数据
priority: 事件优先级
"""
if data is None:
data = {}
event = ContextEvent(event_type, stream_id, data, priority=priority)
# 添加到事件历史
self.event_history.append(event)
if len(self.event_history) > self.max_event_history:
self.event_history = self.event_history[-self.max_event_history:]
# 更新事件统计
event_stats = self.stats["event_stats"]
event_stats["total_events"] += 1
event_stats["last_event_time"] = time.time()
event_type_str = str(event_type)
event_stats["events_by_type"][event_type_str] = event_stats["events_by_type"].get(event_type_str, 0) + 1
# 通知监听器
if event_type in self.event_listeners:
for listener in self.event_listeners[event_type]:
try:
listener(event)
except Exception as e:
logger.error(f"事件监听器执行失败: {e}", exc_info=True)
def get_stream_statistics(self, stream_id: str) -> Optional[Dict[str, Any]]:
"""获取流统计信息
@@ -718,7 +456,6 @@ class StreamContextManager:
"access_count": access_count,
"uptime_seconds": current_time - created_time,
"idle_seconds": current_time - last_access_time,
"validation_errors": metadata.get("validation_errors", 0),
}
except Exception as e:
logger.error(f"获取流统计失败 {stream_id}: {e}", exc_info=True)
@@ -733,31 +470,11 @@ class StreamContextManager:
current_time = time.time()
uptime = current_time - self.stats.get("creation_time", current_time)
# 计算验证统计
validation_stats = self.stats["validation_stats"]
validation_success_rate = (
(validation_stats.get("total_validations", 0) - validation_stats.get("validation_failures", 0)) /
max(1, validation_stats.get("total_validations", 1))
)
# 计算事件统计
event_stats = self.stats["event_stats"]
events_by_type = event_stats.get("events_by_type", {})
return {
**self.stats,
"uptime_hours": uptime / 3600,
"stream_count": len(self.stream_contexts),
"metadata_count": len(self.context_metadata),
"event_history_size": len(self.event_history),
"validators_count": len(self.validators),
"event_listeners": {
str(event_type): len(listeners)
for event_type, listeners in self.event_listeners.items()
},
"validation_success_rate": validation_success_rate,
"event_distribution": events_by_type,
"max_event_history": self.max_event_history,
"auto_cleanup_enabled": self.auto_cleanup,
"cleanup_interval": self.cleanup_interval,
}
@@ -840,31 +557,6 @@ class StreamContextManager:
logger.error(f"验证上下文完整性失败 {stream_id}: {e}")
return False
def _validate_context(self, stream_id: str, context: Any) -> Tuple[bool, Optional[str]]:
"""验证上下文完整性
Args:
stream_id: 流ID
context: 上下文对象
Returns:
Tuple[bool, Optional[str]]: (是否有效, 错误信息)
"""
validation_stats = self.stats["validation_stats"]
validation_stats["total_validations"] += 1
validation_stats["last_validation_time"] = time.time()
for validator in self.validators:
try:
is_valid, error_msg = validator.validate_context(stream_id, context)
if not is_valid:
validation_stats["validation_failures"] += 1
return False, error_msg
except Exception as e:
validation_stats["validation_failures"] += 1
return False, f"验证器执行失败: {e}"
return True, None
async def start(self) -> None:
"""启动上下文管理器"""
if self.is_running:
@@ -924,7 +616,6 @@ class StreamContextManager:
try:
await asyncio.sleep(interval)
self.cleanup_inactive_contexts()
self._cleanup_event_history()
self._cleanup_expired_contexts()
logger.debug("自动清理完成")
except asyncio.CancelledError:
@@ -933,20 +624,6 @@ class StreamContextManager:
logger.error(f"清理循环出错: {e}", exc_info=True)
await asyncio.sleep(interval)
def _cleanup_event_history(self) -> None:
"""清理事件历史"""
max_age = 24 * 3600 # 24小时
# 清理过期事件
self.event_history = [
event for event in self.event_history
if not event.is_expired(max_age)
]
# 保持历史大小限制
if len(self.event_history) > self.max_event_history:
self.event_history = self.event_history[-self.max_event_history:]
def _cleanup_expired_contexts(self) -> None:
"""清理过期上下文"""
current_time = time.time()
@@ -963,21 +640,6 @@ class StreamContextManager:
if expired_contexts:
logger.info(f"清理了 {len(expired_contexts)} 个过期上下文")
def get_event_history(self, limit: int = 100, event_type: Optional[ContextEventType] = None) -> List[ContextEvent]:
"""获取事件历史
Args:
limit: 返回数量限制
event_type: 过滤事件类型
Returns:
List[ContextEvent]: 事件列表
"""
events = self.event_history
if event_type:
events = [event for event in events if event.event_type == event_type]
return events[-limit:]
def get_active_streams(self) -> List[str]:
"""获取活跃流列表
@@ -986,111 +648,6 @@ class StreamContextManager:
"""
return list(self.stream_contexts.keys())
def get_context_summary(self) -> Dict[str, Any]:
"""获取上下文摘要
Returns:
Dict[str, Any]: 上下文摘要信息
"""
current_time = time.time()
uptime = current_time - self.stats.get("creation_time", current_time)
# 计算平均访问次数
total_access = sum(meta.get("access_count", 0) for meta in self.context_metadata.values())
avg_access = total_access / max(1, len(self.context_metadata))
# 计算验证成功率
validation_stats = self.stats["validation_stats"]
total_validations = validation_stats.get("total_validations", 0)
validation_success_rate = (
(total_validations - validation_stats.get("validation_failures", 0)) /
max(1, total_validations)
) if total_validations > 0 else 1.0
return {
"total_streams": len(self.stream_contexts),
"active_streams": len(self.stream_contexts),
"total_messages": self.stats.get("total_messages", 0),
"uptime_hours": uptime / 3600,
"average_access_count": avg_access,
"validation_success_rate": validation_success_rate,
"event_history_size": len(self.event_history),
"validators_count": len(self.validators),
"auto_cleanup_enabled": self.auto_cleanup,
"cleanup_interval": self.cleanup_interval,
"last_activity": self.stats.get("last_activity", 0),
}
def force_validation(self, stream_id: str) -> Tuple[bool, Optional[str]]:
"""强制验证上下文
Args:
stream_id: 流ID
Returns:
Tuple[bool, Optional[str]]: (是否有效, 错误信息)
"""
context = self.get_stream_context(stream_id)
if not context:
return False, "上下文不存在"
return self._validate_context(stream_id, context)
def reset_statistics(self) -> None:
"""重置统计信息"""
# 重置基本统计
self.stats.update({
"total_messages": 0,
"total_streams": len(self.stream_contexts),
"active_streams": len(self.stream_contexts),
"inactive_streams": 0,
"last_activity": time.time(),
"creation_time": time.time(),
})
# 重置验证统计
self.stats["validation_stats"].update({
"total_validations": 0,
"validation_failures": 0,
"last_validation_time": 0.0,
})
# 重置事件统计
self.stats["event_stats"].update({
"total_events": 0,
"events_by_type": {},
"last_event_time": 0.0,
})
logger.info("上下文管理器统计信息已重置")
def export_context_data(self, stream_id: str) -> Optional[Dict[str, Any]]:
"""导出上下文数据
Args:
stream_id: 流ID
Returns:
Optional[Dict[str, Any]]: 导出的数据
"""
context = self.get_stream_context(stream_id, update_access=False)
if not context:
return None
try:
return {
"stream_id": stream_id,
"context_type": type(context).__name__,
"metadata": self.context_metadata.get(stream_id, {}),
"statistics": self.get_stream_statistics(stream_id),
"export_time": time.time(),
"unread_message_count": len(getattr(context, "unread_messages", [])),
"history_message_count": len(getattr(context, "history_messages", [])),
}
except Exception as e:
logger.error(f"导出上下文数据失败 {stream_id}: {e}")
return None
# 全局上下文管理器实例
context_manager = StreamContextManager()