re-style: 格式化代码
This commit is contained in:
@@ -4,14 +4,14 @@
|
||||
提供插件的加载、注册和管理功能
|
||||
"""
|
||||
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
__all__ = [
|
||||
"plugin_manager",
|
||||
"component_registry",
|
||||
"event_manager",
|
||||
"global_announcement_manager",
|
||||
"plugin_manager",
|
||||
]
|
||||
|
||||
@@ -1,27 +1,26 @@
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
|
||||
from pathlib import Path
|
||||
from re import Pattern
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import (
|
||||
ComponentInfo,
|
||||
ActionInfo,
|
||||
ToolInfo,
|
||||
CommandInfo,
|
||||
PlusCommandInfo,
|
||||
EventHandlerInfo,
|
||||
ChatterInfo,
|
||||
PluginInfo,
|
||||
ComponentType,
|
||||
)
|
||||
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import (
|
||||
ActionInfo,
|
||||
ChatterInfo,
|
||||
CommandInfo,
|
||||
ComponentInfo,
|
||||
ComponentType,
|
||||
EventHandlerInfo,
|
||||
PluginInfo,
|
||||
PlusCommandInfo,
|
||||
ToolInfo,
|
||||
)
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
|
||||
logger = get_logger("component_registry")
|
||||
|
||||
@@ -34,46 +33,46 @@ class ComponentRegistry:
|
||||
|
||||
def __init__(self):
|
||||
# 命名空间式组件名构成法 f"{component_type}.{component_name}"
|
||||
self._components: Dict[str, "ComponentInfo"] = {}
|
||||
self._components: dict[str, "ComponentInfo"] = {}
|
||||
"""组件注册表 命名空间式组件名 -> 组件信息"""
|
||||
self._components_by_type: Dict["ComponentType", Dict[str, "ComponentInfo"]] = {
|
||||
self._components_by_type: dict["ComponentType", dict[str, "ComponentInfo"]] = {
|
||||
types: {} for types in ComponentType
|
||||
}
|
||||
"""类型 -> 组件原名称 -> 组件信息"""
|
||||
self._components_classes: Dict[
|
||||
str, Type[Union["BaseCommand", "BaseAction", "BaseTool", "BaseEventHandler", "PlusCommand", "BaseChatter"]]
|
||||
self._components_classes: dict[
|
||||
str, type["BaseCommand" | "BaseAction" | "BaseTool" | "BaseEventHandler" | "PlusCommand" | "BaseChatter"]
|
||||
] = {}
|
||||
"""命名空间式组件名 -> 组件类"""
|
||||
|
||||
# 插件注册表
|
||||
self._plugins: Dict[str, "PluginInfo"] = {}
|
||||
self._plugins: dict[str, "PluginInfo"] = {}
|
||||
"""插件名 -> 插件信息"""
|
||||
|
||||
# Action特定注册表
|
||||
self._action_registry: Dict[str, Type["BaseAction"]] = {}
|
||||
self._action_registry: dict[str, type["BaseAction"]] = {}
|
||||
"""Action注册表 action名 -> action类"""
|
||||
self._default_actions: Dict[str, "ActionInfo"] = {}
|
||||
self._default_actions: dict[str, "ActionInfo"] = {}
|
||||
"""默认动作集,即启用的Action集,用于重置ActionManager状态"""
|
||||
|
||||
# Command特定注册表
|
||||
self._command_registry: Dict[str, Type["BaseCommand"]] = {}
|
||||
self._command_registry: dict[str, type["BaseCommand"]] = {}
|
||||
"""Command类注册表 command名 -> command类"""
|
||||
self._command_patterns: Dict[Pattern, str] = {}
|
||||
self._command_patterns: dict[Pattern, str] = {}
|
||||
"""编译后的正则 -> command名"""
|
||||
|
||||
# 工具特定注册表
|
||||
self._tool_registry: Dict[str, Type["BaseTool"]] = {} # 工具名 -> 工具类
|
||||
self._llm_available_tools: Dict[str, Type["BaseTool"]] = {} # llm可用的工具名 -> 工具类
|
||||
self._tool_registry: dict[str, type["BaseTool"]] = {} # 工具名 -> 工具类
|
||||
self._llm_available_tools: dict[str, type["BaseTool"]] = {} # llm可用的工具名 -> 工具类
|
||||
|
||||
# EventHandler特定注册表
|
||||
self._event_handler_registry: Dict[str, Type["BaseEventHandler"]] = {}
|
||||
self._event_handler_registry: dict[str, type["BaseEventHandler"]] = {}
|
||||
"""event_handler名 -> event_handler类"""
|
||||
self._enabled_event_handlers: Dict[str, Type["BaseEventHandler"]] = {}
|
||||
self._enabled_event_handlers: dict[str, type["BaseEventHandler"]] = {}
|
||||
"""启用的事件处理器 event_handler名 -> event_handler类"""
|
||||
|
||||
self._chatter_registry: Dict[str, Type["BaseChatter"]] = {}
|
||||
self._chatter_registry: dict[str, type["BaseChatter"]] = {}
|
||||
"""chatter名 -> chatter类"""
|
||||
self._enabled_chatter_registry: Dict[str, Type["BaseChatter"]] = {}
|
||||
self._enabled_chatter_registry: dict[str, type["BaseChatter"]] = {}
|
||||
"""启用的chatter名 -> chatter类"""
|
||||
logger.info("组件注册中心初始化完成")
|
||||
|
||||
@@ -101,7 +100,7 @@ class ComponentRegistry:
|
||||
def register_component(
|
||||
self,
|
||||
component_info: ComponentInfo,
|
||||
component_class: Type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]],
|
||||
component_class: type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]],
|
||||
) -> bool:
|
||||
"""注册组件
|
||||
|
||||
@@ -174,7 +173,7 @@ class ComponentRegistry:
|
||||
)
|
||||
return True
|
||||
|
||||
def _register_action_component(self, action_info: "ActionInfo", action_class: Type["BaseAction"]) -> bool:
|
||||
def _register_action_component(self, action_info: "ActionInfo", action_class: type["BaseAction"]) -> bool:
|
||||
"""注册Action组件到Action特定注册表"""
|
||||
if not (action_name := action_info.name):
|
||||
logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
|
||||
@@ -194,7 +193,7 @@ class ComponentRegistry:
|
||||
|
||||
return True
|
||||
|
||||
def _register_command_component(self, command_info: "CommandInfo", command_class: Type["BaseCommand"]) -> bool:
|
||||
def _register_command_component(self, command_info: "CommandInfo", command_class: type["BaseCommand"]) -> bool:
|
||||
"""注册Command组件到Command特定注册表"""
|
||||
if not (command_name := command_info.name):
|
||||
logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
|
||||
@@ -221,7 +220,7 @@ class ComponentRegistry:
|
||||
return True
|
||||
|
||||
def _register_plus_command_component(
|
||||
self, plus_command_info: "PlusCommandInfo", plus_command_class: Type["PlusCommand"]
|
||||
self, plus_command_info: "PlusCommandInfo", plus_command_class: type["PlusCommand"]
|
||||
) -> bool:
|
||||
"""注册PlusCommand组件到特定注册表"""
|
||||
plus_command_name = plus_command_info.name
|
||||
@@ -235,7 +234,7 @@ class ComponentRegistry:
|
||||
|
||||
# 创建专门的PlusCommand注册表(如果还没有)
|
||||
if not hasattr(self, "_plus_command_registry"):
|
||||
self._plus_command_registry: Dict[str, Type["PlusCommand"]] = {}
|
||||
self._plus_command_registry: dict[str, type["PlusCommand"]] = {}
|
||||
|
||||
plus_command_class.plugin_name = plus_command_info.plugin_name
|
||||
# 设置插件配置
|
||||
@@ -245,7 +244,7 @@ class ComponentRegistry:
|
||||
logger.debug(f"已注册PlusCommand组件: {plus_command_name}")
|
||||
return True
|
||||
|
||||
def _register_tool_component(self, tool_info: "ToolInfo", tool_class: Type["BaseTool"]) -> bool:
|
||||
def _register_tool_component(self, tool_info: "ToolInfo", tool_class: type["BaseTool"]) -> bool:
|
||||
"""注册Tool组件到Tool特定注册表"""
|
||||
tool_name = tool_info.name
|
||||
|
||||
@@ -261,7 +260,7 @@ class ComponentRegistry:
|
||||
return True
|
||||
|
||||
def _register_event_handler_component(
|
||||
self, handler_info: "EventHandlerInfo", handler_class: Type["BaseEventHandler"]
|
||||
self, handler_info: "EventHandlerInfo", handler_class: type["BaseEventHandler"]
|
||||
) -> bool:
|
||||
if not (handler_name := handler_info.name):
|
||||
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
|
||||
@@ -287,7 +286,7 @@ class ComponentRegistry:
|
||||
handler_class, self.get_plugin_config(handler_info.plugin_name) or {}
|
||||
)
|
||||
|
||||
def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: Type["BaseChatter"]) -> bool:
|
||||
def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: type["BaseChatter"]) -> bool:
|
||||
"""注册Chatter组件到Chatter特定注册表"""
|
||||
chatter_name = chatter_info.name
|
||||
|
||||
@@ -532,7 +531,7 @@ class ComponentRegistry:
|
||||
self,
|
||||
component_name: str,
|
||||
component_type: Optional["ComponentType"] = None,
|
||||
) -> Optional[Union[Type["BaseCommand"], Type["BaseAction"], Type["BaseEventHandler"], Type["BaseTool"]]]:
|
||||
) -> type["BaseCommand"] | type["BaseAction"] | type["BaseEventHandler"] | type["BaseTool"] | None:
|
||||
"""获取组件类,支持自动命名空间解析
|
||||
|
||||
Args:
|
||||
@@ -574,18 +573,18 @@ class ComponentRegistry:
|
||||
# 4. 都没找到
|
||||
return None
|
||||
|
||||
def get_components_by_type(self, component_type: "ComponentType") -> Dict[str, "ComponentInfo"]:
|
||||
def get_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]:
|
||||
"""获取指定类型的所有组件"""
|
||||
return self._components_by_type.get(component_type, {}).copy()
|
||||
|
||||
def get_enabled_components_by_type(self, component_type: "ComponentType") -> Dict[str, "ComponentInfo"]:
|
||||
def get_enabled_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]:
|
||||
"""获取指定类型的所有启用组件"""
|
||||
components = self.get_components_by_type(component_type)
|
||||
return {name: info for name, info in components.items() if info.enabled}
|
||||
|
||||
# === Action特定查询方法 ===
|
||||
|
||||
def get_action_registry(self) -> Dict[str, Type["BaseAction"]]:
|
||||
def get_action_registry(self) -> dict[str, type["BaseAction"]]:
|
||||
"""获取Action注册表"""
|
||||
return self._action_registry.copy()
|
||||
|
||||
@@ -594,13 +593,13 @@ class ComponentRegistry:
|
||||
info = self.get_component_info(action_name, ComponentType.ACTION)
|
||||
return info if isinstance(info, ActionInfo) else None
|
||||
|
||||
def get_default_actions(self) -> Dict[str, ActionInfo]:
|
||||
def get_default_actions(self) -> dict[str, ActionInfo]:
|
||||
"""获取默认动作集"""
|
||||
return self._default_actions.copy()
|
||||
|
||||
# === Command特定查询方法 ===
|
||||
|
||||
def get_command_registry(self) -> Dict[str, Type["BaseCommand"]]:
|
||||
def get_command_registry(self) -> dict[str, type["BaseCommand"]]:
|
||||
"""获取Command注册表"""
|
||||
return self._command_registry.copy()
|
||||
|
||||
@@ -609,11 +608,11 @@ class ComponentRegistry:
|
||||
info = self.get_component_info(command_name, ComponentType.COMMAND)
|
||||
return info if isinstance(info, CommandInfo) else None
|
||||
|
||||
def get_command_patterns(self) -> Dict[Pattern, str]:
|
||||
def get_command_patterns(self) -> dict[Pattern, str]:
|
||||
"""获取Command模式注册表"""
|
||||
return self._command_patterns.copy()
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type["BaseCommand"], dict, "CommandInfo"]]:
|
||||
def find_command_by_text(self, text: str) -> tuple[type["BaseCommand"], dict, "CommandInfo"] | None:
|
||||
# sourcery skip: use-named-expression, use-next
|
||||
"""根据文本查找匹配的命令
|
||||
|
||||
@@ -640,11 +639,11 @@ class ComponentRegistry:
|
||||
return None
|
||||
|
||||
# === Tool 特定查询方法 ===
|
||||
def get_tool_registry(self) -> Dict[str, Type["BaseTool"]]:
|
||||
def get_tool_registry(self) -> dict[str, type["BaseTool"]]:
|
||||
"""获取Tool注册表"""
|
||||
return self._tool_registry.copy()
|
||||
|
||||
def get_llm_available_tools(self) -> Dict[str, Type["BaseTool"]]:
|
||||
def get_llm_available_tools(self) -> dict[str, type["BaseTool"]]:
|
||||
"""获取LLM可用的Tool列表"""
|
||||
return self._llm_available_tools.copy()
|
||||
|
||||
@@ -661,10 +660,10 @@ class ComponentRegistry:
|
||||
return info if isinstance(info, ToolInfo) else None
|
||||
|
||||
# === PlusCommand 特定查询方法 ===
|
||||
def get_plus_command_registry(self) -> Dict[str, Type["PlusCommand"]]:
|
||||
def get_plus_command_registry(self) -> dict[str, type["PlusCommand"]]:
|
||||
"""获取PlusCommand注册表"""
|
||||
if not hasattr(self, "_plus_command_registry"):
|
||||
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
|
||||
self._plus_command_registry: dict[str, type[PlusCommand]] = {}
|
||||
return self._plus_command_registry.copy()
|
||||
|
||||
def get_registered_plus_command_info(self, command_name: str) -> Optional["PlusCommandInfo"]:
|
||||
@@ -681,7 +680,7 @@ class ComponentRegistry:
|
||||
|
||||
# === EventHandler 特定查询方法 ===
|
||||
|
||||
def get_event_handler_registry(self) -> Dict[str, Type["BaseEventHandler"]]:
|
||||
def get_event_handler_registry(self) -> dict[str, type["BaseEventHandler"]]:
|
||||
"""获取事件处理器注册表"""
|
||||
return self._event_handler_registry.copy()
|
||||
|
||||
@@ -690,21 +689,21 @@ class ComponentRegistry:
|
||||
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
|
||||
return info if isinstance(info, EventHandlerInfo) else None
|
||||
|
||||
def get_enabled_event_handlers(self) -> Dict[str, Type["BaseEventHandler"]]:
|
||||
def get_enabled_event_handlers(self) -> dict[str, type["BaseEventHandler"]]:
|
||||
"""获取启用的事件处理器"""
|
||||
return self._enabled_event_handlers.copy()
|
||||
|
||||
# === Chatter 特定查询方法 ===
|
||||
def get_chatter_registry(self) -> Dict[str, Type["BaseChatter"]]:
|
||||
def get_chatter_registry(self) -> dict[str, type["BaseChatter"]]:
|
||||
"""获取Chatter注册表"""
|
||||
if not hasattr(self, "_chatter_registry"):
|
||||
self._chatter_registry: Dict[str, Type[BaseChatter]] = {}
|
||||
self._chatter_registry: dict[str, type[BaseChatter]] = {}
|
||||
return self._chatter_registry.copy()
|
||||
|
||||
def get_enabled_chatter_registry(self) -> Dict[str, Type["BaseChatter"]]:
|
||||
def get_enabled_chatter_registry(self) -> dict[str, type["BaseChatter"]]:
|
||||
"""获取启用的Chatter注册表"""
|
||||
if not hasattr(self, "_enabled_chatter_registry"):
|
||||
self._enabled_chatter_registry: Dict[str, Type[BaseChatter]] = {}
|
||||
self._enabled_chatter_registry: dict[str, type[BaseChatter]] = {}
|
||||
return self._enabled_chatter_registry.copy()
|
||||
|
||||
def get_registered_chatter_info(self, chatter_name: str) -> Optional["ChatterInfo"]:
|
||||
@@ -718,7 +717,7 @@ class ComponentRegistry:
|
||||
"""获取插件信息"""
|
||||
return self._plugins.get(plugin_name)
|
||||
|
||||
def get_all_plugins(self) -> Dict[str, "PluginInfo"]:
|
||||
def get_all_plugins(self) -> dict[str, "PluginInfo"]:
|
||||
"""获取所有插件"""
|
||||
return self._plugins.copy()
|
||||
|
||||
@@ -726,7 +725,7 @@ class ComponentRegistry:
|
||||
# """获取所有启用的插件"""
|
||||
# return {name: info for name, info in self._plugins.items() if info.enabled}
|
||||
|
||||
def get_plugin_components(self, plugin_name: str) -> List["ComponentInfo"]:
|
||||
def get_plugin_components(self, plugin_name: str) -> list["ComponentInfo"]:
|
||||
"""获取插件的所有组件"""
|
||||
plugin_info = self.get_plugin_info(plugin_name)
|
||||
return plugin_info.components if plugin_info else []
|
||||
@@ -753,7 +752,7 @@ class ComponentRegistry:
|
||||
|
||||
config_path = Path("config") / "plugins" / plugin_name / "config.toml"
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
logger.debug(f"从配置文件读取插件 {plugin_name} 的配置")
|
||||
return config_data
|
||||
@@ -762,7 +761,7 @@ class ComponentRegistry:
|
||||
|
||||
return {}
|
||||
|
||||
def get_registry_stats(self) -> Dict[str, Any]:
|
||||
def get_registry_stats(self) -> dict[str, Any]:
|
||||
"""获取注册中心统计信息"""
|
||||
action_components: int = 0
|
||||
command_components: int = 0
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
提供统一的事件注册、管理和触发接口
|
||||
"""
|
||||
|
||||
from typing import Dict, Type, List, Optional, Any, Union
|
||||
from threading import Lock
|
||||
from typing import Any, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BaseEventHandler
|
||||
@@ -37,17 +37,17 @@ class EventManager:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._events: Dict[str, BaseEvent] = {}
|
||||
self._event_handlers: Dict[str, Type[BaseEventHandler]] = {}
|
||||
self._pending_subscriptions: Dict[str, List[str]] = {} # 缓存失败的订阅
|
||||
self._events: dict[str, BaseEvent] = {}
|
||||
self._event_handlers: dict[str, type[BaseEventHandler]] = {}
|
||||
self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅
|
||||
self._initialized = True
|
||||
logger.info("EventManager 单例初始化完成")
|
||||
|
||||
def register_event(
|
||||
self,
|
||||
event_name: Union[EventType, str],
|
||||
allowed_subscribers: List[str] = None,
|
||||
allowed_triggers: List[str] = None,
|
||||
event_name: EventType | str,
|
||||
allowed_subscribers: list[str] = None,
|
||||
allowed_triggers: list[str] = None,
|
||||
) -> bool:
|
||||
"""注册一个新的事件
|
||||
|
||||
@@ -75,7 +75,7 @@ class EventManager:
|
||||
|
||||
return True
|
||||
|
||||
def get_event(self, event_name: Union[EventType, str]) -> Optional[BaseEvent]:
|
||||
def get_event(self, event_name: EventType | str) -> BaseEvent | None:
|
||||
"""获取指定事件实例
|
||||
|
||||
Args:
|
||||
@@ -86,7 +86,7 @@ class EventManager:
|
||||
"""
|
||||
return self._events.get(event_name)
|
||||
|
||||
def get_all_events(self) -> Dict[str, BaseEvent]:
|
||||
def get_all_events(self) -> dict[str, BaseEvent]:
|
||||
"""获取所有已注册的事件
|
||||
|
||||
Returns:
|
||||
@@ -94,7 +94,7 @@ class EventManager:
|
||||
"""
|
||||
return self._events.copy()
|
||||
|
||||
def get_enabled_events(self) -> Dict[str, BaseEvent]:
|
||||
def get_enabled_events(self) -> dict[str, BaseEvent]:
|
||||
"""获取所有已启用的事件
|
||||
|
||||
Returns:
|
||||
@@ -102,7 +102,7 @@ class EventManager:
|
||||
"""
|
||||
return {name: event for name, event in self._events.items() if event.enabled}
|
||||
|
||||
def get_disabled_events(self) -> Dict[str, BaseEvent]:
|
||||
def get_disabled_events(self) -> dict[str, BaseEvent]:
|
||||
"""获取所有已禁用的事件
|
||||
|
||||
Returns:
|
||||
@@ -110,7 +110,7 @@ class EventManager:
|
||||
"""
|
||||
return {name: event for name, event in self._events.items() if not event.enabled}
|
||||
|
||||
def enable_event(self, event_name: Union[EventType, str]) -> bool:
|
||||
def enable_event(self, event_name: EventType | str) -> bool:
|
||||
"""启用指定事件
|
||||
|
||||
Args:
|
||||
@@ -128,7 +128,7 @@ class EventManager:
|
||||
logger.info(f"事件 {event_name} 已启用")
|
||||
return True
|
||||
|
||||
def disable_event(self, event_name: Union[EventType, str]) -> bool:
|
||||
def disable_event(self, event_name: EventType | str) -> bool:
|
||||
"""禁用指定事件
|
||||
|
||||
Args:
|
||||
@@ -146,9 +146,7 @@ class EventManager:
|
||||
logger.info(f"事件 {event_name} 已禁用")
|
||||
return True
|
||||
|
||||
def register_event_handler(
|
||||
self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None
|
||||
) -> bool:
|
||||
def register_event_handler(self, handler_class: type[BaseEventHandler], plugin_config: dict | None = None) -> bool:
|
||||
"""注册事件处理器
|
||||
|
||||
Args:
|
||||
@@ -190,7 +188,7 @@ class EventManager:
|
||||
logger.info(f"事件处理器 {handler_name} 注册成功")
|
||||
return True
|
||||
|
||||
def get_event_handler(self, handler_name: str) -> Optional[Type[BaseEventHandler]]:
|
||||
def get_event_handler(self, handler_name: str) -> type[BaseEventHandler] | None:
|
||||
"""获取指定事件处理器实例
|
||||
|
||||
Args:
|
||||
@@ -209,7 +207,7 @@ class EventManager:
|
||||
"""
|
||||
return self._event_handlers.copy()
|
||||
|
||||
def subscribe_handler_to_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool:
|
||||
def subscribe_handler_to_event(self, handler_name: str, event_name: EventType | str) -> bool:
|
||||
"""订阅事件处理器到指定事件
|
||||
|
||||
Args:
|
||||
@@ -246,7 +244,7 @@ class EventManager:
|
||||
logger.info(f"事件处理器 {handler_name} 成功订阅到事件 {event_name},当前权重排序完成")
|
||||
return True
|
||||
|
||||
def unsubscribe_handler_from_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool:
|
||||
def unsubscribe_handler_from_event(self, handler_name: str, event_name: EventType | str) -> bool:
|
||||
"""从指定事件取消订阅事件处理器
|
||||
|
||||
Args:
|
||||
@@ -276,7 +274,7 @@ class EventManager:
|
||||
|
||||
return removed
|
||||
|
||||
def get_event_subscribers(self, event_name: Union[EventType, str]) -> Dict[str, BaseEventHandler]:
|
||||
def get_event_subscribers(self, event_name: EventType | str) -> dict[str, BaseEventHandler]:
|
||||
"""获取订阅指定事件的所有事件处理器
|
||||
|
||||
Args:
|
||||
@@ -292,8 +290,8 @@ class EventManager:
|
||||
return {handler.handler_name: handler for handler in event.subscribers}
|
||||
|
||||
async def trigger_event(
|
||||
self, event_name: Union[EventType, str], permission_group: Optional[str] = "", **kwargs
|
||||
) -> Optional[HandlerResultsCollection]:
|
||||
self, event_name: EventType | str, permission_group: str | None = "", **kwargs
|
||||
) -> HandlerResultsCollection | None:
|
||||
"""触发指定事件
|
||||
|
||||
Args:
|
||||
@@ -345,7 +343,7 @@ class EventManager:
|
||||
self._event_handlers.clear()
|
||||
logger.info("所有事件和处理器已清除")
|
||||
|
||||
def get_event_summary(self) -> Dict[str, Any]:
|
||||
def get_event_summary(self) -> dict[str, Any]:
|
||||
"""获取事件系统摘要
|
||||
|
||||
Returns:
|
||||
@@ -364,7 +362,7 @@ class EventManager:
|
||||
"pending_subscriptions": len(self._pending_subscriptions),
|
||||
}
|
||||
|
||||
def _process_pending_subscriptions(self, event_name: Union[EventType, str]) -> None:
|
||||
def _process_pending_subscriptions(self, event_name: EventType | str) -> None:
|
||||
"""处理指定事件的缓存订阅
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import List, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("global_announcement_manager")
|
||||
@@ -8,13 +6,13 @@ logger = get_logger("global_announcement_manager")
|
||||
class GlobalAnnouncementManager:
|
||||
def __init__(self) -> None:
|
||||
# 用户禁用的动作,chat_id -> [action_name]
|
||||
self._user_disabled_actions: Dict[str, List[str]] = {}
|
||||
self._user_disabled_actions: dict[str, list[str]] = {}
|
||||
# 用户禁用的命令,chat_id -> [command_name]
|
||||
self._user_disabled_commands: Dict[str, List[str]] = {}
|
||||
self._user_disabled_commands: dict[str, list[str]] = {}
|
||||
# 用户禁用的事件处理器,chat_id -> [handler_name]
|
||||
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
|
||||
self._user_disabled_event_handlers: dict[str, list[str]] = {}
|
||||
# 用户禁用的工具,chat_id -> [tool_name]
|
||||
self._user_disabled_tools: Dict[str, List[str]] = {}
|
||||
self._user_disabled_tools: dict[str, list[str]] = {}
|
||||
|
||||
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
|
||||
"""禁用特定聊天的某个动作"""
|
||||
@@ -100,19 +98,19 @@ class GlobalAnnouncementManager:
|
||||
return False
|
||||
return False
|
||||
|
||||
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
|
||||
def get_disabled_chat_actions(self, chat_id: str) -> list[str]:
|
||||
"""获取特定聊天禁用的所有动作"""
|
||||
return self._user_disabled_actions.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_commands(self, chat_id: str) -> List[str]:
|
||||
def get_disabled_chat_commands(self, chat_id: str) -> list[str]:
|
||||
"""获取特定聊天禁用的所有命令"""
|
||||
return self._user_disabled_commands.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
|
||||
def get_disabled_chat_event_handlers(self, chat_id: str) -> list[str]:
|
||||
"""获取特定聊天禁用的所有事件处理器"""
|
||||
return self._user_disabled_event_handlers.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
|
||||
def get_disabled_chat_tools(self, chat_id: str) -> list[str]:
|
||||
"""获取特定聊天禁用的所有工具"""
|
||||
return self._user_disabled_tools.get(chat_id, []).copy()
|
||||
|
||||
|
||||
@@ -4,16 +4,16 @@
|
||||
这个模块提供了权限系统的核心实现,包括权限检查、权限节点管理、用户权限管理等功能。
|
||||
"""
|
||||
|
||||
from typing import List, Set, Tuple
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select, delete
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
from src.common.database.sqlalchemy_models import PermissionNodes, UserPermissions, get_engine
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import get_engine, PermissionNodes, UserPermissions
|
||||
from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -24,7 +24,7 @@ class PermissionManager(IPermissionManager):
|
||||
def __init__(self):
|
||||
self.engine = None
|
||||
self.SessionLocal = None
|
||||
self._master_users: Set[Tuple[str, str]] = set()
|
||||
self._master_users: set[tuple[str, str]] = set()
|
||||
self._load_master_users()
|
||||
|
||||
async def initialize(self):
|
||||
@@ -276,7 +276,7 @@ class PermissionManager(IPermissionManager):
|
||||
logger.error(f"撤销权限时发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
async def get_user_permissions(self, user: UserInfo) -> List[str]:
|
||||
async def get_user_permissions(self, user: UserInfo) -> list[str]:
|
||||
"""
|
||||
获取用户拥有的所有权限节点
|
||||
|
||||
@@ -328,7 +328,7 @@ class PermissionManager(IPermissionManager):
|
||||
logger.error(f"获取用户权限时发生未知错误: {e}")
|
||||
return []
|
||||
|
||||
async def get_all_permission_nodes(self) -> List[PermissionNode]:
|
||||
async def get_all_permission_nodes(self) -> list[PermissionNode]:
|
||||
"""
|
||||
获取所有已注册的权限节点
|
||||
|
||||
@@ -356,7 +356,7 @@ class PermissionManager(IPermissionManager):
|
||||
logger.error(f"获取所有权限节点时发生未知错误: {e}")
|
||||
return []
|
||||
|
||||
async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]:
|
||||
async def get_plugin_permission_nodes(self, plugin_name: str) -> list[PermissionNode]:
|
||||
"""
|
||||
获取指定插件的所有权限节点
|
||||
|
||||
@@ -431,7 +431,7 @@ class PermissionManager(IPermissionManager):
|
||||
logger.error(f"删除插件权限时发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
async def get_users_with_permission(self, permission_node: str) -> List[Tuple[str, str]]:
|
||||
async def get_users_with_permission(self, permission_node: str) -> list[tuple[str, str]]:
|
||||
"""
|
||||
获取拥有指定权限的所有用户
|
||||
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import os
|
||||
import traceback
|
||||
import importlib
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Type, Any
|
||||
from importlib.util import spec_from_file_location, module_from_spec
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.plugin_base import PluginBase
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
from src.plugin_system.base.plugin_base import PluginBase
|
||||
from src.plugin_system.utils.manifest_utils import VersionComparator
|
||||
from .component_registry import component_registry
|
||||
|
||||
from .component_registry import component_registry
|
||||
|
||||
logger = get_logger("plugin_manager")
|
||||
|
||||
@@ -26,12 +24,12 @@ class PluginManager:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.plugin_directories: List[str] = [] # 插件根目录列表
|
||||
self.plugin_classes: Dict[str, Type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类
|
||||
self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径
|
||||
self.plugin_directories: list[str] = [] # 插件根目录列表
|
||||
self.plugin_classes: dict[str, type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类
|
||||
self.plugin_paths: dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径
|
||||
|
||||
self.loaded_plugins: Dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例
|
||||
self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息
|
||||
self.loaded_plugins: dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例
|
||||
self.failed_plugins: dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息
|
||||
|
||||
# 确保插件目录存在
|
||||
self._ensure_plugin_directories()
|
||||
@@ -54,7 +52,7 @@ class PluginManager:
|
||||
|
||||
# === 插件加载管理 ===
|
||||
|
||||
def load_all_plugins(self) -> Tuple[int, int]:
|
||||
def load_all_plugins(self) -> tuple[int, int]:
|
||||
"""加载所有插件
|
||||
|
||||
Returns:
|
||||
@@ -87,7 +85,7 @@ class PluginManager:
|
||||
|
||||
return total_registered, total_failed_registration
|
||||
|
||||
def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
|
||||
def load_registered_plugin_classes(self, plugin_name: str) -> tuple[bool, int]:
|
||||
# sourcery skip: extract-duplicate-method, extract-method
|
||||
"""
|
||||
加载已经注册的插件类
|
||||
@@ -142,7 +140,7 @@ class PluginManager:
|
||||
|
||||
except FileNotFoundError as e:
|
||||
# manifest文件缺失
|
||||
error_msg = f"缺少manifest文件: {str(e)}"
|
||||
error_msg = f"缺少manifest文件: {e!s}"
|
||||
self.failed_plugins[plugin_name] = error_msg
|
||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
||||
return False, 1
|
||||
@@ -150,14 +148,14 @@ class PluginManager:
|
||||
except ValueError as e:
|
||||
# manifest文件格式错误或验证失败
|
||||
traceback.print_exc()
|
||||
error_msg = f"manifest验证失败: {str(e)}"
|
||||
error_msg = f"manifest验证失败: {e!s}"
|
||||
self.failed_plugins[plugin_name] = error_msg
|
||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
||||
return False, 1
|
||||
|
||||
except Exception as e:
|
||||
# 其他错误
|
||||
error_msg = f"未知错误: {str(e)}"
|
||||
error_msg = f"未知错误: {e!s}"
|
||||
self.failed_plugins[plugin_name] = error_msg
|
||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
||||
logger.debug("详细错误信息: ", exc_info=True)
|
||||
@@ -192,7 +190,7 @@ class PluginManager:
|
||||
logger.debug(f"插件 {plugin_name} 重载成功")
|
||||
return True
|
||||
|
||||
def rescan_plugin_directory(self) -> Tuple[int, int]:
|
||||
def rescan_plugin_directory(self) -> tuple[int, int]:
|
||||
"""
|
||||
重新扫描插件根目录
|
||||
"""
|
||||
@@ -220,7 +218,7 @@ class PluginManager:
|
||||
return self.loaded_plugins.get(plugin_name)
|
||||
|
||||
# === 查询方法 ===
|
||||
def list_loaded_plugins(self) -> List[str]:
|
||||
def list_loaded_plugins(self) -> list[str]:
|
||||
"""
|
||||
列出所有当前加载的插件。
|
||||
|
||||
@@ -229,7 +227,7 @@ class PluginManager:
|
||||
"""
|
||||
return list(self.loaded_plugins.keys())
|
||||
|
||||
def list_registered_plugins(self) -> List[str]:
|
||||
def list_registered_plugins(self) -> list[str]:
|
||||
"""
|
||||
列出所有已注册的插件类。
|
||||
|
||||
@@ -238,7 +236,7 @@ class PluginManager:
|
||||
"""
|
||||
return list(self.plugin_classes.keys())
|
||||
|
||||
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
|
||||
def get_plugin_path(self, plugin_name: str) -> str | None:
|
||||
"""
|
||||
获取指定插件的路径。
|
||||
|
||||
@@ -329,7 +327,7 @@ class PluginManager:
|
||||
# == 兼容性检查 ==
|
||||
|
||||
@staticmethod
|
||||
def _check_plugin_version_compatibility(plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
def _check_plugin_version_compatibility(plugin_name: str, manifest_data: dict[str, Any]) -> tuple[bool, str]:
|
||||
"""检查插件版本兼容性
|
||||
|
||||
Args:
|
||||
@@ -569,7 +567,7 @@ class PluginManager:
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 插件卸载失败: {plugin_name} - {str(e)}", exc_info=True)
|
||||
logger.error(f"❌ 插件卸载失败: {plugin_name} - {e!s}", exc_info=True)
|
||||
return False
|
||||
|
||||
def reload_plugin(self, plugin_name: str) -> bool:
|
||||
@@ -606,7 +604,7 @@ class PluginManager:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - {str(e)}", exc_info=True)
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - {e!s}", exc_info=True)
|
||||
return False
|
||||
|
||||
def force_reload_plugin(self, plugin_name: str) -> bool:
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
import inspect
|
||||
import time
|
||||
from typing import List, Dict, Tuple, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.cache_manager import tool_cache
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.payload_content import ToolCall
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.llm_models.payload_content import ToolCall
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
import inspect
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.common.cache_manager import tool_cache
|
||||
|
||||
logger = get_logger("tool_use")
|
||||
|
||||
@@ -56,14 +57,14 @@ class ToolExecutor:
|
||||
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
|
||||
|
||||
# 二步工具调用状态管理
|
||||
self._pending_step_two_tools: Dict[str, Dict[str, Any]] = {}
|
||||
self._pending_step_two_tools: dict[str, dict[str, Any]] = {}
|
||||
"""待处理的第二步工具调用,格式为 {tool_name: step_two_definition}"""
|
||||
|
||||
logger.info(f"{self.log_prefix}工具执行器初始化完成")
|
||||
|
||||
async def execute_from_chat_message(
|
||||
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
|
||||
) -> Tuple[List[Dict[str, Any]], List[str], str]:
|
||||
) -> tuple[list[dict[str, Any]], list[str], str]:
|
||||
"""从聊天消息执行工具
|
||||
|
||||
Args:
|
||||
@@ -113,7 +114,7 @@ class ToolExecutor:
|
||||
else:
|
||||
return tool_results, [], ""
|
||||
|
||||
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
def _get_tool_definitions(self) -> list[dict[str, Any]]:
|
||||
all_tools = get_llm_available_tool_definitions()
|
||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||
|
||||
@@ -129,7 +130,7 @@ class ToolExecutor:
|
||||
|
||||
return tool_definitions
|
||||
|
||||
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
|
||||
"""执行工具调用
|
||||
|
||||
Args:
|
||||
@@ -138,7 +139,7 @@ class ToolExecutor:
|
||||
Returns:
|
||||
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
|
||||
"""
|
||||
tool_results: List[Dict[str, Any]] = []
|
||||
tool_results: list[dict[str, Any]] = []
|
||||
used_tools = []
|
||||
|
||||
if not tool_calls:
|
||||
@@ -192,7 +193,7 @@ class ToolExecutor:
|
||||
error_info = {
|
||||
"type": "tool_error",
|
||||
"id": f"tool_error_{time.time()}",
|
||||
"content": f"工具{tool_name}执行失败: {str(e)}",
|
||||
"content": f"工具{tool_name}执行失败: {e!s}",
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
@@ -201,8 +202,8 @@ class ToolExecutor:
|
||||
return tool_results, used_tools
|
||||
|
||||
async def execute_tool_call(
|
||||
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
self, tool_call: ToolCall, tool_instance: BaseTool | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""执行单个工具调用,并处理缓存"""
|
||||
|
||||
function_args = tool_call.args or {}
|
||||
@@ -256,8 +257,8 @@ class ToolExecutor:
|
||||
return result
|
||||
|
||||
async def _original_execute_tool_call(
|
||||
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
self, tool_call: ToolCall, tool_instance: BaseTool | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""执行单个工具调用的原始逻辑"""
|
||||
try:
|
||||
function_name = tool_call.func_name
|
||||
@@ -323,10 +324,10 @@ class ToolExecutor:
|
||||
logger.warning(f"{self.log_prefix}工具 {function_name} 返回空结果")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
||||
logger.error(f"执行工具调用时发生错误: {e!s}")
|
||||
raise e
|
||||
|
||||
async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
|
||||
async def execute_specific_tool_simple(self, tool_name: str, tool_args: dict) -> dict | None:
|
||||
"""直接执行指定工具
|
||||
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user