diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index a1a014949..7c18cad7d 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -99,6 +99,21 @@ class MessageStorage: # 将priority_info字典序列化为JSON字符串,以便存储到数据库的Text字段 priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None + # 准备additional_config,包含format_info和其他配置 + additional_config_data = {} + + # 保存format_info到additional_config中 + if hasattr(message.message_info, 'format_info') and message.message_info.format_info: + format_info_dict = message.message_info.format_info.to_dict() + additional_config_data["format_info"] = format_info_dict + + # 合并adapter传递的其他additional_config + if hasattr(message.message_info, 'additional_config') and message.message_info.additional_config: + additional_config_data.update(message.message_info.additional_config) + + # 序列化为JSON字符串以便存储 + additional_config_json = orjson.dumps(additional_config_data).decode("utf-8") if additional_config_data else None + # 获取数据库会话 new_message = Messages( @@ -127,6 +142,11 @@ class MessageStorage: priority_info=priority_info_json, is_emoji=is_emoji, is_picid=is_picid, + is_notify=is_notify, + is_command=is_command, + key_words=key_words, + key_words_lite=key_words_lite, + additional_config=additional_config_json, ) async with get_db_session() as session: session.add(new_message) diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 35a17d675..69fc902de 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -4,6 +4,8 @@ import random import time from typing import TYPE_CHECKING, Any +import orjson + from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.planner_actions.action_manager import ChatterActionManager from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat @@ -182,13 +184,98 @@ class ActionModifier: def _check_action_associated_types(self, all_actions: dict[str, ActionInfo], chat_context: StreamContext): type_mismatched_actions: list[tuple[str, str]] = [] for action_name, action_info in all_actions.items(): - if action_info.associated_types and not chat_context.check_types(action_info.associated_types): + if action_info.associated_types and not self._check_action_output_types(action_info.associated_types, chat_context): associated_types_str = ", ".join(action_info.associated_types) reason = f"适配器不支持(需要: {associated_types_str})" type_mismatched_actions.append((action_name, reason)) logger.debug(f"{self.log_prefix}决定移除动作: {action_name},原因: {reason}") return type_mismatched_actions + def _check_action_output_types(self, output_types: list[str], chat_context: StreamContext) -> bool: + """ + 检查Action的输出类型是否被当前适配器支持 + + Args: + output_types: Action需要输出的消息类型列表 + chat_context: 聊天上下文 + + Returns: + bool: 如果所有输出类型都支持则返回True + """ + # 获取当前适配器支持的输出类型 + adapter_supported_types = self._get_adapter_supported_output_types(chat_context) + + # 检查所有需要的输出类型是否都被支持 + for output_type in output_types: + if output_type not in adapter_supported_types: + logger.debug(f"适配器不支持输出类型 '{output_type}',支持的类型: {adapter_supported_types}") + return False + return True + + def _get_adapter_supported_output_types(self, chat_context: StreamContext) -> list[str]: + """ + 获取当前适配器支持的输出类型列表 + + Args: + chat_context: 聊天上下文 + + Returns: + list[str]: 支持的输出类型列表 + """ + # 检查additional_config是否存在且不为空 + if (chat_context.current_message + and hasattr(chat_context.current_message, "additional_config") + and chat_context.current_message.additional_config): + + try: + additional_config = chat_context.current_message.additional_config + format_info = None + + # 处理additional_config可能是字符串或字典的情况 + if isinstance(additional_config, str): + # 如果是字符串,尝试解析为JSON + try: + config = orjson.loads(additional_config) + format_info = config.get("format_info") + except (orjson.JSONDecodeError, AttributeError, TypeError): + logger.debug("无法解析additional_config JSON字符串") + format_info = None + + elif isinstance(additional_config, dict): + # 如果是字典,直接获取format_info + format_info = additional_config.get("format_info") + + # 如果找到了format_info,从中提取支持的类型 + if format_info: + # 优先检查accept_format字段 + if "accept_format" in format_info: + accept_format = format_info["accept_format"] + if isinstance(accept_format, str): + accept_format = [accept_format] + elif isinstance(accept_format, list): + pass + else: + accept_format = list(accept_format) if hasattr(accept_format, "__iter__") else [] + + # 合并基础类型和适配器特定类型 + return list(set(accept_format)) + + # 备用检查content_format字段 + elif "content_format" in format_info: + content_format = format_info["content_format"] + if isinstance(content_format, str): + content_format = [content_format] + elif isinstance(content_format, list): + pass + else: + content_format = list(content_format) if hasattr(content_format, "__iter__") else [] + + return list(set(content_format)) + + except Exception as e: + logger.debug(f"解析适配器格式信息失败,使用默认支持类型: {e}") + + async def _get_deactivated_actions_by_type( self, actions_with_info: dict[str, ActionInfo],