添加聊天类型限制功能,支持根据聊天类型过滤命令和动作,新增私聊和群聊专用命令及动作,优化相关日志记录。
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple, Type, Any
|
||||
from typing import List, Tuple, Type, Any, Optional
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
register_plugin,
|
||||
@@ -15,6 +15,7 @@ from src.plugin_system import (
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.apis import send_api
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -145,10 +146,11 @@ class TimeCommand(BaseCommand):
|
||||
"""时间查询Command - 响应/time命令"""
|
||||
|
||||
command_name = "time"
|
||||
command_description = "查询当前时间"
|
||||
command_description = "获取当前时间"
|
||||
|
||||
# === 命令设置(必须填写)===
|
||||
command_pattern = r"^/time$" # 精确匹配 "/time" 命令
|
||||
chat_type_allow = ChatType.GROUP # 仅在群聊中可用
|
||||
|
||||
async def execute(self) -> Tuple[bool, str, bool]:
|
||||
"""执行时间查询"""
|
||||
@@ -221,8 +223,10 @@ class HelloWorldPlugin(BasePlugin):
|
||||
(HelloAction.get_action_info(), HelloAction),
|
||||
(CompareNumbersTool.get_tool_info(), CompareNumbersTool), # 添加比较数字工具
|
||||
(ByeAction.get_action_info(), ByeAction), # 添加告别Action
|
||||
(TimeCommand.get_command_info(), TimeCommand),
|
||||
(TimeCommand.get_command_info(), TimeCommand), # 现在只能在群聊中使用
|
||||
(GetGroupListCommand.get_command_info(), GetGroupListCommand), # 添加获取群列表命令
|
||||
(PrivateInfoCommand.get_command_info(), PrivateInfoCommand), # 私聊专用命令
|
||||
(GroupOnlyAction.get_action_info(), GroupOnlyAction), # 群聊专用动作
|
||||
# (PrintMessage.get_handler_info(), PrintMessage),
|
||||
]
|
||||
|
||||
@@ -247,3 +251,34 @@ class HelloWorldPlugin(BasePlugin):
|
||||
|
||||
# def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
# return [(PrintMessage.get_handler_info(), PrintMessage)]
|
||||
|
||||
# 添加一个新的私聊专用命令
|
||||
class PrivateInfoCommand(BaseCommand):
|
||||
command_name = "private_info"
|
||||
command_description = "获取私聊信息"
|
||||
command_pattern = r"^/私聊信息$"
|
||||
chat_type_allow = ChatType.PRIVATE # 仅在私聊中可用
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
"""执行私聊信息命令"""
|
||||
try:
|
||||
await self.send_text("这是一个只能在私聊中使用的命令!")
|
||||
return True, "私聊信息命令执行成功", False
|
||||
except Exception as e:
|
||||
logger.error(f"私聊信息命令执行失败: {e}")
|
||||
return False, f"命令执行失败: {e}", False
|
||||
|
||||
# 添加一个新的仅群聊可用的Action
|
||||
class GroupOnlyAction(BaseAction):
|
||||
action_name = "group_only_test"
|
||||
action_description = "群聊专用测试动作"
|
||||
chat_type_allow = ChatType.GROUP # 仅在群聊中可用
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行群聊专用测试动作"""
|
||||
try:
|
||||
await self.send_text("这是一个只能在群聊中执行的动作!")
|
||||
return True, "群聊专用动作执行成功"
|
||||
except Exception as e:
|
||||
logger.error(f"群聊专用动作执行失败: {e}")
|
||||
return False, f"动作执行失败: {e}"
|
||||
|
||||
@@ -113,6 +113,12 @@ class ChatBot:
|
||||
command_instance.set_matched_groups(matched_groups)
|
||||
|
||||
try:
|
||||
# 检查聊天类型限制
|
||||
if not command_instance.is_chat_type_allowed():
|
||||
is_group = hasattr(message, 'is_group_message') and message.is_group_message
|
||||
logger.info(f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}")
|
||||
return False, None, True # 跳过此命令,继续处理其他消息
|
||||
|
||||
# 执行命令
|
||||
success, response, intercept_message = await command_instance.execute()
|
||||
|
||||
|
||||
@@ -65,6 +65,36 @@ class ActionModifier:
|
||||
self.action_manager.restore_actions()
|
||||
all_actions = self.action_manager.get_using_actions()
|
||||
|
||||
# === 第0阶段:根据聊天类型过滤动作 ===
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
|
||||
# 获取聊天类型
|
||||
is_group_chat, _ = get_chat_type_and_target_info(self.chat_id)
|
||||
all_registered_actions = component_registry.get_components_by_type(ComponentType.ACTION)
|
||||
|
||||
chat_type_removals = []
|
||||
for action_name in list(all_actions.keys()):
|
||||
if action_name in all_registered_actions:
|
||||
action_info = all_registered_actions[action_name]
|
||||
chat_type_allow = getattr(action_info, 'chat_type_allow', ChatType.ALL)
|
||||
|
||||
# 检查是否符合聊天类型限制
|
||||
should_keep = (chat_type_allow == ChatType.ALL or
|
||||
(chat_type_allow == ChatType.GROUP and is_group_chat) or
|
||||
(chat_type_allow == ChatType.PRIVATE and not is_group_chat))
|
||||
|
||||
if not should_keep:
|
||||
chat_type_removals.append((action_name, f"不支持{'群聊' if is_group_chat else '私聊'}"))
|
||||
self.action_manager.remove_action_from_using(action_name)
|
||||
|
||||
if chat_type_removals:
|
||||
logger.info(f"{self.log_prefix} 第0阶段:根据聊天类型过滤 - 移除了 {len(chat_type_removals)} 个动作")
|
||||
for action_name, reason in chat_type_removals:
|
||||
logger.debug(f"{self.log_prefix} - 移除 {action_name}: {reason}")
|
||||
|
||||
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
@@ -122,7 +152,7 @@ class ActionModifier:
|
||||
logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
|
||||
|
||||
# === 统一日志记录 ===
|
||||
all_removals = removals_s1 + removals_s2 + removals_s3
|
||||
all_removals = chat_type_removals + removals_s1 + removals_s2 + removals_s3
|
||||
removals_summary: str = ""
|
||||
if all_removals:
|
||||
removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals])
|
||||
|
||||
@@ -269,9 +269,6 @@ class ActionPlanner:
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
)
|
||||
|
||||
# 过滤掉bot自己的消息,避免planner把bot消息当作新消息处理
|
||||
message_list_before_now = filter_mai_messages(message_list_before_now)
|
||||
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal",
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Tuple, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType
|
||||
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType, ChatType
|
||||
from src.plugin_system.apis import send_api, database_api, message_api
|
||||
|
||||
|
||||
@@ -91,6 +91,7 @@ class BaseAction(ABC):
|
||||
self.mode_enable: ChatMode = getattr(self.__class__, "mode_enable", ChatMode.ALL)
|
||||
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
|
||||
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
|
||||
self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
|
||||
|
||||
# =============================================================================
|
||||
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
|
||||
@@ -147,6 +148,38 @@ class BaseAction(ABC):
|
||||
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
|
||||
)
|
||||
|
||||
# 验证聊天类型限制
|
||||
if not self._validate_chat_type():
|
||||
logger.warning(
|
||||
f"{self.log_prefix} Action '{self.action_name}' 不支持当前聊天类型: "
|
||||
f"{'群聊' if self.is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
|
||||
)
|
||||
|
||||
def _validate_chat_type(self) -> bool:
|
||||
"""验证当前聊天类型是否允许执行此Action
|
||||
|
||||
Returns:
|
||||
bool: 如果允许执行返回True,否则返回False
|
||||
"""
|
||||
if self.chat_type_allow == ChatType.ALL:
|
||||
return True
|
||||
elif self.chat_type_allow == ChatType.GROUP and self.is_group:
|
||||
return True
|
||||
elif self.chat_type_allow == ChatType.PRIVATE and not self.is_group:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_chat_type_allowed(self) -> bool:
|
||||
"""检查当前聊天类型是否允许执行此Action
|
||||
|
||||
这是一个公开的方法,供外部调用检查聊天类型限制
|
||||
|
||||
Returns:
|
||||
bool: 如果允许执行返回True,否则返回False
|
||||
"""
|
||||
return self._validate_chat_type()
|
||||
|
||||
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
|
||||
"""等待新消息或超时
|
||||
|
||||
@@ -389,6 +422,7 @@ class BaseAction(ABC):
|
||||
action_parameters=getattr(cls, "action_parameters", {}).copy(),
|
||||
action_require=getattr(cls, "action_require", []).copy(),
|
||||
associated_types=getattr(cls, "associated_types", []).copy(),
|
||||
chat_type_allow=getattr(cls, "chat_type_allow", ChatType.ALL),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Tuple, Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import CommandInfo, ComponentType
|
||||
from src.plugin_system.base.component_types import CommandInfo, ComponentType, ChatType
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
@@ -26,6 +26,8 @@ class BaseCommand(ABC):
|
||||
# 默认命令设置
|
||||
command_pattern: str = r""
|
||||
"""命令匹配的正则表达式"""
|
||||
chat_type_allow: ChatType = ChatType.ALL
|
||||
"""允许的聊天类型,默认为所有类型"""
|
||||
|
||||
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
|
||||
"""初始化Command组件
|
||||
@@ -40,8 +42,19 @@ class BaseCommand(ABC):
|
||||
|
||||
self.log_prefix = "[Command]"
|
||||
|
||||
# 从类属性获取chat_type_allow设置
|
||||
self.chat_type_allow = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
|
||||
|
||||
logger.debug(f"{self.log_prefix} Command组件初始化完成")
|
||||
|
||||
# 验证聊天类型限制
|
||||
if not self._validate_chat_type():
|
||||
is_group = hasattr(self.message, 'is_group_message') and self.message.is_group_message
|
||||
logger.warning(
|
||||
f"{self.log_prefix} Command '{self.command_name}' 不支持当前聊天类型: "
|
||||
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
|
||||
)
|
||||
|
||||
def set_matched_groups(self, groups: Dict[str, str]) -> None:
|
||||
"""设置正则表达式匹配的命名组
|
||||
|
||||
@@ -50,6 +63,35 @@ class BaseCommand(ABC):
|
||||
"""
|
||||
self.matched_groups = groups
|
||||
|
||||
def _validate_chat_type(self) -> bool:
|
||||
"""验证当前聊天类型是否允许执行此Command
|
||||
|
||||
Returns:
|
||||
bool: 如果允许执行返回True,否则返回False
|
||||
"""
|
||||
if self.chat_type_allow == ChatType.ALL:
|
||||
return True
|
||||
|
||||
# 检查是否为群聊消息
|
||||
is_group = hasattr(self.message, 'is_group_message') and self.message.is_group_message
|
||||
|
||||
if self.chat_type_allow == ChatType.GROUP and is_group:
|
||||
return True
|
||||
elif self.chat_type_allow == ChatType.PRIVATE and not is_group:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_chat_type_allowed(self) -> bool:
|
||||
"""检查当前聊天类型是否允许执行此Command
|
||||
|
||||
这是一个公开的方法,供外部调用检查聊天类型限制
|
||||
|
||||
Returns:
|
||||
bool: 如果允许执行返回True,否则返回False
|
||||
"""
|
||||
return self._validate_chat_type()
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
"""执行Command的抽象方法,子类必须实现
|
||||
@@ -225,4 +267,5 @@ class BaseCommand(ABC):
|
||||
component_type=ComponentType.COMMAND,
|
||||
description=cls.command_description,
|
||||
command_pattern=cls.command_pattern,
|
||||
chat_type_allow=getattr(cls, "chat_type_allow", ChatType.ALL),
|
||||
)
|
||||
|
||||
@@ -47,6 +47,18 @@ class ChatMode(Enum):
|
||||
return self.value
|
||||
|
||||
|
||||
# 聊天类型枚举
|
||||
class ChatType(Enum):
|
||||
"""聊天类型枚举,用于限制插件在不同聊天环境中的使用"""
|
||||
|
||||
GROUP = "group" # 仅群聊可用
|
||||
PRIVATE = "private" # 仅私聊可用
|
||||
ALL = "all" # 群聊和私聊都可用
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
# 事件类型枚举
|
||||
class EventType(Enum):
|
||||
"""
|
||||
@@ -124,6 +136,7 @@ class ActionInfo(ComponentInfo):
|
||||
# 模式和并行设置
|
||||
mode_enable: ChatMode = ChatMode.ALL
|
||||
parallel_action: bool = False
|
||||
chat_type_allow: ChatType = ChatType.ALL # 允许的聊天类型
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
@@ -143,6 +156,7 @@ class CommandInfo(ComponentInfo):
|
||||
"""命令组件信息"""
|
||||
|
||||
command_pattern: str = "" # 命令匹配模式(正则表达式)
|
||||
chat_type_allow: ChatType = ChatType.ALL # 允许的聊天类型
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
@@ -11,6 +11,7 @@ from src.plugin_system import (
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import database_api
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class AtAction(BaseAction):
|
||||
@@ -21,6 +22,7 @@ class AtAction(BaseAction):
|
||||
action_description = "发送艾特消息"
|
||||
activation_type = ActionActivationType.LLM_JUDGE # 消息接收时激活(?)
|
||||
parallel_action = False
|
||||
chat_type_allow = ChatType.GROUP
|
||||
|
||||
# === 功能描述(必须填写)===
|
||||
action_parameters = {
|
||||
|
||||
@@ -15,6 +15,7 @@ from src.common.database.sqlalchemy_models import Messages, PersonInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import send_api
|
||||
from .qq_emoji_list import qq_face
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
|
||||
logger = get_logger("set_emoji_like_plugin")
|
||||
|
||||
@@ -50,6 +51,7 @@ class SetEmojiLikeAction(BaseAction):
|
||||
action_name = "set_emoji_like"
|
||||
action_description = "为消息设置表情回应/贴表情"
|
||||
activation_type = ActionActivationType.ALWAYS # 消息接收时激活(?)
|
||||
chat_type_allow = ChatType.GROUP
|
||||
parallel_action = True
|
||||
|
||||
# === 功能描述(必须填写)===
|
||||
|
||||
Reference in New Issue
Block a user