增加了卸载和重载插件功能

This commit is contained in:
UnCLAS-Prommer
2025-07-22 18:52:11 +08:00
parent 22c7f667e9
commit 76025032a9
9 changed files with 258 additions and 271 deletions

View File

@@ -27,7 +27,7 @@ class ComponentRegistry:
def __init__(self):
# 组件注册表
self._components: Dict[str, ComponentInfo] = {} # 命名空间式组件名 -> 组件信息
# 类型 -> 命名空间式名称 -> 组件信息
# 类型 -> 组件原名称 -> 组件信息
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
# 命名空间式组件名 -> 组件类
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {}
@@ -160,7 +160,9 @@ class ComponentRegistry:
if pattern not in self._command_patterns:
self._command_patterns[pattern] = command_name
else:
logger.warning(f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令")
logger.warning(
f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令"
)
return True
@@ -176,6 +178,10 @@ class ComponentRegistry:
self._event_handler_registry[handler_name] = handler_class
if not handler_info.enabled:
logger.warning(f"EventHandler组件 {handler_name} 未启用")
return True # 未启用,但是也是注册成功
from .events_manager import events_manager # 延迟导入防止循环导入问题
if events_manager.register_event_subscriber(handler_info, handler_class):
@@ -185,6 +191,33 @@ class ComponentRegistry:
logger.error(f"注册事件处理器 {handler_name} 失败")
return False
# === 组件移除相关 ===
async def remove_component(self, component_name: str, component_type: ComponentType):
target_component_class = self.get_component_class(component_name, component_type)
if not target_component_class:
logger.warning(f"组件 {component_name} 未注册,无法移除")
return
match component_type:
case ComponentType.ACTION:
self._action_registry.pop(component_name, None)
self._default_actions.pop(component_name, None)
case ComponentType.COMMAND:
self._command_registry.pop(component_name, None)
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
for key in keys_to_remove:
self._command_patterns.pop(key, None)
case ComponentType.EVENT_HANDLER:
from .events_manager import events_manager # 延迟导入防止循环导入问题
self._event_handler_registry.pop(component_name, None)
self._enabled_event_handlers.pop(component_name, None)
await events_manager.unregister_event_subscriber(component_name)
self._components.pop(component_name, None)
self._components_by_type[component_type].pop(component_name, None)
self._components_classes.pop(component_name, None)
logger.info(f"组件 {component_name} 已移除")
# === 组件查询方法 ===
def get_component_info(
self, component_name: str, component_type: Optional[ComponentType] = None
@@ -287,7 +320,7 @@ class ComponentRegistry:
# === Action特定查询方法 ===
def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
"""获取Action注册表(用于兼容现有系统)"""
"""获取Action注册表"""
return self._action_registry.copy()
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:

View File

@@ -28,18 +28,16 @@ class EventsManager:
bool: 是否注册成功
"""
handler_name = handler_info.name
plugin_name = getattr(handler_info, "plugin_name", "unknown")
namespace_name = f"{plugin_name}.{handler_name}"
if namespace_name in self._handler_mapping:
logger.warning(f"事件处理器 {namespace_name} 已存在,跳过注册")
if handler_name in self._handler_mapping:
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
return False
if not issubclass(handler_class, BaseEventHandler):
logger.error(f"{handler_class.__name__} 不是 BaseEventHandler 的子类")
return False
self._handler_mapping[namespace_name] = handler_class
self._handler_mapping[handler_name] = handler_class
return self._insert_event_handler(handler_class, handler_info)
async def handle_mai_events(
@@ -71,7 +69,7 @@ class EventsManager:
try:
handler_task = asyncio.create_task(handler.execute(transformed_message))
handler_task.add_done_callback(self._task_done_callback)
handler_task.set_name(f"EventHandler-{handler.handler_name}-{event_type.name}")
handler_task.set_name(f"{handler.plugin_name}-{handler.handler_name}")
self._handler_tasks[handler.handler_name].append(handler_task)
except Exception as e:
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}")
@@ -91,7 +89,7 @@ class EventsManager:
return True
def _remove_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
def _remove_event_handler_instance(self, handler_class: Type[BaseEventHandler]) -> bool:
"""从事件类型列表中移除事件处理器"""
display_handler_name = handler_class.handler_name or handler_class.__name__
if handler_class.event_type == EventType.UNKNOWN:
@@ -190,5 +188,20 @@ class EventsManager:
finally:
del self._handler_tasks[handler_name]
async def unregister_event_subscriber(self, handler_name: str) -> bool:
"""取消注册事件处理器"""
if handler_name not in self._handler_mapping:
logger.warning(f"事件处理器 {handler_name} 不存在,无法取消注册")
return False
await self.cancel_handler_tasks(handler_name)
handler_class = self._handler_mapping.pop(handler_name)
if not self._remove_event_handler_instance(handler_class):
return False
logger.info(f"事件处理器 {handler_name} 已成功取消注册")
return True
events_manager = EventsManager()

View File

@@ -1,5 +1,4 @@
import os
import inspect
import traceback
from typing import Dict, List, Optional, Tuple, Type, Any
@@ -8,11 +7,11 @@ from pathlib import Path
from src.common.logger import get_logger
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.dependency_manager import dependency_manager
from src.plugin_system.base.plugin_base import PluginBase
from src.plugin_system.base.component_types import ComponentType, PluginInfo, PythonDependency
from src.plugin_system.utils.manifest_utils import VersionComparator
from .component_registry import component_registry
from .dependency_manager import dependency_manager
logger = get_logger("plugin_manager")
@@ -36,19 +35,7 @@ class PluginManager:
self._ensure_plugin_directories()
logger.info("插件管理器初始化完成")
def _ensure_plugin_directories(self) -> None:
"""确保所有插件根目录存在,如果不存在则创建"""
default_directories = ["src/plugins/built_in", "plugins"]
for directory in default_directories:
if not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
logger.info(f"创建插件根目录: {directory}")
if directory not in self.plugin_directories:
self.plugin_directories.append(directory)
logger.debug(f"已添加插件根目录: {directory}")
else:
logger.warning(f"根目录不可重复加载: {directory}")
# === 插件目录管理 ===
def add_plugin_directory(self, directory: str) -> bool:
"""添加插件目录"""
@@ -63,6 +50,8 @@ class PluginManager:
logger.warning(f"插件目录不存在: {directory}")
return False
# === 插件加载管理 ===
def load_all_plugins(self) -> Tuple[int, int]:
"""加载所有插件
@@ -86,7 +75,7 @@ class PluginManager:
total_failed_registration = 0
for plugin_name in self.plugin_classes.keys():
load_status, count = self.load_registered_plugin_classes(plugin_name)
load_status, count = self._load_registered_plugin_classes(plugin_name)
if load_status:
total_registered += 1
else:
@@ -96,90 +85,32 @@ class PluginManager:
return total_registered, total_failed_registration
def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
# sourcery skip: extract-duplicate-method, extract-method
async def remove_registered_plugin(self, plugin_name: str) -> None:
"""
加载已经注册的插件类
禁用插件模块
"""
plugin_class = self.plugin_classes.get(plugin_name)
if not plugin_class:
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
return False, 1
try:
# 使用记录的插件目录路径
plugin_dir = self.plugin_paths.get(plugin_name)
if not plugin_name:
raise ValueError("插件名称不能为空")
if plugin_name not in self.loaded_plugins:
logger.warning(f"插件 {plugin_name} 未加载")
return
plugin_instance = self.loaded_plugins[plugin_name]
plugin_info = plugin_instance.plugin_info
for component in plugin_info.components:
await component_registry.remove_component(component.name, component.component_type)
del self.loaded_plugins[plugin_name]
# 如果没有记录,直接返回失败
if not plugin_dir:
return False, 1
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件可能因为缺少manifest而失败
if not plugin_instance:
logger.error(f"插件 {plugin_name} 实例化失败")
return False, 1
# 检查插件是否启用
if not plugin_instance.enable_plugin:
logger.info(f"插件 {plugin_name} 已禁用,跳过加载")
return False, 0
# 检查版本兼容性
is_compatible, compatibility_error = self._check_plugin_version_compatibility(
plugin_name, plugin_instance.manifest_data
)
if not is_compatible:
self.failed_plugins[plugin_name] = compatibility_error
logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}")
return False, 1
if plugin_instance.register_plugin():
self.loaded_plugins[plugin_name] = plugin_instance
self._show_plugin_components(plugin_name)
return True, 1
else:
self.failed_plugins[plugin_name] = "插件注册失败"
logger.error(f"❌ 插件注册失败: {plugin_name}")
return False, 1
except FileNotFoundError as e:
# manifest文件缺失
error_msg = f"缺少manifest文件: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
except ValueError as e:
# manifest文件格式错误或验证失败
traceback.print_exc()
error_msg = f"manifest验证失败: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
except Exception as e:
# 其他错误
error_msg = f"未知错误: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
logger.debug("详细错误信息: ", exc_info=True)
return False, 1
def unload_registered_plugin_module(self, plugin_name: str) -> None:
"""
卸载插件模块
"""
pass
def reload_registered_plugin_module(self, plugin_name: str) -> None:
async def reload_registered_plugin_module(self, plugin_name: str) -> None:
"""
重载插件模块
"""
self.unload_registered_plugin_module(plugin_name)
self.load_registered_plugin_classes(plugin_name)
await self.remove_registered_plugin(plugin_name)
self._load_registered_plugin_classes(plugin_name)
def rescan_plugin_directory(self) -> None:
"""
重新扫描插件根目录
"""
# --------------------------------------- NEED REFACTORING ---------------------------------------
for directory in self.plugin_directories:
if os.path.exists(directory):
logger.debug(f"重新扫描插件根目录: {directory}")
@@ -195,30 +126,6 @@ class PluginManager:
"""获取所有启用的插件信息"""
return list(component_registry.get_enabled_plugins().values())
# def enable_plugin(self, plugin_name: str) -> bool:
# # -------------------------------- NEED REFACTORING --------------------------------
# """启用插件"""
# if plugin_info := component_registry.get_plugin_info(plugin_name):
# plugin_info.enabled = True
# # 启用插件的所有组件
# for component in plugin_info.components:
# component_registry.enable_component(component.name)
# logger.debug(f"已启用插件: {plugin_name}")
# return True
# return False
# def disable_plugin(self, plugin_name: str) -> bool:
# # -------------------------------- NEED REFACTORING --------------------------------
# """禁用插件"""
# if plugin_info := component_registry.get_plugin_info(plugin_name):
# plugin_info.enabled = False
# # 禁用插件的所有组件
# for component in plugin_info.components:
# component_registry.disable_component(component.name)
# logger.debug(f"已禁用插件: {plugin_name}")
# return True
# return False
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
"""获取插件实例
@@ -230,25 +137,6 @@ class PluginManager:
"""
return self.loaded_plugins.get(plugin_name)
def get_plugin_stats(self) -> Dict[str, Any]:
"""获取插件统计信息"""
all_plugins = component_registry.get_all_plugins()
enabled_plugins = component_registry.get_enabled_plugins()
action_components = component_registry.get_components_by_type(ComponentType.ACTION)
command_components = component_registry.get_components_by_type(ComponentType.COMMAND)
return {
"total_plugins": len(all_plugins),
"enabled_plugins": len(enabled_plugins),
"failed_plugins": len(self.failed_plugins),
"total_components": len(action_components) + len(command_components),
"action_components": len(action_components),
"command_components": len(command_components),
"loaded_plugin_files": len(self.loaded_plugins),
"failed_plugin_details": self.failed_plugins.copy(),
}
def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, Any]:
"""检查所有插件的Python依赖包
@@ -347,6 +235,24 @@ class PluginManager:
return dependency_manager.generate_requirements_file(all_dependencies, output_path)
# === 私有方法 ===
# == 目录管理 ==
def _ensure_plugin_directories(self) -> None:
"""确保所有插件根目录存在,如果不存在则创建"""
default_directories = ["src/plugins/built_in", "plugins"]
for directory in default_directories:
if not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
logger.info(f"创建插件根目录: {directory}")
if directory not in self.plugin_directories:
self.plugin_directories.append(directory)
logger.debug(f"已添加插件根目录: {directory}")
else:
logger.warning(f"根目录不可重复加载: {directory}")
# == 插件加载 ==
def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]:
"""从指定目录加载插件模块"""
loaded_count = 0
@@ -372,18 +278,6 @@ class PluginManager:
return loaded_count, failed_count
def _find_plugin_directory(self, plugin_class: Type[PluginBase]) -> Optional[str]:
"""查找插件类对应的目录路径"""
try:
# module = getmodule(plugin_class)
# if module and hasattr(module, "__file__") and module.__file__:
# return os.path.dirname(module.__file__)
file_path = inspect.getfile(plugin_class)
return os.path.dirname(file_path)
except Exception as e:
logger.debug(f"通过inspect获取插件目录失败: {e}")
return None
def _load_plugin_module_file(self, plugin_file: str) -> bool:
# sourcery skip: extract-method
"""加载单个插件模块文件
@@ -416,6 +310,74 @@ class PluginManager:
self.failed_plugins[module_name] = error_msg
return False
def _load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
# sourcery skip: extract-duplicate-method, extract-method
"""
加载已经注册的插件类
"""
plugin_class = self.plugin_classes.get(plugin_name)
if not plugin_class:
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
return False, 1
try:
# 使用记录的插件目录路径
plugin_dir = self.plugin_paths.get(plugin_name)
# 如果没有记录,直接返回失败
if not plugin_dir:
return False, 1
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件可能因为缺少manifest而失败
if not plugin_instance:
logger.error(f"插件 {plugin_name} 实例化失败")
return False, 1
# 检查插件是否启用
if not plugin_instance.enable_plugin:
logger.info(f"插件 {plugin_name} 已禁用,跳过加载")
return False, 0
# 检查版本兼容性
is_compatible, compatibility_error = self._check_plugin_version_compatibility(
plugin_name, plugin_instance.manifest_data
)
if not is_compatible:
self.failed_plugins[plugin_name] = compatibility_error
logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}")
return False, 1
if plugin_instance.register_plugin():
self.loaded_plugins[plugin_name] = plugin_instance
self._show_plugin_components(plugin_name)
return True, 1
else:
self.failed_plugins[plugin_name] = "插件注册失败"
logger.error(f"❌ 插件注册失败: {plugin_name}")
return False, 1
except FileNotFoundError as e:
# manifest文件缺失
error_msg = f"缺少manifest文件: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
except ValueError as e:
# manifest文件格式错误或验证失败
traceback.print_exc()
error_msg = f"manifest验证失败: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
except Exception as e:
# 其他错误
error_msg = f"未知错误: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
logger.debug("详细错误信息: ", exc_info=True)
return False, 1
# == 兼容性检查 ==
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
"""检查插件版本兼容性
@@ -451,6 +413,8 @@ class PluginManager:
logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}")
return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载
# == 显示统计与插件信息 ==
def _show_stats(self, total_registered: int, total_failed_registration: int):
# sourcery skip: low-code-quality
# 获取组件统计信息
@@ -493,9 +457,15 @@ class PluginManager:
# 组件列表
if plugin_info.components:
action_components = [c for c in plugin_info.components if c.component_type == ComponentType.ACTION]
command_components = [c for c in plugin_info.components if c.component_type == ComponentType.COMMAND]
event_handler_components = [c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER]
action_components = [
c for c in plugin_info.components if c.component_type == ComponentType.ACTION
]
command_components = [
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
]
event_handler_components = [
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
]
if action_components:
action_names = [c.name for c in action_components]
@@ -504,7 +474,7 @@ class PluginManager:
if command_components:
command_names = [c.name for c in command_components]
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
if event_handler_components:
event_handler_names = [c.name for c in event_handler_components]
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")