diff --git a/src/chat/chatter_manager.py b/src/chat/chatter_manager.py index d8eda9baa..cf9d3b039 100644 --- a/src/chat/chatter_manager.py +++ b/src/chat/chatter_manager.py @@ -1,6 +1,7 @@ import time from typing import Any +import asyncio from src.chat.planner_actions.action_manager import ChatterActionManager from src.common.data_models.message_manager_data_model import StreamContext from src.common.logger import get_logger @@ -15,6 +16,7 @@ class ChatterManager: self.action_manager = action_manager self.chatter_classes: dict[ChatType, list[type]] = {} self.instances: dict[str, BaseChatter] = {} + self._processing_tasks: dict[str, asyncio.Task] = {} # 管理器统计 self.stats = { @@ -155,3 +157,11 @@ class ChatterManager: "successful_executions": 0, "failed_executions": 0, } + + def set_processing_task(self, stream_id: str, task: asyncio.Task): + """设置流的处理任务""" + self._processing_tasks[stream_id] = task + + def get_processing_task(self, stream_id: str) -> asyncio.Task | None: + """获取流的处理任务""" + return self._processing_tasks.get(stream_id) diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index fa6bcea0d..516691bae 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -301,6 +301,11 @@ class StreamLoopManager: except asyncio.CancelledError: logger.info(f"流循环被取消: {stream_id}") + if self.chatter_manager: + task = self.chatter_manager.get_processing_task(stream_id) + if task and not task.done(): + task.cancel() + logger.debug(f"已取消 chatter 处理任务: {stream_id}") break except Exception as e: logger.error(f"流循环出错 {stream_id}: {e}", exc_info=True) @@ -388,8 +393,9 @@ class StreamLoopManager: start_time = time.time() # 直接调用chatter_manager处理流上下文 - context.processing_task = asyncio.create_task(self.chatter_manager.process_stream_context(stream_id, context)) - results = await context.processing_task + task = asyncio.create_task(self.chatter_manager.process_stream_context(stream_id, context)) + self.chatter_manager.set_processing_task(stream_id, task) + results = await task success = results.get("success", False) if success: