From e19106b5b0c9da935de450b7f363a79b1c0e0118 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 16 Aug 2025 13:21:13 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=81=8A=E5=A4=A9=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=E9=99=90=E5=88=B6=E5=8A=9F=E8=83=BD=EF=BC=8C=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E6=A0=B9=E6=8D=AE=E8=81=8A=E5=A4=A9=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E8=BF=87=E6=BB=A4=E5=91=BD=E4=BB=A4=E5=92=8C=E5=8A=A8=E4=BD=9C?= =?UTF-8?q?=EF=BC=8C=E6=96=B0=E5=A2=9E=E7=A7=81=E8=81=8A=E5=92=8C=E7=BE=A4?= =?UTF-8?q?=E8=81=8A=E4=B8=93=E7=94=A8=E5=91=BD=E4=BB=A4=E5=8F=8A=E5=8A=A8?= =?UTF-8?q?=E4=BD=9C=EF=BC=8C=E4=BC=98=E5=8C=96=E7=9B=B8=E5=85=B3=E6=97=A5?= =?UTF-8?q?=E5=BF=97=E8=AE=B0=E5=BD=95=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- plugins/hello_world_plugin/plugin.py | 41 +++++++++++++++-- src/chat/message_receive/bot.py | 6 +++ src/chat/planner_actions/action_modifier.py | 32 ++++++++++++- src/chat/planner_actions/planner.py | 3 -- src/plugin_system/base/base_action.py | 36 ++++++++++++++- src/plugin_system/base/base_command.py | 45 ++++++++++++++++++- src/plugin_system/base/component_types.py | 14 ++++++ src/plugins/built_in/at_user_plugin/plugin.py | 2 + src/plugins/built_in/set_emoji_like/plugin.py | 2 + 9 files changed, 172 insertions(+), 9 deletions(-) diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index efb98a939..8b5f04950 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Type, Any +from typing import List, Tuple, Type, Any, Optional from src.plugin_system import ( BasePlugin, register_plugin, @@ -15,6 +15,7 @@ from src.plugin_system import ( from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.apis import send_api from src.common.logger import get_logger +from src.plugin_system.base.component_types import ChatType logger = get_logger(__name__) @@ -145,10 +146,11 @@ class TimeCommand(BaseCommand): """时间查询Command - 响应/time命令""" command_name = "time" - command_description = "查询当前时间" + command_description = "获取当前时间" # === 命令设置(必须填写)=== command_pattern = r"^/time$" # 精确匹配 "/time" 命令 + chat_type_allow = ChatType.GROUP # 仅在群聊中可用 async def execute(self) -> Tuple[bool, str, bool]: """执行时间查询""" @@ -221,8 +223,10 @@ class HelloWorldPlugin(BasePlugin): (HelloAction.get_action_info(), HelloAction), (CompareNumbersTool.get_tool_info(), CompareNumbersTool), # 添加比较数字工具 (ByeAction.get_action_info(), ByeAction), # 添加告别Action - (TimeCommand.get_command_info(), TimeCommand), + (TimeCommand.get_command_info(), TimeCommand), # 现在只能在群聊中使用 (GetGroupListCommand.get_command_info(), GetGroupListCommand), # 添加获取群列表命令 + (PrivateInfoCommand.get_command_info(), PrivateInfoCommand), # 私聊专用命令 + (GroupOnlyAction.get_action_info(), GroupOnlyAction), # 群聊专用动作 # (PrintMessage.get_handler_info(), PrintMessage), ] @@ -247,3 +251,34 @@ class HelloWorldPlugin(BasePlugin): # def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: # return [(PrintMessage.get_handler_info(), PrintMessage)] + +# 添加一个新的私聊专用命令 +class PrivateInfoCommand(BaseCommand): + command_name = "private_info" + command_description = "获取私聊信息" + command_pattern = r"^/私聊信息$" + chat_type_allow = ChatType.PRIVATE # 仅在私聊中可用 + + async def execute(self) -> Tuple[bool, Optional[str], bool]: + """执行私聊信息命令""" + try: + await self.send_text("这是一个只能在私聊中使用的命令!") + return True, "私聊信息命令执行成功", False + except Exception as e: + logger.error(f"私聊信息命令执行失败: {e}") + return False, f"命令执行失败: {e}", False + +# 添加一个新的仅群聊可用的Action +class GroupOnlyAction(BaseAction): + action_name = "group_only_test" + action_description = "群聊专用测试动作" + chat_type_allow = ChatType.GROUP # 仅在群聊中可用 + + async def execute(self) -> Tuple[bool, str]: + """执行群聊专用测试动作""" + try: + await self.send_text("这是一个只能在群聊中执行的动作!") + return True, "群聊专用动作执行成功" + except Exception as e: + logger.error(f"群聊专用动作执行失败: {e}") + return False, f"动作执行失败: {e}" diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index fb3a84d53..dcc616bfb 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -113,6 +113,12 @@ class ChatBot: command_instance.set_matched_groups(matched_groups) try: + # 检查聊天类型限制 + if not command_instance.is_chat_type_allowed(): + is_group = hasattr(message, 'is_group_message') and message.is_group_message + logger.info(f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}") + return False, None, True # 跳过此命令,继续处理其他消息 + # 执行命令 success, response, intercept_message = await command_instance.execute() diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index dfa4c79c1..59b0d4f66 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -65,6 +65,36 @@ class ActionModifier: self.action_manager.restore_actions() all_actions = self.action_manager.get_using_actions() + # === 第0阶段:根据聊天类型过滤动作 === + from src.plugin_system.base.component_types import ChatType + from src.plugin_system.core.component_registry import component_registry + from src.plugin_system.base.component_types import ComponentType + from src.chat.utils.utils import get_chat_type_and_target_info + + # 获取聊天类型 + is_group_chat, _ = get_chat_type_and_target_info(self.chat_id) + all_registered_actions = component_registry.get_components_by_type(ComponentType.ACTION) + + chat_type_removals = [] + for action_name in list(all_actions.keys()): + if action_name in all_registered_actions: + action_info = all_registered_actions[action_name] + chat_type_allow = getattr(action_info, 'chat_type_allow', ChatType.ALL) + + # 检查是否符合聊天类型限制 + should_keep = (chat_type_allow == ChatType.ALL or + (chat_type_allow == ChatType.GROUP and is_group_chat) or + (chat_type_allow == ChatType.PRIVATE and not is_group_chat)) + + if not should_keep: + chat_type_removals.append((action_name, f"不支持{'群聊' if is_group_chat else '私聊'}")) + self.action_manager.remove_action_from_using(action_name) + + if chat_type_removals: + logger.info(f"{self.log_prefix} 第0阶段:根据聊天类型过滤 - 移除了 {len(chat_type_removals)} 个动作") + for action_name, reason in chat_type_removals: + logger.debug(f"{self.log_prefix} - 移除 {action_name}: {reason}") + message_list_before_now_half = get_raw_msg_before_timestamp_with_chat( chat_id=self.chat_stream.stream_id, timestamp=time.time(), @@ -122,7 +152,7 @@ class ActionModifier: logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}") # === 统一日志记录 === - all_removals = removals_s1 + removals_s2 + removals_s3 + all_removals = chat_type_removals + removals_s1 + removals_s2 + removals_s3 removals_summary: str = "" if all_removals: removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals]) diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 92e543461..5a90863d5 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -268,9 +268,6 @@ class ActionPlanner: timestamp=time.time(), limit=int(global_config.chat.max_context_size * 0.6), ) - - # 过滤掉bot自己的消息,避免planner把bot消息当作新消息处理 - message_list_before_now = filter_mai_messages(message_list_before_now) chat_content_block, message_id_list = build_readable_messages_with_id( messages=message_list_before_now, diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 66d723f5e..6021c61f4 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -6,7 +6,7 @@ from typing import Tuple, Optional from src.common.logger import get_logger from src.chat.message_receive.chat_stream import ChatStream -from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType +from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType, ChatType from src.plugin_system.apis import send_api, database_api, message_api @@ -91,6 +91,7 @@ class BaseAction(ABC): self.mode_enable: ChatMode = getattr(self.__class__, "mode_enable", ChatMode.ALL) self.parallel_action: bool = getattr(self.__class__, "parallel_action", True) self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy() + self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL) # ============================================================================= # 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解) @@ -146,6 +147,38 @@ class BaseAction(ABC): logger.debug( f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}" ) + + # 验证聊天类型限制 + if not self._validate_chat_type(): + logger.warning( + f"{self.log_prefix} Action '{self.action_name}' 不支持当前聊天类型: " + f"{'群聊' if self.is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}" + ) + + def _validate_chat_type(self) -> bool: + """验证当前聊天类型是否允许执行此Action + + Returns: + bool: 如果允许执行返回True,否则返回False + """ + if self.chat_type_allow == ChatType.ALL: + return True + elif self.chat_type_allow == ChatType.GROUP and self.is_group: + return True + elif self.chat_type_allow == ChatType.PRIVATE and not self.is_group: + return True + else: + return False + + def is_chat_type_allowed(self) -> bool: + """检查当前聊天类型是否允许执行此Action + + 这是一个公开的方法,供外部调用检查聊天类型限制 + + Returns: + bool: 如果允许执行返回True,否则返回False + """ + return self._validate_chat_type() async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]: """等待新消息或超时 @@ -389,6 +422,7 @@ class BaseAction(ABC): action_parameters=getattr(cls, "action_parameters", {}).copy(), action_require=getattr(cls, "action_require", []).copy(), associated_types=getattr(cls, "associated_types", []).copy(), + chat_type_allow=getattr(cls, "chat_type_allow", ChatType.ALL), ) @abstractmethod diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py index 652acb4c4..a693cbd85 100644 --- a/src/plugin_system/base/base_command.py +++ b/src/plugin_system/base/base_command.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, Tuple, Optional from src.common.logger import get_logger -from src.plugin_system.base.component_types import CommandInfo, ComponentType +from src.plugin_system.base.component_types import CommandInfo, ComponentType, ChatType from src.chat.message_receive.message import MessageRecv from src.plugin_system.apis import send_api @@ -26,6 +26,8 @@ class BaseCommand(ABC): # 默认命令设置 command_pattern: str = r"" """命令匹配的正则表达式""" + chat_type_allow: ChatType = ChatType.ALL + """允许的聊天类型,默认为所有类型""" def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None): """初始化Command组件 @@ -40,7 +42,18 @@ class BaseCommand(ABC): self.log_prefix = "[Command]" + # 从类属性获取chat_type_allow设置 + self.chat_type_allow = getattr(self.__class__, "chat_type_allow", ChatType.ALL) + logger.debug(f"{self.log_prefix} Command组件初始化完成") + + # 验证聊天类型限制 + if not self._validate_chat_type(): + is_group = hasattr(self.message, 'is_group_message') and self.message.is_group_message + logger.warning( + f"{self.log_prefix} Command '{self.command_name}' 不支持当前聊天类型: " + f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}" + ) def set_matched_groups(self, groups: Dict[str, str]) -> None: """设置正则表达式匹配的命名组 @@ -50,6 +63,35 @@ class BaseCommand(ABC): """ self.matched_groups = groups + def _validate_chat_type(self) -> bool: + """验证当前聊天类型是否允许执行此Command + + Returns: + bool: 如果允许执行返回True,否则返回False + """ + if self.chat_type_allow == ChatType.ALL: + return True + + # 检查是否为群聊消息 + is_group = hasattr(self.message, 'is_group_message') and self.message.is_group_message + + if self.chat_type_allow == ChatType.GROUP and is_group: + return True + elif self.chat_type_allow == ChatType.PRIVATE and not is_group: + return True + else: + return False + + def is_chat_type_allowed(self) -> bool: + """检查当前聊天类型是否允许执行此Command + + 这是一个公开的方法,供外部调用检查聊天类型限制 + + Returns: + bool: 如果允许执行返回True,否则返回False + """ + return self._validate_chat_type() + @abstractmethod async def execute(self) -> Tuple[bool, Optional[str], bool]: """执行Command的抽象方法,子类必须实现 @@ -225,4 +267,5 @@ class BaseCommand(ABC): component_type=ComponentType.COMMAND, description=cls.command_description, command_pattern=cls.command_pattern, + chat_type_allow=getattr(cls, "chat_type_allow", ChatType.ALL), ) diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index 661a88ec4..5134b6a36 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -47,6 +47,18 @@ class ChatMode(Enum): return self.value +# 聊天类型枚举 +class ChatType(Enum): + """聊天类型枚举,用于限制插件在不同聊天环境中的使用""" + + GROUP = "group" # 仅群聊可用 + PRIVATE = "private" # 仅私聊可用 + ALL = "all" # 群聊和私聊都可用 + + def __str__(self): + return self.value + + # 事件类型枚举 class EventType(Enum): """ @@ -124,6 +136,7 @@ class ActionInfo(ComponentInfo): # 模式和并行设置 mode_enable: ChatMode = ChatMode.ALL parallel_action: bool = False + chat_type_allow: ChatType = ChatType.ALL # 允许的聊天类型 def __post_init__(self): super().__post_init__() @@ -143,6 +156,7 @@ class CommandInfo(ComponentInfo): """命令组件信息""" command_pattern: str = "" # 命令匹配模式(正则表达式) + chat_type_allow: ChatType = ChatType.ALL # 允许的聊天类型 def __post_init__(self): super().__post_init__() diff --git a/src/plugins/built_in/at_user_plugin/plugin.py b/src/plugins/built_in/at_user_plugin/plugin.py index 4a44ed035..9c2d40f31 100644 --- a/src/plugins/built_in/at_user_plugin/plugin.py +++ b/src/plugins/built_in/at_user_plugin/plugin.py @@ -11,6 +11,7 @@ from src.plugin_system import ( from src.person_info.person_info import get_person_info_manager from src.common.logger import get_logger from src.plugin_system import database_api +from src.plugin_system.base.component_types import ChatType logger = get_logger(__name__) class AtAction(BaseAction): @@ -21,6 +22,7 @@ class AtAction(BaseAction): action_description = "发送艾特消息" activation_type = ActionActivationType.LLM_JUDGE # 消息接收时激活(?) parallel_action = False + chat_type_allow = ChatType.GROUP # === 功能描述(必须填写)=== action_parameters = { diff --git a/src/plugins/built_in/set_emoji_like/plugin.py b/src/plugins/built_in/set_emoji_like/plugin.py index 0bb9e2ca2..3f1a76d0f 100644 --- a/src/plugins/built_in/set_emoji_like/plugin.py +++ b/src/plugins/built_in/set_emoji_like/plugin.py @@ -15,6 +15,7 @@ from src.common.database.sqlalchemy_models import Messages, PersonInfo from src.common.logger import get_logger from src.plugin_system.apis import send_api from .qq_emoji_list import qq_face +from src.plugin_system.base.component_types import ChatType logger = get_logger("set_emoji_like_plugin") @@ -50,6 +51,7 @@ class SetEmojiLikeAction(BaseAction): action_name = "set_emoji_like" action_description = "为消息设置表情回应/贴表情" activation_type = ActionActivationType.ALWAYS # 消息接收时激活(?) + chat_type_allow = ChatType.GROUP parallel_action = True # === 功能描述(必须填写)===