增加了组件的局部禁用方法

This commit is contained in:
UnCLAS-Prommer
2025-07-23 00:41:31 +08:00
parent 87dd9a3756
commit 10bf424540
14 changed files with 195 additions and 162 deletions

View File

@@ -28,6 +28,7 @@ from .core import (
component_registry,
dependency_manager,
events_manager,
global_announcement_manager,
)
# 导入工具模块
@@ -67,6 +68,7 @@ __all__ = [
"component_registry",
"dependency_manager",
"events_manager",
"global_announcement_manager",
# 装饰器
"register_plugin",
"ConfigField",

View File

@@ -65,21 +65,28 @@ class BaseAction(ABC):
self.thinking_id = thinking_id
self.log_prefix = log_prefix
# 保存插件配置
self.plugin_config = plugin_config or {}
"""对应的插件配置"""
# 设置动作基本信息实例属性
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
"""Action的名字"""
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
"""Action的描述"""
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
# 设置激活类型实例属性(从类属性复制,提供默认值)
self.focus_activation_type = getattr(self.__class__, "focus_activation_type", ActionActivationType.ALWAYS)
"""FOCUS模式下的激活类型"""
self.normal_activation_type = getattr(self.__class__, "normal_activation_type", ActionActivationType.ALWAYS)
"""NORMAL模式下的激活类型"""
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
"""当激活类型为RANDOM时的概率"""
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "")
"""协助LLM进行判断的Prompt"""
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
"""激活类型为KEYWORD时的KEYWORDS列表"""
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
self.mode_enable: ChatMode = getattr(self.__class__, "mode_enable", ChatMode.ALL)
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)

View File

@@ -21,13 +21,18 @@ class BaseCommand(ABC):
"""
command_name: str = ""
"""Command组件的名称"""
command_description: str = ""
"""Command组件的描述"""
# 默认命令设置(子类可以覆盖)
command_pattern: str = ""
"""命令匹配的正则表达式"""
command_help: str = ""
"""命令帮助信息"""
command_examples: List[str] = []
intercept_message: bool = True # 默认拦截消息,不继续处理
intercept_message: bool = True
"""是否拦截信息,默认拦截,不进行后续处理"""
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
"""初始化Command组件

View File

@@ -13,16 +13,23 @@ class BaseEventHandler(ABC):
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
"""
event_type: EventType = EventType.UNKNOWN # 事件类型,默认为未知
handler_name: str = "" # 处理器名称
event_type: EventType = EventType.UNKNOWN
"""事件类型,默认为未知"""
handler_name: str = ""
"""处理器名称"""
handler_description: str = ""
weight: int = 0 # 权重,数值越大优先级越高
intercept_message: bool = False # 是否拦截消息,默认为否
"""处理器描述"""
weight: int = 0
"""处理器权重,越大权重越高"""
intercept_message: bool = False
"""是否拦截消息,默认为否"""
def __init__(self):
self.log_prefix = "[EventHandler]"
self.plugin_name = "" # 对应插件名
self.plugin_config: Optional[Dict] = None # 插件配置字典
self.plugin_name = ""
"""对应插件名"""
self.plugin_config: Optional[Dict] = None
"""插件配置字典"""
if self.event_type == EventType.UNKNOWN:
raise NotImplementedError("事件处理器必须指定 event_type")

View File

@@ -8,10 +8,12 @@ from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.dependency_manager import dependency_manager
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
__all__ = [
"plugin_manager",
"component_registry",
"dependency_manager",
"events_manager",
"global_announcement_manager",
]

View File

@@ -418,7 +418,7 @@ class ComponentRegistry:
"""获取Command模式注册表"""
return self._command_patterns.copy()
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, bool, str]]:
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]:
# sourcery skip: use-named-expression, use-next
"""根据文本查找匹配的命令
@@ -439,8 +439,7 @@ class ComponentRegistry:
return (
self._command_registry[command_name],
candidates[0].match(text).groupdict(), # type: ignore
command_info.intercept_message,
command_info.plugin_name,
command_info,
)
# === 事件处理器特定查询方法 ===

View File

@@ -6,6 +6,7 @@ from src.chat.message_receive.message import MessageRecv
from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages
from src.plugin_system.base.base_events_handler import BaseEventHandler
from .global_announcement_manager import global_announcement_manager
logger = get_logger("events_manager")
@@ -53,6 +54,10 @@ class EventsManager:
continue_flag = True
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
for handler in self._events_subscribers.get(event_type, []):
if message.chat_stream and message.chat_stream.stream_id:
stream_id = message.chat_stream.stream_id
if handler.handler_name in global_announcement_manager.get_disabled_chat_event_handlers(stream_id):
continue
handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {})
if handler.intercept_message:
try:

View File

@@ -0,0 +1,90 @@
from typing import List, Dict
from src.common.logger import get_logger
logger = get_logger("global_announcement_manager")
class GlobalAnnouncementManager:
def __init__(self) -> None:
# 用户禁用的动作chat_id -> [action_name]
self._user_disabled_actions: Dict[str, List[str]] = {}
# 用户禁用的命令chat_id -> [command_name]
self._user_disabled_commands: Dict[str, List[str]] = {}
# 用户禁用的事件处理器chat_id -> [handler_name]
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
"""禁用特定聊天的某个动作"""
if chat_id not in self._user_disabled_actions:
self._user_disabled_actions[chat_id] = []
if action_name in self._user_disabled_actions[chat_id]:
logger.warning(f"动作 {action_name} 已经被禁用")
return False
self._user_disabled_actions[chat_id].append(action_name)
return True
def enable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
"""启用特定聊天的某个动作"""
if chat_id in self._user_disabled_actions:
try:
self._user_disabled_actions[chat_id].remove(action_name)
return True
except ValueError:
return False
return False
def disable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
"""禁用特定聊天的某个命令"""
if chat_id not in self._user_disabled_commands:
self._user_disabled_commands[chat_id] = []
if command_name in self._user_disabled_commands[chat_id]:
logger.warning(f"命令 {command_name} 已经被禁用")
return False
self._user_disabled_commands[chat_id].append(command_name)
return True
def enable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
"""启用特定聊天的某个命令"""
if chat_id in self._user_disabled_commands:
try:
self._user_disabled_commands[chat_id].remove(command_name)
return True
except ValueError:
return False
return False
def disable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
"""禁用特定聊天的某个事件处理器"""
if chat_id not in self._user_disabled_event_handlers:
self._user_disabled_event_handlers[chat_id] = []
if handler_name in self._user_disabled_event_handlers[chat_id]:
logger.warning(f"事件处理器 {handler_name} 已经被禁用")
return False
self._user_disabled_event_handlers[chat_id].append(handler_name)
return True
def enable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
"""启用特定聊天的某个事件处理器"""
if chat_id in self._user_disabled_event_handlers:
try:
self._user_disabled_event_handlers[chat_id].remove(handler_name)
return True
except ValueError:
return False
return False
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有动作"""
return self._user_disabled_actions.get(chat_id, []).copy()
def get_disabled_chat_commands(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有命令"""
return self._user_disabled_commands.get(chat_id, []).copy()
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有事件处理器"""
return self._user_disabled_event_handlers.get(chat_id, []).copy()
global_announcement_manager = GlobalAnnouncementManager()