修复代码格式和文件名大小写问题

This commit is contained in:
Windpicker-owo
2025-08-31 20:50:17 +08:00
parent df29014e41
commit 8149731925
218 changed files with 6913 additions and 8257 deletions

View File

@@ -34,7 +34,9 @@ class ComponentRegistry:
"""组件注册表 命名空间式组件名 -> 组件信息"""
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
"""类型 -> 组件原名称 -> 组件信息"""
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler, PlusCommand]]] = {}
self._components_classes: Dict[
str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler, PlusCommand]]
] = {}
"""命名空间式组件名 -> 组件类"""
# 插件注册表
@@ -166,7 +168,7 @@ class ComponentRegistry:
if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction):
logger.error(f"注册失败: {action_name} 不是有效的Action")
return False
action_class.plugin_name = action_info.plugin_name
self._action_registry[action_name] = action_class
@@ -200,7 +202,9 @@ class ComponentRegistry:
return True
def _register_plus_command_component(self, plus_command_info: PlusCommandInfo, plus_command_class: Type[PlusCommand]) -> bool:
def _register_plus_command_component(
self, plus_command_info: PlusCommandInfo, plus_command_class: Type[PlusCommand]
) -> bool:
"""注册PlusCommand组件到特定注册表"""
plus_command_name = plus_command_info.name
@@ -212,7 +216,7 @@ class ComponentRegistry:
return False
# 创建专门的PlusCommand注册表如果还没有
if not hasattr(self, '_plus_command_registry'):
if not hasattr(self, "_plus_command_registry"):
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
plus_command_class.plugin_name = plus_command_info.plugin_name
@@ -249,10 +253,11 @@ class ComponentRegistry:
if not handler_info.enabled:
logger.warning(f"EventHandler组件 {handler_name} 未启用")
return True # 未启用,但是也是注册成功
handler_class.plugin_name = handler_info.plugin_name
# 使用EventManager进行事件处理器注册
from src.plugin_system.core.event_manager import event_manager
return event_manager.register_event_handler(handler_class)
# === 组件移除相关 ===
@@ -281,7 +286,7 @@ class ComponentRegistry:
case ComponentType.PLUS_COMMAND:
# 移除PlusCommand注册
if hasattr(self, '_plus_command_registry'):
if hasattr(self, "_plus_command_registry"):
self._plus_command_registry.pop(component_name, None)
logger.debug(f"已移除PlusCommand组件: {component_name}")
@@ -371,6 +376,7 @@ class ComponentRegistry:
assert issubclass(target_component_class, BaseEventHandler)
self._enabled_event_handlers[component_name] = target_component_class
from .event_manager import event_manager # 延迟导入防止循环导入问题
event_manager.register_event_handler(component_name)
namespaced_name = f"{component_type}.{component_name}"
@@ -572,7 +578,7 @@ class ComponentRegistry:
candidates[0].match(text).groupdict(), # type: ignore
command_info,
)
return None
# === Tool 特定查询方法 ===
@@ -599,7 +605,7 @@ class ComponentRegistry:
# === PlusCommand 特定查询方法 ===
def get_plus_command_registry(self) -> Dict[str, Type[PlusCommand]]:
"""获取PlusCommand注册表"""
if not hasattr(self, '_plus_command_registry'):
if not hasattr(self, "_plus_command_registry"):
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
return self._plus_command_registry.copy()

View File

@@ -2,55 +2,57 @@
事件管理器 - 实现Event和EventHandler的单例管理
提供统一的事件注册、管理和触发接口
"""
from typing import Dict, Type, List, Optional, Any, Union
from threading import Lock
from src.common.logger import get_logger
from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection, HandlerResult
from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection
from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.component_types import EventType
logger = get_logger("event_manager")
class EventManager:
"""事件管理器单例类
负责管理所有事件和事件处理器的注册、订阅、触发等操作
使用单例模式确保全局只有一个事件管理实例
"""
_instance: Optional['EventManager'] = None
_instance: Optional["EventManager"] = None
_lock = Lock()
def __new__(cls) -> 'EventManager':
def __new__(cls) -> "EventManager":
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._events: Dict[str, BaseEvent] = {}
self._event_handlers: Dict[str, Type[BaseEventHandler]] = {}
self._pending_subscriptions: Dict[str, List[str]] = {} # 缓存失败的订阅
self._initialized = True
logger.info("EventManager 单例初始化完成")
def register_event(
self,
event_name: Union[EventType, str],
allowed_subscribers: List[str]=None,
allowed_triggers: List[str]=None
) -> bool:
self,
event_name: Union[EventType, str],
allowed_subscribers: List[str] = None,
allowed_triggers: List[str] = None,
) -> bool:
"""注册一个新的事件
Args:
event_name Union[EventType, str]: 事件名称
allowed_subscribers: List[str]: 事件订阅者白名单,
allowed_subscribers: List[str]: 事件订阅者白名单,
allowed_triggers: List[str]: 事件触发插件白名单
Returns:
bool: 注册成功返回True已存在返回False
@@ -62,57 +64,57 @@ class EventManager:
if event_name in self._events:
logger.warning(f"事件 {event_name} 已存在,跳过注册")
return False
event = BaseEvent(event_name,allowed_subscribers,allowed_triggers)
event = BaseEvent(event_name, allowed_subscribers, allowed_triggers)
self._events[event_name] = event
logger.info(f"事件 {event_name} 注册成功")
# 检查是否有缓存的订阅需要处理
self._process_pending_subscriptions(event_name)
return True
def get_event(self, event_name: Union[EventType, str]) -> Optional[BaseEvent]:
"""获取指定事件实例
Args:
event_name Union[EventType, str]: 事件名称
Returns:
BaseEvent: 事件实例不存在返回None
"""
return self._events.get(event_name)
def get_all_events(self) -> Dict[str, BaseEvent]:
"""获取所有已注册的事件
Returns:
Dict[str, BaseEvent]: 所有事件的字典
"""
return self._events.copy()
def get_enabled_events(self) -> Dict[str, BaseEvent]:
"""获取所有已启用的事件
Returns:
Dict[str, BaseEvent]: 已启用事件的字典
"""
return {name: event for name, event in self._events.items() if event.enabled}
def get_disabled_events(self) -> Dict[str, BaseEvent]:
"""获取所有已禁用的事件
Returns:
Dict[str, BaseEvent]: 已禁用事件的字典
"""
return {name: event for name, event in self._events.items() if not event.enabled}
def enable_event(self, event_name: Union[EventType, str]) -> bool:
"""启用指定事件
Args:
event_name Union[EventType, str]: 事件名称
Returns:
bool: 成功返回True事件不存在返回False
"""
@@ -120,17 +122,17 @@ class EventManager:
if event is None:
logger.error(f"事件 {event_name} 不存在,无法启用")
return False
event.enabled = True
logger.info(f"事件 {event_name} 已启用")
return True
def disable_event(self, event_name: Union[EventType, str]) -> bool:
"""禁用指定事件
Args:
event_name Union[EventType, str]: 事件名称
Returns:
bool: 成功返回True事件不存在返回False
"""
@@ -138,38 +140,38 @@ class EventManager:
if event is None:
logger.error(f"事件 {event_name} 不存在,无法禁用")
return False
event.enabled = False
logger.info(f"事件 {event_name} 已禁用")
return True
def register_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
"""注册事件处理器
Args:
handler_class (Type[BaseEventHandler]): 事件处理器类
Returns:
bool: 注册成功返回True已存在返回False
"""
handler_name = handler_class.handler_name or handler_class.__name__.lower().replace("handler", "")
if EventType.UNKNOWN in handler_class.init_subscribe:
logger.error(f"事件处理器 {handler_name} 不能订阅 UNKNOWN 事件")
return False
if handler_name in self._event_handlers:
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
return False
self._event_handlers[handler_name] = handler_class()
# 处理init_subscribe缓存失败的订阅
if self._event_handlers[handler_name].init_subscribe:
failed_subscriptions = []
for event_name in self._event_handlers[handler_name].init_subscribe:
if not self.subscribe_handler_to_event(handler_name, event_name):
failed_subscriptions.append(event_name)
# 缓存失败的订阅
if failed_subscriptions:
self._pending_subscriptions[handler_name] = failed_subscriptions
@@ -177,33 +179,33 @@ class EventManager:
logger.info(f"事件处理器 {handler_name} 注册成功")
return True
def get_event_handler(self, handler_name: str) -> Optional[Type[BaseEventHandler]]:
"""获取指定事件处理器实例
Args:
handler_name (str): 处理器名称
Returns:
Type[BaseEventHandler]: 处理器实例不存在返回None
"""
return self._event_handlers.get(handler_name)
def get_all_event_handlers(self) -> Dict[str, BaseEventHandler]:
"""获取所有已注册的事件处理器
Returns:
Dict[str, Type[BaseEventHandler]]: 所有处理器的字典
"""
return self._event_handlers.copy()
def subscribe_handler_to_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool:
"""订阅事件处理器到指定事件
Args:
handler_name (str): 处理器名称
event_name Union[EventType, str]: 事件名称
Returns:
bool: 订阅成功返回True
"""
@@ -211,36 +213,36 @@ class EventManager:
if handler_instance is None:
logger.error(f"事件处理器 {handler_name} 不存在,无法订阅到事件 {event_name}")
return False
event = self.get_event(event_name)
if event is None:
logger.error(f"事件 {event_name} 不存在,无法订阅事件处理器 {handler_name}")
return False
if handler_instance in event.subscribers:
logger.warning(f"事件处理器 {handler_name} 已经订阅了事件 {event_name},跳过重复订阅")
return True
# 白名单检查
if event.allowed_subscribers and handler_name not in event.allowed_subscribers:
logger.warning(f"事件处理器 {handler_name} 不在事件 {event_name} 的订阅者白名单中,无法订阅")
return False
event.subscribers.append(handler_instance)
# 按权重从高到低排序订阅者
event.subscribers.sort(key=lambda h: getattr(h, 'weight', 0), reverse=True)
event.subscribers.sort(key=lambda h: getattr(h, "weight", 0), reverse=True)
logger.info(f"事件处理器 {handler_name} 成功订阅到事件 {event_name},当前权重排序完成")
return True
def unsubscribe_handler_from_event(self, handler_name: str, event_name: Union[EventType, str]) -> bool:
"""从指定事件取消订阅事件处理器
Args:
handler_name (str): 处理器名称
event_name Union[EventType, str]: 事件名称
Returns:
bool: 取消订阅成功返回True
"""
@@ -248,55 +250,57 @@ class EventManager:
if event is None:
logger.error(f"事件 {event_name} 不存在,无法取消订阅")
return False
# 查找并移除处理器实例
removed = False
for subscriber in event.subscribers[:]:
if hasattr(subscriber, 'handler_name') and subscriber.handler_name == handler_name:
if hasattr(subscriber, "handler_name") and subscriber.handler_name == handler_name:
event.subscribers.remove(subscriber)
removed = True
break
if removed:
logger.info(f"事件处理器 {handler_name} 成功从事件 {event_name} 取消订阅")
else:
logger.warning(f"事件处理器 {handler_name} 未订阅事件 {event_name}")
return removed
def get_event_subscribers(self, event_name: Union[EventType, str]) -> Dict[str, BaseEventHandler]:
"""获取订阅指定事件的所有事件处理器
Args:
event_name Union[EventType, str]: 事件名称
Returns:
Dict[str, BaseEventHandler]: 处理器字典,键为处理器名称,值为处理器实例
"""
event = self.get_event(event_name)
if event is None:
return {}
return {handler.handler_name: handler for handler in event.subscribers}
async def trigger_event(self, event_name: Union[EventType, str], plugin_name: Optional[str]="", **kwargs) -> Optional[HandlerResultsCollection]:
async def trigger_event(
self, event_name: Union[EventType, str], plugin_name: Optional[str] = "", **kwargs
) -> Optional[HandlerResultsCollection]:
"""触发指定事件
Args:
event_name Union[EventType, str]: 事件名称
plugin_name str: 触发事件的插件名
**kwargs: 传递给处理器的参数
Returns:
HandlerResultsCollection: 所有处理器的执行结果事件不存在返回None
"""
params = kwargs or {}
event = self.get_event(event_name)
if event is None:
logger.error(f"事件 {event_name} 不存在,无法触发")
return None
# 插件白名单检查
if event.allowed_triggers and not plugin_name:
logger.warning(f"事件 {event_name} 存在触发者白名单缺少plugin_name无法验证权限已拒绝触发")
@@ -304,9 +308,9 @@ class EventManager:
elif event.allowed_triggers and plugin_name not in event.allowed_triggers:
logger.warning(f"插件 {plugin_name} 没有权限触发事件 {event_name},已拒绝触发!")
return None
return await event.activate(params)
def init_default_events(self) -> None:
"""初始化默认事件"""
default_events = [
@@ -317,29 +321,29 @@ class EventManager:
EventType.POST_LLM,
EventType.AFTER_LLM,
EventType.POST_SEND,
EventType.AFTER_SEND
EventType.AFTER_SEND,
]
for event_name in default_events:
self.register_event(event_name,allowed_triggers=["SYSTEM"])
self.register_event(event_name, allowed_triggers=["SYSTEM"])
logger.info("默认事件初始化完成")
def clear_all_events(self) -> None:
"""清除所有事件和处理器(主要用于测试)"""
self._events.clear()
self._event_handlers.clear()
logger.info("所有事件和处理器已清除")
def get_event_summary(self) -> Dict[str, Any]:
"""获取事件系统摘要
Returns:
Dict[str, Any]: 包含事件系统统计信息的字典
"""
enabled_events = self.get_enabled_events()
disabled_events = self.get_disabled_events()
return {
"total_events": len(self._events),
"enabled_events": len(enabled_events),
@@ -347,58 +351,58 @@ class EventManager:
"total_handlers": len(self._event_handlers),
"event_names": list(self._events.keys()),
"handler_names": list(self._event_handlers.keys()),
"pending_subscriptions": len(self._pending_subscriptions)
"pending_subscriptions": len(self._pending_subscriptions),
}
def _process_pending_subscriptions(self, event_name: Union[EventType, str]) -> None:
"""处理指定事件的缓存订阅
Args:
event_name Union[EventType, str]: 事件名称
"""
handlers_to_remove = []
for handler_name, pending_events in self._pending_subscriptions.items():
if event_name in pending_events:
if self.subscribe_handler_to_event(handler_name, event_name):
pending_events.remove(event_name)
logger.info(f"成功处理缓存订阅: {handler_name} -> {event_name}")
# 如果该处理器没有更多待处理订阅,标记为移除
if not pending_events:
handlers_to_remove.append(handler_name)
# 清理已完成的处理器缓存
for handler_name in handlers_to_remove:
del self._pending_subscriptions[handler_name]
def process_all_pending_subscriptions(self) -> int:
"""处理所有缓存的订阅
Returns:
int: 成功处理的订阅数量
"""
processed_count = 0
# 复制待处理订阅,避免在迭代时修改字典
pending_copy = dict(self._pending_subscriptions)
for handler_name, pending_events in pending_copy.items():
for event_name in pending_events[:]: # 使用切片避免修改列表
if self.subscribe_handler_to_event(handler_name, event_name):
pending_events.remove(event_name)
processed_count += 1
# 清理已完成的处理器缓存
handlers_to_remove = [name for name, events in self._pending_subscriptions.items() if not events]
for handler_name in handlers_to_remove:
del self._pending_subscriptions[handler_name]
if processed_count > 0:
logger.info(f"批量处理缓存订阅完成,共处理 {processed_count} 个订阅")
return processed_count
# 创建全局事件管理器实例
event_manager = EventManager()
event_manager = EventManager()

View File

@@ -88,7 +88,7 @@ class GlobalAnnouncementManager:
return False
self._user_disabled_tools[chat_id].append(tool_name)
return True
def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
"""启用特定聊天的某个工具"""
if chat_id in self._user_disabled_tools:
@@ -111,7 +111,7 @@ class GlobalAnnouncementManager:
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有事件处理器"""
return self._user_disabled_event_handlers.get(chat_id, []).copy()
def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有工具"""
return self._user_disabled_tools.get(chat_id, []).copy()

View File

@@ -19,14 +19,14 @@ logger = get_logger(__name__)
class PermissionManager(IPermissionManager):
"""权限管理器实现类"""
def __init__(self):
self.engine = get_engine()
self.SessionLocal = sessionmaker(bind=self.engine)
self._master_users: Set[Tuple[str, str]] = set()
self._load_master_users()
logger.info("权限管理器初始化完成")
def _load_master_users(self):
"""从配置文件加载Master用户列表"""
try:
@@ -40,19 +40,19 @@ class PermissionManager(IPermissionManager):
except Exception as e:
logger.warning(f"加载Master用户配置失败: {e}")
self._master_users = set()
def reload_master_users(self):
"""重新加载Master用户配置"""
self._load_master_users()
logger.info("Master用户配置已重新加载")
def is_master(self, user: UserInfo) -> bool:
"""
检查用户是否为Master用户
Args:
user: 用户信息
Returns:
bool: 是否为Master用户
"""
@@ -61,15 +61,15 @@ class PermissionManager(IPermissionManager):
if is_master:
logger.debug(f"用户 {user.platform}:{user.user_id} 是Master用户")
return is_master
def check_permission(self, user: UserInfo, permission_node: str) -> bool:
"""
检查用户是否拥有指定权限节点
Args:
user: 用户信息
permission_node: 权限节点名称
Returns:
bool: 是否拥有权限
"""
@@ -78,46 +78,50 @@ class PermissionManager(IPermissionManager):
if self.is_master(user):
logger.debug(f"Master用户 {user.platform}:{user.user_id} 拥有权限节点 {permission_node}")
return True
with self.SessionLocal() as session:
# 检查权限节点是否存在
node = session.query(PermissionNodes).filter_by(node_name=permission_node).first()
if not node:
logger.warning(f"权限节点 {permission_node} 不存在")
return False
# 检查用户是否有明确的权限设置
user_perm = session.query(UserPermissions).filter_by(
platform=user.platform,
user_id=user.user_id,
permission_node=permission_node
).first()
user_perm = (
session.query(UserPermissions)
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node)
.first()
)
if user_perm:
# 有明确设置,返回设置的值
result = user_perm.granted
logger.debug(f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {result}")
logger.debug(
f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {result}"
)
return result
else:
# 没有明确设置,使用默认值
result = node.default_granted
logger.debug(f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 使用默认设置: {result}")
logger.debug(
f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 使用默认设置: {result}"
)
return result
except SQLAlchemyError as e:
logger.error(f"检查权限时数据库错误: {e}")
return False
except Exception as e:
logger.error(f"检查权限时发生未知错误: {e}")
return False
def register_permission_node(self, node: PermissionNode) -> bool:
"""
注册权限节点
Args:
node: 权限节点
Returns:
bool: 注册是否成功
"""
@@ -133,20 +137,20 @@ class PermissionManager(IPermissionManager):
session.commit()
logger.debug(f"更新权限节点: {node.node_name}")
return True
# 创建新节点
new_node = PermissionNodes(
node_name=node.node_name,
description=node.description,
plugin_name=node.plugin_name,
default_granted=node.default_granted,
created_at=datetime.utcnow()
created_at=datetime.utcnow(),
)
session.add(new_node)
session.commit()
logger.info(f"注册新权限节点: {node.node_name} (插件: {node.plugin_name})")
return True
except IntegrityError as e:
logger.error(f"注册权限节点时发生完整性错误: {e}")
return False
@@ -156,15 +160,15 @@ class PermissionManager(IPermissionManager):
except Exception as e:
logger.error(f"注册权限节点时发生未知错误: {e}")
return False
def grant_permission(self, user: UserInfo, permission_node: str) -> bool:
"""
授权用户权限节点
Args:
user: 用户信息
permission_node: 权限节点名称
Returns:
bool: 授权是否成功
"""
@@ -175,14 +179,14 @@ class PermissionManager(IPermissionManager):
if not node:
logger.error(f"尝试授权不存在的权限节点: {permission_node}")
return False
# 检查是否已有权限记录
existing_perm = session.query(UserPermissions).filter_by(
platform=user.platform,
user_id=user.user_id,
permission_node=permission_node
).first()
existing_perm = (
session.query(UserPermissions)
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node)
.first()
)
if existing_perm:
# 更新现有记录
existing_perm.granted = True
@@ -194,29 +198,29 @@ class PermissionManager(IPermissionManager):
user_id=user.user_id,
permission_node=permission_node,
granted=True,
granted_at=datetime.utcnow()
granted_at=datetime.utcnow(),
)
session.add(new_perm)
session.commit()
logger.info(f"已授权用户 {user.platform}:{user.user_id} 权限节点 {permission_node}")
return True
except SQLAlchemyError as e:
logger.error(f"授权权限时数据库错误: {e}")
return False
except Exception as e:
logger.error(f"授权权限时发生未知错误: {e}")
return False
def revoke_permission(self, user: UserInfo, permission_node: str) -> bool:
"""
撤销用户权限节点
Args:
user: 用户信息
permission_node: 权限节点名称
Returns:
bool: 撤销是否成功
"""
@@ -227,14 +231,14 @@ class PermissionManager(IPermissionManager):
if not node:
logger.error(f"尝试撤销不存在的权限节点: {permission_node}")
return False
# 检查是否已有权限记录
existing_perm = session.query(UserPermissions).filter_by(
platform=user.platform,
user_id=user.user_id,
permission_node=permission_node
).first()
existing_perm = (
session.query(UserPermissions)
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=permission_node)
.first()
)
if existing_perm:
# 更新现有记录
existing_perm.granted = False
@@ -246,28 +250,28 @@ class PermissionManager(IPermissionManager):
user_id=user.user_id,
permission_node=permission_node,
granted=False,
granted_at=datetime.utcnow()
granted_at=datetime.utcnow(),
)
session.add(new_perm)
session.commit()
logger.info(f"已撤销用户 {user.platform}:{user.user_id} 权限节点 {permission_node}")
return True
except SQLAlchemyError as e:
logger.error(f"撤销权限时数据库错误: {e}")
return False
except Exception as e:
logger.error(f"撤销权限时发生未知错误: {e}")
return False
def get_user_permissions(self, user: UserInfo) -> List[str]:
"""
获取用户拥有的所有权限节点
Args:
user: 用户信息
Returns:
List[str]: 权限节点列表
"""
@@ -277,21 +281,21 @@ class PermissionManager(IPermissionManager):
with self.SessionLocal() as session:
all_nodes = session.query(PermissionNodes.node_name).all()
return [node.node_name for node in all_nodes]
permissions = []
with self.SessionLocal() as session:
# 获取所有权限节点
all_nodes = session.query(PermissionNodes).all()
for node in all_nodes:
# 检查用户是否有明确的权限设置
user_perm = session.query(UserPermissions).filter_by(
platform=user.platform,
user_id=user.user_id,
permission_node=node.node_name
).first()
user_perm = (
session.query(UserPermissions)
.filter_by(platform=user.platform, user_id=user.user_id, permission_node=node.node_name)
.first()
)
if user_perm:
# 有明确设置,使用设置的值
if user_perm.granted:
@@ -300,20 +304,20 @@ class PermissionManager(IPermissionManager):
# 没有明确设置,使用默认值
if node.default_granted:
permissions.append(node.node_name)
return permissions
except SQLAlchemyError as e:
logger.error(f"获取用户权限时数据库错误: {e}")
return []
except Exception as e:
logger.error(f"获取用户权限时发生未知错误: {e}")
return []
def get_all_permission_nodes(self) -> List[PermissionNode]:
"""
获取所有已注册的权限节点
Returns:
List[PermissionNode]: 权限节点列表
"""
@@ -325,25 +329,25 @@ class PermissionManager(IPermissionManager):
node_name=node.node_name,
description=node.description,
plugin_name=node.plugin_name,
default_granted=node.default_granted
default_granted=node.default_granted,
)
for node in nodes
]
except SQLAlchemyError as e:
logger.error(f"获取所有权限节点时数据库错误: {e}")
return []
except Exception as e:
logger.error(f"获取所有权限节点时发生未知错误: {e}")
return []
def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]:
"""
获取指定插件的所有权限节点
Args:
plugin_name: 插件名称
Returns:
List[PermissionNode]: 权限节点列表
"""
@@ -355,25 +359,25 @@ class PermissionManager(IPermissionManager):
node_name=node.node_name,
description=node.description,
plugin_name=node.plugin_name,
default_granted=node.default_granted
default_granted=node.default_granted,
)
for node in nodes
]
except SQLAlchemyError as e:
logger.error(f"获取插件权限节点时数据库错误: {e}")
return []
except Exception as e:
logger.error(f"获取插件权限节点时发生未知错误: {e}")
return []
def delete_plugin_permissions(self, plugin_name: str) -> bool:
"""
删除指定插件的所有权限节点(用于插件卸载时清理)
Args:
plugin_name: 插件名称
Returns:
bool: 删除是否成功
"""
@@ -382,68 +386,71 @@ class PermissionManager(IPermissionManager):
# 获取插件的所有权限节点
plugin_nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).all()
node_names = [node.node_name for node in plugin_nodes]
if not node_names:
logger.info(f"插件 {plugin_name} 没有注册任何权限节点")
return True
# 删除用户权限记录
deleted_user_perms = session.query(UserPermissions).filter(
UserPermissions.permission_node.in_(node_names)
).delete(synchronize_session=False)
deleted_user_perms = (
session.query(UserPermissions)
.filter(UserPermissions.permission_node.in_(node_names))
.delete(synchronize_session=False)
)
# 删除权限节点
deleted_nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).delete()
session.commit()
logger.info(f"已删除插件 {plugin_name}{deleted_nodes} 个权限节点和 {deleted_user_perms} 条用户权限记录")
logger.info(
f"已删除插件 {plugin_name}{deleted_nodes} 个权限节点和 {deleted_user_perms} 条用户权限记录"
)
return True
except SQLAlchemyError as e:
logger.error(f"删除插件权限时数据库错误: {e}")
return False
except Exception as e:
logger.error(f"删除插件权限时发生未知错误: {e}")
return False
def get_users_with_permission(self, permission_node: str) -> List[Tuple[str, str]]:
"""
获取拥有指定权限的所有用户
Args:
permission_node: 权限节点名称
Returns:
List[Tuple[str, str]]: 用户列表,格式为 [(platform, user_id), ...]
"""
try:
users = []
with self.SessionLocal() as session:
# 检查权限节点是否存在
node = session.query(PermissionNodes).filter_by(node_name=permission_node).first()
if not node:
logger.warning(f"权限节点 {permission_node} 不存在")
return users
# 获取明确授权的用户
granted_users = session.query(UserPermissions).filter_by(
permission_node=permission_node,
granted=True
).all()
granted_users = (
session.query(UserPermissions).filter_by(permission_node=permission_node, granted=True).all()
)
for user_perm in granted_users:
users.append((user_perm.platform, user_perm.user_id))
# 如果是默认授权的权限节点,还需要考虑没有明确设置的用户
# 但这里我们只返回明确授权的用户,避免返回所有用户
# 添加Master用户他们拥有所有权限
users.extend(list(self._master_users))
# 去重
return list(set(users))
except SQLAlchemyError as e:
logger.error(f"获取拥有权限的用户时数据库错误: {e}")
return []

View File

@@ -36,21 +36,21 @@ class PluginFileHandler(FileSystemEventHandler):
"""文件修改事件"""
if not event.is_directory:
file_path = str(event.src_path)
if file_path.endswith(('.py', '.toml')):
if file_path.endswith((".py", ".toml")):
self._handle_file_change(file_path, "modified")
def on_created(self, event):
"""文件创建事件"""
if not event.is_directory:
file_path = str(event.src_path)
if file_path.endswith(('.py', '.toml')):
if file_path.endswith((".py", ".toml")):
self._handle_file_change(file_path, "created")
def on_deleted(self, event):
"""文件删除事件"""
if not event.is_directory:
file_path = str(event.src_path)
if file_path.endswith(('.py', '.toml')):
if file_path.endswith((".py", ".toml")):
self._handle_file_change(file_path, "deleted")
def _handle_file_change(self, file_path: str, change_type: str):
@@ -63,14 +63,14 @@ class PluginFileHandler(FileSystemEventHandler):
plugin_name, source_type = plugin_info
current_time = time.time()
# 文件变化缓存,避免重复处理同一文件的快速连续变化
file_cache_key = f"{file_path}_{change_type}"
last_file_time = self.file_change_cache.get(file_cache_key, 0)
if current_time - last_file_time < 0.5: # 0.5秒内的重复文件变化忽略
return
self.file_change_cache[file_cache_key] = current_time
# 插件级别的防抖处理
last_plugin_time = self.last_reload_time.get(plugin_name, 0)
if current_time - last_plugin_time < self.debounce_delay:
@@ -85,20 +85,28 @@ class PluginFileHandler(FileSystemEventHandler):
if change_type == "deleted":
# 解析实际的插件名称
actual_plugin_name = self.hot_reload_manager._resolve_plugin_name(plugin_name)
if file_name == "plugin.py":
if actual_plugin_name in plugin_manager.loaded_plugins:
logger.info(f"🗑️ 插件主文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]")
logger.info(
f"🗑️ 插件主文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]"
)
self.hot_reload_manager._unload_plugin(actual_plugin_name)
else:
logger.info(f"🗑️ 插件主文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]")
logger.info(
f"🗑️ 插件主文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]"
)
return
elif file_name in ("manifest.toml", "_manifest.json"):
if actual_plugin_name in plugin_manager.loaded_plugins:
logger.info(f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]")
logger.info(
f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name} -> {actual_plugin_name} [{source_type}]"
)
self.hot_reload_manager._unload_plugin(actual_plugin_name)
else:
logger.info(f"🗑️ 插件配置文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]")
logger.info(
f"🗑️ 插件配置文件被删除,但插件未加载: {plugin_name} -> {actual_plugin_name} [{source_type}]"
)
return
# 对于修改和创建事件,都进行重载
@@ -108,9 +116,7 @@ class PluginFileHandler(FileSystemEventHandler):
# 延迟重载,确保文件写入完成
reload_thread = Thread(
target=self._delayed_reload,
args=(plugin_name, source_type, current_time),
daemon=True
target=self._delayed_reload, args=(plugin_name, source_type, current_time), daemon=True
)
reload_thread.start()
@@ -126,14 +132,14 @@ class PluginFileHandler(FileSystemEventHandler):
# 检查是否还需要重载(可能在等待期间有更新的变化)
if plugin_name not in self.pending_reloads:
return
# 检查是否有更新的重载请求
if self.last_reload_time.get(plugin_name, 0) > trigger_time:
return
self.pending_reloads.discard(plugin_name)
logger.info(f"🔄 开始延迟重载插件: {plugin_name} [{source_type}]")
# 执行深度重载
success = self.hot_reload_manager._deep_reload_plugin(plugin_name)
if success:
@@ -146,7 +152,7 @@ class PluginFileHandler(FileSystemEventHandler):
def _get_plugin_info_from_path(self, file_path: str) -> Optional[Tuple[str, str]]:
"""从文件路径获取插件信息
Returns:
tuple[插件名称, 源类型] 或 None
"""
@@ -162,12 +168,12 @@ class PluginFileHandler(FileSystemEventHandler):
source_type = "built-in"
else:
source_type = "external"
# 获取插件目录名(插件名)
relative_path = path.relative_to(plugin_root)
if len(relative_path.parts) == 0:
continue
plugin_name = relative_path.parts[0]
# 确认这是一个有效的插件目录
@@ -175,9 +181,10 @@ class PluginFileHandler(FileSystemEventHandler):
if plugin_dir.is_dir():
# 检查是否有插件主文件或配置文件
has_plugin_py = (plugin_dir / "plugin.py").exists()
has_manifest = ((plugin_dir / "manifest.toml").exists() or
(plugin_dir / "_manifest.json").exists())
has_manifest = (plugin_dir / "manifest.toml").exists() or (
plugin_dir / "_manifest.json"
).exists()
if has_plugin_py or has_manifest:
return plugin_name, source_type
@@ -195,11 +202,11 @@ class PluginHotReloadManager:
# 默认监听两个目录:根目录下的 plugins 和 src 下的插件目录
self.watch_directories = [
os.path.join(os.getcwd(), "plugins"), # 外部插件目录
os.path.join(os.getcwd(), "src", "plugins", "built_in") # 内置插件目录
os.path.join(os.getcwd(), "src", "plugins", "built_in"), # 内置插件目录
]
else:
self.watch_directories = watch_directories
self.observers = []
self.file_handlers = []
self.is_running = False
@@ -221,13 +228,9 @@ class PluginHotReloadManager:
for watch_dir in self.watch_directories:
observer = Observer()
file_handler = PluginFileHandler(self)
observer.schedule(
file_handler,
watch_dir,
recursive=True
)
observer.schedule(file_handler, watch_dir, recursive=True)
observer.start()
self.observers.append(observer)
self.file_handlers.append(file_handler)
@@ -296,26 +299,26 @@ class PluginHotReloadManager:
if folder_name in plugin_manager.plugin_classes:
logger.debug(f"🔍 直接匹配插件名: {folder_name}")
return folder_name
# 如果没有直接匹配,搜索路径映射,并优先返回在插件类中存在的名称
matched_plugins = []
for plugin_name, plugin_path in plugin_manager.plugin_paths.items():
# 检查路径是否包含该文件夹名
if folder_name in plugin_path:
matched_plugins.append((plugin_name, plugin_path))
# 在匹配的插件中,优先选择在插件类中存在的
for plugin_name, plugin_path in matched_plugins:
if plugin_name in plugin_manager.plugin_classes:
logger.debug(f"🔍 文件夹名 '{folder_name}' 映射到插件名 '{plugin_name}' (路径: {plugin_path})")
return plugin_name
# 如果还是没找到在插件类中存在的,返回第一个匹配项
if matched_plugins:
plugin_name, plugin_path = matched_plugins[0]
logger.warning(f"⚠️ 文件夹 '{folder_name}' 映射到 '{plugin_name}',但该插件类不存在")
return plugin_name
# 如果还是没找到,返回原文件夹名
logger.warning(f"⚠️ 无法找到文件夹 '{folder_name}' 对应的插件名,使用原名称")
return folder_name
@@ -326,13 +329,13 @@ class PluginHotReloadManager:
# 解析实际的插件名称
actual_plugin_name = self._resolve_plugin_name(plugin_name)
logger.info(f"🔄 开始深度重载插件: {plugin_name} -> {actual_plugin_name}")
# 强制清理相关模块缓存
self._force_clear_plugin_modules(plugin_name)
# 使用插件管理器的强制重载功能
success = plugin_manager.force_reload_plugin(actual_plugin_name)
if success:
logger.info(f"✅ 插件深度重载成功: {actual_plugin_name}")
return True
@@ -348,15 +351,15 @@ class PluginHotReloadManager:
def _force_clear_plugin_modules(self, plugin_name: str):
"""强制清理插件相关的模块缓存"""
# 找到所有相关的模块名
modules_to_remove = []
plugin_module_prefix = f"src.plugins.built_in.{plugin_name}"
for module_name in list(sys.modules.keys()):
if plugin_module_prefix in module_name:
modules_to_remove.append(module_name)
# 删除模块缓存
for module_name in modules_to_remove:
if module_name in sys.modules:
@@ -369,7 +372,7 @@ class PluginHotReloadManager:
# 使用插件管理器的重载功能
success = plugin_manager.reload_plugin(plugin_name)
return success
except Exception as e:
logger.error(f"❌ 强制重新导入插件 {plugin_name} 时发生错误: {e}", exc_info=True)
return False
@@ -378,7 +381,7 @@ class PluginHotReloadManager:
"""卸载指定插件"""
try:
logger.info(f"🗑️ 开始卸载插件: {plugin_name}")
if plugin_manager.unload_plugin(plugin_name):
logger.info(f"✅ 插件卸载成功: {plugin_name}")
return True
@@ -409,7 +412,7 @@ class PluginHotReloadManager:
fail_count += 1
logger.info(f"✅ 插件重载完成: 成功 {success_count} 个,失败 {fail_count}")
# 清理全局缓存
importlib.invalidate_caches()
@@ -420,21 +423,21 @@ class PluginHotReloadManager:
"""手动强制重载指定插件(委托给插件管理器)"""
try:
logger.info(f"🔄 手动强制重载插件: {plugin_name}")
# 清理待重载列表中的该插件(避免重复重载)
for handler in self.file_handlers:
handler.pending_reloads.discard(plugin_name)
# 使用插件管理器的强制重载功能
success = plugin_manager.force_reload_plugin(plugin_name)
if success:
logger.info(f"✅ 手动强制重载成功: {plugin_name}")
else:
logger.error(f"❌ 手动强制重载失败: {plugin_name}")
return success
except Exception as e:
logger.error(f"❌ 手动强制重载插件 {plugin_name} 时发生错误: {e}", exc_info=True)
return False
@@ -457,19 +460,15 @@ class PluginHotReloadManager:
try:
observer = Observer()
file_handler = PluginFileHandler(self)
observer.schedule(
file_handler,
directory,
recursive=True
)
observer.schedule(file_handler, directory, recursive=True)
observer.start()
self.observers.append(observer)
self.file_handlers.append(file_handler)
logger.info(f"📂 已添加新的监听目录: {directory}")
except Exception as e:
logger.error(f"❌ 添加监听目录 {directory} 失败: {e}")
self.watch_directories.remove(directory)
@@ -480,7 +479,7 @@ class PluginHotReloadManager:
if self.file_handlers:
for handler in self.file_handlers:
pending_reloads.update(handler.pending_reloads)
return {
"is_running": self.is_running,
"watch_directories": self.watch_directories,
@@ -495,11 +494,11 @@ class PluginHotReloadManager:
"""清理所有Python模块缓存"""
try:
logger.info("🧹 开始清理所有Python模块缓存...")
# 重新扫描所有插件目录,这会重新加载模块
plugin_manager.rescan_plugin_directory()
logger.info("✅ 模块缓存清理完成")
except Exception as e:
logger.error(f"❌ 清理模块缓存时发生错误: {e}", exc_info=True)

View File

@@ -104,7 +104,7 @@ class PluginManager:
return False # 目标文件不存在,视为不同
# 使用 'rb' 模式以二进制方式读取文件,确保哈希值计算的一致性
with open(file1, 'rb') as f1, open(file2, 'rb') as f2:
with open(file1, "rb") as f1, open(file2, "rb") as f2:
return hashlib.md5(f1.read()).hexdigest() == hashlib.md5(f2.read()).hexdigest()
# === 插件目录管理 ===
@@ -300,7 +300,7 @@ class PluginManager:
list: 已注册的插件类名称列表。
"""
return list(self.plugin_classes.keys())
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
"""
获取指定插件的路径。
@@ -366,7 +366,7 @@ class PluginManager:
# 生成模块名和插件信息
plugin_path = Path(plugin_file)
plugin_dir = plugin_path.parent # 插件目录
plugin_name = plugin_dir.name # 插件名称
plugin_name = plugin_dir.name # 插件名称
module_name = ".".join(plugin_path.parent.parts)
try:
@@ -386,7 +386,7 @@ class PluginManager:
except Exception as e:
error_msg = f"加载插件模块 {plugin_file} 失败: {e}"
logger.error(error_msg)
self.failed_plugins[plugin_name if 'plugin_name' in locals() else module_name] = error_msg
self.failed_plugins[plugin_name if "plugin_name" in locals() else module_name] = error_msg
return False
# == 兼容性检查 ==
@@ -478,9 +478,7 @@ class PluginManager:
command_components = [
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
]
tool_components = [
c for c in plugin_info.components if c.component_type == ComponentType.TOOL
]
tool_components = [c for c in plugin_info.components if c.component_type == ComponentType.TOOL]
event_handler_components = [
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
]
@@ -591,7 +589,7 @@ class PluginManager:
plugin_instance = self.loaded_plugins[plugin_name]
# 调用插件的清理方法(如果有的话)
if hasattr(plugin_instance, 'on_unload'):
if hasattr(plugin_instance, "on_unload"):
plugin_instance.on_unload()
# 从组件注册表中移除插件的所有组件
@@ -654,10 +652,10 @@ class PluginManager:
def force_reload_plugin(self, plugin_name: str) -> bool:
"""强制重载插件(使用简化的方法)
Args:
plugin_name: 插件名称
Returns:
bool: 重载是否成功
"""

View File

@@ -129,17 +129,17 @@ class ToolExecutor:
if not tool_calls:
logger.debug(f"{self.log_prefix}无需执行工具")
return [], []
# 提取tool_calls中的函数名称
func_names = []
for call in tool_calls:
try:
if hasattr(call, 'func_name'):
if hasattr(call, "func_name"):
func_names.append(call.func_name)
except Exception as e:
logger.error(f"{self.log_prefix}获取工具名称失败: {e}")
continue
if func_names:
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
else:
@@ -185,9 +185,11 @@ class ToolExecutor:
return tool_results, used_tools
async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
async def execute_tool_call(
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
) -> Optional[Dict[str, Any]]:
"""执行单个工具调用,并处理缓存"""
function_args = tool_call.args or {}
tool_instance = tool_instance or get_tool_instance(tool_call.func_name)
@@ -206,7 +208,7 @@ class ToolExecutor:
tool_name=tool_call.func_name,
function_args=function_args,
tool_file_path=tool_file_path,
semantic_query=semantic_query
semantic_query=semantic_query,
)
if cached_result:
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
@@ -223,14 +225,14 @@ class ToolExecutor:
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
await tool_cache.set(
tool_name=tool_call.func_name,
function_args=function_args,
tool_file_path=tool_file_path,
data=result,
ttl=tool_instance.cache_ttl,
semantic_query=semantic_query
semantic_query=semantic_query,
)
except Exception as e:
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
@@ -238,12 +240,16 @@ class ToolExecutor:
return result
async def _original_execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
async def _original_execute_tool_call(
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
) -> Optional[Dict[str, Any]]:
"""执行单个工具调用的原始逻辑"""
try:
function_name = tool_call.func_name
function_args = tool_call.args or {}
logger.info(f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}")
logger.info(
f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}"
)
function_args["llm_called"] = True # 标记为LLM调用
# 获取对应工具实例
tool_instance = tool_instance or get_tool_instance(function_name)
@@ -261,7 +267,7 @@ class ToolExecutor:
"role": "tool",
"name": function_name,
"type": "function",
"content": result.get("content", "")
"content": result.get("content", ""),
}
logger.warning(f"{self.log_prefix}工具 {function_name} 返回空结果")
return None
@@ -308,7 +314,6 @@ class ToolExecutor:
return None
"""
ToolExecutor使用示例