ruff
This commit is contained in:
@@ -8,6 +8,6 @@ from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
__all__ = [
|
||||
'plugin_manager',
|
||||
'component_registry',
|
||||
]
|
||||
"plugin_manager",
|
||||
"component_registry",
|
||||
]
|
||||
|
||||
@@ -1,149 +1,152 @@
|
||||
from typing import Dict, List, Type, Optional, Any, Pattern
|
||||
from abc import ABC
|
||||
import re
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.plugin_system.base.component_types import (
|
||||
ComponentInfo, ActionInfo, CommandInfo, PluginInfo,
|
||||
ComponentType, ActionActivationType, ChatMode
|
||||
ComponentInfo,
|
||||
ActionInfo,
|
||||
CommandInfo,
|
||||
PluginInfo,
|
||||
ComponentType,
|
||||
)
|
||||
|
||||
logger = get_logger("component_registry")
|
||||
|
||||
|
||||
class ComponentRegistry:
|
||||
"""统一的组件注册中心
|
||||
|
||||
|
||||
负责管理所有插件组件的注册、查询和生命周期管理
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
# 组件注册表
|
||||
self._components: Dict[str, ComponentInfo] = {} # 组件名 -> 组件信息
|
||||
self._components: Dict[str, ComponentInfo] = {} # 组件名 -> 组件信息
|
||||
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {
|
||||
ComponentType.ACTION: {},
|
||||
ComponentType.COMMAND: {},
|
||||
}
|
||||
self._component_classes: Dict[str, Type] = {} # 组件名 -> 组件类
|
||||
|
||||
self._component_classes: Dict[str, Type] = {} # 组件名 -> 组件类
|
||||
|
||||
# 插件注册表
|
||||
self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息
|
||||
|
||||
self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息
|
||||
|
||||
# Action特定注册表
|
||||
self._action_registry: Dict[str, Type] = {} # action名 -> action类
|
||||
self._default_actions: Dict[str, str] = {} # 启用的action名 -> 描述
|
||||
|
||||
# Command特定注册表
|
||||
self._command_registry: Dict[str, Type] = {} # command名 -> command类
|
||||
self._command_patterns: Dict[Pattern, Type] = {} # 编译后的正则 -> command类
|
||||
|
||||
self._action_registry: Dict[str, Type] = {} # action名 -> action类
|
||||
self._default_actions: Dict[str, str] = {} # 启用的action名 -> 描述
|
||||
|
||||
# Command特定注册表
|
||||
self._command_registry: Dict[str, Type] = {} # command名 -> command类
|
||||
self._command_patterns: Dict[Pattern, Type] = {} # 编译后的正则 -> command类
|
||||
|
||||
logger.info("组件注册中心初始化完成")
|
||||
|
||||
|
||||
# === 通用组件注册方法 ===
|
||||
|
||||
|
||||
def register_component(self, component_info: ComponentInfo, component_class: Type) -> bool:
|
||||
"""注册组件
|
||||
|
||||
|
||||
Args:
|
||||
component_info: 组件信息
|
||||
component_class: 组件类
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否注册成功
|
||||
"""
|
||||
component_name = component_info.name
|
||||
component_type = component_info.component_type
|
||||
|
||||
|
||||
if component_name in self._components:
|
||||
logger.warning(f"组件 {component_name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
|
||||
# 注册到通用注册表
|
||||
self._components[component_name] = component_info
|
||||
self._components_by_type[component_type][component_name] = component_info
|
||||
self._component_classes[component_name] = component_class
|
||||
|
||||
|
||||
# 根据组件类型进行特定注册
|
||||
if component_type == ComponentType.ACTION:
|
||||
self._register_action_component(component_info, component_class)
|
||||
elif component_type == ComponentType.COMMAND:
|
||||
self._register_command_component(component_info, component_class)
|
||||
|
||||
|
||||
logger.info(f"已注册{component_type.value}组件: {component_name} ({component_class.__name__})")
|
||||
return True
|
||||
|
||||
|
||||
def _register_action_component(self, action_info: ActionInfo, action_class: Type):
|
||||
"""注册Action组件到Action特定注册表"""
|
||||
action_name = action_info.name
|
||||
self._action_registry[action_name] = action_class
|
||||
|
||||
|
||||
# 如果启用,添加到默认动作集
|
||||
if action_info.enabled:
|
||||
self._default_actions[action_name] = action_info.description
|
||||
|
||||
|
||||
def _register_command_component(self, command_info: CommandInfo, command_class: Type):
|
||||
"""注册Command组件到Command特定注册表"""
|
||||
command_name = command_info.name
|
||||
self._command_registry[command_name] = command_class
|
||||
|
||||
|
||||
# 编译正则表达式并注册
|
||||
if command_info.command_pattern:
|
||||
pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL)
|
||||
self._command_patterns[pattern] = command_class
|
||||
|
||||
|
||||
# === 组件查询方法 ===
|
||||
|
||||
|
||||
def get_component_info(self, component_name: str) -> Optional[ComponentInfo]:
|
||||
"""获取组件信息"""
|
||||
return self._components.get(component_name)
|
||||
|
||||
|
||||
def get_component_class(self, component_name: str) -> Optional[Type]:
|
||||
"""获取组件类"""
|
||||
return self._component_classes.get(component_name)
|
||||
|
||||
|
||||
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
|
||||
"""获取指定类型的所有组件"""
|
||||
return self._components_by_type.get(component_type, {}).copy()
|
||||
|
||||
|
||||
def get_enabled_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
|
||||
"""获取指定类型的所有启用组件"""
|
||||
components = self.get_components_by_type(component_type)
|
||||
return {name: info for name, info in components.items() if info.enabled}
|
||||
|
||||
|
||||
# === Action特定查询方法 ===
|
||||
|
||||
|
||||
def get_action_registry(self) -> Dict[str, Type]:
|
||||
"""获取Action注册表(用于兼容现有系统)"""
|
||||
return self._action_registry.copy()
|
||||
|
||||
|
||||
def get_default_actions(self) -> Dict[str, str]:
|
||||
"""获取默认启用的Action列表(用于兼容现有系统)"""
|
||||
return self._default_actions.copy()
|
||||
|
||||
|
||||
def get_action_info(self, action_name: str) -> Optional[ActionInfo]:
|
||||
"""获取Action信息"""
|
||||
info = self.get_component_info(action_name)
|
||||
return info if isinstance(info, ActionInfo) else None
|
||||
|
||||
|
||||
# === Command特定查询方法 ===
|
||||
|
||||
|
||||
def get_command_registry(self) -> Dict[str, Type]:
|
||||
"""获取Command注册表(用于兼容现有系统)"""
|
||||
return self._command_registry.copy()
|
||||
|
||||
|
||||
def get_command_patterns(self) -> Dict[Pattern, Type]:
|
||||
"""获取Command模式注册表(用于兼容现有系统)"""
|
||||
return self._command_patterns.copy()
|
||||
|
||||
|
||||
def get_command_info(self, command_name: str) -> Optional[CommandInfo]:
|
||||
"""获取Command信息"""
|
||||
info = self.get_component_info(command_name)
|
||||
return info if isinstance(info, CommandInfo) else None
|
||||
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[tuple[Type, dict]]:
|
||||
"""根据文本查找匹配的命令
|
||||
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[tuple[Type, dict]]: (命令类, 匹配的命名组) 或 None
|
||||
"""
|
||||
@@ -156,54 +159,54 @@ class ComponentRegistry:
|
||||
if cls == command_class:
|
||||
command_name = name
|
||||
break
|
||||
|
||||
|
||||
# 检查命令是否启用
|
||||
if command_name:
|
||||
command_info = self.get_command_info(command_name)
|
||||
if command_info and command_info.enabled:
|
||||
return command_class, match.groupdict()
|
||||
return None
|
||||
|
||||
|
||||
# === 插件管理方法 ===
|
||||
|
||||
|
||||
def register_plugin(self, plugin_info: PluginInfo) -> bool:
|
||||
"""注册插件
|
||||
|
||||
|
||||
Args:
|
||||
plugin_info: 插件信息
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否注册成功
|
||||
"""
|
||||
plugin_name = plugin_info.name
|
||||
|
||||
|
||||
if plugin_name in self._plugins:
|
||||
logger.warning(f"插件 {plugin_name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
|
||||
self._plugins[plugin_name] = plugin_info
|
||||
logger.info(f"已注册插件: {plugin_name} (组件数量: {len(plugin_info.components)})")
|
||||
return True
|
||||
|
||||
|
||||
def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]:
|
||||
"""获取插件信息"""
|
||||
return self._plugins.get(plugin_name)
|
||||
|
||||
|
||||
def get_all_plugins(self) -> Dict[str, PluginInfo]:
|
||||
"""获取所有插件"""
|
||||
return self._plugins.copy()
|
||||
|
||||
|
||||
def get_enabled_plugins(self) -> Dict[str, PluginInfo]:
|
||||
"""获取所有启用的插件"""
|
||||
return {name: info for name, info in self._plugins.items() if info.enabled}
|
||||
|
||||
|
||||
def get_plugin_components(self, plugin_name: str) -> List[ComponentInfo]:
|
||||
"""获取插件的所有组件"""
|
||||
plugin_info = self.get_plugin_info(plugin_name)
|
||||
return plugin_info.components if plugin_info else []
|
||||
|
||||
|
||||
# === 状态管理方法 ===
|
||||
|
||||
|
||||
def enable_component(self, component_name: str) -> bool:
|
||||
"""启用组件"""
|
||||
if component_name in self._components:
|
||||
@@ -215,7 +218,7 @@ class ComponentRegistry:
|
||||
logger.info(f"已启用组件: {component_name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def disable_component(self, component_name: str) -> bool:
|
||||
"""禁用组件"""
|
||||
if component_name in self._components:
|
||||
@@ -226,15 +229,14 @@ class ComponentRegistry:
|
||||
logger.info(f"已禁用组件: {component_name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_registry_stats(self) -> Dict[str, Any]:
|
||||
"""获取注册中心统计信息"""
|
||||
return {
|
||||
"total_components": len(self._components),
|
||||
"total_plugins": len(self._plugins),
|
||||
"components_by_type": {
|
||||
component_type.value: len(components)
|
||||
for component_type, components in self._components_by_type.items()
|
||||
component_type.value: len(components) for component_type, components in self._components_by_type.items()
|
||||
},
|
||||
"enabled_components": len([c for c in self._components.values() if c.enabled]),
|
||||
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
|
||||
@@ -242,4 +244,4 @@ class ComponentRegistry:
|
||||
|
||||
|
||||
# 全局组件注册中心实例
|
||||
component_registry = ComponentRegistry()
|
||||
component_registry = ComponentRegistry()
|
||||
|
||||
@@ -9,19 +9,20 @@ from src.plugin_system.base.component_types import PluginInfo, ComponentType
|
||||
|
||||
logger = get_logger("plugin_manager")
|
||||
|
||||
|
||||
class PluginManager:
|
||||
"""插件管理器
|
||||
|
||||
|
||||
负责加载、初始化和管理所有插件及其组件
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.plugin_directories: List[str] = []
|
||||
self.loaded_plugins: Dict[str, Any] = {}
|
||||
self.failed_plugins: Dict[str, str] = {}
|
||||
|
||||
|
||||
logger.info("插件管理器初始化完成")
|
||||
|
||||
|
||||
def add_plugin_directory(self, directory: str):
|
||||
"""添加插件目录"""
|
||||
if os.path.exists(directory):
|
||||
@@ -29,141 +30,142 @@ class PluginManager:
|
||||
logger.info(f"已添加插件目录: {directory}")
|
||||
else:
|
||||
logger.warning(f"插件目录不存在: {directory}")
|
||||
|
||||
|
||||
def load_all_plugins(self) -> tuple[int, int]:
|
||||
"""加载所有插件目录中的插件
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: (插件数量, 组件数量)
|
||||
"""
|
||||
logger.info("开始加载所有插件...")
|
||||
|
||||
|
||||
# 第一阶段:加载所有插件模块(注册插件类)
|
||||
total_loaded_modules = 0
|
||||
total_failed_modules = 0
|
||||
|
||||
|
||||
for directory in self.plugin_directories:
|
||||
loaded, failed = self._load_plugin_modules_from_directory(directory)
|
||||
total_loaded_modules += loaded
|
||||
total_failed_modules += failed
|
||||
|
||||
|
||||
logger.info(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}")
|
||||
|
||||
|
||||
# 第二阶段:实例化所有已注册的插件类
|
||||
from src.plugin_system.base.base_plugin import get_registered_plugin_classes, instantiate_and_register_plugin
|
||||
|
||||
|
||||
plugin_classes = get_registered_plugin_classes()
|
||||
total_registered = 0
|
||||
total_failed_registration = 0
|
||||
|
||||
|
||||
for plugin_name, plugin_class in plugin_classes.items():
|
||||
# 尝试找到插件对应的目录
|
||||
plugin_dir = self._find_plugin_directory(plugin_class)
|
||||
|
||||
|
||||
if instantiate_and_register_plugin(plugin_class, plugin_dir):
|
||||
total_registered += 1
|
||||
self.loaded_plugins[plugin_name] = plugin_class
|
||||
else:
|
||||
total_failed_registration += 1
|
||||
self.failed_plugins[plugin_name] = "插件注册失败"
|
||||
|
||||
|
||||
logger.info(f"插件注册完成 - 成功: {total_registered}, 失败: {total_failed_registration}")
|
||||
|
||||
|
||||
# 获取组件统计信息
|
||||
stats = component_registry.get_registry_stats()
|
||||
logger.info(f"组件注册统计: {stats}")
|
||||
|
||||
|
||||
# 返回插件数量和组件数量
|
||||
return total_registered, stats.get('total_components', 0)
|
||||
|
||||
return total_registered, stats.get("total_components", 0)
|
||||
|
||||
def _find_plugin_directory(self, plugin_class) -> Optional[str]:
|
||||
"""查找插件类对应的目录路径"""
|
||||
try:
|
||||
import inspect
|
||||
|
||||
module = inspect.getmodule(plugin_class)
|
||||
if module and hasattr(module, '__file__') and module.__file__:
|
||||
if module and hasattr(module, "__file__") and module.__file__:
|
||||
return os.path.dirname(module.__file__)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]:
|
||||
"""从指定目录加载插件模块"""
|
||||
loaded_count = 0
|
||||
failed_count = 0
|
||||
|
||||
|
||||
if not os.path.exists(directory):
|
||||
logger.warning(f"插件目录不存在: {directory}")
|
||||
return loaded_count, failed_count
|
||||
|
||||
|
||||
logger.info(f"正在扫描插件目录: {directory}")
|
||||
|
||||
|
||||
# 遍历目录中的所有Python文件和包
|
||||
for item in os.listdir(directory):
|
||||
item_path = os.path.join(directory, item)
|
||||
|
||||
if os.path.isfile(item_path) and item.endswith('.py') and item != '__init__.py':
|
||||
|
||||
if os.path.isfile(item_path) and item.endswith(".py") and item != "__init__.py":
|
||||
# 单文件插件
|
||||
if self._load_plugin_module_file(item_path):
|
||||
loaded_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
elif os.path.isdir(item_path) and not item.startswith('.') and not item.startswith('__'):
|
||||
|
||||
elif os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"):
|
||||
# 插件包
|
||||
plugin_file = os.path.join(item_path, 'plugin.py')
|
||||
plugin_file = os.path.join(item_path, "plugin.py")
|
||||
if os.path.exists(plugin_file):
|
||||
if self._load_plugin_module_file(plugin_file):
|
||||
loaded_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
|
||||
return loaded_count, failed_count
|
||||
|
||||
|
||||
def _load_plugin_module_file(self, plugin_file: str) -> bool:
|
||||
"""加载单个插件模块文件"""
|
||||
plugin_name = None
|
||||
|
||||
|
||||
# 生成模块名
|
||||
plugin_path = Path(plugin_file)
|
||||
if plugin_path.parent.name != 'plugins':
|
||||
if plugin_path.parent.name != "plugins":
|
||||
# 插件包格式:parent_dir.plugin
|
||||
module_name = f"plugins.{plugin_path.parent.name}.plugin"
|
||||
else:
|
||||
# 单文件格式:plugins.filename
|
||||
module_name = f"plugins.{plugin_path.stem}"
|
||||
|
||||
|
||||
try:
|
||||
# 动态导入插件模块
|
||||
spec = importlib.util.spec_from_file_location(module_name, plugin_file)
|
||||
if spec is None or spec.loader is None:
|
||||
logger.error(f"无法创建模块规范: {plugin_file}")
|
||||
return False
|
||||
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
|
||||
# 模块加载成功,插件类会自动通过装饰器注册
|
||||
plugin_name = plugin_path.parent.name if plugin_path.parent.name != 'plugins' else plugin_path.stem
|
||||
|
||||
plugin_name = plugin_path.parent.name if plugin_path.parent.name != "plugins" else plugin_path.stem
|
||||
|
||||
logger.debug(f"插件模块加载成功: {plugin_file}")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"加载插件模块 {plugin_file} 失败: {e}"
|
||||
logger.error(error_msg)
|
||||
if plugin_name:
|
||||
self.failed_plugins[plugin_name] = error_msg
|
||||
return False
|
||||
|
||||
|
||||
def get_loaded_plugins(self) -> List[PluginInfo]:
|
||||
"""获取所有已加载的插件信息"""
|
||||
return list(component_registry.get_all_plugins().values())
|
||||
|
||||
|
||||
def get_enabled_plugins(self) -> List[PluginInfo]:
|
||||
"""获取所有启用的插件信息"""
|
||||
return list(component_registry.get_enabled_plugins().values())
|
||||
|
||||
|
||||
def enable_plugin(self, plugin_name: str) -> bool:
|
||||
"""启用插件"""
|
||||
plugin_info = component_registry.get_plugin_info(plugin_name)
|
||||
@@ -175,7 +177,7 @@ class PluginManager:
|
||||
logger.info(f"已启用插件: {plugin_name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def disable_plugin(self, plugin_name: str) -> bool:
|
||||
"""禁用插件"""
|
||||
plugin_info = component_registry.get_plugin_info(plugin_name)
|
||||
@@ -187,15 +189,15 @@ class PluginManager:
|
||||
logger.info(f"已禁用插件: {plugin_name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_plugin_stats(self) -> Dict[str, Any]:
|
||||
"""获取插件统计信息"""
|
||||
all_plugins = component_registry.get_all_plugins()
|
||||
enabled_plugins = component_registry.get_enabled_plugins()
|
||||
|
||||
|
||||
action_components = component_registry.get_components_by_type(ComponentType.ACTION)
|
||||
command_components = component_registry.get_components_by_type(ComponentType.COMMAND)
|
||||
|
||||
|
||||
return {
|
||||
"total_plugins": len(all_plugins),
|
||||
"enabled_plugins": len(enabled_plugins),
|
||||
@@ -204,9 +206,9 @@ class PluginManager:
|
||||
"action_components": len(action_components),
|
||||
"command_components": len(command_components),
|
||||
"loaded_plugin_files": len(self.loaded_plugins),
|
||||
"failed_plugin_details": self.failed_plugins.copy()
|
||||
"failed_plugin_details": self.failed_plugins.copy(),
|
||||
}
|
||||
|
||||
|
||||
def reload_plugin(self, plugin_name: str) -> bool:
|
||||
"""重新加载插件(高级功能,需要谨慎使用)"""
|
||||
# TODO: 实现插件热重载功能
|
||||
@@ -219,5 +221,5 @@ plugin_manager = PluginManager()
|
||||
|
||||
# 默认插件目录
|
||||
plugin_manager.add_plugin_directory("src/plugins/built_in")
|
||||
plugin_manager.add_plugin_directory("src/plugins/examples")
|
||||
plugin_manager.add_plugin_directory("plugins") # 用户插件目录
|
||||
plugin_manager.add_plugin_directory("src/plugins/examples")
|
||||
plugin_manager.add_plugin_directory("plugins") # 用户插件目录
|
||||
|
||||
Reference in New Issue
Block a user