refactor(chat): 异步化聊天系统并重构兴趣值计算机制

将同步调用改为异步调用以提升性能,重构兴趣值计算流程以支持更灵活的组件化架构。主要改进包括:

- 异步化ChatManager相关方法,避免阻塞主线程
- 重构兴趣值计算系统,从插件内部计算改为通过兴趣管理器统一处理
- 新增should_act字段支持更细粒度的动作决策
- 优化初始化逻辑,避免构造函数中的异步操作
- 扩展插件系统支持兴趣计算器组件注册
- 更新数据库模型以支持新的兴趣值相关字段

这些改进提升了系统的响应性能和可扩展性,同时保持了API的向后兼容性。
This commit is contained in:
Windpicker-owo
2025-10-05 01:25:52 +08:00
parent 80c3ee8524
commit 481252d660
38 changed files with 1493 additions and 262 deletions

View File

@@ -63,7 +63,7 @@ class ExpressionLearner:
model_set=model_config.model_task_config.replyer, request_type="expressor.learner"
)
self.chat_id = chat_id
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
self.chat_name = chat_id # 初始化时使用chat_id稍后异步更新
# 维护每个chat的上次学习时间
self.last_learning_time: float = time.time()
@@ -71,6 +71,14 @@ class ExpressionLearner:
# 学习参数
self.min_messages_for_learning = 25 # 触发学习所需的最少消息数
self.min_learning_interval = 300 # 最短学习时间间隔(秒)
self._chat_name_initialized = False
async def _initialize_chat_name(self):
"""异步初始化chat_name"""
if not self._chat_name_initialized:
stream_name = await get_chat_manager().get_stream_name(self.chat_id)
self.chat_name = stream_name or self.chat_id
self._chat_name_initialized = True
def can_learn_for_chat(self) -> bool:
"""
@@ -144,6 +152,9 @@ class ExpressionLearner:
Returns:
bool: 是否成功触发学习
"""
# 初始化chat_name
await self._initialize_chat_name()
if not await self.should_trigger_learning():
return False
@@ -293,7 +304,7 @@ class ExpressionLearner:
return []
learnt_expressions, chat_id = res
chat_stream = get_chat_manager().get_stream(chat_id)
chat_stream = await get_chat_manager().get_stream(chat_id)
if chat_stream is None:
group_name = f"聊天流 {chat_id}"
elif chat_stream.group_info:

View File

@@ -1,16 +1,22 @@
"""
兴趣度系统模块
提供机器人兴趣标签和智能匹配功能
提供机器人兴趣标签和智能匹配功能,以及消息兴趣值计算功能
"""
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
from .bot_interest_manager import BotInterestManager, bot_interest_manager
from .interest_manager import InterestManager, get_interest_manager
__all__ = [
# 机器人兴趣标签管理
"BotInterestManager",
"BotInterestTag",
"BotPersonalityInterests",
"InterestMatchResult",
"bot_interest_manager",
# 消息兴趣值计算管理
"InterestManager",
"get_interest_manager",
]

View File

@@ -0,0 +1,223 @@
"""兴趣值计算组件管理器
管理兴趣值计算组件的生命周期,确保系统只能有一个兴趣值计算组件实例运行
"""
import asyncio
import time
from typing import TYPE_CHECKING
from src.common.logger import get_logger
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator, InterestCalculationResult
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("interest_manager")
class InterestManager:
"""兴趣值计算组件管理器"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self._current_calculator: BaseInterestCalculator | None = None
self._calculator_lock = asyncio.Lock()
self._last_calculation_time = 0.0
self._total_calculations = 0
self._failed_calculations = 0
self._calculation_queue = asyncio.Queue()
self._worker_task = None
self._shutdown_event = asyncio.Event()
self._initialized = True
async def initialize(self):
"""初始化管理器"""
if self._worker_task is None:
self._worker_task = asyncio.create_task(self._calculation_worker())
logger.info("兴趣值管理器已启动")
async def shutdown(self):
"""关闭管理器"""
self._shutdown_event.set()
if self._worker_task:
self._worker_task.cancel()
try:
await self._worker_task
except asyncio.CancelledError:
pass
if self._current_calculator:
await self._current_calculator.cleanup()
self._current_calculator = None
logger.info("兴趣值管理器已关闭")
async def register_calculator(self, calculator: BaseInterestCalculator) -> bool:
"""注册兴趣值计算组件(系统只能有一个活跃的兴趣值计算器)
Args:
calculator: 兴趣值计算组件实例
Returns:
bool: 注册是否成功
"""
async with self._calculator_lock:
try:
# 检查是否已有相同的计算器
if self._current_calculator and self._current_calculator.component_name == calculator.component_name:
logger.warning(f"兴趣值计算组件 {calculator.component_name} 已经注册,跳过重复注册")
return True
# 如果已有组件在运行,先清理并替换
if self._current_calculator:
logger.info(f"替换现有兴趣值计算组件: {self._current_calculator.component_name} -> {calculator.component_name}")
await self._current_calculator.cleanup()
else:
logger.info(f"注册新的兴趣值计算组件: {calculator.component_name}")
# 初始化新组件
if await calculator.initialize():
self._current_calculator = calculator
logger.info(f"兴趣值计算组件注册成功: {calculator.component_name} v{calculator.component_version}")
logger.info("系统现在只有一个活跃的兴趣值计算器")
return True
else:
logger.error(f"兴趣值计算组件初始化失败: {calculator.component_name}")
return False
except Exception as e:
logger.error(f"注册兴趣值计算组件失败: {e}", exc_info=True)
return False
async def calculate_interest(self, message: "DatabaseMessages") -> InterestCalculationResult:
"""计算消息兴趣值
Args:
message: 数据库消息对象
Returns:
InterestCalculationResult: 计算结果
"""
if not self._current_calculator:
# 返回默认结果
return InterestCalculationResult(
success=False,
message_id=getattr(message, 'message_id', ''),
interest_value=0.3,
error_message="没有可用的兴趣值计算组件"
)
# 异步执行计算,避免阻塞
future = asyncio.create_task(self._async_calculate(message))
return await future
async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult:
"""异步执行兴趣值计算"""
start_time = time.time()
self._total_calculations += 1
try:
# 使用组件的安全执行方法
result = await self._current_calculator._safe_execute(message)
if result.success:
self._last_calculation_time = time.time()
logger.debug(f"兴趣值计算完成: {result.interest_value:.3f} (耗时: {result.calculation_time:.3f}s)")
else:
self._failed_calculations += 1
logger.warning(f"兴趣值计算失败: {result.error_message}")
return result
except Exception as e:
self._failed_calculations += 1
logger.error(f"兴趣值计算异常: {e}", exc_info=True)
return InterestCalculationResult(
success=False,
message_id=getattr(message, 'message_id', ''),
interest_value=0.0,
error_message=f"计算异常: {str(e)}",
calculation_time=time.time() - start_time
)
async def _calculation_worker(self):
"""计算工作线程(预留用于批量处理)"""
while not self._shutdown_event.is_set():
try:
# 等待计算任务或关闭信号
await asyncio.wait_for(
self._calculation_queue.get(),
timeout=1.0
)
# 处理计算任务
# 这里可以实现批量处理逻辑
except asyncio.TimeoutError:
# 超时继续循环
continue
except asyncio.CancelledError:
# 任务被取消,退出循环
break
except Exception as e:
logger.error(f"计算工作线程异常: {e}", exc_info=True)
def get_current_calculator(self) -> BaseInterestCalculator | None:
"""获取当前活跃的兴趣值计算组件"""
return self._current_calculator
def get_statistics(self) -> dict:
"""获取管理器统计信息"""
success_rate = 1.0 - (self._failed_calculations / max(1, self._total_calculations))
stats = {
"manager_statistics": {
"total_calculations": self._total_calculations,
"failed_calculations": self._failed_calculations,
"success_rate": success_rate,
"last_calculation_time": self._last_calculation_time,
"current_calculator": self._current_calculator.component_name if self._current_calculator else None
}
}
# 添加当前组件的统计信息
if self._current_calculator:
stats["calculator_statistics"] = self._current_calculator.get_statistics()
return stats
async def health_check(self) -> bool:
"""健康检查"""
if not self._current_calculator:
return False
try:
# 检查组件是否还活跃
return self._current_calculator.is_enabled
except Exception:
return False
def has_calculator(self) -> bool:
"""检查是否有可用的计算组件"""
return self._current_calculator is not None and self._current_calculator.is_enabled
# 全局实例
_interest_manager = None
def get_interest_manager() -> InterestManager:
"""获取兴趣值管理器实例"""
global _interest_manager
if _interest_manager is None:
_interest_manager = InterestManager()
return _interest_manager

View File

@@ -0,0 +1,166 @@
"""消息兴趣值计算组件管理器
管理消息兴趣值计算组件,确保系统只能有一个兴趣值计算组件实例运行
"""
import asyncio
import time
from typing import TYPE_CHECKING
from src.common.logger import get_logger
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator, InterestCalculationResult
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("message_interest_manager")
class MessageInterestManager:
"""消息兴趣值计算组件管理器"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self._current_calculator: BaseInterestCalculator | None = None
self._calculator_lock = asyncio.Lock()
self._last_calculation_time = 0.0
self._total_calculations = 0
self._failed_calculations = 0
self._initialized = True
async def initialize(self):
"""初始化管理器"""
logger.info("消息兴趣值管理器已初始化")
async def register_calculator(self, calculator: BaseInterestCalculator) -> bool:
"""注册兴趣值计算组件
Args:
calculator: 兴趣值计算组件实例
Returns:
bool: 注册是否成功
"""
async with self._calculator_lock:
try:
# 如果已有组件在运行,先清理
if self._current_calculator:
logger.info(f"替换现有消息兴趣值计算组件: {self._current_calculator.component_name}")
await self._current_calculator.cleanup()
# 初始化新组件
if await calculator.initialize():
self._current_calculator = calculator
logger.info(f"消息兴趣值计算组件注册成功: {calculator.component_name} v{calculator.component_version}")
return True
else:
logger.error(f"消息兴趣值计算组件初始化失败: {calculator.component_name}")
return False
except Exception as e:
logger.error(f"注册消息兴趣值计算组件失败: {e}", exc_info=True)
return False
async def calculate_interest(self, message: "DatabaseMessages") -> InterestCalculationResult:
"""计算消息兴趣值
Args:
message: 数据库消息对象
Returns:
InterestCalculationResult: 计算结果
"""
if not self._current_calculator:
# 返回默认结果
return InterestCalculationResult(
success=False,
message_id=getattr(message, 'message_id', ''),
interest_value=0.3,
error_message="没有可用的消息兴趣值计算组件"
)
start_time = time.time()
self._total_calculations += 1
try:
# 使用组件的安全执行方法
result = await self._current_calculator._safe_execute(message)
if result.success:
self._last_calculation_time = time.time()
logger.debug(f"消息兴趣值计算完成: {result.interest_value:.3f} (耗时: {result.calculation_time:.3f}s)")
else:
self._failed_calculations += 1
logger.warning(f"消息兴趣值计算失败: {result.error_message}")
return result
except Exception as e:
self._failed_calculations += 1
logger.error(f"消息兴趣值计算异常: {e}", exc_info=True)
return InterestCalculationResult(
success=False,
message_id=getattr(message, 'message_id', ''),
interest_value=0.0,
error_message=f"计算异常: {str(e)}",
calculation_time=time.time() - start_time
)
def get_current_calculator(self) -> BaseInterestCalculator | None:
"""获取当前活跃的兴趣值计算组件"""
return self._current_calculator
def get_statistics(self) -> dict:
"""获取管理器统计信息"""
success_rate = 1.0 - (self._failed_calculations / max(1, self._total_calculations))
stats = {
"manager_statistics": {
"total_calculations": self._total_calculations,
"failed_calculations": self._failed_calculations,
"success_rate": success_rate,
"last_calculation_time": self._last_calculation_time,
"current_calculator": self._current_calculator.component_name if self._current_calculator else None
}
}
# 添加当前组件的统计信息
if self._current_calculator:
stats["calculator_statistics"] = self._current_calculator.get_statistics()
return stats
async def health_check(self) -> bool:
"""健康检查"""
if not self._current_calculator:
return False
try:
# 检查组件是否还活跃
return self._current_calculator.is_enabled
except Exception:
return False
def has_calculator(self) -> bool:
"""检查是否有可用的计算组件"""
return self._current_calculator is not None and self._current_calculator.is_enabled
# 全局实例
_message_interest_manager = None
def get_message_interest_manager() -> MessageInterestManager:
"""获取消息兴趣值管理器实例"""
global _message_interest_manager
if _message_interest_manager is None:
_message_interest_manager = MessageInterestManager()
return _message_interest_manager

View File

@@ -995,7 +995,7 @@ class MemorySystem:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream or not hasattr(chat_stream, "context_manager"):
logger.debug(f"未找到stream_id={stream_id}的聊天流或上下文管理器")
@@ -1109,7 +1109,7 @@ class MemorySystem:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream, "context_manager"):
history_limit = self._determine_history_limit(context)
messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True)

View File

@@ -55,8 +55,8 @@ class SingleStreamContextManager:
"""
try:
self.context.add_message(message)
# 推迟兴趣度计算到分发阶段
message.interest_value = getattr(message, "interest_value", None)
# 在上下文管理器中计算兴趣值
await self._calculate_message_interest(message)
self.total_messages += 1
self.last_access_time = time.time()
# 启动流的循环任务(如果还未启动)
@@ -228,51 +228,44 @@ class SingleStreamContextManager:
async def _calculate_message_interest(self, message: DatabaseMessages) -> float:
"""
异步计算消息的兴趣度
此方法通过检查当前是否存在正在运行的 asyncio 事件循环来兼容同步和异步调用。
在上下文管理器中计算消息的兴趣度
"""
# 内部异步函数,封装实际的计算逻辑
async def _get_score():
try:
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
chatter_interest_scoring_system,
)
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
message=message, bot_nickname=global_config.bot.nickname
)
interest_value = interest_score.total_score
logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}")
return interest_value
except ImportError as e:
logger.debug(f"兴趣度计算插件加载失败,可能未启用: {e}")
return 0.5
except Exception as e:
# 在某些情况下(例如机器人自己的消息),没有兴趣度是正常的
logger.info(f"插件内部兴趣度计算失败,使用默认值: {e}")
return 0.5
# 检查并获取当前事件循环
try:
loop = asyncio.get_running_loop()
except RuntimeError: # 'RuntimeError: There is no current event loop...'
loop = None
from src.chat.interest_system.interest_manager import get_interest_manager
if loop and loop.is_running():
# 如果事件循环正在运行,直接 await
return await _get_score()
else:
# 否则,使用 asyncio.run() 来安全执行
return asyncio.run(_get_score())
interest_manager = get_interest_manager()
if interest_manager.has_calculator():
# 使用兴趣值计算组件计算
result = await interest_manager.calculate_interest(message)
if result.success:
# 更新消息对象的兴趣值相关字段
message.interest_value = result.interest_value
message.should_reply = result.should_reply
message.should_act = result.should_act
logger.debug(f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
f"should_reply: {result.should_reply}, should_act: {result.should_act}")
return result.interest_value
else:
logger.warning(f"消息 {message.message_id} 兴趣值计算失败: {result.error_message}")
return 0.5
else:
logger.debug("未找到兴趣值计算器,使用默认兴趣值")
return 0.5
except Exception as e:
logger.error(f"计算消息兴趣度时发生错误: {e}", exc_info=True)
return 0.5
async def add_message_async(self, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
"""异步实现的 add_message将消息添加到 context并 await 能量更新与分发。"""
try:
self.context.add_message(message)
# 推迟兴趣度计算到分发阶段
message.interest_value = getattr(message, "interest_value", None)
# 在上下文管理器中计算兴趣值
await self._calculate_message_interest(message)
self.total_messages += 1
self.last_access_time = time.time()
@@ -280,7 +273,7 @@ class SingleStreamContextManager:
# 启动流的循环任务(如果还未启动)
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id} (兴趣度待计算)")
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id}")
return True
except Exception as e:
logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)

View File

@@ -181,7 +181,7 @@ class StreamLoopManager:
async def _fallback_acquire_slot(self, stream_id: str, force: bool) -> bool:
"""回退方案:获取槽位(原始方法)"""
# 判断是否需要强制分发
should_force = force or self._should_force_dispatch_for_stream(stream_id)
should_force = force or await self._should_force_dispatch_for_stream(stream_id)
# 检查是否超过最大并发限制
current_streams = len(self.stream_loops)
@@ -410,7 +410,7 @@ class StreamLoopManager:
"""
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream:
return chat_stream.context_manager.context
return None
@@ -538,13 +538,13 @@ class StreamLoopManager:
self.chatter_manager = chatter_manager
logger.info(f"设置chatter管理器: {chatter_manager.__class__.__name__}")
def _should_force_dispatch_for_stream(self, stream_id: str) -> bool:
async def _should_force_dispatch_for_stream(self, stream_id: str) -> bool:
if not self.force_dispatch_unread_threshold or self.force_dispatch_unread_threshold <= 0:
return False
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
return False
@@ -595,7 +595,7 @@ class StreamLoopManager:
"""分发完成后基于历史消息刷新能量值"""
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.debug(f"刷新能量时未找到聊天流: {stream_id}")
return
@@ -622,7 +622,7 @@ class StreamLoopManager:
except Exception as e:
logger.error(f"等待流循环任务结束时出错: {stream_id} - {e}")
def _force_dispatch_stream(self, stream_id: str) -> None:
async def _force_dispatch_stream(self, stream_id: str) -> None:
"""强制分发流处理
当流的未读消息超过阈值时,强制触发分发处理
@@ -657,7 +657,7 @@ class StreamLoopManager:
# 获取聊天管理器和流
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"强制分发时未找到流: {stream_id}")
return

View File

@@ -132,7 +132,7 @@ class MessageManager:
"""添加消息到指定聊天流"""
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
return
@@ -153,7 +153,7 @@ class MessageManager:
"""更新消息信息"""
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.update_message: 聊天流 {stream_id} 不存在")
return
@@ -180,7 +180,7 @@ class MessageManager:
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.bulk_update_messages: 聊天流 {stream_id} 不存在")
return 0
@@ -211,7 +211,7 @@ class MessageManager:
"""添加动作到消息"""
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.add_action: 聊天流 {stream_id} 不存在")
return
@@ -223,12 +223,12 @@ class MessageManager:
except Exception as e:
logger.error(f"为消息 {message_id} 添加动作时发生错误: {e}")
def deactivate_stream(self, stream_id: str):
async def deactivate_stream(self, stream_id: str):
"""停用聊天流"""
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在")
return
@@ -245,12 +245,12 @@ class MessageManager:
except Exception as e:
logger.error(f"停用聊天流 {stream_id} 时发生错误: {e}")
def activate_stream(self, stream_id: str):
async def activate_stream(self, stream_id: str):
"""激活聊天流"""
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在")
return
@@ -262,12 +262,12 @@ class MessageManager:
except Exception as e:
logger.error(f"激活聊天流 {stream_id} 时发生错误: {e}")
def get_stream_stats(self, stream_id: str) -> StreamStats | None:
async def get_stream_stats(self, stream_id: str) -> StreamStats | None:
"""获取聊天流统计"""
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
return None
@@ -360,7 +360,7 @@ class MessageManager:
pass
# 增加打断计数并应用afc阈值降低
chat_stream.context_manager.context.increment_interruption_count()
await chat_stream.context_manager.context.increment_interruption_count()
chat_stream.context_manager.context.apply_interruption_afc_reduction(
global_config.chat.interruption_afc_reduction
)
@@ -382,7 +382,7 @@ class MessageManager:
try:
# 通过 ChatManager 获取 ChatStream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"清除消息失败: 聊天流 {stream_id} 不存在")
return
@@ -411,7 +411,7 @@ class MessageManager:
"""清除指定聊天流的所有未读消息"""
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"clear_stream_unread_messages: 聊天流 {stream_id} 不存在")
return

View File

@@ -161,7 +161,7 @@ class ChatStream:
self.last_active_time = time.time()
self.saved = False
def set_context(self, message: "MessageRecv"):
async def set_context(self, message: "MessageRecv"):
"""设置聊天消息上下文"""
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
import json
@@ -234,6 +234,7 @@ class ChatStream:
# 新增兴趣度系统字段 - 添加安全处理
actions=self._safe_get_actions(message),
should_reply=getattr(message, "should_reply", False),
should_act=getattr(message, "should_act", False),
)
self.stream_context.set_current_message(db_message)
@@ -280,6 +281,45 @@ class ChatStream:
logger.warning(f"获取actions字段失败: {e}")
return None
async def _calculate_message_interest(self, db_message):
"""计算消息兴趣值并更新消息对象"""
try:
from src.chat.interest_system.interest_manager import get_interest_manager
interest_manager = get_interest_manager()
if interest_manager.has_calculator():
# 使用兴趣值计算组件计算
result = await interest_manager.calculate_interest(db_message)
if result.success:
# 更新消息对象的兴趣值相关字段
db_message.interest_value = result.interest_value
db_message.should_reply = result.should_reply
db_message.should_act = result.should_act
logger.debug(f"消息 {db_message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
f"should_reply: {result.should_reply}, should_act: {result.should_act}")
else:
logger.warning(f"消息 {db_message.message_id} 兴趣值计算失败: {result.error_message}")
# 使用默认值
db_message.interest_value = 0.3
db_message.should_reply = False
db_message.should_act = False
else:
# 没有兴趣值计算组件,抛出异常
raise RuntimeError("没有可用的兴趣值计算组件")
except Exception as e:
logger.error(f"计算消息兴趣值失败: {e}", exc_info=True)
# 异常情况下使用默认值
if hasattr(db_message, 'interest_value'):
db_message.interest_value = 0.3
if hasattr(db_message, 'should_reply'):
db_message.should_reply = False
if hasattr(db_message, 'should_act'):
db_message.should_act = False
def _extract_reply_from_segment(self, segment) -> str | None:
"""从消息段中提取reply_to信息"""
try:
@@ -497,7 +537,9 @@ class ChatManager:
optimized_stream.set_context(self.last_messages[stream_id])
# 转换为原始ChatStream以保持兼容性
return self._convert_to_original_stream(optimized_stream)
original_stream = self._convert_to_original_stream(optimized_stream)
return original_stream
except Exception as e:
logger.debug(f"缓存管理器获取流失败,使用原始方法: {e}")
@@ -517,7 +559,7 @@ class ChatManager:
from .message import MessageRecv # 延迟导入,避免循环引用
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
stream.set_context(self.last_messages[stream_id])
await stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
return stream
@@ -581,7 +623,7 @@ class ChatManager:
from .message import MessageRecv # 延迟导入,避免循环引用
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
stream.set_context(self.last_messages[stream_id])
await stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
@@ -597,13 +639,13 @@ class ChatManager:
await self._save_stream(stream)
return stream
def get_stream(self, stream_id: str) -> ChatStream | None:
async def get_stream(self, stream_id: str) -> ChatStream | None:
"""通过stream_id获取聊天流"""
stream = self.streams.get(stream_id)
if not stream:
return None
if stream_id in self.last_messages:
stream.set_context(self.last_messages[stream_id])
await stream.set_context(self.last_messages[stream_id])
return stream
def get_stream_by_info(
@@ -613,9 +655,9 @@ class ChatManager:
stream_id = self._generate_stream_id(platform, user_info, group_info)
return self.streams.get(stream_id)
def get_stream_name(self, stream_id: str) -> str | None:
async def get_stream_name(self, stream_id: str) -> str | None:
"""根据 stream_id 获取聊天流名称"""
stream = self.get_stream(stream_id)
stream = await self.get_stream(stream_id)
if not stream:
return None
@@ -813,8 +855,9 @@ class ChatManager:
stream = ChatStream.from_dict(data)
stream.saved = True
self.streams[stream.stream_id] = stream
if stream.stream_id in self.last_messages:
stream.set_context(self.last_messages[stream.stream_id])
# 不在异步加载中设置上下文,避免复杂依赖
# if stream.stream_id in self.last_messages:
# await stream.set_context(self.last_messages[stream.stream_id])
# 确保 ChatStream 有自己的 context_manager
if not hasattr(stream, "context_manager"):

View File

@@ -165,7 +165,7 @@ class ChatterActionManager:
logger.debug(f"🎯 [ActionManager] execute_action接收到 target_message: {target_message}")
# 通过chat_id获取chat_stream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(chat_id)
chat_stream = await chat_manager.get_stream(chat_id)
if not chat_stream:
logger.error(f"{log_prefix} 无法找到chat_id对应的chat_stream: {chat_id}")
@@ -322,13 +322,13 @@ class ChatterActionManager:
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream:
context = chat_stream.context_manager
if context.context.interruption_count > 0:
old_count = context.context.interruption_count
old_afc_adjustment = context.context.get_afc_threshold_adjustment()
context.context.reset_interruption_count()
await context.context.reset_interruption_count()
logger.debug(
f"动作执行成功,重置聊天流 {stream_id} 的打断计数: {old_count} -> 0, afc调整: {old_afc_adjustment} -> 0"
)

View File

@@ -31,8 +31,9 @@ class ActionModifier:
def __init__(self, action_manager: ChatterActionManager, chat_id: str):
"""初始化动作处理器"""
self.chat_id = chat_id
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id) # type: ignore
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
# chat_stream 和 log_prefix 将在异步方法中初始化
self.chat_stream = None # type: ignore
self.log_prefix = f"[{chat_id}]"
self.action_manager = action_manager
@@ -43,6 +44,15 @@ class ActionModifier:
self._llm_judge_cache = {} # 缓存LLM判定结果
self._cache_expiry_time = 30 # 缓存过期时间(秒)
self._last_context_hash = None # 上次上下文的哈希值
self._log_prefix_initialized = False
async def _initialize_log_prefix(self):
"""异步初始化log_prefix和chat_stream"""
if not self._log_prefix_initialized:
self.chat_stream = await get_chat_manager().get_stream(self.chat_id)
stream_name = await get_chat_manager().get_stream_name(self.chat_id)
self.log_prefix = f"[{stream_name or self.chat_id}]"
self._log_prefix_initialized = True
async def modify_actions(
self,
@@ -57,6 +67,9 @@ class ActionModifier:
处理后ActionManager 将包含最终的可用动作集,供规划器直接使用
"""
# 初始化log_prefix
await self._initialize_log_prefix()
logger.debug(f"{self.log_prefix}开始完整动作修改流程")
removals_s1: list[tuple[str, str]] = []
@@ -72,7 +85,7 @@ class ActionModifier:
from src.plugin_system.core.component_registry import component_registry
# 获取聊天类型
is_group_chat, _ = get_chat_type_and_target_info(self.chat_id)
is_group_chat, _ = await get_chat_type_and_target_info(self.chat_id)
all_registered_actions = component_registry.get_components_by_type(ComponentType.ACTION)
chat_type_removals = []

View File

@@ -226,13 +226,21 @@ class DefaultReplyer:
):
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
self.chat_stream = chat_stream
self.is_group_chat: Optional[bool] = None
self.chat_target_info: Optional[Dict[str, Any]] = None
self._initialized = False
# 这些将在异步初始化中设置
self.is_group_chat = False
self.chat_target_info = None
self._chat_info_initialized = False
self.heart_fc_sender = HeartFCSender()
# 使用新的增强记忆系统
# from src.chat.memory_system.enhanced_memory_activator import EnhancedMemoryActivator
self._chat_info_initialized = False
async def _initialize_chat_info(self):
"""异步初始化聊天信息"""
if not self._chat_info_initialized:
self.is_group_chat, self.chat_target_info = await get_chat_type_and_target_info(self.chat_stream.stream_id)
self._chat_info_initialized = True
# self.memory_activator = EnhancedMemoryActivator()
self.memory_activator = None # 暂时禁用记忆激活器
# 旧的即时记忆系统已被移除,现在使用增强记忆系统
@@ -280,7 +288,9 @@ class DefaultReplyer:
Returns:
Tuple[bool, Optional[Dict[str, Any]], Optional[str]]: (是否成功, 生成的回复, 使用的prompt)
"""
# 初始化聊天信息
await self._initialize_chat_info()
prompt = None
selected_expressions = None
if available_actions is None:
@@ -825,7 +835,7 @@ class DefaultReplyer:
from src.plugin_system.apis.chat_api import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(chat_id)
chat_stream = await chat_manager.get_stream(chat_id)
if chat_stream:
stream_context = chat_stream.context_manager
# 使用真正的已读和未读消息
@@ -1015,47 +1025,24 @@ class DefaultReplyer:
return read_history_prompt, unread_history_prompt
async def _get_interest_scores_for_messages(self, messages: list[dict]) -> dict[str, float]:
"""为消息获取兴趣度评分"""
"""为消息获取兴趣度评分(使用预计算的兴趣值)"""
interest_scores = {}
try:
from src.common.data_models.database_data_model import DatabaseMessages
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
chatter_interest_scoring_system as interest_scoring_system,
)
# 转换消息格式
db_messages = []
# 直接使用消息中的预计算兴趣值
for msg_dict in messages:
try:
db_msg = DatabaseMessages(
message_id=msg_dict.get("message_id", ""),
time=msg_dict.get("time", time.time()),
chat_id=msg_dict.get("chat_id", ""),
processed_plain_text=msg_dict.get("processed_plain_text", ""),
user_id=msg_dict.get("user_id", ""),
user_nickname=msg_dict.get("user_nickname", ""),
user_platform=msg_dict.get("platform", "qq"),
chat_info_group_id=msg_dict.get("group_id", ""),
chat_info_group_name=msg_dict.get("group_name", ""),
chat_info_group_platform=msg_dict.get("platform", "qq"),
)
db_messages.append(db_msg)
except Exception as e:
logger.warning(f"转换消息格式失败: {e}")
continue
message_id = msg_dict.get("message_id", "")
interest_value = msg_dict.get("interest_value")
# 计算兴趣度评分
if db_messages:
bot_nickname = global_config.bot.nickname or "麦麦"
scores = await interest_scoring_system.calculate_interest_scores(db_messages, bot_nickname)
# 构建兴趣度字典
for score in scores:
interest_scores[score.message_id] = score.total_score
if interest_value is not None:
interest_scores[message_id] = float(interest_value)
logger.debug(f"使用预计算兴趣度 - 消息 {message_id}: {interest_value:.3f}")
else:
interest_scores[message_id] = 0.5 # 默认值
logger.debug(f"消息 {message_id} 无预计算兴趣值,使用默认值 0.5")
except Exception as e:
logger.warning(f"获取兴趣度评分失败: {e}")
logger.warning(f"处理预计算兴趣值失败: {e}")
return interest_scores

View File

@@ -1047,11 +1047,11 @@ class Prompt:
from src.plugin_system.apis import cross_context_api
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
other_chat_raw_ids = await cross_context_api.get_context_groups(chat_id)
if not other_chat_raw_ids:
return ""
chat_stream = get_chat_manager().get_stream(chat_id)
chat_stream = await get_chat_manager().get_stream(chat_id)
if not chat_stream:
return ""

View File

@@ -623,7 +623,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
return time.strftime("%H:%M:%S", time.localtime(timestamp))
def get_chat_type_and_target_info(chat_id: str) -> tuple[bool, dict | None]:
async def get_chat_type_and_target_info(chat_id: str) -> tuple[bool, dict | None]:
"""
获取聊天类型(是否群聊)和私聊对象信息。
@@ -640,7 +640,7 @@ def get_chat_type_and_target_info(chat_id: str) -> tuple[bool, dict | None]:
chat_target_info = None
try:
if chat_stream := get_chat_manager().get_stream(chat_id):
if chat_stream := await get_chat_manager().get_stream(chat_id):
if chat_stream.group_info:
is_group_chat = True
chat_target_info = None # Explicitly None for group chat