增加了组件的局部禁用方法
This commit is contained in:
@@ -28,6 +28,7 @@ from .core import (
|
||||
component_registry,
|
||||
dependency_manager,
|
||||
events_manager,
|
||||
global_announcement_manager,
|
||||
)
|
||||
|
||||
# 导入工具模块
|
||||
@@ -67,6 +68,7 @@ __all__ = [
|
||||
"component_registry",
|
||||
"dependency_manager",
|
||||
"events_manager",
|
||||
"global_announcement_manager",
|
||||
# 装饰器
|
||||
"register_plugin",
|
||||
"ConfigField",
|
||||
|
||||
@@ -65,21 +65,28 @@ class BaseAction(ABC):
|
||||
self.thinking_id = thinking_id
|
||||
self.log_prefix = log_prefix
|
||||
|
||||
# 保存插件配置
|
||||
self.plugin_config = plugin_config or {}
|
||||
"""对应的插件配置"""
|
||||
|
||||
# 设置动作基本信息实例属性
|
||||
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
|
||||
"""Action的名字"""
|
||||
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
|
||||
"""Action的描述"""
|
||||
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
|
||||
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
|
||||
|
||||
# 设置激活类型实例属性(从类属性复制,提供默认值)
|
||||
self.focus_activation_type = getattr(self.__class__, "focus_activation_type", ActionActivationType.ALWAYS)
|
||||
"""FOCUS模式下的激活类型"""
|
||||
self.normal_activation_type = getattr(self.__class__, "normal_activation_type", ActionActivationType.ALWAYS)
|
||||
"""NORMAL模式下的激活类型"""
|
||||
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
|
||||
"""当激活类型为RANDOM时的概率"""
|
||||
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "")
|
||||
"""协助LLM进行判断的Prompt"""
|
||||
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
|
||||
"""激活类型为KEYWORD时的KEYWORDS列表"""
|
||||
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
|
||||
self.mode_enable: ChatMode = getattr(self.__class__, "mode_enable", ChatMode.ALL)
|
||||
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
|
||||
|
||||
@@ -21,13 +21,18 @@ class BaseCommand(ABC):
|
||||
"""
|
||||
|
||||
command_name: str = ""
|
||||
"""Command组件的名称"""
|
||||
command_description: str = ""
|
||||
"""Command组件的描述"""
|
||||
|
||||
# 默认命令设置(子类可以覆盖)
|
||||
command_pattern: str = ""
|
||||
"""命令匹配的正则表达式"""
|
||||
command_help: str = ""
|
||||
"""命令帮助信息"""
|
||||
command_examples: List[str] = []
|
||||
intercept_message: bool = True # 默认拦截消息,不继续处理
|
||||
intercept_message: bool = True
|
||||
"""是否拦截信息,默认拦截,不进行后续处理"""
|
||||
|
||||
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
|
||||
"""初始化Command组件
|
||||
|
||||
@@ -13,16 +13,23 @@ class BaseEventHandler(ABC):
|
||||
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
|
||||
"""
|
||||
|
||||
event_type: EventType = EventType.UNKNOWN # 事件类型,默认为未知
|
||||
handler_name: str = "" # 处理器名称
|
||||
event_type: EventType = EventType.UNKNOWN
|
||||
"""事件类型,默认为未知"""
|
||||
handler_name: str = ""
|
||||
"""处理器名称"""
|
||||
handler_description: str = ""
|
||||
weight: int = 0 # 权重,数值越大优先级越高
|
||||
intercept_message: bool = False # 是否拦截消息,默认为否
|
||||
"""处理器描述"""
|
||||
weight: int = 0
|
||||
"""处理器权重,越大权重越高"""
|
||||
intercept_message: bool = False
|
||||
"""是否拦截消息,默认为否"""
|
||||
|
||||
def __init__(self):
|
||||
self.log_prefix = "[EventHandler]"
|
||||
self.plugin_name = "" # 对应插件名
|
||||
self.plugin_config: Optional[Dict] = None # 插件配置字典
|
||||
self.plugin_name = ""
|
||||
"""对应插件名"""
|
||||
self.plugin_config: Optional[Dict] = None
|
||||
"""插件配置字典"""
|
||||
if self.event_type == EventType.UNKNOWN:
|
||||
raise NotImplementedError("事件处理器必须指定 event_type")
|
||||
|
||||
|
||||
@@ -8,10 +8,12 @@ from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.core.dependency_manager import dependency_manager
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
__all__ = [
|
||||
"plugin_manager",
|
||||
"component_registry",
|
||||
"dependency_manager",
|
||||
"events_manager",
|
||||
"global_announcement_manager",
|
||||
]
|
||||
|
||||
@@ -418,7 +418,7 @@ class ComponentRegistry:
|
||||
"""获取Command模式注册表"""
|
||||
return self._command_patterns.copy()
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, bool, str]]:
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]:
|
||||
# sourcery skip: use-named-expression, use-next
|
||||
"""根据文本查找匹配的命令
|
||||
|
||||
@@ -439,8 +439,7 @@ class ComponentRegistry:
|
||||
return (
|
||||
self._command_registry[command_name],
|
||||
candidates[0].match(text).groupdict(), # type: ignore
|
||||
command_info.intercept_message,
|
||||
command_info.plugin_name,
|
||||
command_info,
|
||||
)
|
||||
|
||||
# === 事件处理器特定查询方法 ===
|
||||
|
||||
@@ -6,6 +6,7 @@ from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from .global_announcement_manager import global_announcement_manager
|
||||
|
||||
logger = get_logger("events_manager")
|
||||
|
||||
@@ -53,6 +54,10 @@ class EventsManager:
|
||||
continue_flag = True
|
||||
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
|
||||
for handler in self._events_subscribers.get(event_type, []):
|
||||
if message.chat_stream and message.chat_stream.stream_id:
|
||||
stream_id = message.chat_stream.stream_id
|
||||
if handler.handler_name in global_announcement_manager.get_disabled_chat_event_handlers(stream_id):
|
||||
continue
|
||||
handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {})
|
||||
if handler.intercept_message:
|
||||
try:
|
||||
|
||||
90
src/plugin_system/core/global_announcement_manager.py
Normal file
90
src/plugin_system/core/global_announcement_manager.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import List, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("global_announcement_manager")
|
||||
|
||||
|
||||
class GlobalAnnouncementManager:
|
||||
def __init__(self) -> None:
|
||||
# 用户禁用的动作,chat_id -> [action_name]
|
||||
self._user_disabled_actions: Dict[str, List[str]] = {}
|
||||
# 用户禁用的命令,chat_id -> [command_name]
|
||||
self._user_disabled_commands: Dict[str, List[str]] = {}
|
||||
# 用户禁用的事件处理器,chat_id -> [handler_name]
|
||||
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
|
||||
|
||||
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
|
||||
"""禁用特定聊天的某个动作"""
|
||||
if chat_id not in self._user_disabled_actions:
|
||||
self._user_disabled_actions[chat_id] = []
|
||||
if action_name in self._user_disabled_actions[chat_id]:
|
||||
logger.warning(f"动作 {action_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_actions[chat_id].append(action_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
|
||||
"""启用特定聊天的某个动作"""
|
||||
if chat_id in self._user_disabled_actions:
|
||||
try:
|
||||
self._user_disabled_actions[chat_id].remove(action_name)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
return False
|
||||
|
||||
def disable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
|
||||
"""禁用特定聊天的某个命令"""
|
||||
if chat_id not in self._user_disabled_commands:
|
||||
self._user_disabled_commands[chat_id] = []
|
||||
if command_name in self._user_disabled_commands[chat_id]:
|
||||
logger.warning(f"命令 {command_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_commands[chat_id].append(command_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
|
||||
"""启用特定聊天的某个命令"""
|
||||
if chat_id in self._user_disabled_commands:
|
||||
try:
|
||||
self._user_disabled_commands[chat_id].remove(command_name)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
return False
|
||||
|
||||
def disable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
|
||||
"""禁用特定聊天的某个事件处理器"""
|
||||
if chat_id not in self._user_disabled_event_handlers:
|
||||
self._user_disabled_event_handlers[chat_id] = []
|
||||
if handler_name in self._user_disabled_event_handlers[chat_id]:
|
||||
logger.warning(f"事件处理器 {handler_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_event_handlers[chat_id].append(handler_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
|
||||
"""启用特定聊天的某个事件处理器"""
|
||||
if chat_id in self._user_disabled_event_handlers:
|
||||
try:
|
||||
self._user_disabled_event_handlers[chat_id].remove(handler_name)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
return False
|
||||
|
||||
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有动作"""
|
||||
return self._user_disabled_actions.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_commands(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有命令"""
|
||||
return self._user_disabled_commands.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有事件处理器"""
|
||||
return self._user_disabled_event_handlers.get(chat_id, []).copy()
|
||||
|
||||
|
||||
global_announcement_manager = GlobalAnnouncementManager()
|
||||
Reference in New Issue
Block a user