refactor(chat): 异步化聊天系统并重构兴趣值计算机制

将同步调用改为异步调用以提升性能,重构兴趣值计算流程以支持更灵活的组件化架构。主要改进包括:

- 异步化ChatManager相关方法,避免阻塞主线程
- 重构兴趣值计算系统,从插件内部计算改为通过兴趣管理器统一处理
- 新增should_act字段支持更细粒度的动作决策
- 优化初始化逻辑,避免构造函数中的异步操作
- 扩展插件系统支持兴趣计算器组件注册
- 更新数据库模型以支持新的兴趣值相关字段

这些改进提升了系统的响应性能和可扩展性,同时保持了API的向后兼容性。
This commit is contained in:
Windpicker-owo
2025-10-05 01:25:52 +08:00
parent 49025a4973
commit 624298e1b8
38 changed files with 1493 additions and 259 deletions

View File

@@ -85,7 +85,7 @@ class ExpressionLearner:
model_set=model_config.model_task_config.replyer, request_type="expressor.learner"
)
self.chat_id = chat_id
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
self.chat_name = chat_id # 初始化时使用chat_id稍后异步更新
# 维护每个chat的上次学习时间
self.last_learning_time: float = time.time()
@@ -93,6 +93,14 @@ class ExpressionLearner:
# 学习参数
self.min_messages_for_learning = 25 # 触发学习所需的最少消息数
self.min_learning_interval = 300 # 最短学习时间间隔(秒)
self._chat_name_initialized = False
async def _initialize_chat_name(self):
"""异步初始化chat_name"""
if not self._chat_name_initialized:
stream_name = await get_chat_manager().get_stream_name(self.chat_id)
self.chat_name = stream_name or self.chat_id
self._chat_name_initialized = True
def can_learn_for_chat(self) -> bool:
"""
@@ -166,6 +174,9 @@ class ExpressionLearner:
Returns:
bool: 是否成功触发学习
"""
# 初始化chat_name
await self._initialize_chat_name()
if not await self.should_trigger_learning():
return False
@@ -323,7 +334,7 @@ class ExpressionLearner:
return []
learnt_expressions, chat_id = res
chat_stream = get_chat_manager().get_stream(chat_id)
chat_stream = await get_chat_manager().get_stream(chat_id)
if chat_stream is None:
group_name = f"聊天流 {chat_id}"
elif chat_stream.group_info:

View File

@@ -1,16 +1,22 @@
"""
兴趣度系统模块
提供机器人兴趣标签和智能匹配功能
提供机器人兴趣标签和智能匹配功能,以及消息兴趣值计算功能
"""
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
from .bot_interest_manager import BotInterestManager, bot_interest_manager
from .interest_manager import InterestManager, get_interest_manager
__all__ = [
# 机器人兴趣标签管理
"BotInterestManager",
"BotInterestTag",
"BotPersonalityInterests",
"InterestMatchResult",
"bot_interest_manager",
# 消息兴趣值计算管理
"InterestManager",
"get_interest_manager",
]

View File

@@ -0,0 +1,223 @@
"""兴趣值计算组件管理器
管理兴趣值计算组件的生命周期,确保系统只能有一个兴趣值计算组件实例运行
"""
import asyncio
import time
from typing import TYPE_CHECKING
from src.common.logger import get_logger
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator, InterestCalculationResult
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("interest_manager")
class InterestManager:
"""兴趣值计算组件管理器"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self._current_calculator: BaseInterestCalculator | None = None
self._calculator_lock = asyncio.Lock()
self._last_calculation_time = 0.0
self._total_calculations = 0
self._failed_calculations = 0
self._calculation_queue = asyncio.Queue()
self._worker_task = None
self._shutdown_event = asyncio.Event()
self._initialized = True
async def initialize(self):
"""初始化管理器"""
if self._worker_task is None:
self._worker_task = asyncio.create_task(self._calculation_worker())
logger.info("兴趣值管理器已启动")
async def shutdown(self):
"""关闭管理器"""
self._shutdown_event.set()
if self._worker_task:
self._worker_task.cancel()
try:
await self._worker_task
except asyncio.CancelledError:
pass
if self._current_calculator:
await self._current_calculator.cleanup()
self._current_calculator = None
logger.info("兴趣值管理器已关闭")
async def register_calculator(self, calculator: BaseInterestCalculator) -> bool:
"""注册兴趣值计算组件(系统只能有一个活跃的兴趣值计算器)
Args:
calculator: 兴趣值计算组件实例
Returns:
bool: 注册是否成功
"""
async with self._calculator_lock:
try:
# 检查是否已有相同的计算器
if self._current_calculator and self._current_calculator.component_name == calculator.component_name:
logger.warning(f"兴趣值计算组件 {calculator.component_name} 已经注册,跳过重复注册")
return True
# 如果已有组件在运行,先清理并替换
if self._current_calculator:
logger.info(f"替换现有兴趣值计算组件: {self._current_calculator.component_name} -> {calculator.component_name}")
await self._current_calculator.cleanup()
else:
logger.info(f"注册新的兴趣值计算组件: {calculator.component_name}")
# 初始化新组件
if await calculator.initialize():
self._current_calculator = calculator
logger.info(f"兴趣值计算组件注册成功: {calculator.component_name} v{calculator.component_version}")
logger.info("系统现在只有一个活跃的兴趣值计算器")
return True
else:
logger.error(f"兴趣值计算组件初始化失败: {calculator.component_name}")
return False
except Exception as e:
logger.error(f"注册兴趣值计算组件失败: {e}", exc_info=True)
return False
async def calculate_interest(self, message: "DatabaseMessages") -> InterestCalculationResult:
"""计算消息兴趣值
Args:
message: 数据库消息对象
Returns:
InterestCalculationResult: 计算结果
"""
if not self._current_calculator:
# 返回默认结果
return InterestCalculationResult(
success=False,
message_id=getattr(message, 'message_id', ''),
interest_value=0.3,
error_message="没有可用的兴趣值计算组件"
)
# 异步执行计算,避免阻塞
future = asyncio.create_task(self._async_calculate(message))
return await future
async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult:
"""异步执行兴趣值计算"""
start_time = time.time()
self._total_calculations += 1
try:
# 使用组件的安全执行方法
result = await self._current_calculator._safe_execute(message)
if result.success:
self._last_calculation_time = time.time()
logger.debug(f"兴趣值计算完成: {result.interest_value:.3f} (耗时: {result.calculation_time:.3f}s)")
else:
self._failed_calculations += 1
logger.warning(f"兴趣值计算失败: {result.error_message}")
return result
except Exception as e:
self._failed_calculations += 1
logger.error(f"兴趣值计算异常: {e}", exc_info=True)
return InterestCalculationResult(
success=False,
message_id=getattr(message, 'message_id', ''),
interest_value=0.0,
error_message=f"计算异常: {str(e)}",
calculation_time=time.time() - start_time
)
async def _calculation_worker(self):
"""计算工作线程(预留用于批量处理)"""
while not self._shutdown_event.is_set():
try:
# 等待计算任务或关闭信号
await asyncio.wait_for(
self._calculation_queue.get(),
timeout=1.0
)
# 处理计算任务
# 这里可以实现批量处理逻辑
except asyncio.TimeoutError:
# 超时继续循环
continue
except asyncio.CancelledError:
# 任务被取消,退出循环
break
except Exception as e:
logger.error(f"计算工作线程异常: {e}", exc_info=True)
def get_current_calculator(self) -> BaseInterestCalculator | None:
"""获取当前活跃的兴趣值计算组件"""
return self._current_calculator
def get_statistics(self) -> dict:
"""获取管理器统计信息"""
success_rate = 1.0 - (self._failed_calculations / max(1, self._total_calculations))
stats = {
"manager_statistics": {
"total_calculations": self._total_calculations,
"failed_calculations": self._failed_calculations,
"success_rate": success_rate,
"last_calculation_time": self._last_calculation_time,
"current_calculator": self._current_calculator.component_name if self._current_calculator else None
}
}
# 添加当前组件的统计信息
if self._current_calculator:
stats["calculator_statistics"] = self._current_calculator.get_statistics()
return stats
async def health_check(self) -> bool:
"""健康检查"""
if not self._current_calculator:
return False
try:
# 检查组件是否还活跃
return self._current_calculator.is_enabled
except Exception:
return False
def has_calculator(self) -> bool:
"""检查是否有可用的计算组件"""
return self._current_calculator is not None and self._current_calculator.is_enabled
# 全局实例
_interest_manager = None
def get_interest_manager() -> InterestManager:
"""获取兴趣值管理器实例"""
global _interest_manager
if _interest_manager is None:
_interest_manager = InterestManager()
return _interest_manager

View File

@@ -0,0 +1,166 @@
"""消息兴趣值计算组件管理器
管理消息兴趣值计算组件,确保系统只能有一个兴趣值计算组件实例运行
"""
import asyncio
import time
from typing import TYPE_CHECKING
from src.common.logger import get_logger
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator, InterestCalculationResult
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("message_interest_manager")
class MessageInterestManager:
"""消息兴趣值计算组件管理器"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self._current_calculator: BaseInterestCalculator | None = None
self._calculator_lock = asyncio.Lock()
self._last_calculation_time = 0.0
self._total_calculations = 0
self._failed_calculations = 0
self._initialized = True
async def initialize(self):
"""初始化管理器"""
logger.info("消息兴趣值管理器已初始化")
async def register_calculator(self, calculator: BaseInterestCalculator) -> bool:
"""注册兴趣值计算组件
Args:
calculator: 兴趣值计算组件实例
Returns:
bool: 注册是否成功
"""
async with self._calculator_lock:
try:
# 如果已有组件在运行,先清理
if self._current_calculator:
logger.info(f"替换现有消息兴趣值计算组件: {self._current_calculator.component_name}")
await self._current_calculator.cleanup()
# 初始化新组件
if await calculator.initialize():
self._current_calculator = calculator
logger.info(f"消息兴趣值计算组件注册成功: {calculator.component_name} v{calculator.component_version}")
return True
else:
logger.error(f"消息兴趣值计算组件初始化失败: {calculator.component_name}")
return False
except Exception as e:
logger.error(f"注册消息兴趣值计算组件失败: {e}", exc_info=True)
return False
async def calculate_interest(self, message: "DatabaseMessages") -> InterestCalculationResult:
"""计算消息兴趣值
Args:
message: 数据库消息对象
Returns:
InterestCalculationResult: 计算结果
"""
if not self._current_calculator:
# 返回默认结果
return InterestCalculationResult(
success=False,
message_id=getattr(message, 'message_id', ''),
interest_value=0.3,
error_message="没有可用的消息兴趣值计算组件"
)
start_time = time.time()
self._total_calculations += 1
try:
# 使用组件的安全执行方法
result = await self._current_calculator._safe_execute(message)
if result.success:
self._last_calculation_time = time.time()
logger.debug(f"消息兴趣值计算完成: {result.interest_value:.3f} (耗时: {result.calculation_time:.3f}s)")
else:
self._failed_calculations += 1
logger.warning(f"消息兴趣值计算失败: {result.error_message}")
return result
except Exception as e:
self._failed_calculations += 1
logger.error(f"消息兴趣值计算异常: {e}", exc_info=True)
return InterestCalculationResult(
success=False,
message_id=getattr(message, 'message_id', ''),
interest_value=0.0,
error_message=f"计算异常: {str(e)}",
calculation_time=time.time() - start_time
)
def get_current_calculator(self) -> BaseInterestCalculator | None:
"""获取当前活跃的兴趣值计算组件"""
return self._current_calculator
def get_statistics(self) -> dict:
"""获取管理器统计信息"""
success_rate = 1.0 - (self._failed_calculations / max(1, self._total_calculations))
stats = {
"manager_statistics": {
"total_calculations": self._total_calculations,
"failed_calculations": self._failed_calculations,
"success_rate": success_rate,
"last_calculation_time": self._last_calculation_time,
"current_calculator": self._current_calculator.component_name if self._current_calculator else None
}
}
# 添加当前组件的统计信息
if self._current_calculator:
stats["calculator_statistics"] = self._current_calculator.get_statistics()
return stats
async def health_check(self) -> bool:
"""健康检查"""
if not self._current_calculator:
return False
try:
# 检查组件是否还活跃
return self._current_calculator.is_enabled
except Exception:
return False
def has_calculator(self) -> bool:
"""检查是否有可用的计算组件"""
return self._current_calculator is not None and self._current_calculator.is_enabled
# 全局实例
_message_interest_manager = None
def get_message_interest_manager() -> MessageInterestManager:
"""获取消息兴趣值管理器实例"""
global _message_interest_manager
if _message_interest_manager is None:
_message_interest_manager = MessageInterestManager()
return _message_interest_manager

View File

@@ -995,7 +995,7 @@ class MemorySystem:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream or not hasattr(chat_stream, "context_manager"):
logger.debug(f"未找到stream_id={stream_id}的聊天流或上下文管理器")
@@ -1109,7 +1109,7 @@ class MemorySystem:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream, "context_manager"):
history_limit = self._determine_history_limit(context)
messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True)

View File

@@ -55,8 +55,8 @@ class SingleStreamContextManager:
"""
try:
self.context.add_message(message)
# 推迟兴趣度计算到分发阶段
message.interest_value = getattr(message, "interest_value", None)
# 在上下文管理器中计算兴趣值
await self._calculate_message_interest(message)
self.total_messages += 1
self.last_access_time = time.time()
# 启动流的循环任务(如果还未启动)
@@ -228,51 +228,44 @@ class SingleStreamContextManager:
async def _calculate_message_interest(self, message: DatabaseMessages) -> float:
"""
异步计算消息的兴趣度
此方法通过检查当前是否存在正在运行的 asyncio 事件循环来兼容同步和异步调用。
在上下文管理器中计算消息的兴趣度
"""
# 内部异步函数,封装实际的计算逻辑
async def _get_score():
try:
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
chatter_interest_scoring_system,
)
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
message=message, bot_nickname=global_config.bot.nickname
)
interest_value = interest_score.total_score
logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}")
return interest_value
except ImportError as e:
logger.debug(f"兴趣度计算插件加载失败,可能未启用: {e}")
return 0.5
except Exception as e:
# 在某些情况下(例如机器人自己的消息),没有兴趣度是正常的
logger.info(f"插件内部兴趣度计算失败,使用默认值: {e}")
return 0.5
# 检查并获取当前事件循环
try:
loop = asyncio.get_running_loop()
except RuntimeError: # 'RuntimeError: There is no current event loop...'
loop = None
from src.chat.interest_system.interest_manager import get_interest_manager
if loop and loop.is_running():
# 如果事件循环正在运行,直接 await
return await _get_score()
else:
# 否则,使用 asyncio.run() 来安全执行
return asyncio.run(_get_score())
interest_manager = get_interest_manager()
if interest_manager.has_calculator():
# 使用兴趣值计算组件计算
result = await interest_manager.calculate_interest(message)
if result.success:
# 更新消息对象的兴趣值相关字段
message.interest_value = result.interest_value
message.should_reply = result.should_reply
message.should_act = result.should_act
logger.debug(f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
f"should_reply: {result.should_reply}, should_act: {result.should_act}")
return result.interest_value
else:
logger.warning(f"消息 {message.message_id} 兴趣值计算失败: {result.error_message}")
return 0.5
else:
logger.debug("未找到兴趣值计算器,使用默认兴趣值")
return 0.5
except Exception as e:
logger.error(f"计算消息兴趣度时发生错误: {e}", exc_info=True)
return 0.5
async def add_message_async(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
"""异步实现的 add_message将消息添加到 context并 await 能量更新与分发。"""
try:
self.context.add_message(message)
# 推迟兴趣度计算到分发阶段
message.interest_value = getattr(message, "interest_value", None)
# 在上下文管理器中计算兴趣值
await self._calculate_message_interest(message)
self.total_messages += 1
self.last_access_time = time.time()
@@ -280,7 +273,7 @@ class SingleStreamContextManager:
# 启动流的循环任务(如果还未启动)
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id} (兴趣度待计算)")
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id}")
return True
except Exception as e:
logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)

View File

@@ -181,7 +181,7 @@ class StreamLoopManager:
async def _fallback_acquire_slot(self, stream_id: str, force: bool) -> bool:
"""回退方案:获取槽位(原始方法)"""
# 判断是否需要强制分发
should_force = force or self._should_force_dispatch_for_stream(stream_id)
should_force = force or await self._should_force_dispatch_for_stream(stream_id)
# 检查是否超过最大并发限制
current_streams = len(self.stream_loops)
@@ -410,7 +410,7 @@ class StreamLoopManager:
"""
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream:
return chat_stream.context_manager.context
return None
@@ -538,13 +538,13 @@ class StreamLoopManager:
self.chatter_manager = chatter_manager
logger.info(f"设置chatter管理器: {chatter_manager.__class__.__name__}")
def _should_force_dispatch_for_stream(self, stream_id: str) -> bool:
async def _should_force_dispatch_for_stream(self, stream_id: str) -> bool:
if not self.force_dispatch_unread_threshold or self.force_dispatch_unread_threshold <= 0:
return False
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
return False
@@ -595,7 +595,7 @@ class StreamLoopManager:
"""分发完成后基于历史消息刷新能量值"""
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.debug(f"刷新能量时未找到聊天流: {stream_id}")
return
@@ -622,7 +622,7 @@ class StreamLoopManager:
except Exception as e:
logger.error(f"等待流循环任务结束时出错: {stream_id} - {e}")
def _force_dispatch_stream(self, stream_id: str) -> None:
async def _force_dispatch_stream(self, stream_id: str) -> None:
"""强制分发流处理
当流的未读消息超过阈值时,强制触发分发处理
@@ -657,7 +657,7 @@ class StreamLoopManager:
# 获取聊天管理器和流
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"强制分发时未找到流: {stream_id}")
return

View File

@@ -132,7 +132,7 @@ class MessageManager:
"""添加消息到指定聊天流"""
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
return
@@ -153,7 +153,7 @@ class MessageManager:
"""更新消息信息"""
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.update_message: 聊天流 {stream_id} 不存在")
return
@@ -180,7 +180,7 @@ class MessageManager:
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.bulk_update_messages: 聊天流 {stream_id} 不存在")
return 0
@@ -211,7 +211,7 @@ class MessageManager:
"""添加动作到消息"""
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.add_action: 聊天流 {stream_id} 不存在")
return
@@ -223,12 +223,12 @@ class MessageManager:
except Exception as e:
logger.error(f"为消息 {message_id} 添加动作时发生错误: {e}")
def deactivate_stream(self, stream_id: str):
async def deactivate_stream(self, stream_id: str):
"""停用聊天流"""
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在")
return
@@ -245,12 +245,12 @@ class MessageManager:
except Exception as e:
logger.error(f"停用聊天流 {stream_id} 时发生错误: {e}")
def activate_stream(self, stream_id: str):
async def activate_stream(self, stream_id: str):
"""激活聊天流"""
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在")
return
@@ -262,12 +262,12 @@ class MessageManager:
except Exception as e:
logger.error(f"激活聊天流 {stream_id} 时发生错误: {e}")
def get_stream_stats(self, stream_id: str) -> StreamStats | None:
async def get_stream_stats(self, stream_id: str) -> StreamStats | None:
"""获取聊天流统计"""
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
return None
@@ -360,7 +360,7 @@ class MessageManager:
pass
# 增加打断计数并应用afc阈值降低
chat_stream.context_manager.context.increment_interruption_count()
await chat_stream.context_manager.context.increment_interruption_count()
chat_stream.context_manager.context.apply_interruption_afc_reduction(
global_config.chat.interruption_afc_reduction
)
@@ -382,7 +382,7 @@ class MessageManager:
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"清除消息失败: 聊天流 {stream_id} 不存在")
return
@@ -411,7 +411,7 @@ class MessageManager:
"""清除指定聊天流的所有未读消息"""
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"clear_stream_unread_messages: 聊天流 {stream_id} 不存在")
return

View File

@@ -161,7 +161,7 @@ class ChatStream:
self.last_active_time = time.time()
self.saved = False
def set_context(self, message: "MessageRecv"):
async def set_context(self, message: "MessageRecv"):
"""设置聊天消息上下文"""
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
import json
@@ -234,6 +234,7 @@ class ChatStream:
# 新增兴趣度系统字段 - 添加安全处理
actions=self._safe_get_actions(message),
should_reply=getattr(message, "should_reply", False),
should_act=getattr(message, "should_act", False),
)
self.stream_context.set_current_message(db_message)
@@ -280,6 +281,45 @@ class ChatStream:
logger.warning(f"获取actions字段失败: {e}")
return None
async def _calculate_message_interest(self, db_message):
"""计算消息兴趣值并更新消息对象"""
try:
from src.chat.interest_system.interest_manager import get_interest_manager
interest_manager = get_interest_manager()
if interest_manager.has_calculator():
# 使用兴趣值计算组件计算
result = await interest_manager.calculate_interest(db_message)
if result.success:
# 更新消息对象的兴趣值相关字段
db_message.interest_value = result.interest_value
db_message.should_reply = result.should_reply
db_message.should_act = result.should_act
logger.debug(f"消息 {db_message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
f"should_reply: {result.should_reply}, should_act: {result.should_act}")
else:
logger.warning(f"消息 {db_message.message_id} 兴趣值计算失败: {result.error_message}")
# 使用默认值
db_message.interest_value = 0.3
db_message.should_reply = False
db_message.should_act = False
else:
# 没有兴趣值计算组件,抛出异常
raise RuntimeError("没有可用的兴趣值计算组件")
except Exception as e:
logger.error(f"计算消息兴趣值失败: {e}", exc_info=True)
# 异常情况下使用默认值
if hasattr(db_message, 'interest_value'):
db_message.interest_value = 0.3
if hasattr(db_message, 'should_reply'):
db_message.should_reply = False
if hasattr(db_message, 'should_act'):
db_message.should_act = False
def _extract_reply_from_segment(self, segment) -> str | None:
"""从消息段中提取reply_to信息"""
try:
@@ -497,7 +537,9 @@ class ChatManager:
optimized_stream.set_context(self.last_messages[stream_id])
# 转换为原始ChatStream以保持兼容性
return self._convert_to_original_stream(optimized_stream)
original_stream = self._convert_to_original_stream(optimized_stream)
return original_stream
except Exception as e:
logger.debug(f"缓存管理器获取流失败,使用原始方法: {e}")
@@ -517,7 +559,7 @@ class ChatManager:
from .message import MessageRecv # 延迟导入,避免循环引用
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
stream.set_context(self.last_messages[stream_id])
await stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
return stream
@@ -581,7 +623,7 @@ class ChatManager:
from .message import MessageRecv # 延迟导入,避免循环引用
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
stream.set_context(self.last_messages[stream_id])
await stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
@@ -597,13 +639,13 @@ class ChatManager:
await self._save_stream(stream)
return stream
def get_stream(self, stream_id: str) -> ChatStream | None:
async def get_stream(self, stream_id: str) -> ChatStream | None:
"""通过stream_id获取聊天流"""
stream = self.streams.get(stream_id)
if not stream:
return None
if stream_id in self.last_messages:
stream.set_context(self.last_messages[stream_id])
await stream.set_context(self.last_messages[stream_id])
return stream
def get_stream_by_info(
@@ -613,9 +655,9 @@ class ChatManager:
stream_id = self._generate_stream_id(platform, user_info, group_info)
return self.streams.get(stream_id)
def get_stream_name(self, stream_id: str) -> str | None:
async def get_stream_name(self, stream_id: str) -> str | None:
"""根据 stream_id 获取聊天流名称"""
stream = self.get_stream(stream_id)
stream = await self.get_stream(stream_id)
if not stream:
return None
@@ -813,8 +855,9 @@ class ChatManager:
stream = ChatStream.from_dict(data)
stream.saved = True
self.streams[stream.stream_id] = stream
if stream.stream_id in self.last_messages:
stream.set_context(self.last_messages[stream.stream_id])
# 不在异步加载中设置上下文,避免复杂依赖
# if stream.stream_id in self.last_messages:
# await stream.set_context(self.last_messages[stream.stream_id])
# 确保 ChatStream 有自己的 context_manager
if not hasattr(stream, "context_manager"):

View File

@@ -165,7 +165,7 @@ class ChatterActionManager:
logger.debug(f"🎯 [ActionManager] execute_action接收到 target_message: {target_message}")
# 通过chat_id获取chat_stream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(chat_id)
chat_stream = await chat_manager.get_stream(chat_id)
if not chat_stream:
logger.error(f"{log_prefix} 无法找到chat_id对应的chat_stream: {chat_id}")
@@ -322,13 +322,13 @@ class ChatterActionManager:
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream:
context = chat_stream.context_manager
if context.context.interruption_count > 0:
old_count = context.context.interruption_count
old_afc_adjustment = context.context.get_afc_threshold_adjustment()
context.context.reset_interruption_count()
await context.context.reset_interruption_count()
logger.debug(
f"动作执行成功,重置聊天流 {stream_id} 的打断计数: {old_count} -> 0, afc调整: {old_afc_adjustment} -> 0"
)

View File

@@ -31,8 +31,9 @@ class ActionModifier:
def __init__(self, action_manager: ChatterActionManager, chat_id: str):
"""初始化动作处理器"""
self.chat_id = chat_id
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id) # type: ignore
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
# chat_stream 和 log_prefix 将在异步方法中初始化
self.chat_stream = None # type: ignore
self.log_prefix = f"[{chat_id}]"
self.action_manager = action_manager
@@ -43,6 +44,15 @@ class ActionModifier:
self._llm_judge_cache = {} # 缓存LLM判定结果
self._cache_expiry_time = 30 # 缓存过期时间(秒)
self._last_context_hash = None # 上次上下文的哈希值
self._log_prefix_initialized = False
async def _initialize_log_prefix(self):
"""异步初始化log_prefix和chat_stream"""
if not self._log_prefix_initialized:
self.chat_stream = await get_chat_manager().get_stream(self.chat_id)
stream_name = await get_chat_manager().get_stream_name(self.chat_id)
self.log_prefix = f"[{stream_name or self.chat_id}]"
self._log_prefix_initialized = True
async def modify_actions(
self,
@@ -57,6 +67,9 @@ class ActionModifier:
处理后ActionManager 将包含最终的可用动作集,供规划器直接使用
"""
# 初始化log_prefix
await self._initialize_log_prefix()
logger.debug(f"{self.log_prefix}开始完整动作修改流程")
removals_s1: list[tuple[str, str]] = []
@@ -72,7 +85,7 @@ class ActionModifier:
from src.plugin_system.core.component_registry import component_registry
# 获取聊天类型
is_group_chat, _ = get_chat_type_and_target_info(self.chat_id)
is_group_chat, _ = await get_chat_type_and_target_info(self.chat_id)
all_registered_actions = component_registry.get_components_by_type(ComponentType.ACTION)
chat_type_removals = []

View File

@@ -229,11 +229,21 @@ class DefaultReplyer:
):
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
# 这些将在异步初始化中设置
self.is_group_chat = False
self.chat_target_info = None
self._chat_info_initialized = False
self.heart_fc_sender = HeartFCSender()
# 使用新的增强记忆系统
# from src.chat.memory_system.enhanced_memory_activator import EnhancedMemoryActivator
self._chat_info_initialized = False
async def _initialize_chat_info(self):
"""异步初始化聊天信息"""
if not self._chat_info_initialized:
self.is_group_chat, self.chat_target_info = await get_chat_type_and_target_info(self.chat_stream.stream_id)
self._chat_info_initialized = True
# self.memory_activator = EnhancedMemoryActivator()
self.memory_activator = None # 暂时禁用记忆激活器
# 旧的即时记忆系统已被移除,现在使用增强记忆系统
@@ -267,6 +277,9 @@ class DefaultReplyer:
Returns:
Tuple[bool, Optional[Dict[str, Any]], Optional[str]]: (是否成功, 生成的回复, 使用的prompt)
"""
# 初始化聊天信息
await self._initialize_chat_info()
prompt = None
if available_actions is None:
available_actions = {}
@@ -810,7 +823,7 @@ class DefaultReplyer:
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(chat_id)
chat_stream = await chat_manager.get_stream(chat_id)
if chat_stream:
stream_context = chat_stream.context_manager
# 使用真正的已读和未读消息
@@ -1000,47 +1013,24 @@ class DefaultReplyer:
return read_history_prompt, unread_history_prompt
async def _get_interest_scores_for_messages(self, messages: list[dict]) -> dict[str, float]:
"""为消息获取兴趣度评分"""
"""为消息获取兴趣度评分(使用预计算的兴趣值)"""
interest_scores = {}
try:
from src.common.data_models.database_data_model import DatabaseMessages
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
chatter_interest_scoring_system as interest_scoring_system,
)
# 转换消息格式
db_messages = []
# 直接使用消息中的预计算兴趣值
for msg_dict in messages:
try:
db_msg = DatabaseMessages(
message_id=msg_dict.get("message_id", ""),
time=msg_dict.get("time", time.time()),
chat_id=msg_dict.get("chat_id", ""),
processed_plain_text=msg_dict.get("processed_plain_text", ""),
user_id=msg_dict.get("user_id", ""),
user_nickname=msg_dict.get("user_nickname", ""),
user_platform=msg_dict.get("platform", "qq"),
chat_info_group_id=msg_dict.get("group_id", ""),
chat_info_group_name=msg_dict.get("group_name", ""),
chat_info_group_platform=msg_dict.get("platform", "qq"),
)
db_messages.append(db_msg)
except Exception as e:
logger.warning(f"转换消息格式失败: {e}")
continue
message_id = msg_dict.get("message_id", "")
interest_value = msg_dict.get("interest_value")
# 计算兴趣度评分
if db_messages:
bot_nickname = global_config.bot.nickname or "麦麦"
scores = await interest_scoring_system.calculate_interest_scores(db_messages, bot_nickname)
# 构建兴趣度字典
for score in scores:
interest_scores[score.message_id] = score.total_score
if interest_value is not None:
interest_scores[message_id] = float(interest_value)
logger.debug(f"使用预计算兴趣度 - 消息 {message_id}: {interest_value:.3f}")
else:
interest_scores[message_id] = 0.5 # 默认值
logger.debug(f"消息 {message_id} 无预计算兴趣值,使用默认值 0.5")
except Exception as e:
logger.warning(f"获取兴趣度评分失败: {e}")
logger.warning(f"处理预计算兴趣值失败: {e}")
return interest_scores

View File

@@ -1043,11 +1043,11 @@ class Prompt:
from src.plugin_system.apis import cross_context_api
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
other_chat_raw_ids = await cross_context_api.get_context_groups(chat_id)
if not other_chat_raw_ids:
return ""
chat_stream = get_chat_manager().get_stream(chat_id)
chat_stream = await get_chat_manager().get_stream(chat_id)
if not chat_stream:
return ""

View File

@@ -622,7 +622,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
return time.strftime("%H:%M:%S", time.localtime(timestamp))
def get_chat_type_and_target_info(chat_id: str) -> tuple[bool, dict | None]:
async def get_chat_type_and_target_info(chat_id: str) -> tuple[bool, dict | None]:
"""
获取聊天类型(是否群聊)和私聊对象信息。
@@ -639,7 +639,7 @@ def get_chat_type_and_target_info(chat_id: str) -> tuple[bool, dict | None]:
chat_target_info = None
try:
if chat_stream := get_chat_manager().get_stream(chat_id):
if chat_stream := await get_chat_manager().get_stream(chat_id):
if chat_stream.group_info:
is_group_chat = True
chat_target_info = None # Explicitly None for group chat

View File

@@ -98,6 +98,7 @@ class DatabaseMessages(BaseDataModel):
# 新增字段
actions: list | None = None,
should_reply: bool = False,
should_act: bool = False,
**kwargs: Any,
):
self.message_id = message_id
@@ -109,6 +110,7 @@ class DatabaseMessages(BaseDataModel):
# 新增字段
self.actions = actions
self.should_reply = should_reply
self.should_act = should_act
self.key_words = key_words
self.key_words_lite = key_words_lite

View File

@@ -184,22 +184,22 @@ class StreamContext(BaseDataModel):
return max(0.0, min(1.0, probability))
def increment_interruption_count(self):
async def increment_interruption_count(self):
"""增加打断计数"""
self.interruption_count += 1
self.last_interruption_time = time.time()
# 同步打断计数到ChatStream
self._sync_interruption_count_to_stream()
await self._sync_interruption_count_to_stream()
def reset_interruption_count(self):
async def reset_interruption_count(self):
"""重置打断计数和afc阈值调整"""
self.interruption_count = 0
self.last_interruption_time = 0.0
self.afc_threshold_adjustment = 0.0
# 同步打断计数到ChatStream
self._sync_interruption_count_to_stream()
await self._sync_interruption_count_to_stream()
def apply_interruption_afc_reduction(self, reduction_value: float):
"""应用打断导致的afc阈值降低"""
@@ -210,14 +210,14 @@ class StreamContext(BaseDataModel):
"""获取当前的afc阈值调整量"""
return self.afc_threshold_adjustment
def _sync_interruption_count_to_stream(self):
async def _sync_interruption_count_to_stream(self):
"""同步打断计数到ChatStream"""
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
if chat_manager:
chat_stream = chat_manager.get_stream(self.stream_id)
chat_stream = await chat_manager.get_stream(self.stream_id)
if chat_stream and hasattr(chat_stream, "interruption_count"):
# 在这里我们只是标记需要保存实际的保存会在下次save时进行
chat_stream.saved = False

View File

@@ -240,6 +240,7 @@ class Messages(Base):
# 兴趣度系统字段
actions = Column(Text, nullable=True) # JSON格式存储动作列表
should_reply = Column(Boolean, nullable=True, default=False)
should_act = Column(Boolean, nullable=True, default=False)
__table_args__ = (
Index("idx_messages_message_id", "message_id"),
@@ -247,6 +248,7 @@ class Messages(Base):
Index("idx_messages_time", "time"),
Index("idx_messages_user_id", "user_id"),
Index("idx_messages_should_reply", "should_reply"),
Index("idx_messages_should_act", "should_act"),
)

View File

@@ -339,8 +339,8 @@ class MemoryConfig(ValidatedConfigBase):
# === 混合记忆系统配置 ===
# 采样模式配置
memory_sampling_mode: Literal["adaptive", "hippocampus", "precision"] = Field(
default="adaptive", description="记忆采样模式:adaptive(自适应)hippocampus(海马体双峰采样)precision(精准记忆)"
memory_sampling_mode: Literal["all", "hippocampus", "immediate"] = Field(
default="all", description="记忆采样模式hippocampus(海马体定时采样)immediate(即时采样)all(所有模式)"
)
# 海马体双峰采样配置

View File

@@ -101,6 +101,85 @@ class MainSystem:
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
async def _initialize_interest_calculator(self):
"""初始化兴趣值计算组件 - 通过插件系统自动发现和加载"""
try:
logger.info("开始自动发现兴趣值计算组件...")
# 使用组件注册表自动发现兴趣计算器组件
interest_calculators = {}
try:
from src.plugin_system.apis.component_manage_api import get_components_info_by_type
from src.plugin_system.base.component_types import ComponentType
interest_calculators = get_components_info_by_type(ComponentType.INTEREST_CALCULATOR)
logger.info(f"通过组件注册表发现 {len(interest_calculators)} 个兴趣计算器组件")
except Exception as e:
logger.error(f"从组件注册表获取兴趣计算器失败: {e}")
if not interest_calculators:
logger.warning("未发现任何兴趣计算器组件")
return
logger.info(f"发现的兴趣计算器组件:")
for calc_name, calc_info in interest_calculators.items():
enabled = getattr(calc_info, 'enabled', True)
default_enabled = getattr(calc_info, 'enabled_by_default', True)
logger.info(f" - {calc_name}: 启用: {enabled}, 默认启用: {default_enabled}")
# 初始化兴趣度管理器
from src.chat.interest_system.interest_manager import get_interest_manager
interest_manager = get_interest_manager()
await interest_manager.initialize()
# 尝试注册计算器(单例模式,只注册第一个可用的)
registered_calculator = None
# 使用组件注册表获取组件类并注册
for calc_name, calc_info in interest_calculators.items():
enabled = getattr(calc_info, 'enabled', True)
default_enabled = getattr(calc_info, 'enabled_by_default', True)
if not enabled or not default_enabled:
logger.info(f"兴趣计算器 {calc_name} 未启用,跳过")
continue
try:
from src.plugin_system.core.component_registry import component_registry
component_class = component_registry.get_component_class(calc_name, ComponentType.INTEREST_CALCULATOR)
if component_class:
logger.info(f"成功获取 {calc_name} 的组件类: {component_class.__name__}")
# 创建组件实例
calculator_instance = component_class()
logger.info(f"成功创建兴趣计算器实例: {calc_name}")
# 初始化组件
if await calculator_instance.initialize():
# 注册到兴趣管理器
success = await interest_manager.register_calculator(calculator_instance)
if success:
registered_calculator = calculator_instance
logger.info(f"成功注册兴趣计算器: {calc_name}")
break # 只注册一个成功的计算器
else:
logger.error(f"兴趣计算器 {calc_name} 注册失败")
else:
logger.error(f"兴趣计算器 {calc_name} 初始化失败")
else:
logger.warning(f"无法找到 {calc_name} 的组件类")
except Exception as e:
logger.error(f"处理兴趣计算器 {calc_name} 时出错: {e}", exc_info=True)
if registered_calculator:
logger.info(f"当前活跃的兴趣度计算器: {registered_calculator.component_name} v{registered_calculator.component_version}")
else:
logger.error("未能成功注册任何兴趣计算器")
except Exception as e:
logger.error(f"初始化兴趣度计算器失败: {e}", exc_info=True)
async def _async_cleanup(self):
"""异步清理资源"""
try:
@@ -264,7 +343,8 @@ MoFox_Bot(第三方修改版)
# 初始化表情管理器
get_emoji_manager().initialize()
logger.info("表情包管理器初始化成功")
'''
# 初始化回复后关系追踪系统
try:
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
@@ -276,7 +356,7 @@ MoFox_Bot(第三方修改版)
except Exception as e:
logger.error(f"回复后关系追踪系统初始化失败: {e}")
relationship_tracker = None
'''
# 启动情绪管理器
await mood_manager.start()
@@ -293,6 +373,9 @@ MoFox_Bot(第三方修改版)
# 老记忆系统已完全删除
# 初始化消息兴趣值计算组件
await self._initialize_interest_calculator()
# 初始化LPMM知识库
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
@@ -350,17 +433,36 @@ MoFox_Bot(第三方修改版)
"""调度定时任务"""
try:
while True:
tasks = [
get_emoji_manager().start_periodic_check_register(),
self.app.run(),
self.server.run(),
]
try:
tasks = [
get_emoji_manager().start_periodic_check_register(),
self.app.run(),
self.server.run(),
]
# 增强记忆系统不需要定时任务,已禁用原有记忆系统的定时任务
# 增强记忆系统不需要定时任务,已禁用原有记忆系统的定时任务
# 使用 return_exceptions=True 防止单个任务失败导致整个程序崩溃
await asyncio.gather(*tasks, return_exceptions=True)
await asyncio.gather(*tasks)
except (ConnectionResetError, OSError) as e:
logger.warning(f"网络连接发生错误,尝试重新启动任务: {e}")
await asyncio.sleep(1) # 短暂等待后重新开始
continue
except asyncio.InvalidStateError as e:
logger.error(f"异步任务状态无效,重新初始化: {e}")
await asyncio.sleep(2) # 等待更长时间让系统稳定
continue
except Exception as e:
logger.error(f"调度任务发生未预期异常: {e}")
logger.error(traceback.format_exc())
await asyncio.sleep(5) # 发生其他错误时等待更长时间
continue
except asyncio.CancelledError:
logger.info("调度任务被取消,正在退出...")
except Exception as e:
logger.error(f"调度任务发生异常: {e}")
logger.error(f"调度任务发生致命异常: {e}")
logger.error(traceback.format_exc())
raise
async def shutdown(self):

View File

@@ -38,13 +38,11 @@ def init_prompt():
class MaiThinking:
def __init__(self, chat_id):
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.platform = self.chat_stream.platform
if self.chat_stream.group_info:
self.is_group = True
else:
self.is_group = False
# 这些将在异步初始化中设置
self.chat_stream = None # type: ignore
self.platform = None
self.is_group = False
self._initialized = False
self.s4u_message_processor = S4UMessageProcessor()
@@ -63,6 +61,15 @@ class MaiThinking:
self.thinking_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type="thinking")
async def _initialize(self):
"""异步初始化方法"""
if not self._initialized:
self.chat_stream = await get_chat_manager().get_stream(self.chat_id)
if self.chat_stream:
self.platform = self.chat_stream.platform
self.is_group = bool(self.chat_stream.group_info)
self._initialized = True
async def do_think_before_response(self):
pass
@@ -98,6 +105,9 @@ class MaiThinking:
pass
async def build_internal_message_recv(self, message_text: str):
# 初始化
await self._initialize()
msg_id = f"internal_{time.time()}"
message_dict = {

View File

@@ -162,9 +162,9 @@ class S4UChatManager:
def __init__(self):
self.s4u_chats: dict[str, "S4UChat"] = {}
def get_or_create_chat(self, chat_stream: ChatStream) -> "S4UChat":
async def get_or_create_chat(self, chat_stream: ChatStream) -> "S4UChat":
if chat_stream.stream_id not in self.s4u_chats:
stream_name = get_chat_manager().get_stream_name(chat_stream.stream_id) or chat_stream.stream_id
stream_name = await get_chat_manager().get_stream_name(chat_stream.stream_id) or chat_stream.stream_id
logger.info(f"Creating new S4UChat for stream: {stream_name}")
self.s4u_chats[chat_stream.stream_id] = S4UChat(chat_stream)
return self.s4u_chats[chat_stream.stream_id]
@@ -187,7 +187,7 @@ class S4UChat:
self.last_msg_id = self.msg_id
self.chat_stream = chat_stream
self.stream_id = chat_stream.stream_id
self.stream_name = get_chat_manager().get_stream_name(self.stream_id) or self.stream_id
self.stream_name = self.stream_id # 初始化时使用stream_id,稍后异步更新
self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id)
# 两个消息队列
@@ -213,6 +213,13 @@ class S4UChat:
self.voice_done = ""
logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.")
self._stream_name_initialized = False
async def _initialize_stream_name(self):
"""异步初始化stream_name"""
if not self._stream_name_initialized:
self.stream_name = await get_chat_manager().get_stream_name(self.stream_id) or self.stream_id
self._stream_name_initialized = True
@staticmethod
def _get_priority_info(message: MessageRecv) -> dict:
@@ -263,6 +270,9 @@ class S4UChat:
self.interest_dict[person_id] = 0
async def add_message(self, message: MessageRecvS4U | MessageRecv) -> None:
# 初始化stream_name
await self._initialize_stream_name()
self.decay_interest_score()
"""根据VIP状态和中断逻辑将消息放入相应队列。"""

View File

@@ -157,7 +157,7 @@ class S4UMessageProcessor:
await self.storage.store_message(message, chat)
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
s4u_chat = await get_s4u_chat_manager().get_or_create_chat(chat)
await s4u_chat.add_message(message)
@@ -191,7 +191,7 @@ class S4UMessageProcessor:
chat = await get_chat_manager().get_or_create_stream(
platform="amaidesu_default", user_info=message.message_info.user_info, group_info=group_info
)
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
s4u_chat = await get_s4u_chat_manager().get_or_create_chat(chat)
message.message_info.group_info = s4u_chat.chat_stream.group_info
message.message_info.platform = s4u_chat.chat_stream.platform
@@ -215,7 +215,7 @@ class S4UMessageProcessor:
@staticmethod
async def hadle_if_voice_done(message: MessageRecvS4U):
if message.voice_done:
s4u_chat = get_s4u_chat_manager().get_or_create_chat(message.chat_stream)
s4u_chat = await get_s4u_chat_manager().get_or_create_chat(message.chat_stream)
s4u_chat.voice_done = message.voice_done
return True
return False

View File

@@ -48,17 +48,27 @@ class ChatMood:
def __init__(self, chat_id: str):
self.chat_id: str = chat_id
chat_manager = get_chat_manager()
self.chat_stream = chat_manager.get_stream(self.chat_id)
if not self.chat_stream:
raise ValueError(f"Chat stream for chat_id {chat_id} not found")
self.log_prefix = f"[{self.chat_stream.group_info.group_name if self.chat_stream.group_info else self.chat_stream.user_info.user_nickname}]"
# 这些将在异步初始化中设置
self.chat_stream = None # type: ignore
self.log_prefix = f"[{chat_id}]"
self._initialized = False
self.mood_state: str = "感觉很平静"
self.is_angry_from_wakeup: bool = False # 是否因被吵醒而愤怒
async def _initialize(self):
"""异步初始化方法"""
if not self._initialized:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
self.chat_stream = await chat_manager.get_stream(self.chat_id)
if not self.chat_stream:
raise ValueError(f"Chat stream for chat_id {self.chat_id} not found")
self.log_prefix = f"[{self.chat_stream.group_info.group_name if self.chat_stream.group_info else self.chat_stream.user_info.user_nickname}]"
self._initialized = True
self.regression_count: int = 0
self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood")

View File

@@ -58,16 +58,24 @@ class RelationshipBuilder:
# 最后清理时间,用于定期清理老消息段
self.last_cleanup_time = 0.0
# 获取聊天名称用于日志
try:
chat_name = get_chat_manager().get_stream_name(self.chat_id)
self.log_prefix = f"[{chat_name}]"
except Exception:
self.log_prefix = f"[{self.chat_id}]"
# log_prefix 将在异步方法中初始化
self.log_prefix = f"[{self.chat_id}]"
self._log_prefix_initialized = False
# 加载持久化的缓存
self._load_cache()
async def _initialize_log_prefix(self):
"""异步初始化log_prefix"""
if not self._log_prefix_initialized:
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_name = await get_chat_manager().get_stream_name(self.chat_id)
self.log_prefix = f"[{chat_name}]"
except Exception:
self.log_prefix = f"[{self.chat_id}]"
self._log_prefix_initialized = True
# ================================
# 缓存管理模块
# 负责持久化存储、状态管理、缓存读写
@@ -339,6 +347,9 @@ class RelationshipBuilder:
"""构建关系
immediate_build: 立即构建关系,可选值为"all"或person_id
"""
# 初始化log_prefix
await self._initialize_log_prefix()
self._cleanup_old_segments()
current_time = time.time()

View File

@@ -79,8 +79,16 @@ class RelationshipFetcher:
model_set=model_config.model_task_config.utils_small, request_type="relation.fetch"
)
name = get_chat_manager().get_stream_name(self.chat_id)
self.log_prefix = f"[{name}] 实时信息"
self.log_prefix = f"[{self.chat_id}] 实时信息" # 初始化时使用chat_id稍后异步更新
self._log_prefix_initialized = False
async def _initialize_log_prefix(self):
"""异步初始化log_prefix"""
if not self._log_prefix_initialized:
from src.chat.message_receive.chat_stream import get_chat_manager
name = await get_chat_manager().get_stream_name(self.chat_id)
self.log_prefix = f"[{name}] 实时信息"
self._log_prefix_initialized = True
def _cleanup_expired_cache(self):
"""清理过期的信息缓存"""
@@ -94,6 +102,9 @@ class RelationshipFetcher:
async def build_relation_info(self, person_id, points_num=5):
"""构建详细的人物关系信息,包含从数据库中查询的丰富关系描述"""
# 初始化log_prefix
await self._initialize_log_prefix()
# 清理过期的信息缓存
self._cleanup_expired_cache()

View File

@@ -16,11 +16,11 @@ from src.config.config import global_config
logger = get_logger("cross_context_api")
def get_context_groups(chat_id: str) -> list[list[str]] | None:
async def get_context_groups(chat_id: str) -> list[list[str]] | None:
"""
获取当前聊天所在的共享组的其他聊天ID
"""
current_stream = get_chat_manager().get_stream(chat_id)
current_stream = await get_chat_manager().get_stream(chat_id)
if not current_stream:
return None
@@ -59,7 +59,7 @@ async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos:
limit=5, # 可配置
)
if messages:
chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id
chat_name = await get_chat_manager().get_stream_name(stream_id) or chat_raw_id
formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative")
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
except Exception as e:
@@ -100,7 +100,7 @@ async def build_cross_context_s4u(
user_messages = [msg for msg in messages if msg.get("user_id") == user_id][-5:]
if user_messages:
chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id
chat_name = await get_chat_manager().get_stream_name(stream_id) or chat_raw_id
user_name = (
target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id
)
@@ -167,7 +167,7 @@ async def get_chat_history_by_group_name(group_name: str) -> str:
limit=5, # 可配置
)
if messages:
chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id
chat_name = await get_chat_manager().get_stream_name(stream_id) or chat_raw_id
formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative")
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
except Exception as e:

View File

@@ -166,7 +166,7 @@ async def _send_to_target(
logger.debug(f"[SendAPI] 发送{message_type}消息到 {stream_id}")
# 查找目标聊天流
target_stream = get_chat_manager().get_stream(stream_id)
target_stream = await get_chat_manager().get_stream(stream_id)
if not target_stream:
logger.error(f"[SendAPI] 未找到聊天流: {stream_id}")
return False
@@ -416,7 +416,7 @@ async def adapter_command_to_stream(
logger.debug(f"[SendAPI] 自动生成临时stream_id: {stream_id}")
# 查找目标聊天流
target_stream = get_chat_manager().get_stream(stream_id)
target_stream = await get_chat_manager().get_stream(stream_id)
if not target_stream:
# 如果是自动生成的stream_id且找不到聊天流创建一个临时的虚拟流
if stream_id.startswith("adapter_temp_"):

View File

@@ -0,0 +1,220 @@
"""兴趣值计算组件基类
提供兴趣值计算的标准接口,确保只能有一个兴趣值计算组件实例运行
"""
import time
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ComponentType, InterestCalculatorInfo
logger = get_logger("base_interest_calculator")
class InterestCalculationResult:
"""兴趣值计算结果"""
def __init__(
self,
success: bool,
message_id: str,
interest_value: float,
should_take_action: bool = False,
should_reply: bool = False,
should_act: bool = False,
error_message: str | None = None,
calculation_time: float = 0.0
):
self.success = success
self.message_id = message_id
self.interest_value = max(0.0, min(1.0, interest_value)) # 确保在0-1范围内
self.should_take_action = should_take_action
self.should_reply = should_reply
self.should_act = should_act
self.error_message = error_message
self.calculation_time = calculation_time
self.timestamp = time.time()
def to_dict(self) -> dict:
"""转换为字典格式"""
return {
"success": self.success,
"message_id": self.message_id,
"interest_value": self.interest_value,
"should_take_action": self.should_take_action,
"should_reply": self.should_reply,
"should_act": self.should_act,
"error_message": self.error_message,
"calculation_time": self.calculation_time,
"timestamp": self.timestamp
}
def __repr__(self) -> str:
return (f"InterestCalculationResult("
f"success={self.success}, "
f"message_id={self.message_id}, "
f"interest_value={self.interest_value:.3f}, "
f"should_take_action={self.should_take_action}, "
f"should_reply={self.should_reply}, "
f"should_act={self.should_act})")
class BaseInterestCalculator(ABC):
"""兴趣值计算组件基类
所有兴趣值计算组件都必须继承此类,并实现 execute 方法
系统确保只能有一个兴趣值计算组件实例运行
"""
# 子类必须定义这些属性
component_name: str = ""
component_version: str = ""
component_description: str = ""
enabled_by_default: bool = True # 是否默认启用
def __init__(self):
self._enabled = False
self._last_calculation_time = 0.0
self._total_calculations = 0
self._failed_calculations = 0
self._average_calculation_time = 0.0
# 验证必须定义的属性
if not self.component_name:
raise ValueError("子类必须定义 component_name 属性")
if not self.component_version:
raise ValueError("子类必须定义 component_version 属性")
if not self.component_description:
raise ValueError("子类必须定义 component_description 属性")
@abstractmethod
async def execute(self, message: "DatabaseMessages") -> InterestCalculationResult:
"""执行兴趣值计算
Args:
message: 数据库消息对象
Returns:
InterestCalculationResult: 计算结果
"""
pass
async def initialize(self) -> bool:
"""初始化组件
Returns:
bool: 初始化是否成功
"""
try:
self._enabled = True
return True
except Exception as e:
self._enabled = False
return False
async def cleanup(self) -> bool:
"""清理组件资源
Returns:
bool: 清理是否成功
"""
try:
self._enabled = False
return True
except Exception:
return False
@property
def is_enabled(self) -> bool:
"""组件是否已启用"""
return self._enabled
def get_statistics(self) -> dict:
"""获取组件统计信息"""
return {
"component_name": self.component_name,
"component_version": self.component_version,
"enabled": self._enabled,
"total_calculations": self._total_calculations,
"failed_calculations": self._failed_calculations,
"success_rate": 1.0 - (self._failed_calculations / max(1, self._total_calculations)),
"average_calculation_time": self._average_calculation_time,
"last_calculation_time": self._last_calculation_time
}
def _update_statistics(self, result: InterestCalculationResult):
"""更新统计信息"""
self._total_calculations += 1
if not result.success:
self._failed_calculations += 1
# 更新平均计算时间
if self._total_calculations == 1:
self._average_calculation_time = result.calculation_time
else:
alpha = 0.1 # 指数移动平均的平滑因子
self._average_calculation_time = (
alpha * result.calculation_time +
(1 - alpha) * self._average_calculation_time
)
self._last_calculation_time = result.timestamp
async def _safe_execute(self, message: "DatabaseMessages") -> InterestCalculationResult:
"""安全执行计算,包含统计和错误处理"""
if not self._enabled:
return InterestCalculationResult(
success=False,
message_id=getattr(message, 'message_id', ''),
interest_value=0.0,
error_message="组件未启用"
)
start_time = time.time()
try:
result = await self.execute(message)
result.calculation_time = time.time() - start_time
self._update_statistics(result)
return result
except Exception as e:
result = InterestCalculationResult(
success=False,
message_id=getattr(message, 'message_id', ''),
interest_value=0.0,
error_message=f"计算执行失败: {str(e)}",
calculation_time=time.time() - start_time
)
self._update_statistics(result)
return result
@classmethod
def get_interest_calculator_info(cls) -> "InterestCalculatorInfo":
"""从类属性生成InterestCalculatorInfo
遵循BaseCommand和BaseAction的设计模式从类属性自动生成组件信息
Returns:
InterestCalculatorInfo: 生成的兴趣计算器信息对象
"""
name = getattr(cls, 'component_name', cls.__name__.lower().replace('calculator', ''))
if "." in name:
logger.error(f"InterestCalculator名称 '{name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"InterestCalculator名称 '{name}' 包含非法字符 '.',请使用下划线替代")
return InterestCalculatorInfo(
name=name,
component_type=ComponentType.INTEREST_CALCULATOR,
description=getattr(cls, 'component_description', cls.__doc__ or "兴趣度计算器"),
enabled_by_default=getattr(cls, 'enabled_by_default', True),
)
def __repr__(self) -> str:
return (f"{self.__class__.__name__}("
f"name={self.component_name}, "
f"version={self.component_version}, "
f"enabled={self._enabled})")

View File

@@ -1,11 +1,20 @@
from abc import abstractmethod
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, PlusCommandInfo, ToolInfo
from src.plugin_system.base.component_types import (
ActionInfo,
CommandInfo,
ComponentType,
EventHandlerInfo,
InterestCalculatorInfo,
PlusCommandInfo,
ToolInfo,
)
from .base_action import BaseAction
from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler
from .base_interest_calculator import BaseInterestCalculator
from .base_tool import BaseTool
from .plugin_base import PluginBase
from .plus_command import PlusCommand
@@ -21,6 +30,72 @@ class BasePlugin(PluginBase):
- Command组件处理命令请求
- 未来可扩展Scheduler、Listener等
"""
@classmethod
def _get_component_info_from_class(cls, component_class: type, component_type: ComponentType):
"""从组件类自动生成组件信息
Args:
component_class: 组件类
component_type: 组件类型
Returns:
对应类型的ComponentInfo对象
"""
if component_type == ComponentType.COMMAND:
if hasattr(component_class, 'get_command_info'):
return component_class.get_command_info()
else:
logger.warning(f"Command类 {component_class.__name__} 缺少 get_command_info 方法")
return None
elif component_type == ComponentType.ACTION:
if hasattr(component_class, 'get_action_info'):
return component_class.get_action_info()
else:
logger.warning(f"Action类 {component_class.__name__} 缺少 get_action_info 方法")
return None
elif component_type == ComponentType.INTEREST_CALCULATOR:
if hasattr(component_class, 'get_interest_calculator_info'):
return component_class.get_interest_calculator_info()
else:
logger.warning(f"InterestCalculator类 {component_class.__name__} 缺少 get_interest_calculator_info 方法")
return None
elif component_type == ComponentType.PLUS_COMMAND:
# PlusCommand的get_info逻辑可以在这里实现
logger.warning("PlusCommand的get_info逻辑尚未实现")
return None
elif component_type == ComponentType.TOOL:
# Tool的get_info逻辑可以在这里实现
logger.warning("Tool的get_info逻辑尚未实现")
return None
elif component_type == ComponentType.EVENT_HANDLER:
# EventHandler的get_info逻辑可以在这里实现
logger.warning("EventHandler的get_info逻辑尚未实现")
return None
else:
logger.error(f"不支持的组件类型: {component_type}")
return None
@classmethod
def get_component_info(cls, component_class: type, component_type: ComponentType):
"""获取组件信息的通用方法
这是一个便捷方法内部调用_get_component_info_from_class
Args:
component_class: 组件类
component_type: 组件类型
Returns:
对应类型的ComponentInfo对象
"""
return cls._get_component_info_from_class(component_class, component_type)
@abstractmethod
def get_plugin_components(
self,
@@ -30,6 +105,7 @@ class BasePlugin(PluginBase):
| tuple[PlusCommandInfo, type[PlusCommand]]
| tuple[EventHandlerInfo, type[BaseEventHandler]]
| tuple[ToolInfo, type[BaseTool]]
| tuple[InterestCalculatorInfo, type[BaseInterestCalculator]]
]:
"""获取插件包含的组件列表

View File

@@ -19,6 +19,7 @@ class ComponentType(Enum):
SCHEDULER = "scheduler" # 定时任务组件(预留)
EVENT_HANDLER = "event_handler" # 事件处理组件
CHATTER = "chatter" # 聊天处理器组件
INTEREST_CALCULATOR = "interest_calculator" # 兴趣度计算组件
def __str__(self) -> str:
return self.value
@@ -229,6 +230,17 @@ class ChatterInfo(ComponentInfo):
self.component_type = ComponentType.CHATTER
@dataclass
class InterestCalculatorInfo(ComponentInfo):
"""兴趣度计算组件信息(单例模式)"""
enabled_by_default: bool = True # 是否默认启用
def __post_init__(self):
super().__post_init__()
self.component_type = ComponentType.INTEREST_CALCULATOR
@dataclass
class EventInfo(ComponentInfo):
"""事件组件信息"""

View File

@@ -8,6 +8,7 @@ from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.base.base_chatter import BaseChatter
from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import (
ActionInfo,
@@ -16,6 +17,7 @@ from src.plugin_system.base.component_types import (
ComponentInfo,
ComponentType,
EventHandlerInfo,
InterestCalculatorInfo,
PluginInfo,
PlusCommandInfo,
ToolInfo,
@@ -162,8 +164,13 @@ class ComponentRegistry:
assert isinstance(component_info, ChatterInfo)
assert issubclass(component_class, BaseChatter)
ret = self._register_chatter_component(component_info, component_class)
case ComponentType.INTEREST_CALCULATOR:
assert isinstance(component_info, InterestCalculatorInfo)
assert issubclass(component_class, BaseInterestCalculator)
ret = self._register_interest_calculator_component(component_info, component_class)
case _:
logger.warning(f"未知组件类型: {component_type}")
ret = False
if not ret:
return False
@@ -311,6 +318,38 @@ class ComponentRegistry:
logger.debug(f"已注册Chatter组件: {chatter_name}")
return True
def _register_interest_calculator_component(
self, interest_calculator_info: "InterestCalculatorInfo", interest_calculator_class: type["BaseInterestCalculator"]
) -> bool:
"""注册InterestCalculator组件到特定注册表"""
calculator_name = interest_calculator_info.name
if not calculator_name:
logger.error(f"InterestCalculator组件 {interest_calculator_class.__name__} 必须指定名称")
return False
if not isinstance(interest_calculator_info, InterestCalculatorInfo) or not issubclass(interest_calculator_class, BaseInterestCalculator):
logger.error(f"注册失败: {calculator_name} 不是有效的InterestCalculator")
return False
# 创建专门的InterestCalculator注册表如果还没有
if not hasattr(self, "_interest_calculator_registry"):
self._interest_calculator_registry: dict[str, type["BaseInterestCalculator"]] = {}
if not hasattr(self, "_enabled_interest_calculator_registry"):
self._enabled_interest_calculator_registry: dict[str, type["BaseInterestCalculator"]] = {}
interest_calculator_class.plugin_name = interest_calculator_info.plugin_name
# 设置插件配置
interest_calculator_class.plugin_config = self.get_plugin_config(interest_calculator_info.plugin_name) or {}
self._interest_calculator_registry[calculator_name] = interest_calculator_class
if not interest_calculator_info.enabled:
logger.warning(f"InterestCalculator组件 {calculator_name} 未启用")
return True # 未启用,但是也是注册成功
self._enabled_interest_calculator_registry[calculator_name] = interest_calculator_class
logger.debug(f"已注册InterestCalculator组件: {calculator_name}")
return True
# === 组件移除相关 ===
async def remove_component(self, component_name: str, component_type: "ComponentType", plugin_name: str) -> bool:

View File

@@ -51,16 +51,28 @@ class ToolExecutor:
chat_id: 聊天标识符,用于日志记录
"""
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(self.chat_id)
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
# chat_stream 和 log_prefix 将在异步方法中初始化
self.chat_stream = None # type: ignore
self.log_prefix = f"[{chat_id}]"
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
# 二步工具调用状态管理
self._pending_step_two_tools: dict[str, dict[str, Any]] = {}
"""待处理的第二步工具调用,格式为 {tool_name: step_two_definition}"""
self._log_prefix_initialized = False
logger.info(f"{self.log_prefix}工具执行器初始化完成")
# logger.info(f"{self.log_prefix}工具执行器初始化完成") # 移到异步初始化中
async def _initialize_log_prefix(self):
"""异步初始化log_prefix和chat_stream"""
if not self._log_prefix_initialized:
from src.chat.message_receive.chat_stream import get_chat_manager
self.chat_stream = await get_chat_manager().get_stream(self.chat_id)
stream_name = await get_chat_manager().get_stream_name(self.chat_id)
self.log_prefix = f"[{stream_name or self.chat_id}]"
self._log_prefix_initialized = True
logger.info(f"{self.log_prefix}工具执行器初始化完成")
async def execute_from_chat_message(
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
@@ -77,6 +89,8 @@ class ToolExecutor:
如果return_details为False: Tuple[List[Dict], List[str], str] - (工具执行结果列表, 空, 空)
如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词)
"""
# 初始化log_prefix
await self._initialize_log_prefix()
# 获取可用工具
tools = self._get_tool_definitions()

View File

@@ -1,7 +0,0 @@
"""
亲和力聊天处理器插件
"""
from .plugin import AffinityChatterPlugin
__all__ = ["AffinityChatterPlugin"]

View File

@@ -0,0 +1,301 @@
"""AffinityFlow 风格兴趣值计算组件
基于原有的 AffinityFlow 兴趣度评分系统,提供标准化的兴趣值计算功能
"""
import time
from typing import TYPE_CHECKING
from src.chat.interest_system import bot_interest_manager
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator, InterestCalculationResult
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("affinity_interest_calculator")
class AffinityInterestCalculator(BaseInterestCalculator):
"""AffinityFlow 风格兴趣值计算组件"""
# 直接定义类属性
component_name = "affinity_interest_calculator"
component_version = "1.0.0"
component_description = "基于AffinityFlow逻辑的兴趣值计算组件使用智能兴趣匹配和用户关系评分"
def __init__(self):
super().__init__()
# 智能兴趣匹配配置(已在类属性中定义)
# 从配置加载评分权重
affinity_config = global_config.affinity_flow
self.score_weights = {
"interest_match": affinity_config.keyword_match_weight, # 兴趣匹配度权重
"relationship": affinity_config.relationship_weight, # 关系分权重
"mentioned": affinity_config.mention_bot_weight, # 是否提及bot权重
}
# 评分阈值
self.reply_threshold = affinity_config.reply_action_interest_threshold # 回复动作兴趣阈值
self.mention_threshold = affinity_config.mention_bot_adjustment_threshold # 提及bot后的调整阈值
# 连续不回复概率提升
self.no_reply_count = 0
self.max_no_reply_count = affinity_config.max_no_reply_count
self.probability_boost_per_no_reply = (
affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count
) # 每次不回复增加的概率
# 用户关系数据缓存
self.user_relationships: dict[str, float] = {} # user_id -> relationship_score
logger.info(f"[Affinity兴趣计算器] 初始化完成:")
logger.info(f" - 权重配置: {self.score_weights}")
logger.info(f" - 回复阈值: {self.reply_threshold}")
logger.info(f" - 智能匹配: {self.use_smart_matching}")
# 检查 bot_interest_manager 状态
try:
logger.info(f" - bot_interest_manager 初始化状态: {bot_interest_manager.is_initialized}")
if not bot_interest_manager.is_initialized:
logger.warning(" - bot_interest_manager 未初始化这将导致兴趣匹配返回默认值0.3")
except Exception as e:
logger.error(f" - 检查 bot_interest_manager 时出错: {e}")
async def execute(self, message: "DatabaseMessages") -> InterestCalculationResult:
"""执行AffinityFlow风格的兴趣值计算"""
try:
start_time = time.time()
message_id = getattr(message, 'message_id', '')
content = getattr(message, 'processed_plain_text', '')
user_id = getattr(message, 'user_info', {}).user_id if hasattr(message, 'user_info') and hasattr(message.user_info, 'user_id') else ''
logger.debug(f"[Affinity兴趣计算] 开始处理消息 {message_id}")
logger.debug(f"[Affinity兴趣计算] 消息内容: {content[:50]}...")
logger.debug(f"[Affinity兴趣计算] 用户ID: {user_id}")
# 1. 计算兴趣匹配分
keywords = self._extract_keywords_from_database(message)
interest_match_score = await self._calculate_interest_match_score(content, keywords)
logger.debug(f"[Affinity兴趣计算] 兴趣匹配分: {interest_match_score}")
# 2. 计算关系分
relationship_score = await self._calculate_relationship_score(user_id)
logger.debug(f"[Affinity兴趣计算] 关系分: {relationship_score}")
# 3. 计算提及分
mentioned_score = self._calculate_mentioned_score(message, global_config.bot.nickname)
logger.debug(f"[Affinity兴趣计算] 提及分: {mentioned_score}")
# 4. 综合评分
# 确保所有分数都是有效的 float 值
interest_match_score = float(interest_match_score) if interest_match_score is not None else 0.0
relationship_score = float(relationship_score) if relationship_score is not None else 0.0
mentioned_score = float(mentioned_score) if mentioned_score is not None else 0.0
total_score = (
interest_match_score * self.score_weights["interest_match"]
+ relationship_score * self.score_weights["relationship"]
+ mentioned_score * self.score_weights["mentioned"]
)
logger.debug(f"[Affinity兴趣计算] 综合得分计算: {interest_match_score:.3f}*{self.score_weights['interest_match']} + "
f"{relationship_score:.3f}*{self.score_weights['relationship']} + "
f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {total_score:.3f}")
# 5. 考虑连续不回复的概率提升
adjusted_score = self._apply_no_reply_boost(total_score)
logger.debug(f"[Affinity兴趣计算] 应用不回复提升后: {total_score:.3f}{adjusted_score:.3f}")
# 6. 决定是否回复和执行动作
should_reply = adjusted_score > self.reply_threshold
should_take_action = adjusted_score > (self.reply_threshold + 0.1)
logger.debug(f"[Affinity兴趣计算] 阈值判断: {adjusted_score:.3f} > 回复阈值:{self.reply_threshold:.3f}? = {should_reply}")
logger.debug(f"[Affinity兴趣计算] 阈值判断: {adjusted_score:.3f} > 动作阈值:{self.reply_threshold + 0.1:.3f}? = {should_take_action}")
calculation_time = time.time() - start_time
logger.debug(f"Affinity兴趣值计算完成 - 消息 {message_id}: {adjusted_score:.3f} "
f"(匹配:{interest_match_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})")
return InterestCalculationResult(
success=True,
message_id=message_id,
interest_value=adjusted_score,
should_take_action=should_take_action,
should_reply=should_reply,
should_act=should_take_action,
calculation_time=calculation_time
)
except Exception as e:
logger.error(f"Affinity兴趣值计算失败: {e}", exc_info=True)
return InterestCalculationResult(
success=False,
message_id=getattr(message, 'message_id', ''),
interest_value=0.0,
error_message=str(e)
)
async def _calculate_interest_match_score(self, content: str, keywords: list[str] = None) -> float:
"""计算兴趣匹配度(使用智能兴趣匹配系统)"""
# 调试日志:检查各个条件
if not content:
logger.debug("兴趣匹配返回0.0: 内容为空")
return 0.0
if not self.use_smart_matching:
logger.debug("兴趣匹配返回0.0: 智能匹配未启用")
return 0.0
if not bot_interest_manager.is_initialized:
logger.debug("兴趣匹配返回0.0: bot_interest_manager未初始化")
return 0.0
logger.debug(f"开始兴趣匹配计算,内容: {content[:50]}...")
try:
# 使用机器人的兴趣标签系统进行智能匹配
match_result = await bot_interest_manager.calculate_interest_match(content, keywords)
logger.debug(f"兴趣匹配结果: {match_result}")
if match_result:
# 返回匹配分数,考虑置信度和匹配标签数量
affinity_config = global_config.affinity_flow
match_count_bonus = min(
len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus
)
final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus
logger.debug(f"兴趣匹配最终得分: {final_score}")
return final_score
else:
logger.debug("兴趣匹配返回0.0: match_result为None")
return 0.0
except Exception as e:
logger.warning(f"智能兴趣匹配失败: {e}")
return 0.0
async def _calculate_relationship_score(self, user_id: str) -> float:
"""计算用户关系分"""
if not user_id:
return global_config.affinity_flow.base_relationship_score
# 优先使用内存中的关系分
if user_id in self.user_relationships:
relationship_value = self.user_relationships[user_id]
return min(relationship_value, 1.0)
# 如果内存中没有,尝试从关系追踪器获取
try:
from .relationship_tracker import ChatterRelationshipTracker
global_tracker = ChatterRelationshipTracker()
if global_tracker:
relationship_score = await global_tracker.get_user_relationship_score(user_id)
# 同时更新内存缓存
self.user_relationships[user_id] = relationship_score
return relationship_score
except Exception as e:
logger.debug(f"获取用户关系分失败: {e}")
# 默认新用户的基础分
return global_config.affinity_flow.base_relationship_score
def _calculate_mentioned_score(self, message: "DatabaseMessages", bot_nickname: str) -> float:
"""计算提及分"""
is_mentioned = getattr(message, 'is_mentioned', False)
is_at = getattr(message, 'is_at', False)
processed_plain_text = getattr(message, 'processed_plain_text', '')
if is_mentioned:
if is_at:
return 1.0 # 直接@机器人,最高分
else:
return 0.8 # 提及机器人名字,高分
else:
# 检查是否被提及(文本匹配)
bot_aliases = [bot_nickname] + global_config.bot.alias_names
is_text_mentioned = any(alias in processed_plain_text for alias in bot_aliases if alias)
# 如果被提及或是私聊都视为提及了bot
if is_text_mentioned or not hasattr(message, "chat_info_group_id"):
return global_config.affinity_flow.mention_bot_interest_score
else:
return 0.0 # 未提及机器人
def _apply_no_reply_boost(self, base_score: float) -> float:
"""应用连续不回复的概率提升"""
if self.no_reply_count > 0 and self.no_reply_count < self.max_no_reply_count:
boost = self.no_reply_count * self.probability_boost_per_no_reply
return min(1.0, base_score + boost)
return base_score
def _extract_keywords_from_database(self, message: "DatabaseMessages") -> list[str]:
"""从数据库消息中提取关键词"""
keywords = []
# 尝试从 key_words 字段提取存储的是JSON字符串
key_words = getattr(message, 'key_words', '')
if key_words:
try:
import orjson
extracted = orjson.loads(key_words)
if isinstance(extracted, list):
keywords = extracted
except (orjson.JSONDecodeError, TypeError):
keywords = []
# 如果没有 keywords尝试从 key_words_lite 提取
if not keywords:
key_words_lite = getattr(message, 'key_words_lite', '')
if key_words_lite:
try:
import orjson
extracted = orjson.loads(key_words_lite)
if isinstance(extracted, list):
keywords = extracted
except (orjson.JSONDecodeError, TypeError):
keywords = []
# 如果还是没有,从消息内容中提取(降级方案)
if not keywords:
content = getattr(message, 'processed_plain_text', '') or ''
keywords = self._extract_keywords_from_content(content)
return keywords[:15] # 返回前15个关键词
def _extract_keywords_from_content(self, content: str) -> list[str]:
"""从内容中提取关键词(降级方案)"""
import re
# 清理文本
content = re.sub(r"[^\w\s\u4e00-\u9fff]", " ", content) # 保留中文、英文、数字
words = content.split()
# 过滤和关键词提取
keywords = []
for word in words:
word = word.strip()
if (
len(word) >= 2 # 至少2个字符
and word.isalnum() # 字母数字
and not word.isdigit()
): # 不是纯数字
keywords.append(word.lower())
# 去重并限制数量
unique_keywords = list(set(keywords))
return unique_keywords[:10] # 返回前10个唯一关键词
def update_no_reply_count(self, replied: bool):
"""更新连续不回复计数"""
if replied:
self.no_reply_count = 0
else:
self.no_reply_count = min(self.no_reply_count + 1, self.max_no_reply_count)
# 是否使用智能兴趣匹配(作为类属性)
use_smart_matching = True

View File

@@ -323,7 +323,7 @@ class ChatterPlanFilter:
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(plan.chat_id)
chat_stream = await chat_manager.get_stream(plan.chat_id)
if not chat_stream:
logger.warning(f"[plan_filter] 聊天流 {plan.chat_id} 不存在")
return "最近没有聊天内容。", "没有未读消息。", []
@@ -397,52 +397,22 @@ class ChatterPlanFilter:
interest_scores = {}
try:
from src.common.data_models.database_data_model import DatabaseMessages
from .interest_scoring import chatter_interest_scoring_system
# 使用插件内部的兴趣度评分系统计算评分
# 直接使用消息中已预计算的兴趣值,无需重新计算
for msg_dict in messages:
try:
# 将字典转换为DatabaseMessages对象
# 处理两种可能的数据格式flatten()返回的平铺字段 或 包含user_info字段的字典
user_info_dict = msg_dict.get("user_info", {})
if isinstance(user_info_dict, dict) and user_info_dict:
# 如果有user_info字段使用它
db_message = DatabaseMessages(
message_id=msg_dict.get("message_id", ""),
user_id=user_info_dict.get("user_id", ""),
user_nickname=user_info_dict.get("user_nickname", ""),
user_platform=user_info_dict.get("platform", ""),
processed_plain_text=msg_dict.get("processed_plain_text", ""),
key_words=msg_dict.get("key_words", "[]"),
is_mentioned=msg_dict.get("is_mentioned", False),
**{"user_info": user_info_dict}, # 通过kwargs传入user_info
)
else:
# 如果没有user_info字段使用平铺的字段flatten()方法返回的格式)
db_message = DatabaseMessages(
message_id=msg_dict.get("message_id", ""),
user_id=msg_dict.get("user_id", ""),
user_nickname=msg_dict.get("user_nickname", ""),
user_platform=msg_dict.get("user_platform", ""),
processed_plain_text=msg_dict.get("processed_plain_text", ""),
key_words=msg_dict.get("key_words", "[]"),
is_mentioned=msg_dict.get("is_mentioned", False),
)
# 计算消息兴趣度
interest_score_obj = await chatter_interest_scoring_system._calculate_single_message_score(
message=db_message, bot_nickname=global_config.bot.nickname
)
interest_score = interest_score_obj.total_score
# 直接使用消息中已预计算的兴趣值
interest_score = msg_dict.get("interest_value", 0.3)
should_reply = msg_dict.get("should_reply", False)
# 构建兴趣度字典
interest_scores[msg_dict.get("message_id", "")] = interest_score
logger.debug(f"使用消息预计算兴趣值: {interest_score:.3f}, should_reply: {should_reply}")
except Exception as e:
logger.warning(f"计算消息兴趣失败: {e}")
continue
logger.warning(f"获取消息预计算兴趣失败: {e}")
# 使用默认值
interest_scores[msg_dict.get("message_id", "")] = 0.3
except Exception as e:
logger.warning(f"获取兴趣度评分失败: {e}")

View File

@@ -54,7 +54,7 @@ class ChatterPlanGenerator:
"""
try:
# 获取聊天类型和目标信息
chat_type, target_info = get_chat_type_and_target_info(self.chat_id)
chat_type, target_info = await get_chat_type_and_target_info(self.chat_id)
# 获取可用动作列表
available_actions = await self._get_available_actions(chat_type, mode)

View File

@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any
from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
@@ -105,43 +104,39 @@ class ChatterActionPlanner:
interest_updates: list[dict[str, Any]] = []
if unread_messages:
# 为每条消息计算兴趣度,并延迟提交数据库更新
# 直接使用消息中已计算的标志,无需重复计算兴趣值
for message in unread_messages:
try:
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
message=message,
bot_nickname=global_config.bot.nickname,
)
message_interest = interest_score.total_score
message_interest = getattr(message, 'interest_value', 0.3)
message_should_reply = getattr(message, 'should_reply', False)
message_should_act = getattr(message, 'should_act', False)
message.interest_value = message_interest
message.should_reply = (
message_interest > global_config.affinity_flow.non_reply_action_interest_threshold
)
interest_updates.append(
{
"message_id": message.message_id,
"interest_value": message_interest,
"should_reply": message.should_reply,
}
)
logger.info(
f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}"
)
# 确保interest_value不是None
if message_interest is None:
message_interest = 0.3
# 更新最高兴趣度消息
if message_interest > score:
score = message_interest
if message.should_reply:
if message_should_reply:
should_reply = True
else:
reply_not_available = True
# 如果should_act为false强制设为no_action
if not message_should_act:
reply_not_available = True
logger.debug(
f"消息 {message.message_id} 预计算标志: interest={message_interest:.3f}, "
f"should_reply={message_should_reply}, should_act={message_should_act}"
)
except Exception as e:
logger.warning(f"计算消息 {message.message_id} 兴趣度失败: {e}")
logger.warning(f"处理消息 {message.message_id} 失败: {e}")
message.interest_value = 0.0
message.should_reply = False
message.should_act = False
interest_updates.append(
{
"message_id": message.message_id,

View File

@@ -1,11 +1,11 @@
"""
亲和力聊天处理器插件
亲和力聊天处理器插件(包含兴趣计算器功能)
"""
from src.common.logger import get_logger
from src.plugin_system.apis.plugin_register_api import register_plugin
from src.plugin_system.base.base_plugin import BasePlugin
from src.plugin_system.base.component_types import ComponentInfo
from src.plugin_system.base.component_types import ComponentInfo, ComponentType, InterestCalculatorInfo
logger = get_logger("affinity_chatter_plugin")
@@ -15,6 +15,7 @@ class AffinityChatterPlugin(BasePlugin):
"""亲和力聊天处理器插件
- 延迟导入 `AffinityChatter` 并通过组件注册器注册为聊天处理器
- 延迟导入 `AffinityInterestCalculator` 并通过组件注册器注册为兴趣计算器
- 提供 `get_plugin_components` 以兼容插件注册机制
"""
@@ -28,17 +29,27 @@ class AffinityChatterPlugin(BasePlugin):
config_schema = {}
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
"""返回插件包含的组件列表ChatterInfo, AffinityChatter
"""返回插件包含的组件列表
这里采用延迟导入 AffinityChatter 来避免循环依赖和启动顺序问题。
这里采用延迟导入避免循环依赖和启动顺序问题。
如果导入失败则返回空列表以让注册过程继续而不崩溃。
"""
components = []
try:
# 延迟导入以避免循环导入
# 延迟导入 AffinityChatter
from .affinity_chatter import AffinityChatter
return [(AffinityChatter.get_chatter_info(), AffinityChatter)]
components.append((AffinityChatter.get_chatter_info(), AffinityChatter))
except Exception as e:
logger.error(f"加载 AffinityChatter 时出错: {e}")
return []
try:
# 延迟导入 AffinityInterestCalculator
from .affinity_interest_calculator import AffinityInterestCalculator
components.append((AffinityInterestCalculator.get_interest_calculator_info(), AffinityInterestCalculator))
except Exception as e:
logger.error(f"加载 AffinityInterestCalculator 时出错: {e}")
return components