re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
parent ecb02cae31
commit 7923eafef3
263 changed files with 3103 additions and 3123 deletions

View File

@@ -4,14 +4,14 @@
提供插件的加载、注册和管理功能
"""
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.plugin_manager import plugin_manager
__all__ = [
"plugin_manager",
"component_registry",
"event_manager",
"global_announcement_manager",
"plugin_manager",
]

View File

@@ -1,27 +1,26 @@
from pathlib import Path
import re
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
from pathlib import Path
from re import Pattern
from typing import Any, Optional, Union
from src.common.logger import get_logger
from src.plugin_system.base.component_types import (
ComponentInfo,
ActionInfo,
ToolInfo,
CommandInfo,
PlusCommandInfo,
EventHandlerInfo,
ChatterInfo,
PluginInfo,
ComponentType,
)
from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.plus_command import PlusCommand
from src.plugin_system.base.base_chatter import BaseChatter
from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import (
ActionInfo,
ChatterInfo,
CommandInfo,
ComponentInfo,
ComponentType,
EventHandlerInfo,
PluginInfo,
PlusCommandInfo,
ToolInfo,
)
from src.plugin_system.base.plus_command import PlusCommand
logger = get_logger("component_registry")
@@ -34,46 +33,46 @@ class ComponentRegistry:
def __init__(self):
# 命名空间式组件名构成法 f"{component_type}.{component_name}"
self._components: Dict[str, "ComponentInfo"] = {}
self._components: dict[str, "ComponentInfo"] = {}
"""组件注册表 命名空间式组件名 -> 组件信息"""
self._components_by_type: Dict["ComponentType", Dict[str, "ComponentInfo"]] = {
self._components_by_type: dict["ComponentType", dict[str, "ComponentInfo"]] = {
types: {} for types in ComponentType
}
"""类型 -> 组件原名称 -> 组件信息"""
self._components_classes: Dict[
str, Type[Union["BaseCommand", "BaseAction", "BaseTool", "BaseEventHandler", "PlusCommand", "BaseChatter"]]
self._components_classes: dict[
str, type["BaseCommand" | "BaseAction" | "BaseTool" | "BaseEventHandler" | "PlusCommand" | "BaseChatter"]
] = {}
"""命名空间式组件名 -> 组件类"""
# 插件注册表
self._plugins: Dict[str, "PluginInfo"] = {}
self._plugins: dict[str, "PluginInfo"] = {}
"""插件名 -> 插件信息"""
# Action特定注册表
self._action_registry: Dict[str, Type["BaseAction"]] = {}
self._action_registry: dict[str, type["BaseAction"]] = {}
"""Action注册表 action名 -> action类"""
self._default_actions: Dict[str, "ActionInfo"] = {}
self._default_actions: dict[str, "ActionInfo"] = {}
"""默认动作集即启用的Action集用于重置ActionManager状态"""
# Command特定注册表
self._command_registry: Dict[str, Type["BaseCommand"]] = {}
self._command_registry: dict[str, type["BaseCommand"]] = {}
"""Command类注册表 command名 -> command类"""
self._command_patterns: Dict[Pattern, str] = {}
self._command_patterns: dict[Pattern, str] = {}
"""编译后的正则 -> command名"""
# 工具特定注册表
self._tool_registry: Dict[str, Type["BaseTool"]] = {} # 工具名 -> 工具类
self._llm_available_tools: Dict[str, Type["BaseTool"]] = {} # llm可用的工具名 -> 工具类
self._tool_registry: dict[str, type["BaseTool"]] = {} # 工具名 -> 工具类
self._llm_available_tools: dict[str, type["BaseTool"]] = {} # llm可用的工具名 -> 工具类
# EventHandler特定注册表
self._event_handler_registry: Dict[str, Type["BaseEventHandler"]] = {}
self._event_handler_registry: dict[str, type["BaseEventHandler"]] = {}
"""event_handler名 -> event_handler类"""
self._enabled_event_handlers: Dict[str, Type["BaseEventHandler"]] = {}
self._enabled_event_handlers: dict[str, type["BaseEventHandler"]] = {}
"""启用的事件处理器 event_handler名 -> event_handler类"""
self._chatter_registry: Dict[str, Type["BaseChatter"]] = {}
self._chatter_registry: dict[str, type["BaseChatter"]] = {}
"""chatter名 -> chatter类"""
self._enabled_chatter_registry: Dict[str, Type["BaseChatter"]] = {}
self._enabled_chatter_registry: dict[str, type["BaseChatter"]] = {}
"""启用的chatter名 -> chatter类"""
logger.info("组件注册中心初始化完成")
@@ -101,7 +100,7 @@ class ComponentRegistry:
def register_component(
self,
component_info: ComponentInfo,
component_class: Type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]],
component_class: type[Union["BaseCommand", "BaseAction", "BaseEventHandler", "BaseTool", "BaseChatter"]],
) -> bool:
"""注册组件
@@ -174,7 +173,7 @@ class ComponentRegistry:
)
return True
def _register_action_component(self, action_info: "ActionInfo", action_class: Type["BaseAction"]) -> bool:
def _register_action_component(self, action_info: "ActionInfo", action_class: type["BaseAction"]) -> bool:
"""注册Action组件到Action特定注册表"""
if not (action_name := action_info.name):
logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
@@ -194,7 +193,7 @@ class ComponentRegistry:
return True
def _register_command_component(self, command_info: "CommandInfo", command_class: Type["BaseCommand"]) -> bool:
def _register_command_component(self, command_info: "CommandInfo", command_class: type["BaseCommand"]) -> bool:
"""注册Command组件到Command特定注册表"""
if not (command_name := command_info.name):
logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
@@ -221,7 +220,7 @@ class ComponentRegistry:
return True
def _register_plus_command_component(
self, plus_command_info: "PlusCommandInfo", plus_command_class: Type["PlusCommand"]
self, plus_command_info: "PlusCommandInfo", plus_command_class: type["PlusCommand"]
) -> bool:
"""注册PlusCommand组件到特定注册表"""
plus_command_name = plus_command_info.name
@@ -235,7 +234,7 @@ class ComponentRegistry:
# 创建专门的PlusCommand注册表如果还没有
if not hasattr(self, "_plus_command_registry"):
self._plus_command_registry: Dict[str, Type["PlusCommand"]] = {}
self._plus_command_registry: dict[str, type["PlusCommand"]] = {}
plus_command_class.plugin_name = plus_command_info.plugin_name
# 设置插件配置
@@ -245,7 +244,7 @@ class ComponentRegistry:
logger.debug(f"已注册PlusCommand组件: {plus_command_name}")
return True
def _register_tool_component(self, tool_info: "ToolInfo", tool_class: Type["BaseTool"]) -> bool:
def _register_tool_component(self, tool_info: "ToolInfo", tool_class: type["BaseTool"]) -> bool:
"""注册Tool组件到Tool特定注册表"""
tool_name = tool_info.name
@@ -261,7 +260,7 @@ class ComponentRegistry:
return True
def _register_event_handler_component(
self, handler_info: "EventHandlerInfo", handler_class: Type["BaseEventHandler"]
self, handler_info: "EventHandlerInfo", handler_class: type["BaseEventHandler"]
) -> bool:
if not (handler_name := handler_info.name):
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
@@ -287,7 +286,7 @@ class ComponentRegistry:
handler_class, self.get_plugin_config(handler_info.plugin_name) or {}
)
def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: Type["BaseChatter"]) -> bool:
def _register_chatter_component(self, chatter_info: "ChatterInfo", chatter_class: type["BaseChatter"]) -> bool:
"""注册Chatter组件到Chatter特定注册表"""
chatter_name = chatter_info.name
@@ -532,7 +531,7 @@ class ComponentRegistry:
self,
component_name: str,
component_type: Optional["ComponentType"] = None,
) -> Optional[Union[Type["BaseCommand"], Type["BaseAction"], Type["BaseEventHandler"], Type["BaseTool"]]]:
) -> type["BaseCommand"] | type["BaseAction"] | type["BaseEventHandler"] | type["BaseTool"] | None:
"""获取组件类,支持自动命名空间解析
Args:
@@ -574,18 +573,18 @@ class ComponentRegistry:
# 4. 都没找到
return None
def get_components_by_type(self, component_type: "ComponentType") -> Dict[str, "ComponentInfo"]:
def get_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]:
"""获取指定类型的所有组件"""
return self._components_by_type.get(component_type, {}).copy()
def get_enabled_components_by_type(self, component_type: "ComponentType") -> Dict[str, "ComponentInfo"]:
def get_enabled_components_by_type(self, component_type: "ComponentType") -> dict[str, "ComponentInfo"]:
"""获取指定类型的所有启用组件"""
components = self.get_components_by_type(component_type)
return {name: info for name, info in components.items() if info.enabled}
# === Action特定查询方法 ===
def get_action_registry(self) -> Dict[str, Type["BaseAction"]]:
def get_action_registry(self) -> dict[str, type["BaseAction"]]:
"""获取Action注册表"""
return self._action_registry.copy()
@@ -594,13 +593,13 @@ class ComponentRegistry:
info = self.get_component_info(action_name, ComponentType.ACTION)
return info if isinstance(info, ActionInfo) else None
def get_default_actions(self) -> Dict[str, ActionInfo]:
def get_default_actions(self) -> dict[str, ActionInfo]:
"""获取默认动作集"""
return self._default_actions.copy()
# === Command特定查询方法 ===
def get_command_registry(self) -> Dict[str, Type["BaseCommand"]]:
def get_command_registry(self) -> dict[str, type["BaseCommand"]]:
"""获取Command注册表"""
return self._command_registry.copy()
@@ -609,11 +608,11 @@ class ComponentRegistry:
info = self.get_component_info(command_name, ComponentType.COMMAND)
return info if isinstance(info, CommandInfo) else None
def get_command_patterns(self) -> Dict[Pattern, str]:
def get_command_patterns(self) -> dict[Pattern, str]:
"""获取Command模式注册表"""
return self._command_patterns.copy()
def find_command_by_text(self, text: str) -> Optional[Tuple[Type["BaseCommand"], dict, "CommandInfo"]]:
def find_command_by_text(self, text: str) -> tuple[type["BaseCommand"], dict, "CommandInfo"] | None:
# sourcery skip: use-named-expression, use-next
"""根据文本查找匹配的命令
@@ -640,11 +639,11 @@ class ComponentRegistry:
return None
# === Tool 特定查询方法 ===
def get_tool_registry(self) -> Dict[str, Type["BaseTool"]]:
def get_tool_registry(self) -> dict[str, type["BaseTool"]]:
"""获取Tool注册表"""
return self._tool_registry.copy()
def get_llm_available_tools(self) -> Dict[str, Type["BaseTool"]]:
def get_llm_available_tools(self) -> dict[str, type["BaseTool"]]:
"""获取LLM可用的Tool列表"""
return self._llm_available_tools.copy()
@@ -661,10 +660,10 @@ class ComponentRegistry:
return info if isinstance(info, ToolInfo) else None
# === PlusCommand 特定查询方法 ===
def get_plus_command_registry(self) -> Dict[str, Type["PlusCommand"]]:
def get_plus_command_registry(self) -> dict[str, type["PlusCommand"]]:
"""获取PlusCommand注册表"""
if not hasattr(self, "_plus_command_registry"):
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
self._plus_command_registry: dict[str, type[PlusCommand]] = {}
return self._plus_command_registry.copy()
def get_registered_plus_command_info(self, command_name: str) -> Optional["PlusCommandInfo"]:
@@ -681,7 +680,7 @@ class ComponentRegistry:
# === EventHandler 特定查询方法 ===
def get_event_handler_registry(self) -> Dict[str, Type["BaseEventHandler"]]:
def get_event_handler_registry(self) -> dict[str, type["BaseEventHandler"]]:
"""获取事件处理器注册表"""
return self._event_handler_registry.copy()
@@ -690,21 +689,21 @@ class ComponentRegistry:
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
return info if isinstance(info, EventHandlerInfo) else None
def get_enabled_event_handlers(self) -> Dict[str, Type["BaseEventHandler"]]:
def get_enabled_event_handlers(self) -> dict[str, type["BaseEventHandler"]]:
"""获取启用的事件处理器"""
return self._enabled_event_handlers.copy()
# === Chatter 特定查询方法 ===
def get_chatter_registry(self) -> Dict[str, Type["BaseChatter"]]:
def get_chatter_registry(self) -> dict[str, type["BaseChatter"]]:
"""获取Chatter注册表"""
if not hasattr(self, "_chatter_registry"):
self._chatter_registry: Dict[str, Type[BaseChatter]] = {}
self._chatter_registry: dict[str, type[BaseChatter]] = {}
return self._chatter_registry.copy()
def get_enabled_chatter_registry(self) -> Dict[str, Type["BaseChatter"]]:
def get_enabled_chatter_registry(self) -> dict[str, type["BaseChatter"]]:
"""获取启用的Chatter注册表"""
if not hasattr(self, "_enabled_chatter_registry"):
self._enabled_chatter_registry: Dict[str, Type[BaseChatter]] = {}
self._enabled_chatter_registry: dict[str, type[BaseChatter]] = {}
return self._enabled_chatter_registry.copy()
def get_registered_chatter_info(self, chatter_name: str) -> Optional["ChatterInfo"]:
@@ -718,7 +717,7 @@ class ComponentRegistry:
"""获取插件信息"""
return self._plugins.get(plugin_name)
def get_all_plugins(self) -> Dict[str, "PluginInfo"]:
def get_all_plugins(self) -> dict[str, "PluginInfo"]:
"""获取所有插件"""
return self._plugins.copy()
@@ -726,7 +725,7 @@ class ComponentRegistry:
# """获取所有启用的插件"""
# return {name: info for name, info in self._plugins.items() if info.enabled}
def get_plugin_components(self, plugin_name: str) -> List["ComponentInfo"]:
def get_plugin_components(self, plugin_name: str) -> list["ComponentInfo"]:
"""获取插件的所有组件"""
plugin_info = self.get_plugin_info(plugin_name)
return plugin_info.components if plugin_info else []
@@ -753,7 +752,7 @@ class ComponentRegistry:
config_path = Path("config") / "plugins" / plugin_name / "config.toml"
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
with open(config_path, encoding="utf-8") as f:
config_data = toml.load(f)
logger.debug(f"从配置文件读取插件 {plugin_name} 的配置")
return config_data
@@ -762,7 +761,7 @@ class ComponentRegistry:
return {}
def get_registry_stats(self) -> Dict[str, Any]:
def get_registry_stats(self) -> dict[str, Any]:
"""获取注册中心统计信息"""
action_components: int = 0
command_components: int = 0

View File

@@ -3,8 +3,8 @@
提供统一的事件注册、管理和触发接口
"""
from typing import Dict, Type, List, Optional, Any, Union
from threading import Lock
from typing import Any, Optional
from src.common.logger import get_logger
from src.plugin_system import BaseEventHandler
@@ -37,17 +37,17 @@ class EventManager:
if self._initialized:
return
self._events: Dict[str, BaseEvent] = {}
self._event_handlers: Dict[str, Type[BaseEventHandler]] = {}
self._pending_subscriptions: Dict[str, List[str]] = {} # 缓存失败的订阅
self._events: dict[str, BaseEvent] = {}
self._event_handlers: dict[str, type[BaseEventHandler]] = {}
self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅
self._initialized = True
logger.info("EventManager 单例初始化完成")
def register_event(
self,
event_name: Union[EventType, str],
allowed_subscribers: List[str] = None,
allowed_triggers: List[str] = None,
event_name: EventType | str,
allowed_subscribers: list[str] = None,
allowed_triggers: list[str] = None,
) -> bool:
"""注册一个新的事件
@@ -75,7 +75,7 @@ class EventManager:
return True
def get_event(self, event_name: Union[EventType, str]) -> Optional[BaseEvent]:
def get_event(self, event_name: EventType | str) -> BaseEvent | None:
"""获取指定事件实例
Args:
@@ -86,7 +86,7 @@ class EventManager:
"""
return self._events.get(event_name)
def get_all_events(self) -> Dict[str, BaseEvent]:
def get_all_events(self) -> dict[str, BaseEvent]:
"""获取所有已注册的事件
Returns:
@@ -94,7 +94,7 @@ class EventManager:
"""
return self._events.copy()
def get_enabled_events(self) -> Dict[str, BaseEvent]:
def get_enabled_events(self) -> dict[str, BaseEvent]:
"""获取所有已启用的事件
Returns:
@@ -102,7 +102,7 @@ class EventManager:
"""
return {name: event for name, event in self._events.items() if event.enabled}
def get_disabled_events(self) -> Dict[str, BaseEvent]:
def get_disabled_events(self) -> dict[str, BaseEvent]:
"""获取所有已禁用的事件
Returns:
@@ -110,7 +110,7 @@ class EventManager:
"""
return {name: event for name, event in self._events.items() if not event.enabled}
def enable_event(self, event_name: Union[EventType, str]) -> bool:
def enable_event(self, event_name: EventType | str) -> bool:
"""启用指定事件
Args:
@@ -128,7 +128,7 @@ class EventManager:
logger.info(f"事件 {event_name} 已启用")
return True
def disable_event(self, event_name: Union[EventType, str]) -> bool:
def disable_event(self, event_name: EventType | str) -> bool:
"""禁用指定事件
Args:
@@ -146,9 +146,7 @@ class EventManager:
logger.info(f"事件 {event_name} 已禁用")
return True
def register_event_handler(
self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None
) -> bool:
def register_event_handler(self, handler_class: type[BaseEventHandler], plugin_config: dict | None = None) -> bool:
"""注册事件处理器
Args:
@@ -190,7 +188,7 @@ class EventManager:
logger.info(f"事件处理器 {handler_name} 注册成功")
return True
def get_event_handler(self, handler_name: str) -> Optional[Type[BaseEventHandler]]:
def get_event_handler(self, handler_name: str) -> type[BaseEventHandler] | None:
"""获取指定事件处理器实例
Args:
@@ -209,7 +207,7 @@ class EventManager:
"""
return self._event_handlers.copy()
def subscribe_handler_to_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool:
def subscribe_handler_to_event(self, handler_name: str, event_name: EventType | str) -> bool:
"""订阅事件处理器到指定事件
Args:
@@ -246,7 +244,7 @@ class EventManager:
logger.info(f"事件处理器 {handler_name} 成功订阅到事件 {event_name},当前权重排序完成")
return True
def unsubscribe_handler_from_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool:
def unsubscribe_handler_from_event(self, handler_name: str, event_name: EventType | str) -> bool:
"""从指定事件取消订阅事件处理器
Args:
@@ -276,7 +274,7 @@ class EventManager:
return removed
def get_event_subscribers(self, event_name: Union[EventType, str]) -> Dict[str, BaseEventHandler]:
def get_event_subscribers(self, event_name: EventType | str) -> dict[str, BaseEventHandler]:
"""获取订阅指定事件的所有事件处理器
Args:
@@ -292,8 +290,8 @@ class EventManager:
return {handler.handler_name: handler for handler in event.subscribers}
async def trigger_event(
self, event_name: Union[EventType, str], permission_group: Optional[str] = "", **kwargs
) -> Optional[HandlerResultsCollection]:
self, event_name: EventType | str, permission_group: str | None = "", **kwargs
) -> HandlerResultsCollection | None:
"""触发指定事件
Args:
@@ -345,7 +343,7 @@ class EventManager:
self._event_handlers.clear()
logger.info("所有事件和处理器已清除")
def get_event_summary(self) -> Dict[str, Any]:
def get_event_summary(self) -> dict[str, Any]:
"""获取事件系统摘要
Returns:
@@ -364,7 +362,7 @@ class EventManager:
"pending_subscriptions": len(self._pending_subscriptions),
}
def _process_pending_subscriptions(self, event_name: Union[EventType, str]) -> None:
def _process_pending_subscriptions(self, event_name: EventType | str) -> None:
"""处理指定事件的缓存订阅
Args:

View File

@@ -1,5 +1,3 @@
from typing import List, Dict
from src.common.logger import get_logger
logger = get_logger("global_announcement_manager")
@@ -8,13 +6,13 @@ logger = get_logger("global_announcement_manager")
class GlobalAnnouncementManager:
def __init__(self) -> None:
# 用户禁用的动作chat_id -> [action_name]
self._user_disabled_actions: Dict[str, List[str]] = {}
self._user_disabled_actions: dict[str, list[str]] = {}
# 用户禁用的命令chat_id -> [command_name]
self._user_disabled_commands: Dict[str, List[str]] = {}
self._user_disabled_commands: dict[str, list[str]] = {}
# 用户禁用的事件处理器chat_id -> [handler_name]
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
self._user_disabled_event_handlers: dict[str, list[str]] = {}
# 用户禁用的工具chat_id -> [tool_name]
self._user_disabled_tools: Dict[str, List[str]] = {}
self._user_disabled_tools: dict[str, list[str]] = {}
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
"""禁用特定聊天的某个动作"""
@@ -100,19 +98,19 @@ class GlobalAnnouncementManager:
return False
return False
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
def get_disabled_chat_actions(self, chat_id: str) -> list[str]:
"""获取特定聊天禁用的所有动作"""
return self._user_disabled_actions.get(chat_id, []).copy()
def get_disabled_chat_commands(self, chat_id: str) -> List[str]:
def get_disabled_chat_commands(self, chat_id: str) -> list[str]:
"""获取特定聊天禁用的所有命令"""
return self._user_disabled_commands.get(chat_id, []).copy()
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
def get_disabled_chat_event_handlers(self, chat_id: str) -> list[str]:
"""获取特定聊天禁用的所有事件处理器"""
return self._user_disabled_event_handlers.get(chat_id, []).copy()
def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
def get_disabled_chat_tools(self, chat_id: str) -> list[str]:
"""获取特定聊天禁用的所有工具"""
return self._user_disabled_tools.get(chat_id, []).copy()

View File

@@ -4,16 +4,16 @@
这个模块提供了权限系统的核心实现,包括权限检查、权限节点管理、用户权限管理等功能。
"""
from typing import List, Set, Tuple
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from datetime import datetime
from sqlalchemy import select, delete
from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.ext.asyncio import async_sessionmaker
from src.common.database.sqlalchemy_models import PermissionNodes, UserPermissions, get_engine
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import get_engine, PermissionNodes, UserPermissions
from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo
from src.config.config import global_config
from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo
logger = get_logger(__name__)
@@ -24,7 +24,7 @@ class PermissionManager(IPermissionManager):
def __init__(self):
self.engine = None
self.SessionLocal = None
self._master_users: Set[Tuple[str, str]] = set()
self._master_users: set[tuple[str, str]] = set()
self._load_master_users()
async def initialize(self):
@@ -276,7 +276,7 @@ class PermissionManager(IPermissionManager):
logger.error(f"撤销权限时发生未知错误: {e}")
return False
async def get_user_permissions(self, user: UserInfo) -> List[str]:
async def get_user_permissions(self, user: UserInfo) -> list[str]:
"""
获取用户拥有的所有权限节点
@@ -328,7 +328,7 @@ class PermissionManager(IPermissionManager):
logger.error(f"获取用户权限时发生未知错误: {e}")
return []
async def get_all_permission_nodes(self) -> List[PermissionNode]:
async def get_all_permission_nodes(self) -> list[PermissionNode]:
"""
获取所有已注册的权限节点
@@ -356,7 +356,7 @@ class PermissionManager(IPermissionManager):
logger.error(f"获取所有权限节点时发生未知错误: {e}")
return []
async def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]:
async def get_plugin_permission_nodes(self, plugin_name: str) -> list[PermissionNode]:
"""
获取指定插件的所有权限节点
@@ -431,7 +431,7 @@ class PermissionManager(IPermissionManager):
logger.error(f"删除插件权限时发生未知错误: {e}")
return False
async def get_users_with_permission(self, permission_node: str) -> List[Tuple[str, str]]:
async def get_users_with_permission(self, permission_node: str) -> list[tuple[str, str]]:
"""
获取拥有指定权限的所有用户

View File

@@ -1,19 +1,17 @@
import asyncio
import importlib
import os
import traceback
import importlib
from typing import Dict, List, Optional, Tuple, Type, Any
from importlib.util import spec_from_file_location, module_from_spec
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from typing import Any, Optional
from src.common.logger import get_logger
from src.plugin_system.base.plugin_base import PluginBase
from src.plugin_system.base.component_types import ComponentType
from src.plugin_system.base.plugin_base import PluginBase
from src.plugin_system.utils.manifest_utils import VersionComparator
from .component_registry import component_registry
from .component_registry import component_registry
logger = get_logger("plugin_manager")
@@ -26,12 +24,12 @@ class PluginManager:
"""
def __init__(self):
self.plugin_directories: List[str] = [] # 插件根目录列表
self.plugin_classes: Dict[str, Type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类
self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径
self.plugin_directories: list[str] = [] # 插件根目录列表
self.plugin_classes: dict[str, type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类
self.plugin_paths: dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径
self.loaded_plugins: Dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例
self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息
self.loaded_plugins: dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例
self.failed_plugins: dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息
# 确保插件目录存在
self._ensure_plugin_directories()
@@ -54,7 +52,7 @@ class PluginManager:
# === 插件加载管理 ===
def load_all_plugins(self) -> Tuple[int, int]:
def load_all_plugins(self) -> tuple[int, int]:
"""加载所有插件
Returns:
@@ -87,7 +85,7 @@ class PluginManager:
return total_registered, total_failed_registration
def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
def load_registered_plugin_classes(self, plugin_name: str) -> tuple[bool, int]:
# sourcery skip: extract-duplicate-method, extract-method
"""
加载已经注册的插件类
@@ -142,7 +140,7 @@ class PluginManager:
except FileNotFoundError as e:
# manifest文件缺失
error_msg = f"缺少manifest文件: {str(e)}"
error_msg = f"缺少manifest文件: {e!s}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
@@ -150,14 +148,14 @@ class PluginManager:
except ValueError as e:
# manifest文件格式错误或验证失败
traceback.print_exc()
error_msg = f"manifest验证失败: {str(e)}"
error_msg = f"manifest验证失败: {e!s}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
except Exception as e:
# 其他错误
error_msg = f"未知错误: {str(e)}"
error_msg = f"未知错误: {e!s}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
logger.debug("详细错误信息: ", exc_info=True)
@@ -192,7 +190,7 @@ class PluginManager:
logger.debug(f"插件 {plugin_name} 重载成功")
return True
def rescan_plugin_directory(self) -> Tuple[int, int]:
def rescan_plugin_directory(self) -> tuple[int, int]:
"""
重新扫描插件根目录
"""
@@ -220,7 +218,7 @@ class PluginManager:
return self.loaded_plugins.get(plugin_name)
# === 查询方法 ===
def list_loaded_plugins(self) -> List[str]:
def list_loaded_plugins(self) -> list[str]:
"""
列出所有当前加载的插件。
@@ -229,7 +227,7 @@ class PluginManager:
"""
return list(self.loaded_plugins.keys())
def list_registered_plugins(self) -> List[str]:
def list_registered_plugins(self) -> list[str]:
"""
列出所有已注册的插件类。
@@ -238,7 +236,7 @@ class PluginManager:
"""
return list(self.plugin_classes.keys())
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
def get_plugin_path(self, plugin_name: str) -> str | None:
"""
获取指定插件的路径。
@@ -329,7 +327,7 @@ class PluginManager:
# == 兼容性检查 ==
@staticmethod
def _check_plugin_version_compatibility(plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
def _check_plugin_version_compatibility(plugin_name: str, manifest_data: dict[str, Any]) -> tuple[bool, str]:
"""检查插件版本兼容性
Args:
@@ -569,7 +567,7 @@ class PluginManager:
return True
except Exception as e:
logger.error(f"❌ 插件卸载失败: {plugin_name} - {str(e)}", exc_info=True)
logger.error(f"❌ 插件卸载失败: {plugin_name} - {e!s}", exc_info=True)
return False
def reload_plugin(self, plugin_name: str) -> bool:
@@ -606,7 +604,7 @@ class PluginManager:
return False
except Exception as e:
logger.error(f"❌ 插件重载失败: {plugin_name} - {str(e)}", exc_info=True)
logger.error(f"❌ 插件重载失败: {plugin_name} - {e!s}", exc_info=True)
return False
def force_reload_plugin(self, plugin_name: str) -> bool:

View File

@@ -1,16 +1,17 @@
import inspect
import time
from typing import List, Dict, Tuple, Optional, Any
from typing import Any
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.cache_manager import tool_cache
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.payload_content import ToolCall
from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content import ToolCall
from src.config.config import global_config, model_config
from src.chat.utils.prompt import Prompt, global_prompt_manager
import inspect
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
from src.common.cache_manager import tool_cache
logger = get_logger("tool_use")
@@ -56,14 +57,14 @@ class ToolExecutor:
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
# 二步工具调用状态管理
self._pending_step_two_tools: Dict[str, Dict[str, Any]] = {}
self._pending_step_two_tools: dict[str, dict[str, Any]] = {}
"""待处理的第二步工具调用,格式为 {tool_name: step_two_definition}"""
logger.info(f"{self.log_prefix}工具执行器初始化完成")
async def execute_from_chat_message(
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
) -> Tuple[List[Dict[str, Any]], List[str], str]:
) -> tuple[list[dict[str, Any]], list[str], str]:
"""从聊天消息执行工具
Args:
@@ -113,7 +114,7 @@ class ToolExecutor:
else:
return tool_results, [], ""
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
def _get_tool_definitions(self) -> list[dict[str, Any]]:
all_tools = get_llm_available_tool_definitions()
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
@@ -129,7 +130,7 @@ class ToolExecutor:
return tool_definitions
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
"""执行工具调用
Args:
@@ -138,7 +139,7 @@ class ToolExecutor:
Returns:
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
"""
tool_results: List[Dict[str, Any]] = []
tool_results: list[dict[str, Any]] = []
used_tools = []
if not tool_calls:
@@ -192,7 +193,7 @@ class ToolExecutor:
error_info = {
"type": "tool_error",
"id": f"tool_error_{time.time()}",
"content": f"工具{tool_name}执行失败: {str(e)}",
"content": f"工具{tool_name}执行失败: {e!s}",
"tool_name": tool_name,
"timestamp": time.time(),
}
@@ -201,8 +202,8 @@ class ToolExecutor:
return tool_results, used_tools
async def execute_tool_call(
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
) -> Optional[Dict[str, Any]]:
self, tool_call: ToolCall, tool_instance: BaseTool | None = None
) -> dict[str, Any] | None:
"""执行单个工具调用,并处理缓存"""
function_args = tool_call.args or {}
@@ -256,8 +257,8 @@ class ToolExecutor:
return result
async def _original_execute_tool_call(
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
) -> Optional[Dict[str, Any]]:
self, tool_call: ToolCall, tool_instance: BaseTool | None = None
) -> dict[str, Any] | None:
"""执行单个工具调用的原始逻辑"""
try:
function_name = tool_call.func_name
@@ -323,10 +324,10 @@ class ToolExecutor:
logger.warning(f"{self.log_prefix}工具 {function_name} 返回空结果")
return None
except Exception as e:
logger.error(f"执行工具调用时发生错误: {str(e)}")
logger.error(f"执行工具调用时发生错误: {e!s}")
raise e
async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
async def execute_specific_tool_simple(self, tool_name: str, tool_args: dict) -> dict | None:
"""直接执行指定工具
Args: