refactor(chat): 优化任务管理机制支持多重回复
重构聊天管理器的任务处理系统,将单一任务追踪改为支持多重回复的任务列表管理。 主要变更: - 将 `_processing_tasks` 从单任务字典改为任务列表字典 - 新增 `add_processing_task` 和 `get_all_processing_tasks` 方法 - 增强 `cancel_all_stream_tasks` 方法支持批量取消 - 修复消息打断机制,确保取消所有相关任务 - 优化任务清理逻辑,自动移除已完成任务 这些改进使系统能够更好地处理并发回复场景,提高任务管理的灵活性和可靠性。
This commit is contained in:
@@ -16,7 +16,8 @@ class ChatterManager:
|
|||||||
self.action_manager = action_manager
|
self.action_manager = action_manager
|
||||||
self.chatter_classes: dict[ChatType, list[type]] = {}
|
self.chatter_classes: dict[ChatType, list[type]] = {}
|
||||||
self.instances: dict[str, BaseChatter] = {}
|
self.instances: dict[str, BaseChatter] = {}
|
||||||
self._processing_tasks: dict[str, asyncio.Task] = {}
|
# 🌟 优化:统一任务追踪,支持多重回复
|
||||||
|
self._processing_tasks: dict[str, list[asyncio.Task]] = {}
|
||||||
|
|
||||||
# 管理器统计
|
# 管理器统计
|
||||||
self.stats = {
|
self.stats = {
|
||||||
@@ -174,15 +175,71 @@ class ChatterManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def set_processing_task(self, stream_id: str, task: asyncio.Task):
|
def set_processing_task(self, stream_id: str, task: asyncio.Task):
|
||||||
"""设置流的处理任务"""
|
"""设置流的主要处理任务"""
|
||||||
self._processing_tasks[stream_id] = task
|
if stream_id not in self._processing_tasks:
|
||||||
|
self._processing_tasks[stream_id] = []
|
||||||
|
self._processing_tasks[stream_id].insert(0, task) # 主要任务放在第一位
|
||||||
|
logger.debug(f"设置流 {stream_id} 的主要处理任务")
|
||||||
|
|
||||||
def get_processing_task(self, stream_id: str) -> asyncio.Task | None:
|
def get_processing_task(self, stream_id: str) -> asyncio.Task | None:
|
||||||
"""获取流的处理任务"""
|
"""获取流的主要处理任务"""
|
||||||
return self._processing_tasks.get(stream_id)
|
tasks = self._processing_tasks.get(stream_id, [])
|
||||||
|
return tasks[0] if tasks and not tasks[0].done() else None
|
||||||
|
|
||||||
|
def add_processing_task(self, stream_id: str, task: asyncio.Task):
|
||||||
|
"""添加处理任务到流(支持多重回复)"""
|
||||||
|
if stream_id not in self._processing_tasks:
|
||||||
|
self._processing_tasks[stream_id] = []
|
||||||
|
self._processing_tasks[stream_id].append(task)
|
||||||
|
logger.debug(f"添加处理任务到流 {stream_id},当前任务数: {len(self._processing_tasks[stream_id])}")
|
||||||
|
|
||||||
|
def get_all_processing_tasks(self, stream_id: str) -> list[asyncio.Task]:
|
||||||
|
"""获取流的所有活跃处理任务"""
|
||||||
|
if stream_id not in self._processing_tasks:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 清理已完成的任务并返回活跃任务
|
||||||
|
active_tasks = [task for task in self._processing_tasks[stream_id] if not task.done()]
|
||||||
|
self._processing_tasks[stream_id] = active_tasks
|
||||||
|
|
||||||
|
if len(active_tasks) == 0:
|
||||||
|
del self._processing_tasks[stream_id]
|
||||||
|
|
||||||
|
return active_tasks
|
||||||
|
|
||||||
|
def cancel_all_stream_tasks(self, stream_id: str) -> int:
|
||||||
|
"""取消指定流的所有处理任务(包括多重回复)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 流ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 成功取消的任务数量
|
||||||
|
"""
|
||||||
|
if stream_id not in self._processing_tasks:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
tasks = self._processing_tasks[stream_id]
|
||||||
|
cancelled_count = 0
|
||||||
|
|
||||||
|
logger.info(f"开始取消流 {stream_id} 的所有处理任务,共 {len(tasks)} 个")
|
||||||
|
|
||||||
|
for task in tasks:
|
||||||
|
try:
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
cancelled_count += 1
|
||||||
|
logger.debug(f"成功取消任务 {task.get_name() if hasattr(task, 'get_name') else 'unnamed'}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"取消任务时出错: {e}")
|
||||||
|
|
||||||
|
# 清理任务记录
|
||||||
|
del self._processing_tasks[stream_id]
|
||||||
|
logger.info(f"流 {stream_id} 的任务取消完成,成功取消 {cancelled_count} 个任务")
|
||||||
|
return cancelled_count
|
||||||
|
|
||||||
def cancel_processing_task(self, stream_id: str) -> bool:
|
def cancel_processing_task(self, stream_id: str) -> bool:
|
||||||
"""取消流的处理任务
|
"""取消流的主要处理任务
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stream_id: 流ID
|
stream_id: 流ID
|
||||||
@@ -190,14 +247,14 @@ class ChatterManager:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 是否成功取消了任务
|
bool: 是否成功取消了任务
|
||||||
"""
|
"""
|
||||||
task = self._processing_tasks.get(stream_id)
|
main_task = self.get_processing_task(stream_id)
|
||||||
if task and not task.done():
|
if main_task and not main_task.done():
|
||||||
try:
|
try:
|
||||||
task.cancel()
|
main_task.cancel()
|
||||||
logger.info(f"已取消流 {stream_id} 的处理任务")
|
logger.info(f"已取消流 {stream_id} 的主要处理任务")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"取消流 {stream_id} 的处理任务时出错: {e}")
|
logger.warning(f"取消流 {stream_id} 的主要处理任务时出错: {e}")
|
||||||
return False
|
return False
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -209,22 +266,30 @@ class ChatterManager:
|
|||||||
"""
|
"""
|
||||||
if stream_id in self._processing_tasks:
|
if stream_id in self._processing_tasks:
|
||||||
del self._processing_tasks[stream_id]
|
del self._processing_tasks[stream_id]
|
||||||
logger.debug(f"已移除流 {stream_id} 的处理任务记录")
|
logger.debug(f"已移除流 {stream_id} 的所有处理任务记录")
|
||||||
|
|
||||||
def get_active_processing_tasks(self) -> dict[str, asyncio.Task]:
|
def get_active_processing_tasks(self) -> dict[str, asyncio.Task]:
|
||||||
"""获取所有活跃的处理任务
|
"""获取所有活跃的主要处理任务
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, asyncio.Task]: 流ID到处理任务的映射
|
Dict[str, asyncio.Task]: 流ID到主要处理任务的映射
|
||||||
"""
|
"""
|
||||||
# 过滤掉已完成的任务
|
# 过滤掉已完成的任务,只返回主要任务
|
||||||
active_tasks = {}
|
active_tasks = {}
|
||||||
for stream_id, task in self._processing_tasks.items():
|
for stream_id, task_list in list(self._processing_tasks.items()):
|
||||||
if not task.done():
|
if task_list:
|
||||||
active_tasks[stream_id] = task
|
main_task = task_list[0] # 获取主要任务
|
||||||
else:
|
if not main_task.done():
|
||||||
logger.debug(f"清理已完成的处理任务: {stream_id}")
|
active_tasks[stream_id] = main_task
|
||||||
del self._processing_tasks[stream_id]
|
else:
|
||||||
|
# 清理已完成的主要任务
|
||||||
|
task_list = [t for t in task_list if not t.done()]
|
||||||
|
if task_list:
|
||||||
|
self._processing_tasks[stream_id] = task_list
|
||||||
|
active_tasks[stream_id] = task_list[0] # 新的主要任务
|
||||||
|
else:
|
||||||
|
del self._processing_tasks[stream_id]
|
||||||
|
logger.debug(f"清理已完成的处理任务: {stream_id}")
|
||||||
|
|
||||||
return active_tasks
|
return active_tasks
|
||||||
|
|
||||||
|
|||||||
@@ -348,14 +348,14 @@ class MessageManager:
|
|||||||
logger.error(f"清理不活跃聊天流时发生错误: {e}")
|
logger.error(f"清理不活跃聊天流时发生错误: {e}")
|
||||||
|
|
||||||
async def _check_and_handle_interruption(self, chat_stream: ChatStream | None = None):
|
async def _check_and_handle_interruption(self, chat_stream: ChatStream | None = None):
|
||||||
"""检查并处理消息打断"""
|
"""检查并处理消息打断 - 支持多重回复任务取消"""
|
||||||
if not global_config.chat.interruption_enabled or not chat_stream:
|
if not global_config.chat.interruption_enabled or not chat_stream:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 从 chatter_manager 检查是否有正在进行的处理任务
|
# 🌟 修复:获取所有处理任务(包括多重回复)
|
||||||
processing_task = self.chatter_manager.get_processing_task(chat_stream.stream_id)
|
all_processing_tasks = self.chatter_manager.get_all_processing_tasks(chat_stream.stream_id)
|
||||||
|
|
||||||
if processing_task and not processing_task.done():
|
if all_processing_tasks:
|
||||||
# 计算打断概率 - 使用新的线性概率模型
|
# 计算打断概率 - 使用新的线性概率模型
|
||||||
interruption_probability = chat_stream.context_manager.context.calculate_interruption_probability(
|
interruption_probability = chat_stream.context_manager.context.calculate_interruption_probability(
|
||||||
global_config.chat.interruption_max_limit
|
global_config.chat.interruption_max_limit
|
||||||
@@ -370,14 +370,15 @@ class MessageManager:
|
|||||||
|
|
||||||
# 根据概率决定是否打断
|
# 根据概率决定是否打断
|
||||||
if random.random() < interruption_probability:
|
if random.random() < interruption_probability:
|
||||||
logger.info(f"聊天流 {chat_stream.stream_id} 触发消息打断,打断概率: {interruption_probability:.2f}")
|
logger.info(f"聊天流 {chat_stream.stream_id} 触发消息打断,打断概率: {interruption_probability:.2f},检测到 {len(all_processing_tasks)} 个任务")
|
||||||
|
|
||||||
# 取消现有任务
|
# 🌟 修复:取消所有任务(包括多重回复)
|
||||||
processing_task.cancel()
|
cancelled_count = self.chatter_manager.cancel_all_stream_tasks(chat_stream.stream_id)
|
||||||
try:
|
|
||||||
await processing_task
|
if cancelled_count > 0:
|
||||||
except asyncio.CancelledError:
|
logger.info(f"消息打断成功取消 {cancelled_count} 个任务: {chat_stream.stream_id}")
|
||||||
logger.debug(f"消息打断成功取消任务: {chat_stream.stream_id}")
|
else:
|
||||||
|
logger.warning(f"消息打断未能取消任何任务: {chat_stream.stream_id}")
|
||||||
|
|
||||||
# 增加打断计数并应用afc阈值降低
|
# 增加打断计数并应用afc阈值降低
|
||||||
await chat_stream.context_manager.context.increment_interruption_count()
|
await chat_stream.context_manager.context.increment_interruption_count()
|
||||||
@@ -395,7 +396,7 @@ class MessageManager:
|
|||||||
f"聊天流 {chat_stream.stream_id} 已打断,当前打断次数: {chat_stream.context_manager.context.interruption_count}/{global_config.chat.interruption_max_limit}, afc阈值调整: {chat_stream.context_manager.context.get_afc_threshold_adjustment()}"
|
f"聊天流 {chat_stream.stream_id} 已打断,当前打断次数: {chat_stream.context_manager.context.interruption_count}/{global_config.chat.interruption_max_limit}, afc阈值调整: {chat_stream.context_manager.context.get_afc_threshold_adjustment()}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"聊天流 {chat_stream.stream_id} 未触发打断,打断概率: {interruption_probability:.2f}")
|
logger.debug(f"聊天流 {chat_stream.stream_id} 未触发打断,打断概率: {interruption_probability:.2f},检测到 {len(all_processing_tasks)} 个任务")
|
||||||
|
|
||||||
async def clear_all_unread_messages(self, stream_id: str):
|
async def clear_all_unread_messages(self, stream_id: str):
|
||||||
"""清除指定上下文中的所有未读消息,在消息处理完成后调用"""
|
"""清除指定上下文中的所有未读消息,在消息处理完成后调用"""
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ PlanExecutor: 接收 Plan 对象并执行其中的所有动作。
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
|
from typing import Any
|
||||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||||
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan
|
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -51,7 +51,7 @@ class ChatterPlanExecutor:
|
|||||||
"""设置关系追踪器"""
|
"""设置关系追踪器"""
|
||||||
self.relationship_tracker = relationship_tracker
|
self.relationship_tracker = relationship_tracker
|
||||||
|
|
||||||
async def execute(self, plan: Plan) -> dict[str, any]:
|
async def execute(self, plan: Plan) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
遍历并执行Plan对象中`decided_actions`列表里的所有动作。
|
遍历并执行Plan对象中`decided_actions`列表里的所有动作。
|
||||||
|
|
||||||
@@ -119,7 +119,7 @@ class ChatterPlanExecutor:
|
|||||||
"results": execution_results,
|
"results": execution_results,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _execute_reply_actions(self, reply_actions: list[ActionPlannerInfo], plan: Plan) -> dict[str, any]:
|
async def _execute_reply_actions(self, reply_actions: list[ActionPlannerInfo], plan: Plan) -> dict[str, Any]:
|
||||||
"""串行执行所有回复动作,增加去重逻辑,避免对同一消息多次回复"""
|
"""串行执行所有回复动作,增加去重逻辑,避免对同一消息多次回复"""
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
@@ -171,7 +171,7 @@ class ChatterPlanExecutor:
|
|||||||
|
|
||||||
async def _execute_single_reply_action(
|
async def _execute_single_reply_action(
|
||||||
self, action_info: ActionPlannerInfo, plan: Plan, clear_unread: bool = True
|
self, action_info: ActionPlannerInfo, plan: Plan, clear_unread: bool = True
|
||||||
) -> dict[str, any]:
|
) -> dict[str, Any]:
|
||||||
"""执行单个回复动作"""
|
"""执行单个回复动作"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
success = False
|
success = False
|
||||||
@@ -249,7 +249,7 @@ class ChatterPlanExecutor:
|
|||||||
else reply_content,
|
else reply_content,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _execute_other_actions(self, other_actions: list[ActionPlannerInfo], plan: Plan) -> dict[str, any]:
|
async def _execute_other_actions(self, other_actions: list[ActionPlannerInfo], plan: Plan) -> dict[str, Any]:
|
||||||
"""执行其他动作"""
|
"""执行其他动作"""
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
@@ -278,7 +278,7 @@ class ChatterPlanExecutor:
|
|||||||
|
|
||||||
return {"results": results}
|
return {"results": results}
|
||||||
|
|
||||||
async def _execute_single_other_action(self, action_info: ActionPlannerInfo, plan: Plan) -> dict[str, any]:
|
async def _execute_single_other_action(self, action_info: ActionPlannerInfo, plan: Plan) -> dict[str, Any]:
|
||||||
"""执行单个其他动作"""
|
"""执行单个其他动作"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
success = False
|
success = False
|
||||||
@@ -387,7 +387,7 @@ class ChatterPlanExecutor:
|
|||||||
logger.debug(f"action_message类型: {type(action_info.action_message)}")
|
logger.debug(f"action_message类型: {type(action_info.action_message)}")
|
||||||
logger.debug(f"action_message内容: {action_info.action_message}")
|
logger.debug(f"action_message内容: {action_info.action_message}")
|
||||||
|
|
||||||
def get_execution_stats(self) -> dict[str, any]:
|
def get_execution_stats(self) -> dict[str, Any]:
|
||||||
"""获取执行统计信息"""
|
"""获取执行统计信息"""
|
||||||
stats = self.execution_stats.copy()
|
stats = self.execution_stats.copy()
|
||||||
|
|
||||||
@@ -418,7 +418,7 @@ class ChatterPlanExecutor:
|
|||||||
"execution_times": [],
|
"execution_times": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_recent_performance(self, limit: int = 10) -> list[dict[str, any]]:
|
def get_recent_performance(self, limit: int = 10) -> list[dict[str, Any]]:
|
||||||
"""获取最近的执行性能"""
|
"""获取最近的执行性能"""
|
||||||
recent_times = self.execution_stats["execution_times"][-limit:]
|
recent_times = self.execution_stats["execution_times"][-limit:]
|
||||||
if not recent_times:
|
if not recent_times:
|
||||||
|
|||||||
Reference in New Issue
Block a user