修复代码格式和文件名大小写问题
This commit is contained in:
@@ -34,7 +34,9 @@ class ComponentRegistry:
|
||||
"""组件注册表 命名空间式组件名 -> 组件信息"""
|
||||
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
|
||||
"""类型 -> 组件原名称 -> 组件信息"""
|
||||
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler, PlusCommand]]] = {}
|
||||
self._components_classes: Dict[
|
||||
str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler, PlusCommand]]
|
||||
] = {}
|
||||
"""命名空间式组件名 -> 组件类"""
|
||||
|
||||
# 插件注册表
|
||||
@@ -166,7 +168,7 @@ class ComponentRegistry:
|
||||
if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction):
|
||||
logger.error(f"注册失败: {action_name} 不是有效的Action")
|
||||
return False
|
||||
|
||||
|
||||
action_class.plugin_name = action_info.plugin_name
|
||||
self._action_registry[action_name] = action_class
|
||||
|
||||
@@ -200,7 +202,9 @@ class ComponentRegistry:
|
||||
|
||||
return True
|
||||
|
||||
def _register_plus_command_component(self, plus_command_info: PlusCommandInfo, plus_command_class: Type[PlusCommand]) -> bool:
|
||||
def _register_plus_command_component(
|
||||
self, plus_command_info: PlusCommandInfo, plus_command_class: Type[PlusCommand]
|
||||
) -> bool:
|
||||
"""注册PlusCommand组件到特定注册表"""
|
||||
plus_command_name = plus_command_info.name
|
||||
|
||||
@@ -212,7 +216,7 @@ class ComponentRegistry:
|
||||
return False
|
||||
|
||||
# 创建专门的PlusCommand注册表(如果还没有)
|
||||
if not hasattr(self, '_plus_command_registry'):
|
||||
if not hasattr(self, "_plus_command_registry"):
|
||||
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
|
||||
|
||||
plus_command_class.plugin_name = plus_command_info.plugin_name
|
||||
@@ -249,10 +253,11 @@ class ComponentRegistry:
|
||||
if not handler_info.enabled:
|
||||
logger.warning(f"EventHandler组件 {handler_name} 未启用")
|
||||
return True # 未启用,但是也是注册成功
|
||||
|
||||
|
||||
handler_class.plugin_name = handler_info.plugin_name
|
||||
# 使用EventManager进行事件处理器注册
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
return event_manager.register_event_handler(handler_class)
|
||||
|
||||
# === 组件移除相关 ===
|
||||
@@ -281,7 +286,7 @@ class ComponentRegistry:
|
||||
|
||||
case ComponentType.PLUS_COMMAND:
|
||||
# 移除PlusCommand注册
|
||||
if hasattr(self, '_plus_command_registry'):
|
||||
if hasattr(self, "_plus_command_registry"):
|
||||
self._plus_command_registry.pop(component_name, None)
|
||||
logger.debug(f"已移除PlusCommand组件: {component_name}")
|
||||
|
||||
@@ -371,6 +376,7 @@ class ComponentRegistry:
|
||||
assert issubclass(target_component_class, BaseEventHandler)
|
||||
self._enabled_event_handlers[component_name] = target_component_class
|
||||
from .event_manager import event_manager # 延迟导入防止循环导入问题
|
||||
|
||||
event_manager.register_event_handler(component_name)
|
||||
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
@@ -572,7 +578,7 @@ class ComponentRegistry:
|
||||
candidates[0].match(text).groupdict(), # type: ignore
|
||||
command_info,
|
||||
)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
# === Tool 特定查询方法 ===
|
||||
@@ -599,7 +605,7 @@ class ComponentRegistry:
|
||||
# === PlusCommand 特定查询方法 ===
|
||||
def get_plus_command_registry(self) -> Dict[str, Type[PlusCommand]]:
|
||||
"""获取PlusCommand注册表"""
|
||||
if not hasattr(self, '_plus_command_registry'):
|
||||
if not hasattr(self, "_plus_command_registry"):
|
||||
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
|
||||
return self._plus_command_registry.copy()
|
||||
|
||||
|
||||
@@ -2,55 +2,57 @@
|
||||
事件管理器 - 实现Event和EventHandler的单例管理
|
||||
提供统一的事件注册、管理和触发接口
|
||||
"""
|
||||
|
||||
from typing import Dict, Type, List, Optional, Any, Union
|
||||
from threading import Lock
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection, HandlerResult
|
||||
from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
|
||||
logger = get_logger("event_manager")
|
||||
|
||||
|
||||
class EventManager:
|
||||
"""事件管理器单例类
|
||||
|
||||
|
||||
负责管理所有事件和事件处理器的注册、订阅、触发等操作
|
||||
使用单例模式确保全局只有一个事件管理实例
|
||||
"""
|
||||
|
||||
_instance: Optional['EventManager'] = None
|
||||
|
||||
_instance: Optional["EventManager"] = None
|
||||
_lock = Lock()
|
||||
|
||||
def __new__(cls) -> 'EventManager':
|
||||
|
||||
def __new__(cls) -> "EventManager":
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
|
||||
self._events: Dict[str, BaseEvent] = {}
|
||||
self._event_handlers: Dict[str, Type[BaseEventHandler]] = {}
|
||||
self._pending_subscriptions: Dict[str, List[str]] = {} # 缓存失败的订阅
|
||||
self._initialized = True
|
||||
logger.info("EventManager 单例初始化完成")
|
||||
|
||||
|
||||
def register_event(
|
||||
self,
|
||||
event_name: Union[EventType, str],
|
||||
allowed_subscribers: List[str]=None,
|
||||
allowed_triggers: List[str]=None
|
||||
) -> bool:
|
||||
self,
|
||||
event_name: Union[EventType, str],
|
||||
allowed_subscribers: List[str] = None,
|
||||
allowed_triggers: List[str] = None,
|
||||
) -> bool:
|
||||
"""注册一个新的事件
|
||||
|
||||
|
||||
Args:
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
allowed_subscribers: List[str]: 事件订阅者白名单,
|
||||
allowed_subscribers: List[str]: 事件订阅者白名单,
|
||||
allowed_triggers: List[str]: 事件触发插件白名单
|
||||
Returns:
|
||||
bool: 注册成功返回True,已存在返回False
|
||||
@@ -62,57 +64,57 @@ class EventManager:
|
||||
if event_name in self._events:
|
||||
logger.warning(f"事件 {event_name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
event = BaseEvent(event_name,allowed_subscribers,allowed_triggers)
|
||||
|
||||
event = BaseEvent(event_name, allowed_subscribers, allowed_triggers)
|
||||
self._events[event_name] = event
|
||||
logger.info(f"事件 {event_name} 注册成功")
|
||||
|
||||
|
||||
# 检查是否有缓存的订阅需要处理
|
||||
self._process_pending_subscriptions(event_name)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_event(self, event_name: Union[EventType, str]) -> Optional[BaseEvent]:
|
||||
"""获取指定事件实例
|
||||
|
||||
|
||||
Args:
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
|
||||
Returns:
|
||||
BaseEvent: 事件实例,不存在返回None
|
||||
"""
|
||||
return self._events.get(event_name)
|
||||
|
||||
|
||||
def get_all_events(self) -> Dict[str, BaseEvent]:
|
||||
"""获取所有已注册的事件
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, BaseEvent]: 所有事件的字典
|
||||
"""
|
||||
return self._events.copy()
|
||||
|
||||
|
||||
def get_enabled_events(self) -> Dict[str, BaseEvent]:
|
||||
"""获取所有已启用的事件
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, BaseEvent]: 已启用事件的字典
|
||||
"""
|
||||
return {name: event for name, event in self._events.items() if event.enabled}
|
||||
|
||||
|
||||
def get_disabled_events(self) -> Dict[str, BaseEvent]:
|
||||
"""获取所有已禁用的事件
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, BaseEvent]: 已禁用事件的字典
|
||||
"""
|
||||
return {name: event for name, event in self._events.items() if not event.enabled}
|
||||
|
||||
|
||||
def enable_event(self, event_name: Union[EventType, str]) -> bool:
|
||||
"""启用指定事件
|
||||
|
||||
|
||||
Args:
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 成功返回True,事件不存在返回False
|
||||
"""
|
||||
@@ -120,17 +122,17 @@ class EventManager:
|
||||
if event is None:
|
||||
logger.error(f"事件 {event_name} 不存在,无法启用")
|
||||
return False
|
||||
|
||||
|
||||
event.enabled = True
|
||||
logger.info(f"事件 {event_name} 已启用")
|
||||
return True
|
||||
|
||||
|
||||
def disable_event(self, event_name: Union[EventType, str]) -> bool:
|
||||
"""禁用指定事件
|
||||
|
||||
|
||||
Args:
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 成功返回True,事件不存在返回False
|
||||
"""
|
||||
@@ -138,38 +140,38 @@ class EventManager:
|
||||
if event is None:
|
||||
logger.error(f"事件 {event_name} 不存在,无法禁用")
|
||||
return False
|
||||
|
||||
|
||||
event.enabled = False
|
||||
logger.info(f"事件 {event_name} 已禁用")
|
||||
return True
|
||||
|
||||
|
||||
def register_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
|
||||
"""注册事件处理器
|
||||
|
||||
|
||||
Args:
|
||||
handler_class (Type[BaseEventHandler]): 事件处理器类
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 注册成功返回True,已存在返回False
|
||||
"""
|
||||
handler_name = handler_class.handler_name or handler_class.__name__.lower().replace("handler", "")
|
||||
|
||||
|
||||
if EventType.UNKNOWN in handler_class.init_subscribe:
|
||||
logger.error(f"事件处理器 {handler_name} 不能订阅 UNKNOWN 事件")
|
||||
return False
|
||||
if handler_name in self._event_handlers:
|
||||
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
|
||||
self._event_handlers[handler_name] = handler_class()
|
||||
|
||||
|
||||
# 处理init_subscribe,缓存失败的订阅
|
||||
if self._event_handlers[handler_name].init_subscribe:
|
||||
failed_subscriptions = []
|
||||
for event_name in self._event_handlers[handler_name].init_subscribe:
|
||||
if not self.subscribe_handler_to_event(handler_name, event_name):
|
||||
failed_subscriptions.append(event_name)
|
||||
|
||||
|
||||
# 缓存失败的订阅
|
||||
if failed_subscriptions:
|
||||
self._pending_subscriptions[handler_name] = failed_subscriptions
|
||||
@@ -177,33 +179,33 @@ class EventManager:
|
||||
|
||||
logger.info(f"事件处理器 {handler_name} 注册成功")
|
||||
return True
|
||||
|
||||
|
||||
def get_event_handler(self, handler_name: str) -> Optional[Type[BaseEventHandler]]:
|
||||
"""获取指定事件处理器实例
|
||||
|
||||
|
||||
Args:
|
||||
handler_name (str): 处理器名称
|
||||
|
||||
|
||||
Returns:
|
||||
Type[BaseEventHandler]: 处理器实例,不存在返回None
|
||||
"""
|
||||
return self._event_handlers.get(handler_name)
|
||||
|
||||
|
||||
def get_all_event_handlers(self) -> Dict[str, BaseEventHandler]:
|
||||
"""获取所有已注册的事件处理器
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Type[BaseEventHandler]]: 所有处理器的字典
|
||||
"""
|
||||
return self._event_handlers.copy()
|
||||
|
||||
|
||||
def subscribe_handler_to_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool:
|
||||
"""订阅事件处理器到指定事件
|
||||
|
||||
|
||||
Args:
|
||||
handler_name (str): 处理器名称
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 订阅成功返回True
|
||||
"""
|
||||
@@ -211,36 +213,36 @@ class EventManager:
|
||||
if handler_instance is None:
|
||||
logger.error(f"事件处理器 {handler_name} 不存在,无法订阅到事件 {event_name}")
|
||||
return False
|
||||
|
||||
|
||||
event = self.get_event(event_name)
|
||||
if event is None:
|
||||
logger.error(f"事件 {event_name} 不存在,无法订阅事件处理器 {handler_name}")
|
||||
return False
|
||||
|
||||
|
||||
if handler_instance in event.subscribers:
|
||||
logger.warning(f"事件处理器 {handler_name} 已经订阅了事件 {event_name},跳过重复订阅")
|
||||
return True
|
||||
|
||||
|
||||
# 白名单检查
|
||||
if event.allowed_subscribers and handler_name not in event.allowed_subscribers:
|
||||
logger.warning(f"事件处理器 {handler_name} 不在事件 {event_name} 的订阅者白名单中,无法订阅")
|
||||
return False
|
||||
|
||||
|
||||
event.subscribers.append(handler_instance)
|
||||
|
||||
|
||||
# 按权重从高到低排序订阅者
|
||||
event.subscribers.sort(key=lambda h: getattr(h, 'weight', 0), reverse=True)
|
||||
|
||||
event.subscribers.sort(key=lambda h: getattr(h, "weight", 0), reverse=True)
|
||||
|
||||
logger.info(f"事件处理器 {handler_name} 成功订阅到事件 {event_name},当前权重排序完成")
|
||||
return True
|
||||
|
||||
|
||||
def unsubscribe_handler_from_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool:
|
||||
"""从指定事件取消订阅事件处理器
|
||||
|
||||
|
||||
Args:
|
||||
handler_name (str): 处理器名称
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 取消订阅成功返回True
|
||||
"""
|
||||
@@ -248,55 +250,57 @@ class EventManager:
|
||||
if event is None:
|
||||
logger.error(f"事件 {event_name} 不存在,无法取消订阅")
|
||||
return False
|
||||
|
||||
|
||||
# 查找并移除处理器实例
|
||||
removed = False
|
||||
for subscriber in event.subscribers[:]:
|
||||
if hasattr(subscriber, 'handler_name') and subscriber.handler_name == handler_name:
|
||||
if hasattr(subscriber, "handler_name") and subscriber.handler_name == handler_name:
|
||||
event.subscribers.remove(subscriber)
|
||||
removed = True
|
||||
break
|
||||
|
||||
|
||||
if removed:
|
||||
logger.info(f"事件处理器 {handler_name} 成功从事件 {event_name} 取消订阅")
|
||||
else:
|
||||
logger.warning(f"事件处理器 {handler_name} 未订阅事件 {event_name}")
|
||||
|
||||
|
||||
return removed
|
||||
|
||||
|
||||
def get_event_subscribers(self, event_name: Union[EventType, str]) -> Dict[str, BaseEventHandler]:
|
||||
"""获取订阅指定事件的所有事件处理器
|
||||
|
||||
|
||||
Args:
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, BaseEventHandler]: 处理器字典,键为处理器名称,值为处理器实例
|
||||
"""
|
||||
event = self.get_event(event_name)
|
||||
if event is None:
|
||||
return {}
|
||||
|
||||
|
||||
return {handler.handler_name: handler for handler in event.subscribers}
|
||||
|
||||
async def trigger_event(self, event_name: Union[EventType, str], plugin_name: Optional[str]="", **kwargs) -> Optional[HandlerResultsCollection]:
|
||||
|
||||
async def trigger_event(
|
||||
self, event_name: Union[EventType, str], plugin_name: Optional[str] = "", **kwargs
|
||||
) -> Optional[HandlerResultsCollection]:
|
||||
"""触发指定事件
|
||||
|
||||
|
||||
Args:
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
plugin_name str: 触发事件的插件名
|
||||
**kwargs: 传递给处理器的参数
|
||||
|
||||
|
||||
Returns:
|
||||
HandlerResultsCollection: 所有处理器的执行结果,事件不存在返回None
|
||||
"""
|
||||
params = kwargs or {}
|
||||
|
||||
|
||||
event = self.get_event(event_name)
|
||||
if event is None:
|
||||
logger.error(f"事件 {event_name} 不存在,无法触发")
|
||||
return None
|
||||
|
||||
|
||||
# 插件白名单检查
|
||||
if event.allowed_triggers and not plugin_name:
|
||||
logger.warning(f"事件 {event_name} 存在触发者白名单,缺少plugin_name无法验证权限,已拒绝触发!")
|
||||
@@ -304,9 +308,9 @@ class EventManager:
|
||||
elif event.allowed_triggers and plugin_name not in event.allowed_triggers:
|
||||
logger.warning(f"插件 {plugin_name} 没有权限触发事件 {event_name},已拒绝触发!")
|
||||
return None
|
||||
|
||||
|
||||
return await event.activate(params)
|
||||
|
||||
|
||||
def init_default_events(self) -> None:
|
||||
"""初始化默认事件"""
|
||||
default_events = [
|
||||
@@ -317,29 +321,29 @@ class EventManager:
|
||||
EventType.POST_LLM,
|
||||
EventType.AFTER_LLM,
|
||||
EventType.POST_SEND,
|
||||
EventType.AFTER_SEND
|
||||
EventType.AFTER_SEND,
|
||||
]
|
||||
|
||||
|
||||
for event_name in default_events:
|
||||
self.register_event(event_name,allowed_triggers=["SYSTEM"])
|
||||
|
||||
self.register_event(event_name, allowed_triggers=["SYSTEM"])
|
||||
|
||||
logger.info("默认事件初始化完成")
|
||||
|
||||
|
||||
def clear_all_events(self) -> None:
|
||||
"""清除所有事件和处理器(主要用于测试)"""
|
||||
self._events.clear()
|
||||
self._event_handlers.clear()
|
||||
logger.info("所有事件和处理器已清除")
|
||||
|
||||
|
||||
def get_event_summary(self) -> Dict[str, Any]:
|
||||
"""获取事件系统摘要
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含事件系统统计信息的字典
|
||||
"""
|
||||
enabled_events = self.get_enabled_events()
|
||||
disabled_events = self.get_disabled_events()
|
||||
|
||||
|
||||
return {
|
||||
"total_events": len(self._events),
|
||||
"enabled_events": len(enabled_events),
|
||||
@@ -347,58 +351,58 @@ class EventManager:
|
||||
"total_handlers": len(self._event_handlers),
|
||||
"event_names": list(self._events.keys()),
|
||||
"handler_names": list(self._event_handlers.keys()),
|
||||
"pending_subscriptions": len(self._pending_subscriptions)
|
||||
"pending_subscriptions": len(self._pending_subscriptions),
|
||||
}
|
||||
|
||||
def _process_pending_subscriptions(self, event_name: Union[EventType, str]) -> None:
|
||||
"""处理指定事件的缓存订阅
|
||||
|
||||
|
||||
Args:
|
||||
event_name Union[EventType, str]: 事件名称
|
||||
"""
|
||||
handlers_to_remove = []
|
||||
|
||||
|
||||
for handler_name, pending_events in self._pending_subscriptions.items():
|
||||
if event_name in pending_events:
|
||||
if self.subscribe_handler_to_event(handler_name, event_name):
|
||||
pending_events.remove(event_name)
|
||||
logger.info(f"成功处理缓存订阅: {handler_name} -> {event_name}")
|
||||
|
||||
|
||||
# 如果该处理器没有更多待处理订阅,标记为移除
|
||||
if not pending_events:
|
||||
handlers_to_remove.append(handler_name)
|
||||
|
||||
|
||||
# 清理已完成的处理器缓存
|
||||
for handler_name in handlers_to_remove:
|
||||
del self._pending_subscriptions[handler_name]
|
||||
|
||||
def process_all_pending_subscriptions(self) -> int:
|
||||
"""处理所有缓存的订阅
|
||||
|
||||
|
||||
Returns:
|
||||
int: 成功处理的订阅数量
|
||||
"""
|
||||
processed_count = 0
|
||||
|
||||
|
||||
# 复制待处理订阅,避免在迭代时修改字典
|
||||
pending_copy = dict(self._pending_subscriptions)
|
||||
|
||||
|
||||
for handler_name, pending_events in pending_copy.items():
|
||||
for event_name in pending_events[:]: # 使用切片避免修改列表
|
||||
if self.subscribe_handler_to_event(handler_name, event_name):
|
||||
pending_events.remove(event_name)
|
||||
processed_count += 1
|
||||
|
||||
|
||||
# 清理已完成的处理器缓存
|
||||
handlers_to_remove = [name for name, events in self._pending_subscriptions.items() if not events]
|
||||
for handler_name in handlers_to_remove:
|
||||
del self._pending_subscriptions[handler_name]
|
||||
|
||||
|
||||
if processed_count > 0:
|
||||
logger.info(f"批量处理缓存订阅完成,共处理 {processed_count} 个订阅")
|
||||
|
||||
|
||||
return processed_count
|
||||
|
||||
|
||||
# 创建全局事件管理器实例
|
||||
event_manager = EventManager()
|
||||
event_manager = EventManager()
|
||||
|
||||
@@ -88,7 +88,7 @@ class GlobalAnnouncementManager:
|
||||
return False
|
||||
self._user_disabled_tools[chat_id].append(tool_name)
|
||||
return True
|
||||
|
||||
|
||||
def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
|
||||
"""启用特定聊天的某个工具"""
|
||||
if chat_id in self._user_disabled_tools:
|
||||
@@ -111,7 +111,7 @@ class GlobalAnnouncementManager:
|
||||
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有事件处理器"""
|
||||
return self._user_disabled_event_handlers.get(chat_id, []).copy()
|
||||
|
||||
|
||||
def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有工具"""
|
||||
return self._user_disabled_tools.get(chat_id, []).copy()
|
||||
|
||||
@@ -19,14 +19,14 @@ logger = get_logger(__name__)
|
||||
|
||||
class PermissionManager(IPermissionManager):
|
||||
"""权限管理器实现类"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.engine = get_engine()
|
||||
self.SessionLocal = sessionmaker(bind=self.engine)
|
||||
self._master_users: Set[Tuple[str, str]] = set()
|
||||
self._load_master_users()
|
||||
logger.info("权限管理器初始化完成")
|
||||
|
||||
|
||||
def _load_master_users(self):
|
||||
"""从配置文件加载Master用户列表"""
|
||||
try:
|
||||
@@ -40,19 +40,19 @@ class PermissionManager(IPermissionManager):
|
||||
except Exception as e:
|
||||
logger.warning(f"加载Master用户配置失败: {e}")
|
||||
self._master_users = set()
|
||||
|
||||
|
||||
def reload_master_users(self):
|
||||
"""重新加载Master用户配置"""
|
||||
self._load_master_users()
|
||||
logger.info("Master用户配置已重新加载")
|
||||
|
||||
|
||||
def is_master(self, user: UserInfo) -> bool:
|
||||
"""
|
||||
检查用户是否为Master用户
|
||||
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否为Master用户
|
||||
"""
|
||||
@@ -61,15 +61,15 @@ class PermissionManager(IPermissionManager):
|
||||
if is_master:
|
||||
logger.debug(f"用户 {user.platform}:{user.user_id} 是Master用户")
|
||||
return is_master
|
||||
|
||||
|
||||
def check_permission(self, user: UserInfo, permission_node: str) -> bool:
|
||||
"""
|
||||
检查用户是否拥有指定权限节点
|
||||
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
permission_node: 权限节点名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否拥有权限
|
||||
"""
|
||||
@@ -78,46 +78,50 @@ class PermissionManager(IPermissionManager):
|
||||
if self.is_master(user):
|
||||
logger.debug(f"Master用户 {user.platform}:{user.user_id} 拥有权限节点 {permission_node}")
|
||||
return True
|
||||
|
||||
|
||||
with self.SessionLocal() as session:
|
||||
# 检查权限节点是否存在
|
||||
node = session.query(PermissionNodes).filter_by(node_name=permission_node).first()
|
||||
if not node:
|
||||
logger.warning(f"权限节点 {permission_node} 不存在")
|
||||
return False
|
||||
|
||||
|
||||
# 检查用户是否有明确的权限设置
|
||||
user_perm = session.query(UserPermissions).filter_by(
|
||||
platform=user.platform,
|
||||
user_id=user.user_id,
|
||||
permission_node=permission_node
|
||||
).first()
|
||||
|
||||
user_perm = (
|
||||
session.query(UserPermissions)
|
||||
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node)
|
||||
.first()
|
||||
)
|
||||
|
||||
if user_perm:
|
||||
# 有明确设置,返回设置的值
|
||||
result = user_perm.granted
|
||||
logger.debug(f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {result}")
|
||||
logger.debug(
|
||||
f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {result}"
|
||||
)
|
||||
return result
|
||||
else:
|
||||
# 没有明确设置,使用默认值
|
||||
result = node.default_granted
|
||||
logger.debug(f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 使用默认设置: {result}")
|
||||
logger.debug(
|
||||
f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 使用默认设置: {result}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"检查权限时数据库错误: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"检查权限时发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def register_permission_node(self, node: PermissionNode) -> bool:
|
||||
"""
|
||||
注册权限节点
|
||||
|
||||
|
||||
Args:
|
||||
node: 权限节点
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 注册是否成功
|
||||
"""
|
||||
@@ -133,20 +137,20 @@ class PermissionManager(IPermissionManager):
|
||||
session.commit()
|
||||
logger.debug(f"更新权限节点: {node.node_name}")
|
||||
return True
|
||||
|
||||
|
||||
# 创建新节点
|
||||
new_node = PermissionNodes(
|
||||
node_name=node.node_name,
|
||||
description=node.description,
|
||||
plugin_name=node.plugin_name,
|
||||
default_granted=node.default_granted,
|
||||
created_at=datetime.utcnow()
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
session.add(new_node)
|
||||
session.commit()
|
||||
logger.info(f"注册新权限节点: {node.node_name} (插件: {node.plugin_name})")
|
||||
return True
|
||||
|
||||
|
||||
except IntegrityError as e:
|
||||
logger.error(f"注册权限节点时发生完整性错误: {e}")
|
||||
return False
|
||||
@@ -156,15 +160,15 @@ class PermissionManager(IPermissionManager):
|
||||
except Exception as e:
|
||||
logger.error(f"注册权限节点时发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def grant_permission(self, user: UserInfo, permission_node: str) -> bool:
|
||||
"""
|
||||
授权用户权限节点
|
||||
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
permission_node: 权限节点名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 授权是否成功
|
||||
"""
|
||||
@@ -175,14 +179,14 @@ class PermissionManager(IPermissionManager):
|
||||
if not node:
|
||||
logger.error(f"尝试授权不存在的权限节点: {permission_node}")
|
||||
return False
|
||||
|
||||
|
||||
# 检查是否已有权限记录
|
||||
existing_perm = session.query(UserPermissions).filter_by(
|
||||
platform=user.platform,
|
||||
user_id=user.user_id,
|
||||
permission_node=permission_node
|
||||
).first()
|
||||
|
||||
existing_perm = (
|
||||
session.query(UserPermissions)
|
||||
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_perm:
|
||||
# 更新现有记录
|
||||
existing_perm.granted = True
|
||||
@@ -194,29 +198,29 @@ class PermissionManager(IPermissionManager):
|
||||
user_id=user.user_id,
|
||||
permission_node=permission_node,
|
||||
granted=True,
|
||||
granted_at=datetime.utcnow()
|
||||
granted_at=datetime.utcnow(),
|
||||
)
|
||||
session.add(new_perm)
|
||||
|
||||
|
||||
session.commit()
|
||||
logger.info(f"已授权用户 {user.platform}:{user.user_id} 权限节点 {permission_node}")
|
||||
return True
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"授权权限时数据库错误: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"授权权限时发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def revoke_permission(self, user: UserInfo, permission_node: str) -> bool:
|
||||
"""
|
||||
撤销用户权限节点
|
||||
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
permission_node: 权限节点名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 撤销是否成功
|
||||
"""
|
||||
@@ -227,14 +231,14 @@ class PermissionManager(IPermissionManager):
|
||||
if not node:
|
||||
logger.error(f"尝试撤销不存在的权限节点: {permission_node}")
|
||||
return False
|
||||
|
||||
|
||||
# 检查是否已有权限记录
|
||||
existing_perm = session.query(UserPermissions).filter_by(
|
||||
platform=user.platform,
|
||||
user_id=user.user_id,
|
||||
permission_node=permission_node
|
||||
).first()
|
||||
|
||||
existing_perm = (
|
||||
session.query(UserPermissions)
|
||||
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_perm:
|
||||
# 更新现有记录
|
||||
existing_perm.granted = False
|
||||
@@ -246,28 +250,28 @@ class PermissionManager(IPermissionManager):
|
||||
user_id=user.user_id,
|
||||
permission_node=permission_node,
|
||||
granted=False,
|
||||
granted_at=datetime.utcnow()
|
||||
granted_at=datetime.utcnow(),
|
||||
)
|
||||
session.add(new_perm)
|
||||
|
||||
|
||||
session.commit()
|
||||
logger.info(f"已撤销用户 {user.platform}:{user.user_id} 权限节点 {permission_node}")
|
||||
return True
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"撤销权限时数据库错误: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"撤销权限时发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_user_permissions(self, user: UserInfo) -> List[str]:
|
||||
"""
|
||||
获取用户拥有的所有权限节点
|
||||
|
||||
|
||||
Args:
|
||||
user: 用户信息
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: 权限节点列表
|
||||
"""
|
||||
@@ -277,21 +281,21 @@ class PermissionManager(IPermissionManager):
|
||||
with self.SessionLocal() as session:
|
||||
all_nodes = session.query(PermissionNodes.node_name).all()
|
||||
return [node.node_name for node in all_nodes]
|
||||
|
||||
|
||||
permissions = []
|
||||
|
||||
|
||||
with self.SessionLocal() as session:
|
||||
# 获取所有权限节点
|
||||
all_nodes = session.query(PermissionNodes).all()
|
||||
|
||||
|
||||
for node in all_nodes:
|
||||
# 检查用户是否有明确的权限设置
|
||||
user_perm = session.query(UserPermissions).filter_by(
|
||||
platform=user.platform,
|
||||
user_id=user.user_id,
|
||||
permission_node=node.node_name
|
||||
).first()
|
||||
|
||||
user_perm = (
|
||||
session.query(UserPermissions)
|
||||
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=node.node_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if user_perm:
|
||||
# 有明确设置,使用设置的值
|
||||
if user_perm.granted:
|
||||
@@ -300,20 +304,20 @@ class PermissionManager(IPermissionManager):
|
||||
# 没有明确设置,使用默认值
|
||||
if node.default_granted:
|
||||
permissions.append(node.node_name)
|
||||
|
||||
|
||||
return permissions
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"获取用户权限时数据库错误: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户权限时发生未知错误: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_all_permission_nodes(self) -> List[PermissionNode]:
|
||||
"""
|
||||
获取所有已注册的权限节点
|
||||
|
||||
|
||||
Returns:
|
||||
List[PermissionNode]: 权限节点列表
|
||||
"""
|
||||
@@ -325,25 +329,25 @@ class PermissionManager(IPermissionManager):
|
||||
node_name=node.node_name,
|
||||
description=node.description,
|
||||
plugin_name=node.plugin_name,
|
||||
default_granted=node.default_granted
|
||||
default_granted=node.default_granted,
|
||||
)
|
||||
for node in nodes
|
||||
]
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"获取所有权限节点时数据库错误: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"获取所有权限节点时发生未知错误: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]:
|
||||
"""
|
||||
获取指定插件的所有权限节点
|
||||
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
|
||||
Returns:
|
||||
List[PermissionNode]: 权限节点列表
|
||||
"""
|
||||
@@ -355,25 +359,25 @@ class PermissionManager(IPermissionManager):
|
||||
node_name=node.node_name,
|
||||
description=node.description,
|
||||
plugin_name=node.plugin_name,
|
||||
default_granted=node.default_granted
|
||||
default_granted=node.default_granted,
|
||||
)
|
||||
for node in nodes
|
||||
]
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"获取插件权限节点时数据库错误: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件权限节点时发生未知错误: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def delete_plugin_permissions(self, plugin_name: str) -> bool:
|
||||
"""
|
||||
删除指定插件的所有权限节点(用于插件卸载时清理)
|
||||
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
@@ -382,68 +386,71 @@ class PermissionManager(IPermissionManager):
|
||||
# 获取插件的所有权限节点
|
||||
plugin_nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).all()
|
||||
node_names = [node.node_name for node in plugin_nodes]
|
||||
|
||||
|
||||
if not node_names:
|
||||
logger.info(f"插件 {plugin_name} 没有注册任何权限节点")
|
||||
return True
|
||||
|
||||
|
||||
# 删除用户权限记录
|
||||
deleted_user_perms = session.query(UserPermissions).filter(
|
||||
UserPermissions.permission_node.in_(node_names)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
deleted_user_perms = (
|
||||
session.query(UserPermissions)
|
||||
.filter(UserPermissions.permission_node.in_(node_names))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
# 删除权限节点
|
||||
deleted_nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).delete()
|
||||
|
||||
|
||||
session.commit()
|
||||
logger.info(f"已删除插件 {plugin_name} 的 {deleted_nodes} 个权限节点和 {deleted_user_perms} 条用户权限记录")
|
||||
logger.info(
|
||||
f"已删除插件 {plugin_name} 的 {deleted_nodes} 个权限节点和 {deleted_user_perms} 条用户权限记录"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"删除插件权限时数据库错误: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"删除插件权限时发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_users_with_permission(self, permission_node: str) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
获取拥有指定权限的所有用户
|
||||
|
||||
|
||||
Args:
|
||||
permission_node: 权限节点名称
|
||||
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: 用户列表,格式为 [(platform, user_id), ...]
|
||||
"""
|
||||
try:
|
||||
users = []
|
||||
|
||||
|
||||
with self.SessionLocal() as session:
|
||||
# 检查权限节点是否存在
|
||||
node = session.query(PermissionNodes).filter_by(node_name=permission_node).first()
|
||||
if not node:
|
||||
logger.warning(f"权限节点 {permission_node} 不存在")
|
||||
return users
|
||||
|
||||
|
||||
# 获取明确授权的用户
|
||||
granted_users = session.query(UserPermissions).filter_by(
|
||||
permission_node=permission_node,
|
||||
granted=True
|
||||
).all()
|
||||
|
||||
granted_users = (
|
||||
session.query(UserPermissions).filter_by(permission_node=permission_node, granted=True).all()
|
||||
)
|
||||
|
||||
for user_perm in granted_users:
|
||||
users.append((user_perm.platform, user_perm.user_id))
|
||||
|
||||
|
||||
# 如果是默认授权的权限节点,还需要考虑没有明确设置的用户
|
||||
# 但这里我们只返回明确授权的用户,避免返回所有用户
|
||||
|
||||
|
||||
# 添加Master用户(他们拥有所有权限)
|
||||
users.extend(list(self._master_users))
|
||||
|
||||
|
||||
# 去重
|
||||
return list(set(users))
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"获取拥有权限的用户时数据库错误: {e}")
|
||||
return []
|
||||
|
||||
@@ -36,21 +36,21 @@ class PluginFileHandler(FileSystemEventHandler):
|
||||
"""文件修改事件"""
|
||||
if not event.is_directory:
|
||||
file_path = str(event.src_path)
|
||||
if file_path.endswith(('.py', '.toml')):
|
||||
if file_path.endswith((".py", ".toml")):
|
||||
self._handle_file_change(file_path, "modified")
|
||||
|
||||
def on_created(self, event):
|
||||
"""文件创建事件"""
|
||||
if not event.is_directory:
|
||||
file_path = str(event.src_path)
|
||||
if file_path.endswith(('.py', '.toml')):
|
||||
if file_path.endswith((".py", ".toml")):
|
||||
self._handle_file_change(file_path, "created")
|
||||
|
||||
def on_deleted(self, event):
|
||||
"""文件删除事件"""
|
||||
if not event.is_directory:
|
||||
file_path = str(event.src_path)
|
||||
if file_path.endswith(('.py', '.toml')):
|
||||
if file_path.endswith((".py", ".toml")):
|
||||
self._handle_file_change(file_path, "deleted")
|
||||
|
||||
def _handle_file_change(self, file_path: str, change_type: str):
|
||||
@@ -63,14 +63,14 @@ class PluginFileHandler(FileSystemEventHandler):
|
||||
|
||||
plugin_name, source_type = plugin_info
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
# 文件变化缓存,避免重复处理同一文件的快速连续变化
|
||||
file_cache_key = f"{file_path}_{change_type}"
|
||||
last_file_time = self.file_change_cache.get(file_cache_key, 0)
|
||||
if current_time - last_file_time < 0.5: # 0.5秒内的重复文件变化忽略
|
||||
return
|
||||
self.file_change_cache[file_cache_key] = current_time
|
||||
|
||||
|
||||
# 插件级别的防抖处理
|
||||
last_plugin_time = self.last_reload_time.get(plugin_name, 0)
|
||||
if current_time - last_plugin_time < self.debounce_delay:
|
||||
@@ -85,20 +85,28 @@ class PluginFileHandler(FileSystemEventHandler):
|
||||
if change_type == "deleted":
|
||||
# 解析实际的插件名称
|
||||
actual_plugin_name = self.hot_reload_manager._resolve_plugin_name(plugin_name)
|
||||
|
||||
|
||||
if file_name == "plugin.py":
|
||||
if actual_plugin_name in plugin_manager.loaded_plugins:
|
||||
logger.info(f"🗑️ 插件主文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]")
|
||||
logger.info(
|
||||
f"🗑️ 插件主文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]"
|
||||
)
|
||||
self.hot_reload_manager._unload_plugin(actual_plugin_name)
|
||||
else:
|
||||
logger.info(f"🗑️ 插件主文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]")
|
||||
logger.info(
|
||||
f"🗑️ 插件主文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]"
|
||||
)
|
||||
return
|
||||
elif file_name in ("manifest.toml", "_manifest.json"):
|
||||
if actual_plugin_name in plugin_manager.loaded_plugins:
|
||||
logger.info(f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]")
|
||||
logger.info(
|
||||
f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]"
|
||||
)
|
||||
self.hot_reload_manager._unload_plugin(actual_plugin_name)
|
||||
else:
|
||||
logger.info(f"🗑️ 插件配置文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]")
|
||||
logger.info(
|
||||
f"🗑️ 插件配置文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]"
|
||||
)
|
||||
return
|
||||
|
||||
# 对于修改和创建事件,都进行重载
|
||||
@@ -108,9 +116,7 @@ class PluginFileHandler(FileSystemEventHandler):
|
||||
|
||||
# 延迟重载,确保文件写入完成
|
||||
reload_thread = Thread(
|
||||
target=self._delayed_reload,
|
||||
args=(plugin_name, source_type, current_time),
|
||||
daemon=True
|
||||
target=self._delayed_reload, args=(plugin_name, source_type, current_time), daemon=True
|
||||
)
|
||||
reload_thread.start()
|
||||
|
||||
@@ -126,14 +132,14 @@ class PluginFileHandler(FileSystemEventHandler):
|
||||
# 检查是否还需要重载(可能在等待期间有更新的变化)
|
||||
if plugin_name not in self.pending_reloads:
|
||||
return
|
||||
|
||||
|
||||
# 检查是否有更新的重载请求
|
||||
if self.last_reload_time.get(plugin_name, 0) > trigger_time:
|
||||
return
|
||||
|
||||
self.pending_reloads.discard(plugin_name)
|
||||
logger.info(f"🔄 开始延迟重载插件: {plugin_name} [{source_type}]")
|
||||
|
||||
|
||||
# 执行深度重载
|
||||
success = self.hot_reload_manager._deep_reload_plugin(plugin_name)
|
||||
if success:
|
||||
@@ -146,7 +152,7 @@ class PluginFileHandler(FileSystemEventHandler):
|
||||
|
||||
def _get_plugin_info_from_path(self, file_path: str) -> Optional[Tuple[str, str]]:
|
||||
"""从文件路径获取插件信息
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[插件名称, 源类型] 或 None
|
||||
"""
|
||||
@@ -162,12 +168,12 @@ class PluginFileHandler(FileSystemEventHandler):
|
||||
source_type = "built-in"
|
||||
else:
|
||||
source_type = "external"
|
||||
|
||||
|
||||
# 获取插件目录名(插件名)
|
||||
relative_path = path.relative_to(plugin_root)
|
||||
if len(relative_path.parts) == 0:
|
||||
continue
|
||||
|
||||
|
||||
plugin_name = relative_path.parts[0]
|
||||
|
||||
# 确认这是一个有效的插件目录
|
||||
@@ -175,9 +181,10 @@ class PluginFileHandler(FileSystemEventHandler):
|
||||
if plugin_dir.is_dir():
|
||||
# 检查是否有插件主文件或配置文件
|
||||
has_plugin_py = (plugin_dir / "plugin.py").exists()
|
||||
has_manifest = ((plugin_dir / "manifest.toml").exists() or
|
||||
(plugin_dir / "_manifest.json").exists())
|
||||
|
||||
has_manifest = (plugin_dir / "manifest.toml").exists() or (
|
||||
plugin_dir / "_manifest.json"
|
||||
).exists()
|
||||
|
||||
if has_plugin_py or has_manifest:
|
||||
return plugin_name, source_type
|
||||
|
||||
@@ -195,11 +202,11 @@ class PluginHotReloadManager:
|
||||
# 默认监听两个目录:根目录下的 plugins 和 src 下的插件目录
|
||||
self.watch_directories = [
|
||||
os.path.join(os.getcwd(), "plugins"), # 外部插件目录
|
||||
os.path.join(os.getcwd(), "src", "plugins", "built_in") # 内置插件目录
|
||||
os.path.join(os.getcwd(), "src", "plugins", "built_in"), # 内置插件目录
|
||||
]
|
||||
else:
|
||||
self.watch_directories = watch_directories
|
||||
|
||||
|
||||
self.observers = []
|
||||
self.file_handlers = []
|
||||
self.is_running = False
|
||||
@@ -221,13 +228,9 @@ class PluginHotReloadManager:
|
||||
for watch_dir in self.watch_directories:
|
||||
observer = Observer()
|
||||
file_handler = PluginFileHandler(self)
|
||||
|
||||
observer.schedule(
|
||||
file_handler,
|
||||
watch_dir,
|
||||
recursive=True
|
||||
)
|
||||
|
||||
|
||||
observer.schedule(file_handler, watch_dir, recursive=True)
|
||||
|
||||
observer.start()
|
||||
self.observers.append(observer)
|
||||
self.file_handlers.append(file_handler)
|
||||
@@ -296,26 +299,26 @@ class PluginHotReloadManager:
|
||||
if folder_name in plugin_manager.plugin_classes:
|
||||
logger.debug(f"🔍 直接匹配插件名: {folder_name}")
|
||||
return folder_name
|
||||
|
||||
|
||||
# 如果没有直接匹配,搜索路径映射,并优先返回在插件类中存在的名称
|
||||
matched_plugins = []
|
||||
for plugin_name, plugin_path in plugin_manager.plugin_paths.items():
|
||||
# 检查路径是否包含该文件夹名
|
||||
if folder_name in plugin_path:
|
||||
matched_plugins.append((plugin_name, plugin_path))
|
||||
|
||||
|
||||
# 在匹配的插件中,优先选择在插件类中存在的
|
||||
for plugin_name, plugin_path in matched_plugins:
|
||||
if plugin_name in plugin_manager.plugin_classes:
|
||||
logger.debug(f"🔍 文件夹名 '{folder_name}' 映射到插件名 '{plugin_name}' (路径: {plugin_path})")
|
||||
return plugin_name
|
||||
|
||||
|
||||
# 如果还是没找到在插件类中存在的,返回第一个匹配项
|
||||
if matched_plugins:
|
||||
plugin_name, plugin_path = matched_plugins[0]
|
||||
logger.warning(f"⚠️ 文件夹 '{folder_name}' 映射到 '{plugin_name}',但该插件类不存在")
|
||||
return plugin_name
|
||||
|
||||
|
||||
# 如果还是没找到,返回原文件夹名
|
||||
logger.warning(f"⚠️ 无法找到文件夹 '{folder_name}' 对应的插件名,使用原名称")
|
||||
return folder_name
|
||||
@@ -326,13 +329,13 @@ class PluginHotReloadManager:
|
||||
# 解析实际的插件名称
|
||||
actual_plugin_name = self._resolve_plugin_name(plugin_name)
|
||||
logger.info(f"🔄 开始深度重载插件: {plugin_name} -> {actual_plugin_name}")
|
||||
|
||||
|
||||
# 强制清理相关模块缓存
|
||||
self._force_clear_plugin_modules(plugin_name)
|
||||
|
||||
|
||||
# 使用插件管理器的强制重载功能
|
||||
success = plugin_manager.force_reload_plugin(actual_plugin_name)
|
||||
|
||||
|
||||
if success:
|
||||
logger.info(f"✅ 插件深度重载成功: {actual_plugin_name}")
|
||||
return True
|
||||
@@ -348,15 +351,15 @@ class PluginHotReloadManager:
|
||||
|
||||
def _force_clear_plugin_modules(self, plugin_name: str):
|
||||
"""强制清理插件相关的模块缓存"""
|
||||
|
||||
|
||||
# 找到所有相关的模块名
|
||||
modules_to_remove = []
|
||||
plugin_module_prefix = f"src.plugins.built_in.{plugin_name}"
|
||||
|
||||
|
||||
for module_name in list(sys.modules.keys()):
|
||||
if plugin_module_prefix in module_name:
|
||||
modules_to_remove.append(module_name)
|
||||
|
||||
|
||||
# 删除模块缓存
|
||||
for module_name in modules_to_remove:
|
||||
if module_name in sys.modules:
|
||||
@@ -369,7 +372,7 @@ class PluginHotReloadManager:
|
||||
# 使用插件管理器的重载功能
|
||||
success = plugin_manager.reload_plugin(plugin_name)
|
||||
return success
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 强制重新导入插件 {plugin_name} 时发生错误: {e}", exc_info=True)
|
||||
return False
|
||||
@@ -378,7 +381,7 @@ class PluginHotReloadManager:
|
||||
"""卸载指定插件"""
|
||||
try:
|
||||
logger.info(f"🗑️ 开始卸载插件: {plugin_name}")
|
||||
|
||||
|
||||
if plugin_manager.unload_plugin(plugin_name):
|
||||
logger.info(f"✅ 插件卸载成功: {plugin_name}")
|
||||
return True
|
||||
@@ -409,7 +412,7 @@ class PluginHotReloadManager:
|
||||
fail_count += 1
|
||||
|
||||
logger.info(f"✅ 插件重载完成: 成功 {success_count} 个,失败 {fail_count} 个")
|
||||
|
||||
|
||||
# 清理全局缓存
|
||||
importlib.invalidate_caches()
|
||||
|
||||
@@ -420,21 +423,21 @@ class PluginHotReloadManager:
|
||||
"""手动强制重载指定插件(委托给插件管理器)"""
|
||||
try:
|
||||
logger.info(f"🔄 手动强制重载插件: {plugin_name}")
|
||||
|
||||
|
||||
# 清理待重载列表中的该插件(避免重复重载)
|
||||
for handler in self.file_handlers:
|
||||
handler.pending_reloads.discard(plugin_name)
|
||||
|
||||
|
||||
# 使用插件管理器的强制重载功能
|
||||
success = plugin_manager.force_reload_plugin(plugin_name)
|
||||
|
||||
|
||||
if success:
|
||||
logger.info(f"✅ 手动强制重载成功: {plugin_name}")
|
||||
else:
|
||||
logger.error(f"❌ 手动强制重载失败: {plugin_name}")
|
||||
|
||||
|
||||
return success
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 手动强制重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
|
||||
return False
|
||||
@@ -457,19 +460,15 @@ class PluginHotReloadManager:
|
||||
try:
|
||||
observer = Observer()
|
||||
file_handler = PluginFileHandler(self)
|
||||
|
||||
observer.schedule(
|
||||
file_handler,
|
||||
directory,
|
||||
recursive=True
|
||||
)
|
||||
|
||||
|
||||
observer.schedule(file_handler, directory, recursive=True)
|
||||
|
||||
observer.start()
|
||||
self.observers.append(observer)
|
||||
self.file_handlers.append(file_handler)
|
||||
|
||||
|
||||
logger.info(f"📂 已添加新的监听目录: {directory}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 添加监听目录 {directory} 失败: {e}")
|
||||
self.watch_directories.remove(directory)
|
||||
@@ -480,7 +479,7 @@ class PluginHotReloadManager:
|
||||
if self.file_handlers:
|
||||
for handler in self.file_handlers:
|
||||
pending_reloads.update(handler.pending_reloads)
|
||||
|
||||
|
||||
return {
|
||||
"is_running": self.is_running,
|
||||
"watch_directories": self.watch_directories,
|
||||
@@ -495,11 +494,11 @@ class PluginHotReloadManager:
|
||||
"""清理所有Python模块缓存"""
|
||||
try:
|
||||
logger.info("🧹 开始清理所有Python模块缓存...")
|
||||
|
||||
|
||||
# 重新扫描所有插件目录,这会重新加载模块
|
||||
plugin_manager.rescan_plugin_directory()
|
||||
logger.info("✅ 模块缓存清理完成")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 清理模块缓存时发生错误: {e}", exc_info=True)
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ class PluginManager:
|
||||
return False # 目标文件不存在,视为不同
|
||||
|
||||
# 使用 'rb' 模式以二进制方式读取文件,确保哈希值计算的一致性
|
||||
with open(file1, 'rb') as f1, open(file2, 'rb') as f2:
|
||||
with open(file1, "rb") as f1, open(file2, "rb") as f2:
|
||||
return hashlib.md5(f1.read()).hexdigest() == hashlib.md5(f2.read()).hexdigest()
|
||||
|
||||
# === 插件目录管理 ===
|
||||
@@ -300,7 +300,7 @@ class PluginManager:
|
||||
list: 已注册的插件类名称列表。
|
||||
"""
|
||||
return list(self.plugin_classes.keys())
|
||||
|
||||
|
||||
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
|
||||
"""
|
||||
获取指定插件的路径。
|
||||
@@ -366,7 +366,7 @@ class PluginManager:
|
||||
# 生成模块名和插件信息
|
||||
plugin_path = Path(plugin_file)
|
||||
plugin_dir = plugin_path.parent # 插件目录
|
||||
plugin_name = plugin_dir.name # 插件名称
|
||||
plugin_name = plugin_dir.name # 插件名称
|
||||
module_name = ".".join(plugin_path.parent.parts)
|
||||
|
||||
try:
|
||||
@@ -386,7 +386,7 @@ class PluginManager:
|
||||
except Exception as e:
|
||||
error_msg = f"加载插件模块 {plugin_file} 失败: {e}"
|
||||
logger.error(error_msg)
|
||||
self.failed_plugins[plugin_name if 'plugin_name' in locals() else module_name] = error_msg
|
||||
self.failed_plugins[plugin_name if "plugin_name" in locals() else module_name] = error_msg
|
||||
return False
|
||||
|
||||
# == 兼容性检查 ==
|
||||
@@ -478,9 +478,7 @@ class PluginManager:
|
||||
command_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
|
||||
]
|
||||
tool_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.TOOL
|
||||
]
|
||||
tool_components = [c for c in plugin_info.components if c.component_type == ComponentType.TOOL]
|
||||
event_handler_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
|
||||
]
|
||||
@@ -591,7 +589,7 @@ class PluginManager:
|
||||
plugin_instance = self.loaded_plugins[plugin_name]
|
||||
|
||||
# 调用插件的清理方法(如果有的话)
|
||||
if hasattr(plugin_instance, 'on_unload'):
|
||||
if hasattr(plugin_instance, "on_unload"):
|
||||
plugin_instance.on_unload()
|
||||
|
||||
# 从组件注册表中移除插件的所有组件
|
||||
@@ -654,10 +652,10 @@ class PluginManager:
|
||||
|
||||
def force_reload_plugin(self, plugin_name: str) -> bool:
|
||||
"""强制重载插件(使用简化的方法)
|
||||
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 重载是否成功
|
||||
"""
|
||||
|
||||
@@ -129,17 +129,17 @@ class ToolExecutor:
|
||||
if not tool_calls:
|
||||
logger.debug(f"{self.log_prefix}无需执行工具")
|
||||
return [], []
|
||||
|
||||
|
||||
# 提取tool_calls中的函数名称
|
||||
func_names = []
|
||||
for call in tool_calls:
|
||||
try:
|
||||
if hasattr(call, 'func_name'):
|
||||
if hasattr(call, "func_name"):
|
||||
func_names.append(call.func_name)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}获取工具名称失败: {e}")
|
||||
continue
|
||||
|
||||
|
||||
if func_names:
|
||||
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
|
||||
else:
|
||||
@@ -185,9 +185,11 @@ class ToolExecutor:
|
||||
|
||||
return tool_results, used_tools
|
||||
|
||||
async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
|
||||
async def execute_tool_call(
|
||||
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""执行单个工具调用,并处理缓存"""
|
||||
|
||||
|
||||
function_args = tool_call.args or {}
|
||||
tool_instance = tool_instance or get_tool_instance(tool_call.func_name)
|
||||
|
||||
@@ -206,7 +208,7 @@ class ToolExecutor:
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=function_args,
|
||||
tool_file_path=tool_file_path,
|
||||
semantic_query=semantic_query
|
||||
semantic_query=semantic_query,
|
||||
)
|
||||
if cached_result:
|
||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
|
||||
@@ -223,14 +225,14 @@ class ToolExecutor:
|
||||
semantic_query = None
|
||||
if tool_instance.semantic_cache_query_key:
|
||||
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
|
||||
|
||||
|
||||
await tool_cache.set(
|
||||
tool_name=tool_call.func_name,
|
||||
function_args=function_args,
|
||||
tool_file_path=tool_file_path,
|
||||
data=result,
|
||||
ttl=tool_instance.cache_ttl,
|
||||
semantic_query=semantic_query
|
||||
semantic_query=semantic_query,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
|
||||
@@ -238,12 +240,16 @@ class ToolExecutor:
|
||||
|
||||
return result
|
||||
|
||||
async def _original_execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
|
||||
async def _original_execute_tool_call(
|
||||
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""执行单个工具调用的原始逻辑"""
|
||||
try:
|
||||
function_name = tool_call.func_name
|
||||
function_args = tool_call.args or {}
|
||||
logger.info(f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}"
|
||||
)
|
||||
function_args["llm_called"] = True # 标记为LLM调用
|
||||
# 获取对应工具实例
|
||||
tool_instance = tool_instance or get_tool_instance(function_name)
|
||||
@@ -261,7 +267,7 @@ class ToolExecutor:
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"type": "function",
|
||||
"content": result.get("content", "")
|
||||
"content": result.get("content", ""),
|
||||
}
|
||||
logger.warning(f"{self.log_prefix}工具 {function_name} 返回空结果")
|
||||
return None
|
||||
@@ -308,7 +314,6 @@ class ToolExecutor:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
"""
|
||||
ToolExecutor使用示例:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user