初始化
This commit is contained in:
19
src/plugin_system/core/__init__.py
Normal file
19
src/plugin_system/core/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
插件核心管理模块
|
||||
|
||||
提供插件的加载、注册和管理功能
|
||||
"""
|
||||
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
||||
|
||||
__all__ = [
|
||||
"plugin_manager",
|
||||
"component_registry",
|
||||
"events_manager",
|
||||
"global_announcement_manager",
|
||||
"hot_reload_manager",
|
||||
]
|
||||
688
src/plugin_system/core/component_registry.py
Normal file
688
src/plugin_system/core/component_registry.py
Normal file
@@ -0,0 +1,688 @@
|
||||
import re
|
||||
|
||||
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import (
|
||||
ComponentInfo,
|
||||
ActionInfo,
|
||||
ToolInfo,
|
||||
CommandInfo,
|
||||
EventHandlerInfo,
|
||||
PluginInfo,
|
||||
ComponentType,
|
||||
)
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
|
||||
logger = get_logger("component_registry")
|
||||
|
||||
|
||||
class ComponentRegistry:
|
||||
"""统一的组件注册中心
|
||||
|
||||
负责管理所有插件组件的注册、查询和生命周期管理
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 命名空间式组件名构成法 f"{component_type}.{component_name}"
|
||||
self._components: Dict[str, ComponentInfo] = {}
|
||||
"""组件注册表 命名空间式组件名 -> 组件信息"""
|
||||
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
|
||||
"""类型 -> 组件原名称 -> 组件信息"""
|
||||
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler]]] = {}
|
||||
"""命名空间式组件名 -> 组件类"""
|
||||
|
||||
# 插件注册表
|
||||
self._plugins: Dict[str, PluginInfo] = {}
|
||||
"""插件名 -> 插件信息"""
|
||||
|
||||
# Action特定注册表
|
||||
self._action_registry: Dict[str, Type[BaseAction]] = {}
|
||||
"""Action注册表 action名 -> action类"""
|
||||
self._default_actions: Dict[str, ActionInfo] = {}
|
||||
"""默认动作集,即启用的Action集,用于重置ActionManager状态"""
|
||||
|
||||
# Command特定注册表
|
||||
self._command_registry: Dict[str, Type[BaseCommand]] = {}
|
||||
"""Command类注册表 command名 -> command类"""
|
||||
self._command_patterns: Dict[Pattern, str] = {}
|
||||
"""编译后的正则 -> command名"""
|
||||
|
||||
# 工具特定注册表
|
||||
self._tool_registry: Dict[str, Type[BaseTool]] = {} # 工具名 -> 工具类
|
||||
self._llm_available_tools: Dict[str, Type[BaseTool]] = {} # llm可用的工具名 -> 工具类
|
||||
|
||||
# EventHandler特定注册表
|
||||
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {}
|
||||
"""event_handler名 -> event_handler类"""
|
||||
self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {}
|
||||
"""启用的事件处理器 event_handler名 -> event_handler类"""
|
||||
|
||||
logger.info("组件注册中心初始化完成")
|
||||
|
||||
# == 注册方法 ==
|
||||
|
||||
def register_plugin(self, plugin_info: PluginInfo) -> bool:
|
||||
"""注册插件
|
||||
|
||||
Args:
|
||||
plugin_info: 插件信息
|
||||
|
||||
Returns:
|
||||
bool: 是否注册成功
|
||||
"""
|
||||
plugin_name = plugin_info.name
|
||||
|
||||
if plugin_name in self._plugins:
|
||||
logger.warning(f"插件 {plugin_name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
self._plugins[plugin_name] = plugin_info
|
||||
logger.debug(f"已注册插件: {plugin_name} (组件数量: {len(plugin_info.components)})")
|
||||
return True
|
||||
|
||||
def register_component(
|
||||
self,
|
||||
component_info: ComponentInfo,
|
||||
component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler, BaseTool]],
|
||||
) -> bool:
|
||||
"""注册组件
|
||||
|
||||
Args:
|
||||
component_info (ComponentInfo): 组件信息
|
||||
component_class (Type[Union[BaseCommand, BaseAction, BaseEventHandler]]): 组件类
|
||||
|
||||
Returns:
|
||||
bool: 是否注册成功
|
||||
"""
|
||||
component_name = component_info.name
|
||||
component_type = component_info.component_type
|
||||
plugin_name = getattr(component_info, "plugin_name", "unknown")
|
||||
if "." in component_name:
|
||||
logger.error(f"组件名称 '{component_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return False
|
||||
if "." in plugin_name:
|
||||
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return False
|
||||
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
|
||||
if namespaced_name in self._components:
|
||||
existing_info = self._components[namespaced_name]
|
||||
existing_plugin = getattr(existing_info, "plugin_name", "unknown")
|
||||
|
||||
logger.warning(
|
||||
f"组件名冲突: '{plugin_name}' 插件的 {component_type} 类型组件 '{component_name}' 已被插件 '{existing_plugin}' 注册,跳过此组件注册"
|
||||
)
|
||||
return False
|
||||
|
||||
self._components[namespaced_name] = component_info # 注册到通用注册表(使用命名空间化的名称)
|
||||
self._components_by_type[component_type][component_name] = component_info # 类型内部仍使用原名
|
||||
self._components_classes[namespaced_name] = component_class
|
||||
|
||||
# 根据组件类型进行特定注册(使用原始名称)
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
assert isinstance(component_info, ActionInfo)
|
||||
assert issubclass(component_class, BaseAction)
|
||||
ret = self._register_action_component(component_info, component_class)
|
||||
case ComponentType.COMMAND:
|
||||
assert isinstance(component_info, CommandInfo)
|
||||
assert issubclass(component_class, BaseCommand)
|
||||
ret = self._register_command_component(component_info, component_class)
|
||||
case ComponentType.TOOL:
|
||||
assert isinstance(component_info, ToolInfo)
|
||||
assert issubclass(component_class, BaseTool)
|
||||
ret = self._register_tool_component(component_info, component_class)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
assert isinstance(component_info, EventHandlerInfo)
|
||||
assert issubclass(component_class, BaseEventHandler)
|
||||
ret = self._register_event_handler_component(component_info, component_class)
|
||||
case _:
|
||||
logger.warning(f"未知组件类型: {component_type}")
|
||||
|
||||
if not ret:
|
||||
return False
|
||||
logger.debug(
|
||||
f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' "
|
||||
f"({component_class.__name__}) [插件: {plugin_name}]"
|
||||
)
|
||||
return True
|
||||
|
||||
def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]) -> bool:
|
||||
"""注册Action组件到Action特定注册表"""
|
||||
if not (action_name := action_info.name):
|
||||
logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
|
||||
return False
|
||||
if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction):
|
||||
logger.error(f"注册失败: {action_name} 不是有效的Action")
|
||||
return False
|
||||
|
||||
self._action_registry[action_name] = action_class
|
||||
|
||||
# 如果启用,添加到默认动作集
|
||||
if action_info.enabled:
|
||||
self._default_actions[action_name] = action_info
|
||||
|
||||
return True
|
||||
|
||||
def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]) -> bool:
|
||||
"""注册Command组件到Command特定注册表"""
|
||||
if not (command_name := command_info.name):
|
||||
logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
|
||||
return False
|
||||
if not isinstance(command_info, CommandInfo) or not issubclass(command_class, BaseCommand):
|
||||
logger.error(f"注册失败: {command_name} 不是有效的Command")
|
||||
return False
|
||||
|
||||
self._command_registry[command_name] = command_class
|
||||
|
||||
# 如果启用了且有匹配模式
|
||||
if command_info.enabled and command_info.command_pattern:
|
||||
pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL)
|
||||
if pattern not in self._command_patterns:
|
||||
self._command_patterns[pattern] = command_name
|
||||
else:
|
||||
logger.warning(
|
||||
f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool:
|
||||
"""注册Tool组件到Tool特定注册表"""
|
||||
tool_name = tool_info.name
|
||||
|
||||
self._tool_registry[tool_name] = tool_class
|
||||
|
||||
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
|
||||
if tool_info.enabled:
|
||||
self._llm_available_tools[tool_name] = tool_class
|
||||
|
||||
return True
|
||||
|
||||
def _register_event_handler_component(
|
||||
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
|
||||
) -> bool:
|
||||
if not (handler_name := handler_info.name):
|
||||
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
|
||||
return False
|
||||
if not isinstance(handler_info, EventHandlerInfo) or not issubclass(handler_class, BaseEventHandler):
|
||||
logger.error(f"注册失败: {handler_name} 不是有效的EventHandler")
|
||||
return False
|
||||
|
||||
self._event_handler_registry[handler_name] = handler_class
|
||||
|
||||
if not handler_info.enabled:
|
||||
logger.warning(f"EventHandler组件 {handler_name} 未启用")
|
||||
return True # 未启用,但是也是注册成功
|
||||
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
if events_manager.register_event_subscriber(handler_info, handler_class):
|
||||
self._enabled_event_handlers[handler_name] = handler_class
|
||||
return True
|
||||
else:
|
||||
logger.error(f"注册事件处理器 {handler_name} 失败")
|
||||
return False
|
||||
|
||||
# === 组件移除相关 ===
|
||||
|
||||
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
|
||||
target_component_class = self.get_component_class(component_name, component_type)
|
||||
if not target_component_class:
|
||||
logger.warning(f"组件 {component_name} 未注册,无法移除")
|
||||
return False
|
||||
try:
|
||||
# 根据组件类型进行特定的清理操作
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
# 移除Action注册
|
||||
self._action_registry.pop(component_name, None)
|
||||
self._default_actions.pop(component_name, None)
|
||||
logger.debug(f"已移除Action组件: {component_name}")
|
||||
|
||||
case ComponentType.COMMAND:
|
||||
# 移除Command注册和模式
|
||||
self._command_registry.pop(component_name, None)
|
||||
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
|
||||
for key in keys_to_remove:
|
||||
self._command_patterns.pop(key, None)
|
||||
logger.debug(f"已移除Command组件: {component_name} (清理了 {len(keys_to_remove)} 个模式)")
|
||||
|
||||
case ComponentType.TOOL:
|
||||
# 移除Tool注册
|
||||
self._tool_registry.pop(component_name, None)
|
||||
self._llm_available_tools.pop(component_name, None)
|
||||
logger.debug(f"已移除Tool组件: {component_name}")
|
||||
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
# 移除EventHandler注册和事件订阅
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
self._event_handler_registry.pop(component_name, None)
|
||||
self._enabled_event_handlers.pop(component_name, None)
|
||||
try:
|
||||
await events_manager.unregister_event_subscriber(component_name)
|
||||
logger.debug(f"已移除EventHandler组件: {component_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"移除EventHandler事件订阅时出错: {e}")
|
||||
|
||||
case _:
|
||||
logger.warning(f"未知的组件类型: {component_type}")
|
||||
return False
|
||||
|
||||
# 移除通用注册信息
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
self._components.pop(namespaced_name, None)
|
||||
self._components_by_type[component_type].pop(component_name, None)
|
||||
self._components_classes.pop(namespaced_name, None)
|
||||
|
||||
logger.info(f"组件 {component_name} ({component_type}) 已完全移除")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"移除组件 {component_name} ({component_type}) 时发生错误: {e}")
|
||||
return False
|
||||
|
||||
def remove_plugin_registry(self, plugin_name: str) -> bool:
|
||||
"""移除插件注册信息
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: 是否成功移除
|
||||
"""
|
||||
if plugin_name not in self._plugins:
|
||||
logger.warning(f"插件 {plugin_name} 未注册,无法移除")
|
||||
return False
|
||||
del self._plugins[plugin_name]
|
||||
logger.info(f"插件 {plugin_name} 已移除")
|
||||
return True
|
||||
|
||||
# === 组件全局启用/禁用方法 ===
|
||||
|
||||
def enable_component(self, component_name: str, component_type: ComponentType) -> bool:
|
||||
"""全局的启用某个组件
|
||||
Parameters:
|
||||
component_name: 组件名称
|
||||
component_type: 组件类型
|
||||
Returns:
|
||||
bool: 启用成功返回True,失败返回False
|
||||
"""
|
||||
target_component_class = self.get_component_class(component_name, component_type)
|
||||
target_component_info = self.get_component_info(component_name, component_type)
|
||||
if not target_component_class or not target_component_info:
|
||||
logger.warning(f"组件 {component_name} 未注册,无法启用")
|
||||
return False
|
||||
target_component_info.enabled = True
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
assert isinstance(target_component_info, ActionInfo)
|
||||
self._default_actions[component_name] = target_component_info
|
||||
case ComponentType.COMMAND:
|
||||
assert isinstance(target_component_info, CommandInfo)
|
||||
pattern = target_component_info.command_pattern
|
||||
self._command_patterns[re.compile(pattern)] = component_name
|
||||
case ComponentType.TOOL:
|
||||
assert isinstance(target_component_info, ToolInfo)
|
||||
assert issubclass(target_component_class, BaseTool)
|
||||
self._llm_available_tools[component_name] = target_component_class
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
assert isinstance(target_component_info, EventHandlerInfo)
|
||||
assert issubclass(target_component_class, BaseEventHandler)
|
||||
self._enabled_event_handlers[component_name] = target_component_class
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
events_manager.register_event_subscriber(target_component_info, target_component_class)
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
self._components[namespaced_name].enabled = True
|
||||
self._components_by_type[component_type][component_name].enabled = True
|
||||
logger.info(f"组件 {component_name} 已启用")
|
||||
return True
|
||||
|
||||
async def disable_component(self, component_name: str, component_type: ComponentType) -> bool:
|
||||
"""全局的禁用某个组件
|
||||
Parameters:
|
||||
component_name: 组件名称
|
||||
component_type: 组件类型
|
||||
Returns:
|
||||
bool: 禁用成功返回True,失败返回False
|
||||
"""
|
||||
target_component_class = self.get_component_class(component_name, component_type)
|
||||
target_component_info = self.get_component_info(component_name, component_type)
|
||||
if not target_component_class or not target_component_info:
|
||||
logger.warning(f"组件 {component_name} 未注册,无法禁用")
|
||||
return False
|
||||
target_component_info.enabled = False
|
||||
try:
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
self._default_actions.pop(component_name)
|
||||
case ComponentType.COMMAND:
|
||||
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
|
||||
case ComponentType.TOOL:
|
||||
self._llm_available_tools.pop(component_name)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
self._enabled_event_handlers.pop(component_name)
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
await events_manager.unregister_event_subscriber(component_name)
|
||||
self._components[component_name].enabled = False
|
||||
self._components_by_type[component_type][component_name].enabled = False
|
||||
logger.info(f"组件 {component_name} 已禁用")
|
||||
return True
|
||||
except KeyError as e:
|
||||
logger.warning(f"禁用组件时未找到组件或已禁用: {component_name}, 发生错误: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"禁用组件 {component_name} 时发生错误: {e}")
|
||||
return False
|
||||
|
||||
# === 组件查询方法 ===
|
||||
def get_component_info(
|
||||
self, component_name: str, component_type: Optional[ComponentType] = None
|
||||
) -> Optional[ComponentInfo]:
|
||||
# sourcery skip: class-extract-method
|
||||
"""获取组件信息,支持自动命名空间解析
|
||||
|
||||
Args:
|
||||
component_name: 组件名称,可以是原始名称或命名空间化的名称
|
||||
component_type: 组件类型,如果提供则优先在该类型中查找
|
||||
|
||||
Returns:
|
||||
Optional[ComponentInfo]: 组件信息或None
|
||||
"""
|
||||
# 1. 如果已经是命名空间化的名称,直接查找
|
||||
if "." in component_name:
|
||||
return self._components.get(component_name)
|
||||
|
||||
# 2. 如果指定了组件类型,构造命名空间化的名称查找
|
||||
if component_type:
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
return self._components.get(namespaced_name)
|
||||
|
||||
# 3. 如果没有指定类型,尝试在所有命名空间中查找
|
||||
candidates = []
|
||||
for namespace_prefix in [types.value for types in ComponentType]:
|
||||
namespaced_name = f"{namespace_prefix}.{component_name}"
|
||||
if component_info := self._components.get(namespaced_name):
|
||||
candidates.append((namespace_prefix, namespaced_name, component_info))
|
||||
|
||||
if len(candidates) == 1:
|
||||
# 只有一个匹配,直接返回
|
||||
return candidates[0][2]
|
||||
elif len(candidates) > 1:
|
||||
# 多个匹配,记录警告并返回第一个
|
||||
namespaces = [ns for ns, _, _ in candidates]
|
||||
logger.warning(
|
||||
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces},使用第一个匹配项: {candidates[0][1]}"
|
||||
)
|
||||
return candidates[0][2]
|
||||
|
||||
# 4. 都没找到
|
||||
return None
|
||||
|
||||
def get_component_class(
|
||||
self,
|
||||
component_name: str,
|
||||
component_type: Optional[ComponentType] = None,
|
||||
) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler], Type[BaseTool]]]:
|
||||
"""获取组件类,支持自动命名空间解析
|
||||
|
||||
Args:
|
||||
component_name: 组件名称,可以是原始名称或命名空间化的名称
|
||||
component_type: 组件类型,如果提供则优先在该类型中查找
|
||||
|
||||
Returns:
|
||||
Optional[Union[BaseCommand, BaseAction]]: 组件类或None
|
||||
"""
|
||||
# 1. 如果已经是命名空间化的名称,直接查找
|
||||
if "." in component_name:
|
||||
return self._components_classes.get(component_name)
|
||||
|
||||
# 2. 如果指定了组件类型,构造命名空间化的名称查找
|
||||
if component_type:
|
||||
namespaced_name = f"{component_type.value}.{component_name}"
|
||||
return self._components_classes.get(namespaced_name)
|
||||
|
||||
# 3. 如果没有指定类型,尝试在所有命名空间中查找
|
||||
candidates = []
|
||||
for namespace_prefix in [types.value for types in ComponentType]:
|
||||
namespaced_name = f"{namespace_prefix}.{component_name}"
|
||||
if component_class := self._components_classes.get(namespaced_name):
|
||||
candidates.append((namespace_prefix, namespaced_name, component_class))
|
||||
|
||||
if len(candidates) == 1:
|
||||
# 只有一个匹配,直接返回
|
||||
_, full_name, cls = candidates[0]
|
||||
logger.debug(f"自动解析组件: '{component_name}' -> '{full_name}'")
|
||||
return cls
|
||||
elif len(candidates) > 1:
|
||||
# 多个匹配,记录警告并返回第一个
|
||||
namespaces = [ns for ns, _, _ in candidates]
|
||||
logger.warning(
|
||||
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces},使用第一个匹配项: {candidates[0][1]}"
|
||||
)
|
||||
return candidates[0][2]
|
||||
|
||||
# 4. 都没找到
|
||||
return None
|
||||
|
||||
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
|
||||
"""获取指定类型的所有组件"""
|
||||
return self._components_by_type.get(component_type, {}).copy()
|
||||
|
||||
def get_enabled_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
|
||||
"""获取指定类型的所有启用组件"""
|
||||
components = self.get_components_by_type(component_type)
|
||||
return {name: info for name, info in components.items() if info.enabled}
|
||||
|
||||
# === Action特定查询方法 ===
|
||||
|
||||
def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
|
||||
"""获取Action注册表"""
|
||||
return self._action_registry.copy()
|
||||
|
||||
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
|
||||
"""获取Action信息"""
|
||||
info = self.get_component_info(action_name, ComponentType.ACTION)
|
||||
return info if isinstance(info, ActionInfo) else None
|
||||
|
||||
def get_default_actions(self) -> Dict[str, ActionInfo]:
|
||||
"""获取默认动作集"""
|
||||
return self._default_actions.copy()
|
||||
|
||||
# === Command特定查询方法 ===
|
||||
|
||||
def get_command_registry(self) -> Dict[str, Type[BaseCommand]]:
|
||||
"""获取Command注册表"""
|
||||
return self._command_registry.copy()
|
||||
|
||||
def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]:
|
||||
"""获取Command信息"""
|
||||
info = self.get_component_info(command_name, ComponentType.COMMAND)
|
||||
return info if isinstance(info, CommandInfo) else None
|
||||
|
||||
def get_command_patterns(self) -> Dict[Pattern, str]:
|
||||
"""获取Command模式注册表"""
|
||||
return self._command_patterns.copy()
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]:
|
||||
# sourcery skip: use-named-expression, use-next
|
||||
"""根据文本查找匹配的命令
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
Tuple: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None
|
||||
"""
|
||||
|
||||
candidates = [pattern for pattern in self._command_patterns if pattern.match(text)]
|
||||
if not candidates:
|
||||
return None
|
||||
if len(candidates) > 1:
|
||||
logger.warning(f"文本 '{text}' 匹配到多个命令模式: {candidates},使用第一个匹配")
|
||||
command_name = self._command_patterns[candidates[0]]
|
||||
command_info: CommandInfo = self.get_registered_command_info(command_name) # type: ignore
|
||||
return (
|
||||
self._command_registry[command_name],
|
||||
candidates[0].match(text).groupdict(), # type: ignore
|
||||
command_info,
|
||||
)
|
||||
|
||||
# === Tool 特定查询方法 ===
|
||||
def get_tool_registry(self) -> Dict[str, Type[BaseTool]]:
|
||||
"""获取Tool注册表"""
|
||||
return self._tool_registry.copy()
|
||||
|
||||
def get_llm_available_tools(self) -> Dict[str, Type[BaseTool]]:
|
||||
"""获取LLM可用的Tool列表"""
|
||||
return self._llm_available_tools.copy()
|
||||
|
||||
def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]:
|
||||
"""获取Tool信息
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
|
||||
Returns:
|
||||
ToolInfo: 工具信息对象,如果工具不存在则返回 None
|
||||
"""
|
||||
info = self.get_component_info(tool_name, ComponentType.TOOL)
|
||||
return info if isinstance(info, ToolInfo) else None
|
||||
|
||||
# === EventHandler 特定查询方法 ===
|
||||
|
||||
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
|
||||
"""获取事件处理器注册表"""
|
||||
return self._event_handler_registry.copy()
|
||||
|
||||
def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]:
|
||||
"""获取事件处理器信息"""
|
||||
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
|
||||
return info if isinstance(info, EventHandlerInfo) else None
|
||||
|
||||
def get_enabled_event_handlers(self) -> Dict[str, Type[BaseEventHandler]]:
|
||||
"""获取启用的事件处理器"""
|
||||
return self._enabled_event_handlers.copy()
|
||||
|
||||
# === 插件查询方法 ===
|
||||
|
||||
def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]:
|
||||
"""获取插件信息"""
|
||||
return self._plugins.get(plugin_name)
|
||||
|
||||
def get_all_plugins(self) -> Dict[str, PluginInfo]:
|
||||
"""获取所有插件"""
|
||||
return self._plugins.copy()
|
||||
|
||||
# def get_enabled_plugins(self) -> Dict[str, PluginInfo]:
|
||||
# """获取所有启用的插件"""
|
||||
# return {name: info for name, info in self._plugins.items() if info.enabled}
|
||||
|
||||
def get_plugin_components(self, plugin_name: str) -> List[ComponentInfo]:
|
||||
"""获取插件的所有组件"""
|
||||
plugin_info = self.get_plugin_info(plugin_name)
|
||||
return plugin_info.components if plugin_info else []
|
||||
|
||||
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
|
||||
"""获取插件配置
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
Optional[dict]: 插件配置字典或None
|
||||
"""
|
||||
# 从插件管理器获取插件实例的配置
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
plugin_instance = plugin_manager.get_plugin_instance(plugin_name)
|
||||
return plugin_instance.config if plugin_instance else None
|
||||
|
||||
def get_registry_stats(self) -> Dict[str, Any]:
|
||||
"""获取注册中心统计信息"""
|
||||
action_components: int = 0
|
||||
command_components: int = 0
|
||||
tool_components: int = 0
|
||||
events_handlers: int = 0
|
||||
for component in self._components.values():
|
||||
if component.component_type == ComponentType.ACTION:
|
||||
action_components += 1
|
||||
elif component.component_type == ComponentType.COMMAND:
|
||||
command_components += 1
|
||||
elif component.component_type == ComponentType.TOOL:
|
||||
tool_components += 1
|
||||
elif component.component_type == ComponentType.EVENT_HANDLER:
|
||||
events_handlers += 1
|
||||
return {
|
||||
"action_components": action_components,
|
||||
"command_components": command_components,
|
||||
"tool_components": tool_components,
|
||||
"event_handlers": events_handlers,
|
||||
"total_components": len(self._components),
|
||||
"total_plugins": len(self._plugins),
|
||||
"components_by_type": {
|
||||
component_type.value: len(components) for component_type, components in self._components_by_type.items()
|
||||
},
|
||||
"enabled_components": len([c for c in self._components.values() if c.enabled]),
|
||||
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
|
||||
}
|
||||
|
||||
# === 组件移除相关 ===
|
||||
|
||||
async def unregister_plugin(self, plugin_name: str) -> bool:
|
||||
"""卸载插件及其所有组件
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: 是否成功卸载
|
||||
"""
|
||||
plugin_info = self.get_plugin_info(plugin_name)
|
||||
if not plugin_info:
|
||||
logger.warning(f"插件 {plugin_name} 未注册,无法卸载")
|
||||
return False
|
||||
|
||||
logger.info(f"开始卸载插件: {plugin_name}")
|
||||
|
||||
# 记录卸载失败的组件
|
||||
failed_components = []
|
||||
|
||||
# 逐个移除插件的所有组件
|
||||
for component_info in plugin_info.components:
|
||||
try:
|
||||
success = await self.remove_component(
|
||||
component_info.name,
|
||||
component_info.component_type,
|
||||
plugin_name,
|
||||
)
|
||||
if not success:
|
||||
failed_components.append(f"{component_info.component_type}.{component_info.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"移除组件 {component_info.name} 时发生异常: {e}")
|
||||
failed_components.append(f"{component_info.component_type}.{component_info.name}")
|
||||
|
||||
# 移除插件注册信息
|
||||
plugin_removed = self.remove_plugin_registry(plugin_name)
|
||||
|
||||
if failed_components:
|
||||
logger.warning(f"插件 {plugin_name} 部分组件卸载失败: {failed_components}")
|
||||
return False
|
||||
elif not plugin_removed:
|
||||
logger.error(f"插件 {plugin_name} 注册信息移除失败")
|
||||
return False
|
||||
else:
|
||||
logger.info(f"插件 {plugin_name} 卸载成功")
|
||||
return True
|
||||
|
||||
|
||||
# 创建全局组件注册中心实例
|
||||
component_registry = ComponentRegistry()
|
||||
262
src/plugin_system/core/events_manager.py
Normal file
262
src/plugin_system/core/events_manager.py
Normal file
@@ -0,0 +1,262 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
from typing import List, Dict, Optional, Type, Tuple, Any
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from .global_announcement_manager import global_announcement_manager
|
||||
|
||||
logger = get_logger("events_manager")
|
||||
|
||||
|
||||
class EventsManager:
|
||||
def __init__(self):
|
||||
# 有权重的 events 订阅者注册表
|
||||
self._events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType}
|
||||
self._handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
|
||||
self._handler_tasks: Dict[str, List[asyncio.Task]] = {} # 事件处理器正在处理的任务
|
||||
|
||||
def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool:
|
||||
"""注册事件处理器
|
||||
|
||||
Args:
|
||||
handler_info (EventHandlerInfo): 事件处理器信息
|
||||
handler_class (Type[BaseEventHandler]): 事件处理器类
|
||||
|
||||
Returns:
|
||||
bool: 是否注册成功
|
||||
"""
|
||||
handler_name = handler_info.name
|
||||
|
||||
if handler_name in self._handler_mapping:
|
||||
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
|
||||
return True
|
||||
|
||||
if not issubclass(handler_class, BaseEventHandler):
|
||||
logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类")
|
||||
return False
|
||||
|
||||
self._handler_mapping[handler_name] = handler_class
|
||||
return self._insert_event_handler(handler_class, handler_info)
|
||||
|
||||
async def handle_mai_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
message: Optional[MessageRecv] = None,
|
||||
llm_prompt: Optional[str] = None,
|
||||
llm_response: Optional[Dict[str, Any]] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
action_usage: Optional[List[str]] = None,
|
||||
) -> bool:
|
||||
"""处理 events"""
|
||||
from src.plugin_system.core import component_registry
|
||||
|
||||
continue_flag = True
|
||||
transformed_message: Optional[MaiMessages] = None
|
||||
if not message:
|
||||
assert stream_id, "如果没有消息,必须提供流ID"
|
||||
if event_type in [EventType.ON_MESSAGE, EventType.ON_PLAN, EventType.POST_LLM, EventType.AFTER_LLM]:
|
||||
transformed_message = self._build_message_from_stream(stream_id, llm_prompt, llm_response)
|
||||
else:
|
||||
transformed_message = self._transform_event_without_message(
|
||||
stream_id, llm_prompt, llm_response, action_usage
|
||||
)
|
||||
else:
|
||||
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
|
||||
for handler in self._events_subscribers.get(event_type, []):
|
||||
if transformed_message.stream_id:
|
||||
stream_id = transformed_message.stream_id
|
||||
if handler.handler_name in global_announcement_manager.get_disabled_chat_event_handlers(stream_id):
|
||||
continue
|
||||
handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {})
|
||||
if handler.intercept_message:
|
||||
try:
|
||||
success, continue_processing, result = await handler.execute(transformed_message)
|
||||
if not success:
|
||||
logger.error(f"EventHandler {handler.handler_name} 执行失败: {result}")
|
||||
else:
|
||||
logger.debug(f"EventHandler {handler.handler_name} 执行成功: {result}")
|
||||
continue_flag = continue_flag and continue_processing
|
||||
except Exception as e:
|
||||
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}")
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
handler_task = asyncio.create_task(handler.execute(transformed_message))
|
||||
handler_task.add_done_callback(self._task_done_callback)
|
||||
handler_task.set_name(f"{handler.plugin_name}-{handler.handler_name}")
|
||||
if handler.handler_name not in self._handler_tasks:
|
||||
self._handler_tasks[handler.handler_name] = []
|
||||
self._handler_tasks[handler.handler_name].append(handler_task)
|
||||
except Exception as e:
|
||||
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}")
|
||||
continue
|
||||
return continue_flag
|
||||
|
||||
def _insert_event_handler(self, handler_class: Type[BaseEventHandler], handler_info: EventHandlerInfo) -> bool:
|
||||
"""插入事件处理器到对应的事件类型列表中并设置其插件配置"""
|
||||
if handler_class.event_type == EventType.UNKNOWN:
|
||||
logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册")
|
||||
return False
|
||||
|
||||
handler_instance = handler_class()
|
||||
handler_instance.set_plugin_name(handler_info.plugin_name or "unknown")
|
||||
self._events_subscribers[handler_class.event_type].append(handler_instance)
|
||||
self._events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True)
|
||||
|
||||
return True
|
||||
|
||||
def _remove_event_handler_instance(self, handler_class: Type[BaseEventHandler]) -> bool:
|
||||
"""从事件类型列表中移除事件处理器"""
|
||||
display_handler_name = handler_class.handler_name or handler_class.__name__
|
||||
if handler_class.event_type == EventType.UNKNOWN:
|
||||
logger.warning(f"事件处理器 {display_handler_name} 的事件类型未知,不存在于处理器列表中")
|
||||
return False
|
||||
|
||||
handlers = self._events_subscribers[handler_class.event_type]
|
||||
for i, handler in enumerate(handlers):
|
||||
if isinstance(handler, handler_class):
|
||||
del handlers[i]
|
||||
logger.debug(f"事件处理器 {display_handler_name} 已移除")
|
||||
return True
|
||||
|
||||
logger.warning(f"未找到事件处理器 {display_handler_name},无法移除")
|
||||
return False
|
||||
|
||||
def _transform_event_message(
|
||||
self, message: MessageRecv, llm_prompt: Optional[str] = None, llm_response: Optional[Dict[str, Any]] = None
|
||||
) -> MaiMessages:
|
||||
"""转换事件消息格式"""
|
||||
# 直接赋值部分内容
|
||||
transformed_message = MaiMessages(
|
||||
llm_prompt=llm_prompt,
|
||||
llm_response_content=llm_response.get("content") if llm_response else None,
|
||||
llm_response_reasoning=llm_response.get("reasoning") if llm_response else None,
|
||||
llm_response_model=llm_response.get("model") if llm_response else None,
|
||||
llm_response_tool_call=llm_response.get("tool_calls") if llm_response else None,
|
||||
raw_message=message.raw_message,
|
||||
additional_data=message.message_info.additional_config or {},
|
||||
)
|
||||
|
||||
# 消息段处理
|
||||
if message.message_segment.type == "seglist":
|
||||
transformed_message.message_segments = list(message.message_segment.data) # type: ignore
|
||||
else:
|
||||
transformed_message.message_segments = [message.message_segment]
|
||||
|
||||
# stream_id 处理
|
||||
if hasattr(message, "chat_stream") and message.chat_stream:
|
||||
transformed_message.stream_id = message.chat_stream.stream_id
|
||||
|
||||
# 处理后文本
|
||||
transformed_message.plain_text = message.processed_plain_text
|
||||
|
||||
# 基本信息
|
||||
if hasattr(message, "message_info") and message.message_info:
|
||||
if message.message_info.platform:
|
||||
transformed_message.message_base_info["platform"] = message.message_info.platform
|
||||
if message.message_info.group_info:
|
||||
transformed_message.is_group_message = True
|
||||
transformed_message.message_base_info.update(
|
||||
{
|
||||
"group_id": message.message_info.group_info.group_id,
|
||||
"group_name": message.message_info.group_info.group_name,
|
||||
}
|
||||
)
|
||||
if message.message_info.user_info:
|
||||
if not transformed_message.is_group_message:
|
||||
transformed_message.is_private_message = True
|
||||
transformed_message.message_base_info.update(
|
||||
{
|
||||
"user_id": message.message_info.user_info.user_id,
|
||||
"user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称
|
||||
"user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名)
|
||||
}
|
||||
)
|
||||
|
||||
return transformed_message
|
||||
|
||||
def _build_message_from_stream(
|
||||
self, stream_id: str, llm_prompt: Optional[str] = None, llm_response: Optional[Dict[str, Any]] = None
|
||||
) -> MaiMessages:
|
||||
"""从流ID构建消息"""
|
||||
chat_stream = get_chat_manager().get_stream(stream_id)
|
||||
assert chat_stream, f"未找到流ID为 {stream_id} 的聊天流"
|
||||
message = chat_stream.context.get_last_message()
|
||||
return self._transform_event_message(message, llm_prompt, llm_response)
|
||||
|
||||
def _transform_event_without_message(
|
||||
self,
|
||||
stream_id: str,
|
||||
llm_prompt: Optional[str] = None,
|
||||
llm_response: Optional[Dict[str, Any]] = None,
|
||||
action_usage: Optional[List[str]] = None,
|
||||
) -> MaiMessages:
|
||||
"""没有message对象时进行转换"""
|
||||
chat_stream = get_chat_manager().get_stream(stream_id)
|
||||
assert chat_stream, f"未找到流ID为 {stream_id} 的聊天流"
|
||||
return MaiMessages(
|
||||
stream_id=stream_id,
|
||||
llm_prompt=llm_prompt,
|
||||
llm_response_content=(llm_response.get("content") if llm_response else None),
|
||||
llm_response_reasoning=(llm_response.get("reasoning") if llm_response else None),
|
||||
llm_response_model=llm_response.get("model") if llm_response else None,
|
||||
llm_response_tool_call=(llm_response.get("tool_calls") if llm_response else None),
|
||||
is_group_message=(not (not chat_stream.group_info)),
|
||||
is_private_message=(not chat_stream.group_info),
|
||||
action_usage=action_usage,
|
||||
additional_data={"response_is_processed": True},
|
||||
)
|
||||
|
||||
def _task_done_callback(self, task: asyncio.Task[Tuple[bool, bool, str | None]]):
|
||||
"""任务完成回调"""
|
||||
task_name = task.get_name() or "Unknown Task"
|
||||
try:
|
||||
success, _, result = task.result() # 忽略是否继续的标志,因为消息本身未被拦截
|
||||
if success:
|
||||
logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}")
|
||||
else:
|
||||
logger.error(f"事件处理任务 {task_name} 执行失败: {result}")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"事件处理任务 {task_name} 发生异常: {e}")
|
||||
finally:
|
||||
with contextlib.suppress(ValueError, KeyError):
|
||||
self._handler_tasks[task_name].remove(task)
|
||||
|
||||
async def cancel_handler_tasks(self, handler_name: str) -> None:
|
||||
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
|
||||
if remaining_tasks := [task for task in tasks_to_be_cancelled if not task.done()]:
|
||||
for task in remaining_tasks:
|
||||
task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=5)
|
||||
logger.info(f"已取消事件处理器 {handler_name} 的所有任务")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"取消事件处理器 {handler_name} 的任务超时,开始强制取消")
|
||||
except Exception as e:
|
||||
logger.error(f"取消事件处理器 {handler_name} 的任务时发生异常: {e}")
|
||||
if handler_name in self._handler_tasks:
|
||||
del self._handler_tasks[handler_name]
|
||||
|
||||
async def unregister_event_subscriber(self, handler_name: str) -> bool:
|
||||
"""取消注册事件处理器"""
|
||||
if handler_name not in self._handler_mapping:
|
||||
logger.warning(f"事件处理器 {handler_name} 不存在,无法取消注册")
|
||||
return False
|
||||
|
||||
await self.cancel_handler_tasks(handler_name)
|
||||
|
||||
handler_class = self._handler_mapping.pop(handler_name)
|
||||
if not self._remove_event_handler_instance(handler_class):
|
||||
return False
|
||||
|
||||
logger.info(f"事件处理器 {handler_name} 已成功取消注册")
|
||||
return True
|
||||
|
||||
|
||||
events_manager = EventsManager()
|
||||
120
src/plugin_system/core/global_announcement_manager.py
Normal file
120
src/plugin_system/core/global_announcement_manager.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from typing import List, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("global_announcement_manager")
|
||||
|
||||
|
||||
class GlobalAnnouncementManager:
|
||||
def __init__(self) -> None:
|
||||
# 用户禁用的动作,chat_id -> [action_name]
|
||||
self._user_disabled_actions: Dict[str, List[str]] = {}
|
||||
# 用户禁用的命令,chat_id -> [command_name]
|
||||
self._user_disabled_commands: Dict[str, List[str]] = {}
|
||||
# 用户禁用的事件处理器,chat_id -> [handler_name]
|
||||
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
|
||||
# 用户禁用的工具,chat_id -> [tool_name]
|
||||
self._user_disabled_tools: Dict[str, List[str]] = {}
|
||||
|
||||
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
|
||||
"""禁用特定聊天的某个动作"""
|
||||
if chat_id not in self._user_disabled_actions:
|
||||
self._user_disabled_actions[chat_id] = []
|
||||
if action_name in self._user_disabled_actions[chat_id]:
|
||||
logger.warning(f"动作 {action_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_actions[chat_id].append(action_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
|
||||
"""启用特定聊天的某个动作"""
|
||||
if chat_id in self._user_disabled_actions:
|
||||
try:
|
||||
self._user_disabled_actions[chat_id].remove(action_name)
|
||||
return True
|
||||
except ValueError:
|
||||
logger.warning(f"动作 {action_name} 不在禁用列表中")
|
||||
return False
|
||||
return False
|
||||
|
||||
def disable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
|
||||
"""禁用特定聊天的某个命令"""
|
||||
if chat_id not in self._user_disabled_commands:
|
||||
self._user_disabled_commands[chat_id] = []
|
||||
if command_name in self._user_disabled_commands[chat_id]:
|
||||
logger.warning(f"命令 {command_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_commands[chat_id].append(command_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
|
||||
"""启用特定聊天的某个命令"""
|
||||
if chat_id in self._user_disabled_commands:
|
||||
try:
|
||||
self._user_disabled_commands[chat_id].remove(command_name)
|
||||
return True
|
||||
except ValueError:
|
||||
logger.warning(f"命令 {command_name} 不在禁用列表中")
|
||||
return False
|
||||
return False
|
||||
|
||||
def disable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
|
||||
"""禁用特定聊天的某个事件处理器"""
|
||||
if chat_id not in self._user_disabled_event_handlers:
|
||||
self._user_disabled_event_handlers[chat_id] = []
|
||||
if handler_name in self._user_disabled_event_handlers[chat_id]:
|
||||
logger.warning(f"事件处理器 {handler_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_event_handlers[chat_id].append(handler_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
|
||||
"""启用特定聊天的某个事件处理器"""
|
||||
if chat_id in self._user_disabled_event_handlers:
|
||||
try:
|
||||
self._user_disabled_event_handlers[chat_id].remove(handler_name)
|
||||
return True
|
||||
except ValueError:
|
||||
logger.warning(f"事件处理器 {handler_name} 不在禁用列表中")
|
||||
return False
|
||||
return False
|
||||
|
||||
def disable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
|
||||
"""禁用特定聊天的某个工具"""
|
||||
if chat_id not in self._user_disabled_tools:
|
||||
self._user_disabled_tools[chat_id] = []
|
||||
if tool_name in self._user_disabled_tools[chat_id]:
|
||||
logger.warning(f"工具 {tool_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_tools[chat_id].append(tool_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
|
||||
"""启用特定聊天的某个工具"""
|
||||
if chat_id in self._user_disabled_tools:
|
||||
try:
|
||||
self._user_disabled_tools[chat_id].remove(tool_name)
|
||||
return True
|
||||
except ValueError:
|
||||
logger.warning(f"工具 {tool_name} 不在禁用列表中")
|
||||
return False
|
||||
return False
|
||||
|
||||
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有动作"""
|
||||
return self._user_disabled_actions.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_commands(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有命令"""
|
||||
return self._user_disabled_commands.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有事件处理器"""
|
||||
return self._user_disabled_event_handlers.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有工具"""
|
||||
return self._user_disabled_tools.get(chat_id, []).copy()
|
||||
|
||||
|
||||
global_announcement_manager = GlobalAnnouncementManager()
|
||||
242
src/plugin_system/core/plugin_hot_reload.py
Normal file
242
src/plugin_system/core/plugin_hot_reload.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""
|
||||
插件热重载模块
|
||||
|
||||
使用 Watchdog 监听插件目录变化,自动重载插件
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Dict, Set
|
||||
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .plugin_manager import plugin_manager
|
||||
|
||||
logger = get_logger("plugin_hot_reload")
|
||||
|
||||
|
||||
class PluginFileHandler(FileSystemEventHandler):
|
||||
"""插件文件变化处理器"""
|
||||
|
||||
def __init__(self, hot_reload_manager):
|
||||
super().__init__()
|
||||
self.hot_reload_manager = hot_reload_manager
|
||||
self.pending_reloads: Set[str] = set() # 待重载的插件名称
|
||||
self.last_reload_time: Dict[str, float] = {} # 上次重载时间
|
||||
self.debounce_delay = 1.0 # 防抖延迟(秒)
|
||||
|
||||
def on_modified(self, event):
|
||||
"""文件修改事件"""
|
||||
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
|
||||
self._handle_file_change(event.src_path, "modified")
|
||||
|
||||
def on_created(self, event):
|
||||
"""文件创建事件"""
|
||||
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
|
||||
self._handle_file_change(event.src_path, "created")
|
||||
|
||||
def on_deleted(self, event):
|
||||
"""文件删除事件"""
|
||||
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
|
||||
self._handle_file_change(event.src_path, "deleted")
|
||||
|
||||
def _handle_file_change(self, file_path: str, change_type: str):
|
||||
"""处理文件变化"""
|
||||
try:
|
||||
# 获取插件名称
|
||||
plugin_name = self._get_plugin_name_from_path(file_path)
|
||||
if not plugin_name:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
last_time = self.last_reload_time.get(plugin_name, 0)
|
||||
|
||||
# 防抖处理,避免频繁重载
|
||||
if current_time - last_time < self.debounce_delay:
|
||||
return
|
||||
|
||||
file_name = Path(file_path).name
|
||||
logger.info(f"📁 检测到插件文件变化: {file_name} ({change_type})")
|
||||
|
||||
# 如果是删除事件,处理关键文件删除
|
||||
if change_type == "deleted":
|
||||
if file_name == "plugin.py":
|
||||
if plugin_name in plugin_manager.loaded_plugins:
|
||||
logger.info(f"🗑️ 插件主文件被删除,卸载插件: {plugin_name}")
|
||||
self.hot_reload_manager._unload_plugin(plugin_name)
|
||||
return
|
||||
elif file_name == "manifest.toml":
|
||||
if plugin_name in plugin_manager.loaded_plugins:
|
||||
logger.info(f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name}")
|
||||
self.hot_reload_manager._unload_plugin(plugin_name)
|
||||
return
|
||||
|
||||
# 对于修改和创建事件,都进行重载
|
||||
# 添加到待重载列表
|
||||
self.pending_reloads.add(plugin_name)
|
||||
self.last_reload_time[plugin_name] = current_time
|
||||
|
||||
# 延迟重载,避免文件正在写入时重载
|
||||
reload_thread = Thread(
|
||||
target=self._delayed_reload,
|
||||
args=(plugin_name,),
|
||||
daemon=True
|
||||
)
|
||||
reload_thread.start()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 处理文件变化时发生错误: {e}")
|
||||
|
||||
def _delayed_reload(self, plugin_name: str):
|
||||
"""延迟重载插件"""
|
||||
try:
|
||||
time.sleep(self.debounce_delay)
|
||||
|
||||
if plugin_name in self.pending_reloads:
|
||||
self.pending_reloads.remove(plugin_name)
|
||||
self.hot_reload_manager._reload_plugin(plugin_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 延迟重载插件 {plugin_name} 时发生错误: {e}")
|
||||
|
||||
def _get_plugin_name_from_path(self, file_path: str) -> str:
|
||||
"""从文件路径获取插件名称"""
|
||||
try:
|
||||
path = Path(file_path)
|
||||
|
||||
# 检查是否在监听的插件目录中
|
||||
plugin_root = Path(self.hot_reload_manager.watch_directory)
|
||||
if not path.is_relative_to(plugin_root):
|
||||
return ""
|
||||
|
||||
# 获取插件目录名(插件名)
|
||||
relative_path = path.relative_to(plugin_root)
|
||||
plugin_name = relative_path.parts[0]
|
||||
|
||||
# 确认这是一个有效的插件目录(检查是否有 plugin.py 或 manifest.toml)
|
||||
plugin_dir = plugin_root / plugin_name
|
||||
if plugin_dir.is_dir() and ((plugin_dir / "plugin.py").exists() or (plugin_dir / "manifest.toml").exists()):
|
||||
return plugin_name
|
||||
|
||||
return ""
|
||||
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
class PluginHotReloadManager:
|
||||
"""插件热重载管理器"""
|
||||
|
||||
def __init__(self, watch_directory: str = None):
|
||||
print("fuck")
|
||||
print(os.getcwd())
|
||||
self.watch_directory = os.path.join(os.getcwd(), "plugins")
|
||||
self.observer = None
|
||||
self.file_handler = None
|
||||
self.is_running = False
|
||||
|
||||
# 确保监听目录存在
|
||||
if not os.path.exists(self.watch_directory):
|
||||
os.makedirs(self.watch_directory, exist_ok=True)
|
||||
logger.info(f"创建插件监听目录: {self.watch_directory}")
|
||||
|
||||
def start(self):
|
||||
"""启动热重载监听"""
|
||||
if self.is_running:
|
||||
logger.warning("插件热重载已经在运行中")
|
||||
return
|
||||
|
||||
try:
|
||||
self.observer = Observer()
|
||||
self.file_handler = PluginFileHandler(self)
|
||||
|
||||
self.observer.schedule(
|
||||
self.file_handler,
|
||||
self.watch_directory,
|
||||
recursive=True
|
||||
)
|
||||
|
||||
self.observer.start()
|
||||
self.is_running = True
|
||||
|
||||
logger.info("🚀 插件热重载已启动,监听目录: plugins")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 启动插件热重载失败: {e}")
|
||||
self.is_running = False
|
||||
|
||||
def stop(self):
|
||||
"""停止热重载监听"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
if self.observer:
|
||||
self.observer.stop()
|
||||
self.observer.join()
|
||||
|
||||
self.is_running = False
|
||||
|
||||
def _reload_plugin(self, plugin_name: str):
|
||||
"""重载指定插件"""
|
||||
try:
|
||||
logger.info(f"🔄 开始重载插件: {plugin_name}")
|
||||
|
||||
if plugin_manager.reload_plugin(plugin_name):
|
||||
logger.info(f"✅ 插件重载成功: {plugin_name}")
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 重载插件 {plugin_name} 时发生错误: {e}")
|
||||
|
||||
def _unload_plugin(self, plugin_name: str):
|
||||
"""卸载指定插件"""
|
||||
try:
|
||||
logger.info(f"🗑️ 开始卸载插件: {plugin_name}")
|
||||
|
||||
if plugin_manager.unload_plugin(plugin_name):
|
||||
logger.info(f"✅ 插件卸载成功: {plugin_name}")
|
||||
else:
|
||||
logger.error(f"❌ 插件卸载失败: {plugin_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 卸载插件 {plugin_name} 时发生错误: {e}")
|
||||
|
||||
def reload_all_plugins(self):
|
||||
"""重载所有插件"""
|
||||
try:
|
||||
logger.info("🔄 开始重载所有插件...")
|
||||
|
||||
# 获取当前已加载的插件列表
|
||||
loaded_plugins = list(plugin_manager.loaded_plugins.keys())
|
||||
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
for plugin_name in loaded_plugins:
|
||||
if plugin_manager.reload_plugin(plugin_name):
|
||||
success_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
|
||||
logger.info(f"✅ 插件重载完成: 成功 {success_count} 个,失败 {fail_count} 个")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 重载所有插件时发生错误: {e}")
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""获取热重载状态"""
|
||||
return {
|
||||
"is_running": self.is_running,
|
||||
"watch_directory": self.watch_directory,
|
||||
"loaded_plugins": len(plugin_manager.loaded_plugins),
|
||||
"failed_plugins": len(plugin_manager.failed_plugins),
|
||||
}
|
||||
|
||||
|
||||
# 全局热重载管理器实例
|
||||
hot_reload_manager = PluginHotReloadManager()
|
||||
593
src/plugin_system/core/plugin_manager.py
Normal file
593
src/plugin_system/core/plugin_manager.py
Normal file
@@ -0,0 +1,593 @@
|
||||
import os
|
||||
import traceback
|
||||
import sys
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Type, Any
|
||||
from importlib.util import spec_from_file_location, module_from_spec
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.plugin_base import PluginBase
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
from src.plugin_system.utils.manifest_utils import VersionComparator
|
||||
from .component_registry import component_registry
|
||||
|
||||
logger = get_logger("plugin_manager")
|
||||
|
||||
|
||||
class PluginManager:
|
||||
"""
|
||||
插件管理器类
|
||||
|
||||
负责加载,重载和卸载插件,同时管理插件的所有组件
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.plugin_directories: List[str] = [] # 插件根目录列表
|
||||
self.plugin_classes: Dict[str, Type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类
|
||||
self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径
|
||||
|
||||
self.loaded_plugins: Dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例
|
||||
self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息
|
||||
|
||||
# 确保插件目录存在
|
||||
self._ensure_plugin_directories()
|
||||
logger.info("插件管理器初始化完成")
|
||||
|
||||
# === 插件目录管理 ===
|
||||
|
||||
def add_plugin_directory(self, directory: str) -> bool:
|
||||
"""添加插件目录"""
|
||||
if os.path.exists(directory):
|
||||
if directory not in self.plugin_directories:
|
||||
self.plugin_directories.append(directory)
|
||||
logger.debug(f"已添加插件目录: {directory}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"插件不可重复加载: {directory}")
|
||||
else:
|
||||
logger.warning(f"插件目录不存在: {directory}")
|
||||
return False
|
||||
|
||||
# === 插件加载管理 ===
|
||||
|
||||
def load_all_plugins(self) -> Tuple[int, int]:
|
||||
"""加载所有插件
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: (插件数量, 组件数量)
|
||||
"""
|
||||
logger.debug("开始加载所有插件...")
|
||||
|
||||
# 第一阶段:加载所有插件模块(注册插件类)
|
||||
total_loaded_modules = 0
|
||||
total_failed_modules = 0
|
||||
|
||||
for directory in self.plugin_directories:
|
||||
loaded, failed = self._load_plugin_modules_from_directory(directory)
|
||||
total_loaded_modules += loaded
|
||||
total_failed_modules += failed
|
||||
|
||||
logger.debug(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}")
|
||||
|
||||
total_registered = 0
|
||||
total_failed_registration = 0
|
||||
|
||||
for plugin_name in self.plugin_classes.keys():
|
||||
load_status, count = self.load_registered_plugin_classes(plugin_name)
|
||||
if load_status:
|
||||
total_registered += 1
|
||||
else:
|
||||
total_failed_registration += count
|
||||
|
||||
self._show_stats(total_registered, total_failed_registration)
|
||||
|
||||
return total_registered, total_failed_registration
|
||||
|
||||
def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
|
||||
# sourcery skip: extract-duplicate-method, extract-method
|
||||
"""
|
||||
加载已经注册的插件类
|
||||
"""
|
||||
plugin_class = self.plugin_classes.get(plugin_name)
|
||||
if not plugin_class:
|
||||
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
|
||||
return False, 1
|
||||
try:
|
||||
# 使用记录的插件目录路径
|
||||
plugin_dir = self.plugin_paths.get(plugin_name)
|
||||
|
||||
# 如果没有记录,直接返回失败
|
||||
if not plugin_dir:
|
||||
return False, 1
|
||||
|
||||
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败)
|
||||
if not plugin_instance:
|
||||
logger.error(f"插件 {plugin_name} 实例化失败")
|
||||
return False, 1
|
||||
# 检查插件是否启用
|
||||
if not plugin_instance.enable_plugin:
|
||||
logger.info(f"插件 {plugin_name} 已禁用,跳过加载")
|
||||
return False, 0
|
||||
|
||||
# 检查版本兼容性
|
||||
is_compatible, compatibility_error = self._check_plugin_version_compatibility(
|
||||
plugin_name, plugin_instance.manifest_data
|
||||
)
|
||||
if not is_compatible:
|
||||
self.failed_plugins[plugin_name] = compatibility_error
|
||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}")
|
||||
return False, 1
|
||||
if plugin_instance.register_plugin():
|
||||
self.loaded_plugins[plugin_name] = plugin_instance
|
||||
self._show_plugin_components(plugin_name)
|
||||
return True, 1
|
||||
else:
|
||||
self.failed_plugins[plugin_name] = "插件注册失败"
|
||||
logger.error(f"❌ 插件注册失败: {plugin_name}")
|
||||
return False, 1
|
||||
|
||||
except FileNotFoundError as e:
|
||||
# manifest文件缺失
|
||||
error_msg = f"缺少manifest文件: {str(e)}"
|
||||
self.failed_plugins[plugin_name] = error_msg
|
||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
||||
return False, 1
|
||||
|
||||
except ValueError as e:
|
||||
# manifest文件格式错误或验证失败
|
||||
traceback.print_exc()
|
||||
error_msg = f"manifest验证失败: {str(e)}"
|
||||
self.failed_plugins[plugin_name] = error_msg
|
||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
||||
return False, 1
|
||||
|
||||
except Exception as e:
|
||||
# 其他错误
|
||||
error_msg = f"未知错误: {str(e)}"
|
||||
self.failed_plugins[plugin_name] = error_msg
|
||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
||||
logger.debug("详细错误信息: ", exc_info=True)
|
||||
return False, 1
|
||||
|
||||
async def remove_registered_plugin(self, plugin_name: str) -> bool:
|
||||
"""
|
||||
禁用插件模块
|
||||
"""
|
||||
if not plugin_name:
|
||||
raise ValueError("插件名称不能为空")
|
||||
if plugin_name not in self.loaded_plugins:
|
||||
logger.warning(f"插件 {plugin_name} 未加载")
|
||||
return False
|
||||
plugin_instance = self.loaded_plugins[plugin_name]
|
||||
plugin_info = plugin_instance.plugin_info
|
||||
success = True
|
||||
for component in plugin_info.components:
|
||||
success &= await component_registry.remove_component(component.name, component.component_type, plugin_name)
|
||||
success &= component_registry.remove_plugin_registry(plugin_name)
|
||||
del self.loaded_plugins[plugin_name]
|
||||
return success
|
||||
|
||||
async def reload_registered_plugin(self, plugin_name: str) -> bool:
|
||||
"""
|
||||
重载插件模块
|
||||
"""
|
||||
if not await self.remove_registered_plugin(plugin_name):
|
||||
return False
|
||||
if not self.load_registered_plugin_classes(plugin_name)[0]:
|
||||
return False
|
||||
logger.debug(f"插件 {plugin_name} 重载成功")
|
||||
return True
|
||||
|
||||
def rescan_plugin_directory(self) -> Tuple[int, int]:
|
||||
"""
|
||||
重新扫描插件根目录
|
||||
"""
|
||||
total_success = 0
|
||||
total_fail = 0
|
||||
for directory in self.plugin_directories:
|
||||
if os.path.exists(directory):
|
||||
logger.debug(f"重新扫描插件根目录: {directory}")
|
||||
success, fail = self._load_plugin_modules_from_directory(directory)
|
||||
total_success += success
|
||||
total_fail += fail
|
||||
else:
|
||||
logger.warning(f"插件根目录不存在: {directory}")
|
||||
return total_success, total_fail
|
||||
|
||||
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
|
||||
"""获取插件实例
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
Optional[BasePlugin]: 插件实例或None
|
||||
"""
|
||||
return self.loaded_plugins.get(plugin_name)
|
||||
|
||||
# === 查询方法 ===
|
||||
def list_loaded_plugins(self) -> List[str]:
|
||||
"""
|
||||
列出所有当前加载的插件。
|
||||
|
||||
Returns:
|
||||
list: 当前加载的插件名称列表。
|
||||
"""
|
||||
return list(self.loaded_plugins.keys())
|
||||
|
||||
def list_registered_plugins(self) -> List[str]:
|
||||
"""
|
||||
列出所有已注册的插件类。
|
||||
|
||||
Returns:
|
||||
list: 已注册的插件类名称列表。
|
||||
"""
|
||||
return list(self.plugin_classes.keys())
|
||||
|
||||
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
|
||||
"""
|
||||
获取指定插件的路径。
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
Optional[str]: 插件目录的绝对路径,如果插件不存在则返回None。
|
||||
"""
|
||||
return self.plugin_paths.get(plugin_name)
|
||||
|
||||
# === 私有方法 ===
|
||||
# == 目录管理 ==
|
||||
def _ensure_plugin_directories(self) -> None:
|
||||
"""确保所有插件根目录存在,如果不存在则创建"""
|
||||
default_directories = ["src/plugins/built_in", "plugins"]
|
||||
|
||||
for directory in default_directories:
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
logger.info(f"创建插件根目录: {directory}")
|
||||
if directory not in self.plugin_directories:
|
||||
self.plugin_directories.append(directory)
|
||||
logger.debug(f"已添加插件根目录: {directory}")
|
||||
else:
|
||||
logger.warning(f"根目录不可重复加载: {directory}")
|
||||
|
||||
# == 插件加载 ==
|
||||
|
||||
def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]:
|
||||
"""从指定目录加载插件模块"""
|
||||
loaded_count = 0
|
||||
failed_count = 0
|
||||
|
||||
if not os.path.exists(directory):
|
||||
logger.warning(f"插件根目录不存在: {directory}")
|
||||
return 0, 1
|
||||
|
||||
logger.debug(f"正在扫描插件根目录: {directory}")
|
||||
|
||||
# 遍历目录中的所有包
|
||||
for item in os.listdir(directory):
|
||||
item_path = os.path.join(directory, item)
|
||||
|
||||
if os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"):
|
||||
plugin_file = os.path.join(item_path, "plugin.py")
|
||||
if os.path.exists(plugin_file):
|
||||
if self._load_plugin_module_file(plugin_file):
|
||||
loaded_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
return loaded_count, failed_count
|
||||
|
||||
def _load_plugin_module_file(self, plugin_file: str) -> bool:
|
||||
# sourcery skip: extract-method
|
||||
"""加载单个插件模块文件
|
||||
|
||||
Args:
|
||||
plugin_file: 插件文件路径
|
||||
plugin_name: 插件名称
|
||||
plugin_dir: 插件目录路径
|
||||
"""
|
||||
# 生成模块名
|
||||
plugin_path = Path(plugin_file)
|
||||
module_name = ".".join(plugin_path.parent.parts)
|
||||
|
||||
try:
|
||||
# 动态导入插件模块
|
||||
spec = spec_from_file_location(module_name, plugin_file)
|
||||
if spec is None or spec.loader is None:
|
||||
logger.error(f"无法创建模块规范: {plugin_file}")
|
||||
return False
|
||||
|
||||
module = module_from_spec(spec)
|
||||
module.__package__ = module_name # 设置模块包名
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
logger.debug(f"插件模块加载成功: {plugin_file}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"加载插件模块 {plugin_file} 失败: {e}"
|
||||
logger.error(error_msg)
|
||||
self.failed_plugins[module_name] = error_msg
|
||||
return False
|
||||
|
||||
# == 兼容性检查 ==
|
||||
|
||||
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""检查插件版本兼容性
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
manifest_data: manifest数据
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否兼容, 错误信息)
|
||||
"""
|
||||
if "host_application" not in manifest_data:
|
||||
return True, "" # 没有版本要求,默认兼容
|
||||
|
||||
host_app = manifest_data["host_application"]
|
||||
if not isinstance(host_app, dict):
|
||||
return True, ""
|
||||
|
||||
min_version = host_app.get("min_version", "")
|
||||
max_version = host_app.get("max_version", "")
|
||||
|
||||
if not min_version and not max_version:
|
||||
return True, "" # 没有版本要求,默认兼容
|
||||
|
||||
try:
|
||||
current_version = VersionComparator.get_current_host_version()
|
||||
is_compatible, error_msg = VersionComparator.is_version_in_range(current_version, min_version, max_version)
|
||||
if not is_compatible:
|
||||
return False, f"版本不兼容: {error_msg}"
|
||||
logger.debug(f"插件 {plugin_name} 版本兼容性检查通过")
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}")
|
||||
return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载
|
||||
|
||||
# == 显示统计与插件信息 ==
|
||||
|
||||
def _show_stats(self, total_registered: int, total_failed_registration: int):
|
||||
# sourcery skip: low-code-quality
|
||||
# 获取组件统计信息
|
||||
stats = component_registry.get_registry_stats()
|
||||
action_count = stats.get("action_components", 0)
|
||||
command_count = stats.get("command_components", 0)
|
||||
tool_count = stats.get("tool_components", 0)
|
||||
event_handler_count = stats.get("event_handlers", 0)
|
||||
total_components = stats.get("total_components", 0)
|
||||
|
||||
# 📋 显示插件加载总览
|
||||
if total_registered > 0:
|
||||
logger.info("🎉 插件系统加载完成!")
|
||||
logger.info(
|
||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, EventHandler: {event_handler_count})"
|
||||
)
|
||||
|
||||
# 显示详细的插件列表
|
||||
logger.info("📋 已加载插件详情:")
|
||||
for plugin_name in self.loaded_plugins.keys():
|
||||
if plugin_info := component_registry.get_plugin_info(plugin_name):
|
||||
# 插件基本信息
|
||||
version_info = f"v{plugin_info.version}" if plugin_info.version else ""
|
||||
author_info = f"by {plugin_info.author}" if plugin_info.author else "unknown"
|
||||
license_info = f"[{plugin_info.license}]" if plugin_info.license else ""
|
||||
info_parts = [part for part in [version_info, author_info, license_info] if part]
|
||||
extra_info = f" ({', '.join(info_parts)})" if info_parts else ""
|
||||
|
||||
logger.info(f" 📦 {plugin_info.display_name}{extra_info}")
|
||||
|
||||
# Manifest信息
|
||||
if plugin_info.manifest_data:
|
||||
"""
|
||||
if plugin_info.keywords:
|
||||
logger.info(f" 🏷️ 关键词: {', '.join(plugin_info.keywords)}")
|
||||
if plugin_info.categories:
|
||||
logger.info(f" 📁 分类: {', '.join(plugin_info.categories)}")
|
||||
"""
|
||||
if plugin_info.homepage_url:
|
||||
logger.info(f" 🌐 主页: {plugin_info.homepage_url}")
|
||||
|
||||
# 组件列表
|
||||
if plugin_info.components:
|
||||
action_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.ACTION
|
||||
]
|
||||
command_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
|
||||
]
|
||||
tool_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.TOOL
|
||||
]
|
||||
event_handler_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
|
||||
]
|
||||
|
||||
if action_components:
|
||||
action_names = [c.name for c in action_components]
|
||||
logger.info(f" 🎯 Action组件: {', '.join(action_names)}")
|
||||
|
||||
if command_components:
|
||||
command_names = [c.name for c in command_components]
|
||||
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
|
||||
if tool_components:
|
||||
tool_names = [c.name for c in tool_components]
|
||||
logger.info(f" 🛠️ Tool组件: {', '.join(tool_names)}")
|
||||
if event_handler_components:
|
||||
event_handler_names = [c.name for c in event_handler_components]
|
||||
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
|
||||
|
||||
# 依赖信息
|
||||
if plugin_info.dependencies:
|
||||
logger.info(f" 🔗 依赖: {', '.join(plugin_info.dependencies)}")
|
||||
|
||||
# 配置文件信息
|
||||
if plugin_info.config_file:
|
||||
config_status = "✅" if self.plugin_paths.get(plugin_name) else "❌"
|
||||
logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}")
|
||||
|
||||
root_path = Path(__file__)
|
||||
|
||||
# 查找项目根目录
|
||||
while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path:
|
||||
root_path = root_path.parent
|
||||
|
||||
# 显示目录统计
|
||||
logger.info("📂 加载目录统计:")
|
||||
for directory in self.plugin_directories:
|
||||
if os.path.exists(directory):
|
||||
plugins_in_dir = []
|
||||
for plugin_name in self.loaded_plugins.keys():
|
||||
plugin_path = self.plugin_paths.get(plugin_name, "")
|
||||
if (
|
||||
Path(plugin_path)
|
||||
.resolve()
|
||||
.is_relative_to(Path(os.path.join(str(root_path), directory)).resolve())
|
||||
):
|
||||
plugins_in_dir.append(plugin_name)
|
||||
|
||||
if plugins_in_dir:
|
||||
logger.info(f" 📁 {directory}: {len(plugins_in_dir)}个插件 ({', '.join(plugins_in_dir)})")
|
||||
else:
|
||||
logger.info(f" 📁 {directory}: 0个插件")
|
||||
|
||||
# 失败信息
|
||||
if total_failed_registration > 0:
|
||||
logger.info(f"⚠️ 失败统计: {total_failed_registration}个插件加载失败")
|
||||
for failed_plugin, error in self.failed_plugins.items():
|
||||
logger.info(f" ❌ {failed_plugin}: {error}")
|
||||
else:
|
||||
logger.warning("😕 没有成功加载任何插件")
|
||||
|
||||
def _show_plugin_components(self, plugin_name: str) -> None:
|
||||
if plugin_info := component_registry.get_plugin_info(plugin_name):
|
||||
component_types = {}
|
||||
for comp in plugin_info.components:
|
||||
comp_type = comp.component_type.name
|
||||
component_types[comp_type] = component_types.get(comp_type, 0) + 1
|
||||
|
||||
components_str = ", ".join([f"{count}个{ctype}" for ctype, count in component_types.items()])
|
||||
|
||||
# 显示manifest信息
|
||||
manifest_info = ""
|
||||
if plugin_info.license:
|
||||
manifest_info += f" [{plugin_info.license}]"
|
||||
if plugin_info.keywords:
|
||||
manifest_info += f" 关键词: {', '.join(plugin_info.keywords[:3])}" # 只显示前3个关键词
|
||||
if len(plugin_info.keywords) > 3:
|
||||
manifest_info += "..."
|
||||
|
||||
logger.info(
|
||||
f"✅ 插件加载成功: {plugin_name} v{plugin_info.version} ({components_str}){manifest_info} - {plugin_info.description}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"✅ 插件加载成功: {plugin_name}")
|
||||
|
||||
# === 插件卸载和重载管理 ===
|
||||
|
||||
def unload_plugin(self, plugin_name: str) -> bool:
|
||||
"""卸载指定插件
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: 卸载是否成功
|
||||
"""
|
||||
if plugin_name not in self.loaded_plugins:
|
||||
logger.warning(f"插件 {plugin_name} 未加载,无需卸载")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 获取插件实例
|
||||
plugin_instance = self.loaded_plugins[plugin_name]
|
||||
|
||||
# 调用插件的清理方法(如果有的话)
|
||||
if hasattr(plugin_instance, 'on_unload'):
|
||||
plugin_instance.on_unload()
|
||||
|
||||
# 从组件注册表中移除插件的所有组件
|
||||
component_registry.unregister_plugin(plugin_name)
|
||||
|
||||
# 从已加载插件中移除
|
||||
del self.loaded_plugins[plugin_name]
|
||||
|
||||
# 从失败列表中移除(如果存在)
|
||||
if plugin_name in self.failed_plugins:
|
||||
del self.failed_plugins[plugin_name]
|
||||
|
||||
logger.info(f"✅ 插件卸载成功: {plugin_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 插件卸载失败: {plugin_name} - {str(e)}")
|
||||
return False
|
||||
|
||||
def reload_plugin(self, plugin_name: str) -> bool:
|
||||
"""重载指定插件
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: 重载是否成功
|
||||
"""
|
||||
try:
|
||||
# 先卸载插件
|
||||
if plugin_name in self.loaded_plugins:
|
||||
self.unload_plugin(plugin_name)
|
||||
|
||||
# 清除Python模块缓存
|
||||
plugin_path = self.plugin_paths.get(plugin_name)
|
||||
if plugin_path:
|
||||
plugin_file = os.path.join(plugin_path, "plugin.py")
|
||||
if os.path.exists(plugin_file):
|
||||
# 从sys.modules中移除相关模块
|
||||
modules_to_remove = []
|
||||
plugin_module_prefix = ".".join(Path(plugin_file).parent.parts)
|
||||
|
||||
for module_name in sys.modules:
|
||||
if module_name.startswith(plugin_module_prefix):
|
||||
modules_to_remove.append(module_name)
|
||||
|
||||
for module_name in modules_to_remove:
|
||||
del sys.modules[module_name]
|
||||
|
||||
# 从插件类注册表中移除
|
||||
if plugin_name in self.plugin_classes:
|
||||
del self.plugin_classes[plugin_name]
|
||||
|
||||
# 重新加载插件模块
|
||||
if self._load_plugin_module_file(plugin_file):
|
||||
# 重新加载插件实例
|
||||
success, _ = self.load_registered_plugin_classes(plugin_name)
|
||||
if success:
|
||||
logger.info(f"🔄 插件重载成功: {plugin_name}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - 实例化失败")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - 模块加载失败")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件文件不存在")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件路径未知")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - {str(e)}")
|
||||
logger.debug("详细错误信息: ", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
# 全局插件管理器实例
|
||||
plugin_manager = PluginManager()
|
||||
421
src/plugin_system/core/tool_use.py
Normal file
421
src/plugin_system/core/tool_use.py
Normal file
@@ -0,0 +1,421 @@
|
||||
import time
|
||||
from typing import List, Dict, Tuple, Optional, Any
|
||||
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.llm_models.payload_content import ToolCall
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("tool_use")
|
||||
|
||||
|
||||
def init_tool_executor_prompt():
|
||||
"""初始化工具执行器的提示词"""
|
||||
tool_executor_prompt = """
|
||||
你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now}。
|
||||
群里正在进行的聊天内容:
|
||||
{chat_history}
|
||||
|
||||
现在,{sender}发送了内容:{target_message},你想要回复ta。
|
||||
请仔细分析聊天内容,考虑以下几点:
|
||||
1. 内容中是否包含需要查询信息的问题
|
||||
2. 是否有明确的工具使用指令
|
||||
|
||||
If you need to use a tool, please directly call the corresponding tool function. If you do not need to use any tool, simply output "No tool needed".
|
||||
"""
|
||||
Prompt(tool_executor_prompt, "tool_executor_prompt")
|
||||
|
||||
|
||||
# 初始化提示词
|
||||
init_tool_executor_prompt()
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""独立的工具执行器组件
|
||||
|
||||
可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3):
|
||||
"""初始化工具执行器
|
||||
|
||||
Args:
|
||||
executor_id: 执行器标识符,用于日志记录
|
||||
enable_cache: 是否启用缓存机制
|
||||
cache_ttl: 缓存生存时间(周期数)
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(self.chat_id)
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
|
||||
|
||||
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
|
||||
|
||||
# 缓存配置
|
||||
self.enable_cache = enable_cache
|
||||
self.cache_ttl = cache_ttl
|
||||
self.tool_cache = {} # 格式: {cache_key: {"result": result, "ttl": ttl, "timestamp": timestamp}}
|
||||
|
||||
logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'},TTL={cache_ttl}")
|
||||
|
||||
async def execute_from_chat_message(
|
||||
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
|
||||
) -> Tuple[List[Dict[str, Any]], List[str], str]:
|
||||
"""从聊天消息执行工具
|
||||
|
||||
Args:
|
||||
target_message: 目标消息内容
|
||||
chat_history: 聊天历史
|
||||
sender: 发送者
|
||||
return_details: 是否返回详细信息(使用的工具列表和提示词)
|
||||
|
||||
Returns:
|
||||
如果return_details为False: Tuple[List[Dict], List[str], str] - (工具执行结果列表, 空, 空)
|
||||
如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词)
|
||||
"""
|
||||
|
||||
# 首先检查缓存
|
||||
cache_key = self._generate_cache_key(target_message, chat_history, sender)
|
||||
if cached_result := self._get_from_cache(cache_key):
|
||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
|
||||
if not return_details:
|
||||
return cached_result, [], ""
|
||||
|
||||
# 从缓存结果中提取工具名称
|
||||
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
|
||||
return cached_result, used_tools, ""
|
||||
|
||||
# 缓存未命中,执行工具调用
|
||||
# 获取可用工具
|
||||
tools = self._get_tool_definitions()
|
||||
|
||||
# 获取当前时间
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
|
||||
# 构建工具调用提示词
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"tool_executor_prompt",
|
||||
target_message=target_message,
|
||||
chat_history=chat_history,
|
||||
sender=sender,
|
||||
bot_name=bot_name,
|
||||
time_now=time_now,
|
||||
)
|
||||
|
||||
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
|
||||
|
||||
# 调用LLM进行工具决策
|
||||
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
|
||||
prompt=prompt, tools=tools, raise_when_empty=False
|
||||
)
|
||||
|
||||
# 执行工具调用
|
||||
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
||||
|
||||
# 缓存结果
|
||||
if tool_results:
|
||||
self._set_cache(cache_key, tool_results)
|
||||
|
||||
if used_tools:
|
||||
logger.info(f"{self.log_prefix}工具执行完成,共执行{len(used_tools)}个工具: {used_tools}")
|
||||
|
||||
if return_details:
|
||||
return tool_results, used_tools, prompt
|
||||
else:
|
||||
return tool_results, [], ""
|
||||
|
||||
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
all_tools = get_llm_available_tool_definitions()
|
||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||
return [definition for name, definition in all_tools if name not in user_disabled_tools]
|
||||
|
||||
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
"""执行工具调用
|
||||
|
||||
Args:
|
||||
tool_calls: LLM返回的工具调用列表
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
|
||||
"""
|
||||
tool_results: List[Dict[str, Any]] = []
|
||||
used_tools = []
|
||||
|
||||
if not tool_calls:
|
||||
logger.debug(f"{self.log_prefix}无需执行工具")
|
||||
return [], []
|
||||
|
||||
# 提取tool_calls中的函数名称
|
||||
func_names = [call.func_name for call in tool_calls if call.func_name]
|
||||
|
||||
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
|
||||
|
||||
# 执行每个工具调用
|
||||
for tool_call in tool_calls:
|
||||
try:
|
||||
tool_name = tool_call.func_name
|
||||
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
|
||||
|
||||
# 执行工具
|
||||
result = await self.execute_tool_call(tool_call)
|
||||
|
||||
if result:
|
||||
tool_info = {
|
||||
"type": result.get("type", "unknown_type"),
|
||||
"id": result.get("id", f"tool_exec_{time.time()}"),
|
||||
"content": result.get("content", ""),
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
content = tool_info["content"]
|
||||
if not isinstance(content, (str, list, tuple)):
|
||||
tool_info["content"] = str(content)
|
||||
|
||||
tool_results.append(tool_info)
|
||||
used_tools.append(tool_name)
|
||||
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
|
||||
preview = content[:200]
|
||||
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
|
||||
# 添加错误信息到结果中
|
||||
error_info = {
|
||||
"type": "tool_error",
|
||||
"id": f"tool_error_{time.time()}",
|
||||
"content": f"工具{tool_name}执行失败: {str(e)}",
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
tool_results.append(error_info)
|
||||
|
||||
return tool_results, used_tools
|
||||
|
||||
async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
|
||||
# sourcery skip: use-assigned-variable
|
||||
"""执行单个工具调用
|
||||
|
||||
Args:
|
||||
tool_call: 工具调用对象
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 工具调用结果,如果失败则返回None
|
||||
"""
|
||||
try:
|
||||
function_name = tool_call.func_name
|
||||
function_args = tool_call.args or {}
|
||||
function_args["llm_called"] = True # 标记为LLM调用
|
||||
|
||||
# 获取对应工具实例
|
||||
tool_instance = tool_instance or get_tool_instance(function_name)
|
||||
if not tool_instance:
|
||||
logger.warning(f"未知工具名称: {function_name}")
|
||||
return None
|
||||
|
||||
# 执行工具
|
||||
result = await tool_instance.execute(function_args)
|
||||
if result:
|
||||
return {
|
||||
"tool_call_id": tool_call.call_id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"type": "function",
|
||||
"content": result["content"],
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
||||
raise e
|
||||
|
||||
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
|
||||
"""生成缓存键
|
||||
|
||||
Args:
|
||||
target_message: 目标消息内容
|
||||
chat_history: 聊天历史
|
||||
sender: 发送者
|
||||
|
||||
Returns:
|
||||
str: 缓存键
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
# 使用消息内容和群聊状态生成唯一缓存键
|
||||
content = f"{target_message}_{chat_history}_{sender}"
|
||||
return hashlib.md5(content.encode()).hexdigest()
|
||||
|
||||
def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]:
|
||||
"""从缓存获取结果
|
||||
|
||||
Args:
|
||||
cache_key: 缓存键
|
||||
|
||||
Returns:
|
||||
Optional[List[Dict]]: 缓存的结果,如果不存在或过期则返回None
|
||||
"""
|
||||
if not self.enable_cache or cache_key not in self.tool_cache:
|
||||
return None
|
||||
|
||||
cache_item = self.tool_cache[cache_key]
|
||||
if cache_item["ttl"] <= 0:
|
||||
# 缓存过期,删除
|
||||
del self.tool_cache[cache_key]
|
||||
logger.debug(f"{self.log_prefix}缓存过期,删除缓存键: {cache_key}")
|
||||
return None
|
||||
|
||||
# 减少TTL
|
||||
cache_item["ttl"] -= 1
|
||||
logger.debug(f"{self.log_prefix}使用缓存结果,剩余TTL: {cache_item['ttl']}")
|
||||
return cache_item["result"]
|
||||
|
||||
def _set_cache(self, cache_key: str, result: List[Dict]):
|
||||
"""设置缓存
|
||||
|
||||
Args:
|
||||
cache_key: 缓存键
|
||||
result: 要缓存的结果
|
||||
"""
|
||||
if not self.enable_cache:
|
||||
return
|
||||
|
||||
self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()}
|
||||
logger.debug(f"{self.log_prefix}设置缓存,TTL: {self.cache_ttl}")
|
||||
|
||||
def _cleanup_expired_cache(self):
|
||||
"""清理过期的缓存"""
|
||||
if not self.enable_cache:
|
||||
return
|
||||
|
||||
expired_keys = []
|
||||
expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0)
|
||||
for key in expired_keys:
|
||||
del self.tool_cache[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
|
||||
|
||||
async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
|
||||
"""直接执行指定工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_args: 工具参数
|
||||
validate_args: 是否验证参数
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 工具执行结果,失败时返回None
|
||||
"""
|
||||
try:
|
||||
tool_call = ToolCall(
|
||||
call_id=f"direct_tool_{time.time()}",
|
||||
func_name=tool_name,
|
||||
args=tool_args,
|
||||
)
|
||||
|
||||
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
|
||||
|
||||
result = await self.execute_tool_call(tool_call)
|
||||
|
||||
if result:
|
||||
tool_info = {
|
||||
"type": result.get("type", "unknown_type"),
|
||||
"id": result.get("id", f"direct_tool_{time.time()}"),
|
||||
"content": result.get("content", ""),
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}")
|
||||
return tool_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空所有缓存"""
|
||||
if self.enable_cache:
|
||||
cache_count = len(self.tool_cache)
|
||||
self.tool_cache.clear()
|
||||
logger.info(f"{self.log_prefix}清空了{cache_count}个缓存项")
|
||||
|
||||
def get_cache_status(self) -> Dict:
|
||||
"""获取缓存状态信息
|
||||
|
||||
Returns:
|
||||
Dict: 包含缓存统计信息的字典
|
||||
"""
|
||||
if not self.enable_cache:
|
||||
return {"enabled": False, "cache_count": 0}
|
||||
|
||||
# 清理过期缓存
|
||||
self._cleanup_expired_cache()
|
||||
|
||||
total_count = len(self.tool_cache)
|
||||
ttl_distribution = {}
|
||||
|
||||
for cache_item in self.tool_cache.values():
|
||||
ttl = cache_item["ttl"]
|
||||
ttl_distribution[ttl] = ttl_distribution.get(ttl, 0) + 1
|
||||
|
||||
return {
|
||||
"enabled": True,
|
||||
"cache_count": total_count,
|
||||
"cache_ttl": self.cache_ttl,
|
||||
"ttl_distribution": ttl_distribution,
|
||||
}
|
||||
|
||||
def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1):
|
||||
"""动态修改缓存配置
|
||||
|
||||
Args:
|
||||
enable_cache: 是否启用缓存
|
||||
cache_ttl: 缓存TTL
|
||||
"""
|
||||
if enable_cache is not None:
|
||||
self.enable_cache = enable_cache
|
||||
logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}")
|
||||
|
||||
if cache_ttl > 0:
|
||||
self.cache_ttl = cache_ttl
|
||||
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
|
||||
|
||||
|
||||
"""
|
||||
ToolExecutor使用示例:
|
||||
|
||||
# 1. 基础使用 - 从聊天消息执行工具(启用缓存,默认TTL=3)
|
||||
executor = ToolExecutor(executor_id="my_executor")
|
||||
results, _, _ = await executor.execute_from_chat_message(
|
||||
talking_message_str="今天天气怎么样?现在几点了?",
|
||||
is_group_chat=False
|
||||
)
|
||||
|
||||
# 2. 禁用缓存的执行器
|
||||
no_cache_executor = ToolExecutor(executor_id="no_cache", enable_cache=False)
|
||||
|
||||
# 3. 自定义缓存TTL
|
||||
long_cache_executor = ToolExecutor(executor_id="long_cache", cache_ttl=10)
|
||||
|
||||
# 4. 获取详细信息
|
||||
results, used_tools, prompt = await executor.execute_from_chat_message(
|
||||
talking_message_str="帮我查询Python相关知识",
|
||||
is_group_chat=False,
|
||||
return_details=True
|
||||
)
|
||||
|
||||
# 5. 直接执行特定工具
|
||||
result = await executor.execute_specific_tool_simple(
|
||||
tool_name="get_knowledge",
|
||||
tool_args={"query": "机器学习"}
|
||||
)
|
||||
|
||||
# 6. 缓存管理
|
||||
cache_status = executor.get_cache_status() # 查看缓存状态
|
||||
executor.clear_cache() # 清空缓存
|
||||
executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置
|
||||
"""
|
||||
Reference in New Issue
Block a user