refactor(chat): 简化Action和PlusCommand的调用预处理
移除 `ChatBot` 和 `ActionModifier` 中用于过滤禁用组件的模板代码。 这两个模块现在直接从 `ComponentRegistry` 获取为当前聊天会话(`stream_id`)定制的可用组件列表。所有关于组件是否启用的判断逻辑都已下沉到 `plugin_system` 核心中,使得上层调用代码更清晰,且不再需要依赖 `global_announcement_manager` 来进行手动过滤。
This commit is contained in:
@@ -14,7 +14,7 @@ from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.plugin_system.base import BaseCommand, EventType, ComponentType
|
||||
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
|
||||
|
||||
# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
|
||||
@@ -118,20 +118,18 @@ class ChatBot:
|
||||
args_text = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
# 查找匹配的PlusCommand
|
||||
plus_command_registry = component_registry.get_plus_command_registry()
|
||||
available_commands_info = component_registry.get_available_plus_commands_info(chat.stream_id)
|
||||
matching_commands = []
|
||||
|
||||
for plus_command_name, plus_command_class in plus_command_registry.items():
|
||||
plus_command_info = component_registry.get_registered_plus_command_info(plus_command_name)
|
||||
if not plus_command_info:
|
||||
continue
|
||||
|
||||
for plus_command_name, plus_command_info in available_commands_info.items():
|
||||
# 检查命令名是否匹配(命令名和别名)
|
||||
all_commands = [plus_command_name.lower()] + [
|
||||
all_aliases = [plus_command_name.lower()] + [
|
||||
alias.lower() for alias in plus_command_info.command_aliases
|
||||
]
|
||||
if command_word in all_commands:
|
||||
matching_commands.append((plus_command_class, plus_command_info, plus_command_name))
|
||||
if command_word in all_aliases:
|
||||
plus_command_class = component_registry.get_component_class(plus_command_name, ComponentType.PLUS_COMMAND)
|
||||
if plus_command_class:
|
||||
matching_commands.append((plus_command_class, plus_command_info, plus_command_name))
|
||||
|
||||
if not matching_commands:
|
||||
return False, None, True # 没有找到匹配的PlusCommand,继续处理
|
||||
@@ -145,16 +143,6 @@ class ChatBot:
|
||||
|
||||
plus_command_class, plus_command_info, plus_command_name = matching_commands[0]
|
||||
|
||||
# 检查命令是否被禁用
|
||||
if (
|
||||
chat
|
||||
and chat.stream_id
|
||||
and plus_command_name
|
||||
in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
|
||||
):
|
||||
logger.info("用户禁用的PlusCommand,跳过处理")
|
||||
return False, None, True
|
||||
|
||||
message.is_command = True
|
||||
|
||||
# 获取插件配置
|
||||
|
||||
@@ -27,11 +27,9 @@ class ChatterActionManager:
|
||||
def __init__(self):
|
||||
"""初始化动作管理器"""
|
||||
|
||||
# 当前正在使用的动作集合,默认加载默认动作
|
||||
# 当前正在使用的动作集合,在规划开始时加载
|
||||
self._using_actions: dict[str, ActionInfo] = {}
|
||||
|
||||
# 初始化时将默认动作加载到使用中的动作
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
self.chat_id: str | None = None
|
||||
|
||||
self.log_prefix: str = "ChatterActionManager"
|
||||
# 批量存储支持
|
||||
@@ -39,6 +37,12 @@ class ChatterActionManager:
|
||||
self._pending_actions = []
|
||||
self._current_chat_id = None
|
||||
|
||||
async def load_actions(self, stream_id: str | None):
|
||||
"""根据 stream_id 加载当前可用的动作"""
|
||||
self.chat_id = stream_id
|
||||
self._using_actions = component_registry.get_default_actions(stream_id)
|
||||
logger.debug(f"已为 stream '{stream_id}' 加载 {len(self._using_actions)} 个可用动作: {list(self._using_actions.keys())}")
|
||||
|
||||
# === 执行Action方法 ===
|
||||
|
||||
@staticmethod
|
||||
@@ -133,11 +137,12 @@ class ChatterActionManager:
|
||||
logger.debug(f"已从使用集中移除动作 {action_name}")
|
||||
return True
|
||||
|
||||
def restore_actions(self) -> None:
|
||||
"""恢复到默认动作集"""
|
||||
async def restore_actions(self) -> None:
|
||||
"""恢复到当前 stream_id 的默认动作集"""
|
||||
actions_to_restore = list(self._using_actions.keys())
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||
# 使用 self.chat_id 来恢复当前上下文的动作
|
||||
await self.load_actions(self.chat_id)
|
||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到 stream '{self.chat_id}' 的默认动作集 {list(self._using_actions.keys())}")
|
||||
|
||||
async def execute_action(
|
||||
self,
|
||||
|
||||
@@ -11,7 +11,7 @@ from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system.base.component_types import ActionInfo
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
@@ -68,6 +68,16 @@ class ActionModifier:
|
||||
"""
|
||||
# 初始化log_prefix
|
||||
await self._initialize_log_prefix()
|
||||
# 根据 stream_id 加载当前可用的动作
|
||||
await self.action_manager.load_actions(self.chat_id)
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
# 计算并记录禁用的动作数量
|
||||
all_registered_actions = component_registry.get_components_by_type(ComponentType.ACTION)
|
||||
loaded_actions_count = len(self.action_manager.get_using_actions())
|
||||
disabled_actions_count = len(all_registered_actions) - loaded_actions_count
|
||||
if disabled_actions_count > 0:
|
||||
logger.info(f"{self.log_prefix} 用户禁用了 {disabled_actions_count} 个动作。")
|
||||
|
||||
logger.debug(f"{self.log_prefix}开始完整动作修改流程")
|
||||
|
||||
@@ -75,7 +85,6 @@ class ActionModifier:
|
||||
removals_s2: list[tuple[str, str]] = []
|
||||
removals_s3: list[tuple[str, str]] = []
|
||||
|
||||
self.action_manager.restore_actions()
|
||||
all_actions = self.action_manager.get_using_actions()
|
||||
|
||||
# === 第0阶段:根据聊天类型过滤动作 ===
|
||||
@@ -126,15 +135,6 @@ class ActionModifier:
|
||||
if message_content:
|
||||
chat_content = chat_content + "\n" + f"现在,最新的消息是:{message_content}"
|
||||
|
||||
# === 第一阶段:去除用户自行禁用的 ===
|
||||
disabled_actions = global_announcement_manager.get_disabled_chat_actions(self.chat_id)
|
||||
if disabled_actions:
|
||||
for disabled_action_name in disabled_actions:
|
||||
if disabled_action_name in all_actions:
|
||||
removals_s1.append((disabled_action_name, "用户自行禁用"))
|
||||
self.action_manager.remove_action_from_using(disabled_action_name)
|
||||
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
|
||||
|
||||
# === 第二阶段:检查动作的关联类型 ===
|
||||
if not self.chat_stream:
|
||||
logger.error(f"{self.log_prefix} chat_stream 未初始化,无法执行第二阶段")
|
||||
|
||||
@@ -30,7 +30,7 @@ def get_tool_instance(tool_name: str, chat_stream: Any = None) -> BaseTool | Non
|
||||
return tool_class(plugin_config, chat_stream) if tool_class else None
|
||||
|
||||
|
||||
def get_llm_available_tool_definitions() -> list[dict[str, Any]]:
|
||||
def get_llm_available_tool_definitions(stream_id : str | None) -> list[dict[str, Any]]:
|
||||
"""获取LLM可用的工具定义列表(包括 MCP 工具)
|
||||
|
||||
Returns:
|
||||
@@ -38,7 +38,7 @@ def get_llm_available_tool_definitions() -> list[dict[str, Any]]:
|
||||
"""
|
||||
from src.plugin_system.core import component_registry
|
||||
|
||||
llm_available_tools = component_registry.get_llm_available_tools()
|
||||
llm_available_tools = component_registry.get_llm_available_tools(stream_id)
|
||||
tool_definitions = []
|
||||
|
||||
# 获取常规工具定义
|
||||
|
||||
@@ -857,6 +857,23 @@ class ComponentRegistry:
|
||||
info = self.get_component_info(command_name, ComponentType.PLUS_COMMAND)
|
||||
return info if isinstance(info, PlusCommandInfo) else None
|
||||
|
||||
def get_available_plus_commands_info(self, stream_id: str | None = None) -> dict[str, PlusCommandInfo]:
|
||||
"""获取在指定上下文中所有可用的PlusCommand信息
|
||||
|
||||
Args:
|
||||
stream_id: 可选的流ID,用于检查局部组件状态
|
||||
|
||||
Returns:
|
||||
一个字典,键是命令名,值是 PlusCommandInfo 对象
|
||||
"""
|
||||
all_plus_commands = self.get_components_by_type(ComponentType.PLUS_COMMAND)
|
||||
available_commands = {
|
||||
name: info
|
||||
for name, info in all_plus_commands.items()
|
||||
if self.is_component_available(name, ComponentType.PLUS_COMMAND, stream_id)
|
||||
}
|
||||
return cast(dict[str, PlusCommandInfo], available_commands)
|
||||
|
||||
# === EventHandler 特定查询方法 ===
|
||||
|
||||
def get_event_handler_registry(self) -> dict[str, type[BaseEventHandler]]:
|
||||
|
||||
@@ -217,12 +217,11 @@ class ToolExecutor:
|
||||
return tool_results, [], ""
|
||||
|
||||
def _get_tool_definitions(self) -> list[dict[str, Any]]:
|
||||
all_tools = get_llm_available_tool_definitions()
|
||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||
all_tools = get_llm_available_tool_definitions(self.chat_id)
|
||||
|
||||
# 获取基础工具定义(包括二步工具的第一步)
|
||||
tool_definitions = [
|
||||
definition for definition in all_tools if definition.get("function", {}).get("name") not in user_disabled_tools
|
||||
definition for definition in all_tools if definition.get("function", {}).get("name")
|
||||
]
|
||||
|
||||
# 检查是否有待处理的二步工具第二步调用
|
||||
|
||||
Reference in New Issue
Block a user