diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py index 0c80ef83a..69453a3bc 100644 --- a/src/chat/chat_loop/heartFC_chat.py +++ b/src/chat/chat_loop/heartFC_chat.py @@ -17,7 +17,8 @@ from src.chat.planner_actions.action_manager import ActionManager from src.chat.chat_loop.hfc_utils import CycleDetail from src.person_info.relationship_builder_manager import relationship_builder_manager from src.person_info.person_info import get_person_info_manager -from src.plugin_system.base.component_types import ActionInfo, ChatMode +from src.plugin_system.base.component_types import ActionInfo, ChatMode, EventType +from src.plugin_system.core import events_manager from src.plugin_system.apis import generator_api, send_api, message_api, database_api from src.chat.willing.willing_manager import get_willing_manager from src.mais4u.mai_think import mai_thinking_manager @@ -304,7 +305,7 @@ class HeartFChatting: return loop_info, reply_text, cycle_timers - async def _observe(self, message_data: Optional[Dict[str, Any]] = None): + async def _observe(self, message_data: Optional[Dict[str, Any]] = None) -> bool: # sourcery skip: hoist-statement-from-if, merge-comparisons, reintroduce-else if not message_data: message_data = {} @@ -379,6 +380,13 @@ class HeartFChatting: ) if not skip_planner: + planner_info = self.action_planner.get_necessary_info() + prompt_info = await self.action_planner.build_planner_prompt( + is_group_chat=planner_info[0], + chat_target_info=planner_info[1], + current_available_actions=planner_info[2], + ) + await events_manager.handle_mai_events(EventType.ON_PLAN, None, prompt_info[0], None) with Timer("规划器", cycle_timers): plan_result, target_message = await self.action_planner.plan(mode=self.loop_mode) diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index b01bb824c..e1bb42ec7 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -71,7 +71,9 @@ class ActionPlanner: self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]" self.action_manager = action_manager # LLM规划器配置 - self.planner_llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="planner") # 用于动作规划 + self.planner_llm = LLMRequest( + model_set=model_config.model_task_config.planner, request_type="planner" + ) # 用于动作规划 self.last_obs_time_mark = 0.0 # 添加重试计数器 @@ -126,22 +128,7 @@ class ActionPlanner: message_id_list: list = [] try: - is_group_chat = True - is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id) - logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}") - - current_available_actions_dict = self.action_manager.get_using_actions() - - # 获取完整的动作信息 - all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore - ComponentType.ACTION - ) - current_available_actions = {} - for action_name in current_available_actions_dict: - if action_name in all_registered_actions: - current_available_actions[action_name] = all_registered_actions[action_name] - else: - logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到") + is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info() # --- 构建提示词 (调用修改后的 PromptBuilder 方法) --- prompt, message_id_list = await self.build_planner_prompt( @@ -396,5 +383,28 @@ class ActionPlanner: logger.error(traceback.format_exc()) return "构建 Planner Prompt 时出错", [] + def get_necessary_info(self) -> Tuple[bool, Optional[dict], Dict[str, ActionInfo]]: + """ + 获取 Planner 需要的必要信息 + """ + is_group_chat = True + is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id) + logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}") + + current_available_actions_dict = self.action_manager.get_using_actions() + + # 获取完整的动作信息 + all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore + ComponentType.ACTION + ) + current_available_actions = {} + for action_name in current_available_actions_dict: + if action_name in all_registered_actions: + current_available_actions[action_name] = all_registered_actions[action_name] + else: + logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到") + + return is_group_chat, chat_target_info, current_available_actions + init_prompt() diff --git a/src/plugin_system/core/events_manager.py b/src/plugin_system/core/events_manager.py index 3c215a7ff..da1d81c28 100644 --- a/src/plugin_system/core/events_manager.py +++ b/src/plugin_system/core/events_manager.py @@ -3,6 +3,7 @@ import contextlib from typing import List, Dict, Optional, Type, Tuple from src.chat.message_receive.message import MessageRecv +from src.chat.message_receive.chat_stream import chat_manager from src.common.logger import get_logger from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages from src.plugin_system.base.base_events_handler import BaseEventHandler @@ -44,18 +45,24 @@ class EventsManager: async def handle_mai_events( self, event_type: EventType, - message: MessageRecv, + message: Optional[MessageRecv] = None, llm_prompt: Optional[str] = None, llm_response: Optional[str] = None, + stream_id: Optional[str] = None, ) -> bool: """处理 events""" from src.plugin_system.core import component_registry continue_flag = True - transformed_message = self._transform_event_message(message, llm_prompt, llm_response) + transformed_message: Optional[MaiMessages] = None + if not message: + assert stream_id, "如果没有消息,必须提供流ID" + transformed_message = self._build_message_from_stream(stream_id, llm_prompt, llm_response) + else: + transformed_message = self._transform_event_message(message, llm_prompt, llm_response) for handler in self._events_subscribers.get(event_type, []): - if message.chat_stream and message.chat_stream.stream_id: - stream_id = message.chat_stream.stream_id + if transformed_message.stream_id: + stream_id = transformed_message.stream_id if handler.handler_name in global_announcement_manager.get_disabled_chat_event_handlers(stream_id): continue handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {}) @@ -163,6 +170,15 @@ class EventsManager: return transformed_message + def _build_message_from_stream( + self, stream_id: str, llm_prompt: Optional[str] = None, llm_response: Optional[str] = None + ) -> MaiMessages: + """从流ID构建消息""" + chat_stream = chat_manager.get_stream(stream_id) + assert chat_stream, f"未找到流ID为 {stream_id} 的聊天流" + message = chat_stream.context.get_last_message() + return self._transform_event_message(message, llm_prompt, llm_response) + def _task_done_callback(self, task: asyncio.Task[Tuple[bool, bool, str | None]]): """任务完成回调""" task_name = task.get_name() or "Unknown Task"