增加了组件的局部禁用方法
This commit is contained in:
@@ -52,6 +52,8 @@ NO_ACTION = {
|
||||
"action_prompt": "",
|
||||
}
|
||||
|
||||
IS_MAI4U = False
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
# 注释:原来的动作修改超时常量已移除,因为改为顺序执行
|
||||
@@ -263,7 +265,7 @@ class HeartFChatting:
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform="amaidesu_default",
|
||||
user_info=None, # type: ignore
|
||||
user_info=None,
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
@@ -276,7 +278,7 @@ class HeartFChatting:
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform="amaidesu_default",
|
||||
user_info=None, # type: ignore
|
||||
user_info=None,
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
@@ -294,7 +296,8 @@ class HeartFChatting:
|
||||
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考[模式:{self.loop_mode}]")
|
||||
|
||||
await self.send_typing()
|
||||
if IS_MAI4U:
|
||||
await self.send_typing()
|
||||
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
loop_start_time = time.time()
|
||||
@@ -364,7 +367,8 @@ class HeartFChatting:
|
||||
# 发送回复 (不再需要传入 chat)
|
||||
reply_text = await self._send_response(response_set, reply_to_str, loop_start_time, message_data)
|
||||
|
||||
await self.stop_typing()
|
||||
if IS_MAI4U:
|
||||
await self.stop_typing()
|
||||
|
||||
if ENABLE_THINKING:
|
||||
await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text)
|
||||
|
||||
@@ -13,10 +13,9 @@ from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.plugin_system.core import component_registry, events_manager # 导入新插件系统
|
||||
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
# 定义日志配置
|
||||
|
||||
@@ -92,8 +91,20 @@ class ChatBot:
|
||||
# 使用新的组件注册中心查找命令
|
||||
command_result = component_registry.find_command_by_text(text)
|
||||
if command_result:
|
||||
command_class, matched_groups, command_info = command_result
|
||||
intercept_message = command_info.intercept_message
|
||||
plugin_name = command_info.plugin_name
|
||||
command_name = command_info.name
|
||||
if (
|
||||
message.chat_stream
|
||||
and message.chat_stream.stream_id
|
||||
and command_name
|
||||
in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id)
|
||||
):
|
||||
logger.info("用户禁用的命令,跳过处理")
|
||||
return False, None, True
|
||||
|
||||
message.is_command = True
|
||||
command_class, matched_groups, intercept_message, plugin_name = command_result
|
||||
|
||||
# 获取插件配置
|
||||
plugin_config = component_registry.get_plugin_config(plugin_name)
|
||||
@@ -135,13 +146,12 @@ class ChatBot:
|
||||
except Exception as e:
|
||||
logger.error(f"处理命令时出错: {e}")
|
||||
return False, None, True # 出错时继续处理消息
|
||||
|
||||
|
||||
async def hanle_notice_message(self, message: MessageRecv):
|
||||
if message.message_info.message_id == "notice":
|
||||
logger.info("收到notice消息,暂时不支持处理")
|
||||
return True
|
||||
|
||||
|
||||
|
||||
async def do_s4u(self, message_data: Dict[str, Any]):
|
||||
message = MessageRecvS4U(message_data)
|
||||
group_info = message.message_info.group_info
|
||||
@@ -163,7 +173,6 @@ class ChatBot:
|
||||
|
||||
return
|
||||
|
||||
|
||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||
"""处理转化后的统一格式消息
|
||||
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
|
||||
@@ -179,8 +188,6 @@ class ChatBot:
|
||||
- 性能计时
|
||||
"""
|
||||
try:
|
||||
|
||||
|
||||
# 确保所有任务已启动
|
||||
await self._ensure_started()
|
||||
|
||||
@@ -201,11 +208,10 @@ class ChatBot:
|
||||
# print(message_data)
|
||||
# logger.debug(str(message_data))
|
||||
message = MessageRecv(message_data)
|
||||
|
||||
|
||||
if await self.hanle_notice_message(message):
|
||||
return
|
||||
|
||||
|
||||
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
if message.message_info.additional_config:
|
||||
@@ -229,11 +235,10 @@ class ChatBot:
|
||||
|
||||
# 处理消息内容,生成纯文本
|
||||
await message.process()
|
||||
|
||||
|
||||
# if await self.check_ban_content(message):
|
||||
# logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}")
|
||||
# return
|
||||
|
||||
|
||||
# 过滤检查
|
||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import traceback
|
||||
|
||||
from typing import Dict, Optional, Type
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
@@ -24,38 +22,13 @@ class ActionManager:
|
||||
|
||||
def __init__(self):
|
||||
"""初始化动作管理器"""
|
||||
# 所有注册的动作集合
|
||||
self._registered_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
# 当前正在使用的动作集合,默认加载默认动作
|
||||
self._using_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
# 初始化管理器注册表
|
||||
self._load_plugin_system_actions()
|
||||
|
||||
# 初始化时将默认动作加载到使用中的动作
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
|
||||
def _load_plugin_system_actions(self) -> None:
|
||||
"""从插件系统的component_registry加载Action组件"""
|
||||
try:
|
||||
# 获取所有Action组件
|
||||
action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION) # type: ignore
|
||||
|
||||
for action_name, action_info in action_components.items():
|
||||
if action_name in self._registered_actions:
|
||||
logger.debug(f"Action组件 {action_name} 已存在,跳过")
|
||||
continue
|
||||
|
||||
self._registered_actions[action_name] = action_info
|
||||
|
||||
logger.debug(f"从插件系统加载Action组件: {action_name} (插件: {action_info.plugin_name})")
|
||||
|
||||
logger.info(f"加载了 {len(action_components)} 个Action动作")
|
||||
logger.debug("从插件系统加载Action组件成功")
|
||||
except Exception as e:
|
||||
logger.error(f"从插件系统加载Action组件失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# === 执行Action方法 ===
|
||||
|
||||
def create_action(
|
||||
@@ -127,44 +100,10 @@ class ActionManager:
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def get_registered_actions(self) -> Dict[str, ActionInfo]:
|
||||
"""获取所有已注册的动作集"""
|
||||
return self._registered_actions.copy()
|
||||
|
||||
def get_using_actions(self) -> Dict[str, ActionInfo]:
|
||||
"""获取当前正在使用的动作集合"""
|
||||
return self._using_actions.copy()
|
||||
|
||||
# === 增删Action方法 ===
|
||||
def add_action(self, action_name: str) -> bool:
|
||||
"""增加一个Action到管理器
|
||||
|
||||
Parameters:
|
||||
action_name: 动作名称
|
||||
Returns:
|
||||
bool: 添加是否成功
|
||||
"""
|
||||
if action_name in self._registered_actions:
|
||||
return True
|
||||
component_info: ActionInfo = component_registry.get_component_info(action_name, ComponentType.ACTION) # type: ignore
|
||||
if not component_info:
|
||||
logger.warning(f"添加失败: 动作 {action_name} 未注册")
|
||||
return False
|
||||
self._registered_actions[action_name] = component_info
|
||||
return True
|
||||
|
||||
def remove_action(self, action_name: str) -> bool:
|
||||
"""从注册集移除指定动作
|
||||
Parameters:
|
||||
action_name: 动作名称
|
||||
Returns:
|
||||
bool: 移除是否成功
|
||||
"""
|
||||
if action_name not in self._registered_actions:
|
||||
return False
|
||||
del self._registered_actions[action_name]
|
||||
return True
|
||||
|
||||
# === Modify相关方法 ===
|
||||
def remove_action_from_using(self, action_name: str) -> bool:
|
||||
"""
|
||||
@@ -189,47 +128,3 @@ class ActionManager:
|
||||
actions_to_restore = list(self._using_actions.keys())
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||
|
||||
# def add_action_to_using(self, action_name: str) -> bool:
|
||||
|
||||
# """
|
||||
# 添加已注册的动作到当前使用的动作集
|
||||
|
||||
# Args:
|
||||
# action_name: 动作名称
|
||||
|
||||
# Returns:
|
||||
# bool: 添加是否成功
|
||||
# """
|
||||
# if action_name not in self._registered_actions:
|
||||
# logger.warning(f"添加失败: 动作 {action_name} 未注册")
|
||||
# return False
|
||||
|
||||
# if action_name in self._using_actions:
|
||||
# logger.info(f"动作 {action_name} 已经在使用中")
|
||||
# return True
|
||||
|
||||
# self._using_actions[action_name] = self._registered_actions[action_name]
|
||||
# logger.info(f"添加动作 {action_name} 到使用集")
|
||||
# return True
|
||||
|
||||
# def temporarily_remove_actions(self, actions_to_remove: List[str]) -> None:
|
||||
# """临时移除使用集中的指定动作"""
|
||||
# for name in actions_to_remove:
|
||||
# self._using_actions.pop(name, None)
|
||||
|
||||
# def add_system_action_if_needed(self, action_name: str) -> bool:
|
||||
# """
|
||||
# 根据需要添加系统动作到使用集
|
||||
|
||||
# Args:
|
||||
# action_name: 动作名称
|
||||
|
||||
# Returns:
|
||||
# bool: 是否成功添加
|
||||
# """
|
||||
# if action_name in self._registered_actions and action_name not in self._using_actions:
|
||||
# self._using_actions[action_name] = self._registered_actions[action_name]
|
||||
# logger.info(f"临时添加系统动作到使用集: {action_name}")
|
||||
# return True
|
||||
# return False
|
||||
|
||||
@@ -2,7 +2,7 @@ import random
|
||||
import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
from typing import List, Any, Dict, TYPE_CHECKING
|
||||
from typing import List, Any, Dict, TYPE_CHECKING, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -11,6 +11,7 @@ from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageCo
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
||||
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
@@ -60,8 +61,9 @@ class ActionModifier:
|
||||
"""
|
||||
logger.debug(f"{self.log_prefix}开始完整动作修改流程")
|
||||
|
||||
removals_s1 = []
|
||||
removals_s2 = []
|
||||
removals_s1: List[Tuple[str, str]] = []
|
||||
removals_s2: List[Tuple[str, str]] = []
|
||||
removals_s3: List[Tuple[str, str]] = []
|
||||
|
||||
self.action_manager.restore_actions()
|
||||
all_actions = self.action_manager.get_using_actions()
|
||||
@@ -83,25 +85,28 @@ class ActionModifier:
|
||||
if message_content:
|
||||
chat_content = chat_content + "\n" + f"现在,最新的消息是:{message_content}"
|
||||
|
||||
# === 第一阶段:传统观察处理 ===
|
||||
# if history_loop:
|
||||
# removals_from_loop = await self.analyze_loop_actions(history_loop)
|
||||
# if removals_from_loop:
|
||||
# removals_s1.extend(removals_from_loop)
|
||||
# === 第一阶段:去除用户自行禁用的 ===
|
||||
disabled_actions = global_announcement_manager.get_disabled_chat_actions(self.chat_id)
|
||||
if disabled_actions:
|
||||
for disabled_action_name in disabled_actions:
|
||||
if disabled_action_name in all_actions:
|
||||
removals_s1.append((disabled_action_name, "用户自行禁用"))
|
||||
self.action_manager.remove_action_from_using(disabled_action_name)
|
||||
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
|
||||
|
||||
# 检查动作的关联类型
|
||||
# === 第二阶段:检查动作的关联类型 ===
|
||||
chat_context = self.chat_stream.context
|
||||
type_mismatched_actions = self._check_action_associated_types(all_actions, chat_context)
|
||||
|
||||
if type_mismatched_actions:
|
||||
removals_s1.extend(type_mismatched_actions)
|
||||
removals_s2.extend(type_mismatched_actions)
|
||||
|
||||
# 应用第一阶段的移除
|
||||
for action_name, reason in removals_s1:
|
||||
# 应用第二阶段的移除
|
||||
for action_name, reason in removals_s2:
|
||||
self.action_manager.remove_action_from_using(action_name)
|
||||
logger.debug(f"{self.log_prefix}阶段一移除动作: {action_name},原因: {reason}")
|
||||
logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}")
|
||||
|
||||
# === 第二阶段:激活类型判定 ===
|
||||
# === 第三阶段:激活类型判定 ===
|
||||
if chat_content is not None:
|
||||
logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
||||
|
||||
@@ -109,18 +114,18 @@ class ActionModifier:
|
||||
current_using_actions = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取因激活类型判定而需要移除的动作
|
||||
removals_s2 = await self._get_deactivated_actions_by_type(
|
||||
removals_s3 = await self._get_deactivated_actions_by_type(
|
||||
current_using_actions,
|
||||
chat_content,
|
||||
)
|
||||
|
||||
# 应用第二阶段的移除
|
||||
for action_name, reason in removals_s2:
|
||||
# 应用第三阶段的移除
|
||||
for action_name, reason in removals_s3:
|
||||
self.action_manager.remove_action_from_using(action_name)
|
||||
logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}")
|
||||
logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
|
||||
|
||||
# === 统一日志记录 ===
|
||||
all_removals = removals_s1 + removals_s2
|
||||
all_removals = removals_s1 + removals_s2 + removals_s3
|
||||
removals_summary: str = ""
|
||||
if all_removals:
|
||||
removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals])
|
||||
@@ -130,7 +135,7 @@ class ActionModifier:
|
||||
)
|
||||
|
||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
||||
type_mismatched_actions = []
|
||||
type_mismatched_actions: List[Tuple[str, str]] = []
|
||||
for action_name, action_info in all_actions.items():
|
||||
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
|
||||
associated_types_str = ", ".join(action_info.associated_types)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from typing import Dict, Any, Optional, Tuple, List
|
||||
from rich.traceback import install
|
||||
from datetime import datetime
|
||||
from json_repair import repair_json
|
||||
@@ -19,8 +19,8 @@ from src.chat.utils.chat_message_builder import (
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
||||
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ComponentType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
logger = get_logger("planner")
|
||||
|
||||
@@ -99,7 +99,7 @@ class ActionPlanner:
|
||||
|
||||
async def plan(
|
||||
self, mode: ChatMode = ChatMode.FOCUS
|
||||
) -> Tuple[Dict[str, Dict[str, Any] | str], Optional[Dict[str, Any]]]: # sourcery skip: dict-comprehension
|
||||
) -> Tuple[Dict[str, Dict[str, Any] | str], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||
"""
|
||||
@@ -119,9 +119,11 @@ class ActionPlanner:
|
||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取完整的动作信息
|
||||
all_registered_actions = self.action_manager.get_registered_actions()
|
||||
|
||||
for action_name in current_available_actions_dict.keys():
|
||||
all_registered_actions: List[ActionInfo] = list(
|
||||
component_registry.get_components_by_type(ComponentType.ACTION).values() # type: ignore
|
||||
)
|
||||
current_available_actions = {}
|
||||
for action_name in current_available_actions_dict:
|
||||
if action_name in all_registered_actions:
|
||||
current_available_actions[action_name] = all_registered_actions[action_name]
|
||||
else:
|
||||
|
||||
@@ -28,6 +28,7 @@ from .core import (
|
||||
component_registry,
|
||||
dependency_manager,
|
||||
events_manager,
|
||||
global_announcement_manager,
|
||||
)
|
||||
|
||||
# 导入工具模块
|
||||
@@ -67,6 +68,7 @@ __all__ = [
|
||||
"component_registry",
|
||||
"dependency_manager",
|
||||
"events_manager",
|
||||
"global_announcement_manager",
|
||||
# 装饰器
|
||||
"register_plugin",
|
||||
"ConfigField",
|
||||
|
||||
@@ -65,21 +65,28 @@ class BaseAction(ABC):
|
||||
self.thinking_id = thinking_id
|
||||
self.log_prefix = log_prefix
|
||||
|
||||
# 保存插件配置
|
||||
self.plugin_config = plugin_config or {}
|
||||
"""对应的插件配置"""
|
||||
|
||||
# 设置动作基本信息实例属性
|
||||
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
|
||||
"""Action的名字"""
|
||||
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
|
||||
"""Action的描述"""
|
||||
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
|
||||
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
|
||||
|
||||
# 设置激活类型实例属性(从类属性复制,提供默认值)
|
||||
self.focus_activation_type = getattr(self.__class__, "focus_activation_type", ActionActivationType.ALWAYS)
|
||||
"""FOCUS模式下的激活类型"""
|
||||
self.normal_activation_type = getattr(self.__class__, "normal_activation_type", ActionActivationType.ALWAYS)
|
||||
"""NORMAL模式下的激活类型"""
|
||||
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
|
||||
"""当激活类型为RANDOM时的概率"""
|
||||
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "")
|
||||
"""协助LLM进行判断的Prompt"""
|
||||
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
|
||||
"""激活类型为KEYWORD时的KEYWORDS列表"""
|
||||
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
|
||||
self.mode_enable: ChatMode = getattr(self.__class__, "mode_enable", ChatMode.ALL)
|
||||
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
|
||||
|
||||
@@ -21,13 +21,18 @@ class BaseCommand(ABC):
|
||||
"""
|
||||
|
||||
command_name: str = ""
|
||||
"""Command组件的名称"""
|
||||
command_description: str = ""
|
||||
"""Command组件的描述"""
|
||||
|
||||
# 默认命令设置(子类可以覆盖)
|
||||
command_pattern: str = ""
|
||||
"""命令匹配的正则表达式"""
|
||||
command_help: str = ""
|
||||
"""命令帮助信息"""
|
||||
command_examples: List[str] = []
|
||||
intercept_message: bool = True # 默认拦截消息,不继续处理
|
||||
intercept_message: bool = True
|
||||
"""是否拦截信息,默认拦截,不进行后续处理"""
|
||||
|
||||
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
|
||||
"""初始化Command组件
|
||||
|
||||
@@ -13,16 +13,23 @@ class BaseEventHandler(ABC):
|
||||
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
|
||||
"""
|
||||
|
||||
event_type: EventType = EventType.UNKNOWN # 事件类型,默认为未知
|
||||
handler_name: str = "" # 处理器名称
|
||||
event_type: EventType = EventType.UNKNOWN
|
||||
"""事件类型,默认为未知"""
|
||||
handler_name: str = ""
|
||||
"""处理器名称"""
|
||||
handler_description: str = ""
|
||||
weight: int = 0 # 权重,数值越大优先级越高
|
||||
intercept_message: bool = False # 是否拦截消息,默认为否
|
||||
"""处理器描述"""
|
||||
weight: int = 0
|
||||
"""处理器权重,越大权重越高"""
|
||||
intercept_message: bool = False
|
||||
"""是否拦截消息,默认为否"""
|
||||
|
||||
def __init__(self):
|
||||
self.log_prefix = "[EventHandler]"
|
||||
self.plugin_name = "" # 对应插件名
|
||||
self.plugin_config: Optional[Dict] = None # 插件配置字典
|
||||
self.plugin_name = ""
|
||||
"""对应插件名"""
|
||||
self.plugin_config: Optional[Dict] = None
|
||||
"""插件配置字典"""
|
||||
if self.event_type == EventType.UNKNOWN:
|
||||
raise NotImplementedError("事件处理器必须指定 event_type")
|
||||
|
||||
|
||||
@@ -8,10 +8,12 @@ from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.core.dependency_manager import dependency_manager
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
__all__ = [
|
||||
"plugin_manager",
|
||||
"component_registry",
|
||||
"dependency_manager",
|
||||
"events_manager",
|
||||
"global_announcement_manager",
|
||||
]
|
||||
|
||||
@@ -418,7 +418,7 @@ class ComponentRegistry:
|
||||
"""获取Command模式注册表"""
|
||||
return self._command_patterns.copy()
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, bool, str]]:
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]:
|
||||
# sourcery skip: use-named-expression, use-next
|
||||
"""根据文本查找匹配的命令
|
||||
|
||||
@@ -439,8 +439,7 @@ class ComponentRegistry:
|
||||
return (
|
||||
self._command_registry[command_name],
|
||||
candidates[0].match(text).groupdict(), # type: ignore
|
||||
command_info.intercept_message,
|
||||
command_info.plugin_name,
|
||||
command_info,
|
||||
)
|
||||
|
||||
# === 事件处理器特定查询方法 ===
|
||||
|
||||
@@ -6,6 +6,7 @@ from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from .global_announcement_manager import global_announcement_manager
|
||||
|
||||
logger = get_logger("events_manager")
|
||||
|
||||
@@ -53,6 +54,10 @@ class EventsManager:
|
||||
continue_flag = True
|
||||
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
|
||||
for handler in self._events_subscribers.get(event_type, []):
|
||||
if message.chat_stream and message.chat_stream.stream_id:
|
||||
stream_id = message.chat_stream.stream_id
|
||||
if handler.handler_name in global_announcement_manager.get_disabled_chat_event_handlers(stream_id):
|
||||
continue
|
||||
handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {})
|
||||
if handler.intercept_message:
|
||||
try:
|
||||
|
||||
90
src/plugin_system/core/global_announcement_manager.py
Normal file
90
src/plugin_system/core/global_announcement_manager.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import List, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("global_announcement_manager")
|
||||
|
||||
|
||||
class GlobalAnnouncementManager:
|
||||
def __init__(self) -> None:
|
||||
# 用户禁用的动作,chat_id -> [action_name]
|
||||
self._user_disabled_actions: Dict[str, List[str]] = {}
|
||||
# 用户禁用的命令,chat_id -> [command_name]
|
||||
self._user_disabled_commands: Dict[str, List[str]] = {}
|
||||
# 用户禁用的事件处理器,chat_id -> [handler_name]
|
||||
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
|
||||
|
||||
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
|
||||
"""禁用特定聊天的某个动作"""
|
||||
if chat_id not in self._user_disabled_actions:
|
||||
self._user_disabled_actions[chat_id] = []
|
||||
if action_name in self._user_disabled_actions[chat_id]:
|
||||
logger.warning(f"动作 {action_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_actions[chat_id].append(action_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
|
||||
"""启用特定聊天的某个动作"""
|
||||
if chat_id in self._user_disabled_actions:
|
||||
try:
|
||||
self._user_disabled_actions[chat_id].remove(action_name)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
return False
|
||||
|
||||
def disable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
|
||||
"""禁用特定聊天的某个命令"""
|
||||
if chat_id not in self._user_disabled_commands:
|
||||
self._user_disabled_commands[chat_id] = []
|
||||
if command_name in self._user_disabled_commands[chat_id]:
|
||||
logger.warning(f"命令 {command_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_commands[chat_id].append(command_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
|
||||
"""启用特定聊天的某个命令"""
|
||||
if chat_id in self._user_disabled_commands:
|
||||
try:
|
||||
self._user_disabled_commands[chat_id].remove(command_name)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
return False
|
||||
|
||||
def disable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
|
||||
"""禁用特定聊天的某个事件处理器"""
|
||||
if chat_id not in self._user_disabled_event_handlers:
|
||||
self._user_disabled_event_handlers[chat_id] = []
|
||||
if handler_name in self._user_disabled_event_handlers[chat_id]:
|
||||
logger.warning(f"事件处理器 {handler_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_event_handlers[chat_id].append(handler_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
|
||||
"""启用特定聊天的某个事件处理器"""
|
||||
if chat_id in self._user_disabled_event_handlers:
|
||||
try:
|
||||
self._user_disabled_event_handlers[chat_id].remove(handler_name)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
return False
|
||||
|
||||
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有动作"""
|
||||
return self._user_disabled_actions.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_commands(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有命令"""
|
||||
return self._user_disabled_commands.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有事件处理器"""
|
||||
return self._user_disabled_event_handlers.get(chat_id, []).copy()
|
||||
|
||||
|
||||
global_announcement_manager = GlobalAnnouncementManager()
|
||||
Reference in New Issue
Block a user