初始化

This commit is contained in:
雅诺狐
2025-08-11 19:34:18 +08:00
parent ff7d1177fa
commit 2d4745cd58
257 changed files with 69069 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
"""
插件核心管理模块
提供插件的加载、注册和管理功能
"""
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
__all__ = [
"plugin_manager",
"component_registry",
"events_manager",
"global_announcement_manager",
"hot_reload_manager",
]

View File

@@ -0,0 +1,688 @@
import re
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
from src.common.logger import get_logger
from src.plugin_system.base.component_types import (
ComponentInfo,
ActionInfo,
ToolInfo,
CommandInfo,
EventHandlerInfo,
PluginInfo,
ComponentType,
)
from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.base_events_handler import BaseEventHandler
logger = get_logger("component_registry")
class ComponentRegistry:
"""统一的组件注册中心
负责管理所有插件组件的注册、查询和生命周期管理
"""
def __init__(self):
# 命名空间式组件名构成法 f"{component_type}.{component_name}"
self._components: Dict[str, ComponentInfo] = {}
"""组件注册表 命名空间式组件名 -> 组件信息"""
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
"""类型 -> 组件原名称 -> 组件信息"""
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler]]] = {}
"""命名空间式组件名 -> 组件类"""
# 插件注册表
self._plugins: Dict[str, PluginInfo] = {}
"""插件名 -> 插件信息"""
# Action特定注册表
self._action_registry: Dict[str, Type[BaseAction]] = {}
"""Action注册表 action名 -> action类"""
self._default_actions: Dict[str, ActionInfo] = {}
"""默认动作集即启用的Action集用于重置ActionManager状态"""
# Command特定注册表
self._command_registry: Dict[str, Type[BaseCommand]] = {}
"""Command类注册表 command名 -> command类"""
self._command_patterns: Dict[Pattern, str] = {}
"""编译后的正则 -> command名"""
# 工具特定注册表
self._tool_registry: Dict[str, Type[BaseTool]] = {} # 工具名 -> 工具类
self._llm_available_tools: Dict[str, Type[BaseTool]] = {} # llm可用的工具名 -> 工具类
# EventHandler特定注册表
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {}
"""event_handler名 -> event_handler类"""
self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {}
"""启用的事件处理器 event_handler名 -> event_handler类"""
logger.info("组件注册中心初始化完成")
# == 注册方法 ==
def register_plugin(self, plugin_info: PluginInfo) -> bool:
"""注册插件
Args:
plugin_info: 插件信息
Returns:
bool: 是否注册成功
"""
plugin_name = plugin_info.name
if plugin_name in self._plugins:
logger.warning(f"插件 {plugin_name} 已存在,跳过注册")
return False
self._plugins[plugin_name] = plugin_info
logger.debug(f"已注册插件: {plugin_name} (组件数量: {len(plugin_info.components)})")
return True
def register_component(
self,
component_info: ComponentInfo,
component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler, BaseTool]],
) -> bool:
"""注册组件
Args:
component_info (ComponentInfo): 组件信息
component_class (Type[Union[BaseCommand, BaseAction, BaseEventHandler]]): 组件类
Returns:
bool: 是否注册成功
"""
component_name = component_info.name
component_type = component_info.component_type
plugin_name = getattr(component_info, "plugin_name", "unknown")
if "." in component_name:
logger.error(f"组件名称 '{component_name}' 包含非法字符 '.',请使用下划线替代")
return False
if "." in plugin_name:
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
return False
namespaced_name = f"{component_type}.{component_name}"
if namespaced_name in self._components:
existing_info = self._components[namespaced_name]
existing_plugin = getattr(existing_info, "plugin_name", "unknown")
logger.warning(
f"组件名冲突: '{plugin_name}' 插件的 {component_type} 类型组件 '{component_name}' 已被插件 '{existing_plugin}' 注册,跳过此组件注册"
)
return False
self._components[namespaced_name] = component_info # 注册到通用注册表(使用命名空间化的名称)
self._components_by_type[component_type][component_name] = component_info # 类型内部仍使用原名
self._components_classes[namespaced_name] = component_class
# 根据组件类型进行特定注册(使用原始名称)
match component_type:
case ComponentType.ACTION:
assert isinstance(component_info, ActionInfo)
assert issubclass(component_class, BaseAction)
ret = self._register_action_component(component_info, component_class)
case ComponentType.COMMAND:
assert isinstance(component_info, CommandInfo)
assert issubclass(component_class, BaseCommand)
ret = self._register_command_component(component_info, component_class)
case ComponentType.TOOL:
assert isinstance(component_info, ToolInfo)
assert issubclass(component_class, BaseTool)
ret = self._register_tool_component(component_info, component_class)
case ComponentType.EVENT_HANDLER:
assert isinstance(component_info, EventHandlerInfo)
assert issubclass(component_class, BaseEventHandler)
ret = self._register_event_handler_component(component_info, component_class)
case _:
logger.warning(f"未知组件类型: {component_type}")
if not ret:
return False
logger.debug(
f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' "
f"({component_class.__name__}) [插件: {plugin_name}]"
)
return True
def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]) -> bool:
"""注册Action组件到Action特定注册表"""
if not (action_name := action_info.name):
logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
return False
if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction):
logger.error(f"注册失败: {action_name} 不是有效的Action")
return False
self._action_registry[action_name] = action_class
# 如果启用,添加到默认动作集
if action_info.enabled:
self._default_actions[action_name] = action_info
return True
def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]) -> bool:
"""注册Command组件到Command特定注册表"""
if not (command_name := command_info.name):
logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
return False
if not isinstance(command_info, CommandInfo) or not issubclass(command_class, BaseCommand):
logger.error(f"注册失败: {command_name} 不是有效的Command")
return False
self._command_registry[command_name] = command_class
# 如果启用了且有匹配模式
if command_info.enabled and command_info.command_pattern:
pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL)
if pattern not in self._command_patterns:
self._command_patterns[pattern] = command_name
else:
logger.warning(
f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令"
)
return True
def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool:
"""注册Tool组件到Tool特定注册表"""
tool_name = tool_info.name
self._tool_registry[tool_name] = tool_class
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
if tool_info.enabled:
self._llm_available_tools[tool_name] = tool_class
return True
def _register_event_handler_component(
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
) -> bool:
if not (handler_name := handler_info.name):
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
return False
if not isinstance(handler_info, EventHandlerInfo) or not issubclass(handler_class, BaseEventHandler):
logger.error(f"注册失败: {handler_name} 不是有效的EventHandler")
return False
self._event_handler_registry[handler_name] = handler_class
if not handler_info.enabled:
logger.warning(f"EventHandler组件 {handler_name} 未启用")
return True # 未启用,但是也是注册成功
from .events_manager import events_manager # 延迟导入防止循环导入问题
if events_manager.register_event_subscriber(handler_info, handler_class):
self._enabled_event_handlers[handler_name] = handler_class
return True
else:
logger.error(f"注册事件处理器 {handler_name} 失败")
return False
# === 组件移除相关 ===
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
target_component_class = self.get_component_class(component_name, component_type)
if not target_component_class:
logger.warning(f"组件 {component_name} 未注册,无法移除")
return False
try:
# 根据组件类型进行特定的清理操作
match component_type:
case ComponentType.ACTION:
# 移除Action注册
self._action_registry.pop(component_name, None)
self._default_actions.pop(component_name, None)
logger.debug(f"已移除Action组件: {component_name}")
case ComponentType.COMMAND:
# 移除Command注册和模式
self._command_registry.pop(component_name, None)
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
for key in keys_to_remove:
self._command_patterns.pop(key, None)
logger.debug(f"已移除Command组件: {component_name} (清理了 {len(keys_to_remove)} 个模式)")
case ComponentType.TOOL:
# 移除Tool注册
self._tool_registry.pop(component_name, None)
self._llm_available_tools.pop(component_name, None)
logger.debug(f"已移除Tool组件: {component_name}")
case ComponentType.EVENT_HANDLER:
# 移除EventHandler注册和事件订阅
from .events_manager import events_manager # 延迟导入防止循环导入问题
self._event_handler_registry.pop(component_name, None)
self._enabled_event_handlers.pop(component_name, None)
try:
await events_manager.unregister_event_subscriber(component_name)
logger.debug(f"已移除EventHandler组件: {component_name}")
except Exception as e:
logger.warning(f"移除EventHandler事件订阅时出错: {e}")
case _:
logger.warning(f"未知的组件类型: {component_type}")
return False
# 移除通用注册信息
namespaced_name = f"{component_type}.{component_name}"
self._components.pop(namespaced_name, None)
self._components_by_type[component_type].pop(component_name, None)
self._components_classes.pop(namespaced_name, None)
logger.info(f"组件 {component_name} ({component_type}) 已完全移除")
return True
except Exception as e:
logger.error(f"移除组件 {component_name} ({component_type}) 时发生错误: {e}")
return False
def remove_plugin_registry(self, plugin_name: str) -> bool:
"""移除插件注册信息
Args:
plugin_name: 插件名称
Returns:
bool: 是否成功移除
"""
if plugin_name not in self._plugins:
logger.warning(f"插件 {plugin_name} 未注册,无法移除")
return False
del self._plugins[plugin_name]
logger.info(f"插件 {plugin_name} 已移除")
return True
# === 组件全局启用/禁用方法 ===
def enable_component(self, component_name: str, component_type: ComponentType) -> bool:
"""全局的启用某个组件
Parameters:
component_name: 组件名称
component_type: 组件类型
Returns:
bool: 启用成功返回True失败返回False
"""
target_component_class = self.get_component_class(component_name, component_type)
target_component_info = self.get_component_info(component_name, component_type)
if not target_component_class or not target_component_info:
logger.warning(f"组件 {component_name} 未注册,无法启用")
return False
target_component_info.enabled = True
match component_type:
case ComponentType.ACTION:
assert isinstance(target_component_info, ActionInfo)
self._default_actions[component_name] = target_component_info
case ComponentType.COMMAND:
assert isinstance(target_component_info, CommandInfo)
pattern = target_component_info.command_pattern
self._command_patterns[re.compile(pattern)] = component_name
case ComponentType.TOOL:
assert isinstance(target_component_info, ToolInfo)
assert issubclass(target_component_class, BaseTool)
self._llm_available_tools[component_name] = target_component_class
case ComponentType.EVENT_HANDLER:
assert isinstance(target_component_info, EventHandlerInfo)
assert issubclass(target_component_class, BaseEventHandler)
self._enabled_event_handlers[component_name] = target_component_class
from .events_manager import events_manager # 延迟导入防止循环导入问题
events_manager.register_event_subscriber(target_component_info, target_component_class)
namespaced_name = f"{component_type}.{component_name}"
self._components[namespaced_name].enabled = True
self._components_by_type[component_type][component_name].enabled = True
logger.info(f"组件 {component_name} 已启用")
return True
async def disable_component(self, component_name: str, component_type: ComponentType) -> bool:
"""全局的禁用某个组件
Parameters:
component_name: 组件名称
component_type: 组件类型
Returns:
bool: 禁用成功返回True失败返回False
"""
target_component_class = self.get_component_class(component_name, component_type)
target_component_info = self.get_component_info(component_name, component_type)
if not target_component_class or not target_component_info:
logger.warning(f"组件 {component_name} 未注册,无法禁用")
return False
target_component_info.enabled = False
try:
match component_type:
case ComponentType.ACTION:
self._default_actions.pop(component_name)
case ComponentType.COMMAND:
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
case ComponentType.TOOL:
self._llm_available_tools.pop(component_name)
case ComponentType.EVENT_HANDLER:
self._enabled_event_handlers.pop(component_name)
from .events_manager import events_manager # 延迟导入防止循环导入问题
await events_manager.unregister_event_subscriber(component_name)
self._components[component_name].enabled = False
self._components_by_type[component_type][component_name].enabled = False
logger.info(f"组件 {component_name} 已禁用")
return True
except KeyError as e:
logger.warning(f"禁用组件时未找到组件或已禁用: {component_name}, 发生错误: {e}")
return False
except Exception as e:
logger.error(f"禁用组件 {component_name} 时发生错误: {e}")
return False
# === 组件查询方法 ===
def get_component_info(
self, component_name: str, component_type: Optional[ComponentType] = None
) -> Optional[ComponentInfo]:
# sourcery skip: class-extract-method
"""获取组件信息,支持自动命名空间解析
Args:
component_name: 组件名称,可以是原始名称或命名空间化的名称
component_type: 组件类型,如果提供则优先在该类型中查找
Returns:
Optional[ComponentInfo]: 组件信息或None
"""
# 1. 如果已经是命名空间化的名称,直接查找
if "." in component_name:
return self._components.get(component_name)
# 2. 如果指定了组件类型,构造命名空间化的名称查找
if component_type:
namespaced_name = f"{component_type}.{component_name}"
return self._components.get(namespaced_name)
# 3. 如果没有指定类型,尝试在所有命名空间中查找
candidates = []
for namespace_prefix in [types.value for types in ComponentType]:
namespaced_name = f"{namespace_prefix}.{component_name}"
if component_info := self._components.get(namespaced_name):
candidates.append((namespace_prefix, namespaced_name, component_info))
if len(candidates) == 1:
# 只有一个匹配,直接返回
return candidates[0][2]
elif len(candidates) > 1:
# 多个匹配,记录警告并返回第一个
namespaces = [ns for ns, _, _ in candidates]
logger.warning(
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces},使用第一个匹配项: {candidates[0][1]}"
)
return candidates[0][2]
# 4. 都没找到
return None
def get_component_class(
self,
component_name: str,
component_type: Optional[ComponentType] = None,
) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler], Type[BaseTool]]]:
"""获取组件类,支持自动命名空间解析
Args:
component_name: 组件名称,可以是原始名称或命名空间化的名称
component_type: 组件类型,如果提供则优先在该类型中查找
Returns:
Optional[Union[BaseCommand, BaseAction]]: 组件类或None
"""
# 1. 如果已经是命名空间化的名称,直接查找
if "." in component_name:
return self._components_classes.get(component_name)
# 2. 如果指定了组件类型,构造命名空间化的名称查找
if component_type:
namespaced_name = f"{component_type.value}.{component_name}"
return self._components_classes.get(namespaced_name)
# 3. 如果没有指定类型,尝试在所有命名空间中查找
candidates = []
for namespace_prefix in [types.value for types in ComponentType]:
namespaced_name = f"{namespace_prefix}.{component_name}"
if component_class := self._components_classes.get(namespaced_name):
candidates.append((namespace_prefix, namespaced_name, component_class))
if len(candidates) == 1:
# 只有一个匹配,直接返回
_, full_name, cls = candidates[0]
logger.debug(f"自动解析组件: '{component_name}' -> '{full_name}'")
return cls
elif len(candidates) > 1:
# 多个匹配,记录警告并返回第一个
namespaces = [ns for ns, _, _ in candidates]
logger.warning(
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces},使用第一个匹配项: {candidates[0][1]}"
)
return candidates[0][2]
# 4. 都没找到
return None
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
"""获取指定类型的所有组件"""
return self._components_by_type.get(component_type, {}).copy()
def get_enabled_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
"""获取指定类型的所有启用组件"""
components = self.get_components_by_type(component_type)
return {name: info for name, info in components.items() if info.enabled}
# === Action特定查询方法 ===
def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
"""获取Action注册表"""
return self._action_registry.copy()
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
"""获取Action信息"""
info = self.get_component_info(action_name, ComponentType.ACTION)
return info if isinstance(info, ActionInfo) else None
def get_default_actions(self) -> Dict[str, ActionInfo]:
"""获取默认动作集"""
return self._default_actions.copy()
# === Command特定查询方法 ===
def get_command_registry(self) -> Dict[str, Type[BaseCommand]]:
"""获取Command注册表"""
return self._command_registry.copy()
def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]:
"""获取Command信息"""
info = self.get_component_info(command_name, ComponentType.COMMAND)
return info if isinstance(info, CommandInfo) else None
def get_command_patterns(self) -> Dict[Pattern, str]:
"""获取Command模式注册表"""
return self._command_patterns.copy()
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]:
# sourcery skip: use-named-expression, use-next
"""根据文本查找匹配的命令
Args:
text: 输入文本
Returns:
Tuple: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None
"""
candidates = [pattern for pattern in self._command_patterns if pattern.match(text)]
if not candidates:
return None
if len(candidates) > 1:
logger.warning(f"文本 '{text}' 匹配到多个命令模式: {candidates},使用第一个匹配")
command_name = self._command_patterns[candidates[0]]
command_info: CommandInfo = self.get_registered_command_info(command_name) # type: ignore
return (
self._command_registry[command_name],
candidates[0].match(text).groupdict(), # type: ignore
command_info,
)
# === Tool 特定查询方法 ===
def get_tool_registry(self) -> Dict[str, Type[BaseTool]]:
"""获取Tool注册表"""
return self._tool_registry.copy()
def get_llm_available_tools(self) -> Dict[str, Type[BaseTool]]:
"""获取LLM可用的Tool列表"""
return self._llm_available_tools.copy()
def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]:
"""获取Tool信息
Args:
tool_name: 工具名称
Returns:
ToolInfo: 工具信息对象,如果工具不存在则返回 None
"""
info = self.get_component_info(tool_name, ComponentType.TOOL)
return info if isinstance(info, ToolInfo) else None
# === EventHandler 特定查询方法 ===
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
"""获取事件处理器注册表"""
return self._event_handler_registry.copy()
def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]:
"""获取事件处理器信息"""
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
return info if isinstance(info, EventHandlerInfo) else None
def get_enabled_event_handlers(self) -> Dict[str, Type[BaseEventHandler]]:
"""获取启用的事件处理器"""
return self._enabled_event_handlers.copy()
# === 插件查询方法 ===
def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]:
"""获取插件信息"""
return self._plugins.get(plugin_name)
def get_all_plugins(self) -> Dict[str, PluginInfo]:
"""获取所有插件"""
return self._plugins.copy()
# def get_enabled_plugins(self) -> Dict[str, PluginInfo]:
# """获取所有启用的插件"""
# return {name: info for name, info in self._plugins.items() if info.enabled}
def get_plugin_components(self, plugin_name: str) -> List[ComponentInfo]:
"""获取插件的所有组件"""
plugin_info = self.get_plugin_info(plugin_name)
return plugin_info.components if plugin_info else []
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
"""获取插件配置
Args:
plugin_name: 插件名称
Returns:
Optional[dict]: 插件配置字典或None
"""
# 从插件管理器获取插件实例的配置
from src.plugin_system.core.plugin_manager import plugin_manager
plugin_instance = plugin_manager.get_plugin_instance(plugin_name)
return plugin_instance.config if plugin_instance else None
def get_registry_stats(self) -> Dict[str, Any]:
"""获取注册中心统计信息"""
action_components: int = 0
command_components: int = 0
tool_components: int = 0
events_handlers: int = 0
for component in self._components.values():
if component.component_type == ComponentType.ACTION:
action_components += 1
elif component.component_type == ComponentType.COMMAND:
command_components += 1
elif component.component_type == ComponentType.TOOL:
tool_components += 1
elif component.component_type == ComponentType.EVENT_HANDLER:
events_handlers += 1
return {
"action_components": action_components,
"command_components": command_components,
"tool_components": tool_components,
"event_handlers": events_handlers,
"total_components": len(self._components),
"total_plugins": len(self._plugins),
"components_by_type": {
component_type.value: len(components) for component_type, components in self._components_by_type.items()
},
"enabled_components": len([c for c in self._components.values() if c.enabled]),
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
}
# === 组件移除相关 ===
async def unregister_plugin(self, plugin_name: str) -> bool:
"""卸载插件及其所有组件
Args:
plugin_name: 插件名称
Returns:
bool: 是否成功卸载
"""
plugin_info = self.get_plugin_info(plugin_name)
if not plugin_info:
logger.warning(f"插件 {plugin_name} 未注册,无法卸载")
return False
logger.info(f"开始卸载插件: {plugin_name}")
# 记录卸载失败的组件
failed_components = []
# 逐个移除插件的所有组件
for component_info in plugin_info.components:
try:
success = await self.remove_component(
component_info.name,
component_info.component_type,
plugin_name,
)
if not success:
failed_components.append(f"{component_info.component_type}.{component_info.name}")
except Exception as e:
logger.error(f"移除组件 {component_info.name} 时发生异常: {e}")
failed_components.append(f"{component_info.component_type}.{component_info.name}")
# 移除插件注册信息
plugin_removed = self.remove_plugin_registry(plugin_name)
if failed_components:
logger.warning(f"插件 {plugin_name} 部分组件卸载失败: {failed_components}")
return False
elif not plugin_removed:
logger.error(f"插件 {plugin_name} 注册信息移除失败")
return False
else:
logger.info(f"插件 {plugin_name} 卸载成功")
return True
# 创建全局组件注册中心实例
component_registry = ComponentRegistry()

View File

@@ -0,0 +1,262 @@
import asyncio
import contextlib
from typing import List, Dict, Optional, Type, Tuple, Any
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages
from src.plugin_system.base.base_events_handler import BaseEventHandler
from .global_announcement_manager import global_announcement_manager
logger = get_logger("events_manager")
class EventsManager:
def __init__(self):
# 有权重的 events 订阅者注册表
self._events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType}
self._handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
self._handler_tasks: Dict[str, List[asyncio.Task]] = {} # 事件处理器正在处理的任务
def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool:
"""注册事件处理器
Args:
handler_info (EventHandlerInfo): 事件处理器信息
handler_class (Type[BaseEventHandler]): 事件处理器类
Returns:
bool: 是否注册成功
"""
handler_name = handler_info.name
if handler_name in self._handler_mapping:
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
return True
if not issubclass(handler_class, BaseEventHandler):
logger.error(f"{handler_class.__name__} 不是 BaseEventHandler 的子类")
return False
self._handler_mapping[handler_name] = handler_class
return self._insert_event_handler(handler_class, handler_info)
async def handle_mai_events(
self,
event_type: EventType,
message: Optional[MessageRecv] = None,
llm_prompt: Optional[str] = None,
llm_response: Optional[Dict[str, Any]] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
) -> bool:
"""处理 events"""
from src.plugin_system.core import component_registry
continue_flag = True
transformed_message: Optional[MaiMessages] = None
if not message:
assert stream_id, "如果没有消息必须提供流ID"
if event_type in [EventType.ON_MESSAGE, EventType.ON_PLAN, EventType.POST_LLM, EventType.AFTER_LLM]:
transformed_message = self._build_message_from_stream(stream_id, llm_prompt, llm_response)
else:
transformed_message = self._transform_event_without_message(
stream_id, llm_prompt, llm_response, action_usage
)
else:
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
for handler in self._events_subscribers.get(event_type, []):
if transformed_message.stream_id:
stream_id = transformed_message.stream_id
if handler.handler_name in global_announcement_manager.get_disabled_chat_event_handlers(stream_id):
continue
handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {})
if handler.intercept_message:
try:
success, continue_processing, result = await handler.execute(transformed_message)
if not success:
logger.error(f"EventHandler {handler.handler_name} 执行失败: {result}")
else:
logger.debug(f"EventHandler {handler.handler_name} 执行成功: {result}")
continue_flag = continue_flag and continue_processing
except Exception as e:
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}")
continue
else:
try:
handler_task = asyncio.create_task(handler.execute(transformed_message))
handler_task.add_done_callback(self._task_done_callback)
handler_task.set_name(f"{handler.plugin_name}-{handler.handler_name}")
if handler.handler_name not in self._handler_tasks:
self._handler_tasks[handler.handler_name] = []
self._handler_tasks[handler.handler_name].append(handler_task)
except Exception as e:
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}")
continue
return continue_flag
def _insert_event_handler(self, handler_class: Type[BaseEventHandler], handler_info: EventHandlerInfo) -> bool:
"""插入事件处理器到对应的事件类型列表中并设置其插件配置"""
if handler_class.event_type == EventType.UNKNOWN:
logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册")
return False
handler_instance = handler_class()
handler_instance.set_plugin_name(handler_info.plugin_name or "unknown")
self._events_subscribers[handler_class.event_type].append(handler_instance)
self._events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True)
return True
def _remove_event_handler_instance(self, handler_class: Type[BaseEventHandler]) -> bool:
"""从事件类型列表中移除事件处理器"""
display_handler_name = handler_class.handler_name or handler_class.__name__
if handler_class.event_type == EventType.UNKNOWN:
logger.warning(f"事件处理器 {display_handler_name} 的事件类型未知,不存在于处理器列表中")
return False
handlers = self._events_subscribers[handler_class.event_type]
for i, handler in enumerate(handlers):
if isinstance(handler, handler_class):
del handlers[i]
logger.debug(f"事件处理器 {display_handler_name} 已移除")
return True
logger.warning(f"未找到事件处理器 {display_handler_name},无法移除")
return False
def _transform_event_message(
self, message: MessageRecv, llm_prompt: Optional[str] = None, llm_response: Optional[Dict[str, Any]] = None
) -> MaiMessages:
"""转换事件消息格式"""
# 直接赋值部分内容
transformed_message = MaiMessages(
llm_prompt=llm_prompt,
llm_response_content=llm_response.get("content") if llm_response else None,
llm_response_reasoning=llm_response.get("reasoning") if llm_response else None,
llm_response_model=llm_response.get("model") if llm_response else None,
llm_response_tool_call=llm_response.get("tool_calls") if llm_response else None,
raw_message=message.raw_message,
additional_data=message.message_info.additional_config or {},
)
# 消息段处理
if message.message_segment.type == "seglist":
transformed_message.message_segments = list(message.message_segment.data) # type: ignore
else:
transformed_message.message_segments = [message.message_segment]
# stream_id 处理
if hasattr(message, "chat_stream") and message.chat_stream:
transformed_message.stream_id = message.chat_stream.stream_id
# 处理后文本
transformed_message.plain_text = message.processed_plain_text
# 基本信息
if hasattr(message, "message_info") and message.message_info:
if message.message_info.platform:
transformed_message.message_base_info["platform"] = message.message_info.platform
if message.message_info.group_info:
transformed_message.is_group_message = True
transformed_message.message_base_info.update(
{
"group_id": message.message_info.group_info.group_id,
"group_name": message.message_info.group_info.group_name,
}
)
if message.message_info.user_info:
if not transformed_message.is_group_message:
transformed_message.is_private_message = True
transformed_message.message_base_info.update(
{
"user_id": message.message_info.user_info.user_id,
"user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称
"user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名)
}
)
return transformed_message
def _build_message_from_stream(
self, stream_id: str, llm_prompt: Optional[str] = None, llm_response: Optional[Dict[str, Any]] = None
) -> MaiMessages:
"""从流ID构建消息"""
chat_stream = get_chat_manager().get_stream(stream_id)
assert chat_stream, f"未找到流ID为 {stream_id} 的聊天流"
message = chat_stream.context.get_last_message()
return self._transform_event_message(message, llm_prompt, llm_response)
def _transform_event_without_message(
self,
stream_id: str,
llm_prompt: Optional[str] = None,
llm_response: Optional[Dict[str, Any]] = None,
action_usage: Optional[List[str]] = None,
) -> MaiMessages:
"""没有message对象时进行转换"""
chat_stream = get_chat_manager().get_stream(stream_id)
assert chat_stream, f"未找到流ID为 {stream_id} 的聊天流"
return MaiMessages(
stream_id=stream_id,
llm_prompt=llm_prompt,
llm_response_content=(llm_response.get("content") if llm_response else None),
llm_response_reasoning=(llm_response.get("reasoning") if llm_response else None),
llm_response_model=llm_response.get("model") if llm_response else None,
llm_response_tool_call=(llm_response.get("tool_calls") if llm_response else None),
is_group_message=(not (not chat_stream.group_info)),
is_private_message=(not chat_stream.group_info),
action_usage=action_usage,
additional_data={"response_is_processed": True},
)
def _task_done_callback(self, task: asyncio.Task[Tuple[bool, bool, str | None]]):
"""任务完成回调"""
task_name = task.get_name() or "Unknown Task"
try:
success, _, result = task.result() # 忽略是否继续的标志,因为消息本身未被拦截
if success:
logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}")
else:
logger.error(f"事件处理任务 {task_name} 执行失败: {result}")
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"事件处理任务 {task_name} 发生异常: {e}")
finally:
with contextlib.suppress(ValueError, KeyError):
self._handler_tasks[task_name].remove(task)
async def cancel_handler_tasks(self, handler_name: str) -> None:
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
if remaining_tasks := [task for task in tasks_to_be_cancelled if not task.done()]:
for task in remaining_tasks:
task.cancel()
try:
await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=5)
logger.info(f"已取消事件处理器 {handler_name} 的所有任务")
except asyncio.TimeoutError:
logger.warning(f"取消事件处理器 {handler_name} 的任务超时,开始强制取消")
except Exception as e:
logger.error(f"取消事件处理器 {handler_name} 的任务时发生异常: {e}")
if handler_name in self._handler_tasks:
del self._handler_tasks[handler_name]
async def unregister_event_subscriber(self, handler_name: str) -> bool:
"""取消注册事件处理器"""
if handler_name not in self._handler_mapping:
logger.warning(f"事件处理器 {handler_name} 不存在,无法取消注册")
return False
await self.cancel_handler_tasks(handler_name)
handler_class = self._handler_mapping.pop(handler_name)
if not self._remove_event_handler_instance(handler_class):
return False
logger.info(f"事件处理器 {handler_name} 已成功取消注册")
return True
events_manager = EventsManager()

View File

@@ -0,0 +1,120 @@
from typing import List, Dict
from src.common.logger import get_logger
logger = get_logger("global_announcement_manager")
class GlobalAnnouncementManager:
def __init__(self) -> None:
# 用户禁用的动作chat_id -> [action_name]
self._user_disabled_actions: Dict[str, List[str]] = {}
# 用户禁用的命令chat_id -> [command_name]
self._user_disabled_commands: Dict[str, List[str]] = {}
# 用户禁用的事件处理器chat_id -> [handler_name]
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
# 用户禁用的工具chat_id -> [tool_name]
self._user_disabled_tools: Dict[str, List[str]] = {}
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
"""禁用特定聊天的某个动作"""
if chat_id not in self._user_disabled_actions:
self._user_disabled_actions[chat_id] = []
if action_name in self._user_disabled_actions[chat_id]:
logger.warning(f"动作 {action_name} 已经被禁用")
return False
self._user_disabled_actions[chat_id].append(action_name)
return True
def enable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
"""启用特定聊天的某个动作"""
if chat_id in self._user_disabled_actions:
try:
self._user_disabled_actions[chat_id].remove(action_name)
return True
except ValueError:
logger.warning(f"动作 {action_name} 不在禁用列表中")
return False
return False
def disable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
"""禁用特定聊天的某个命令"""
if chat_id not in self._user_disabled_commands:
self._user_disabled_commands[chat_id] = []
if command_name in self._user_disabled_commands[chat_id]:
logger.warning(f"命令 {command_name} 已经被禁用")
return False
self._user_disabled_commands[chat_id].append(command_name)
return True
def enable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
"""启用特定聊天的某个命令"""
if chat_id in self._user_disabled_commands:
try:
self._user_disabled_commands[chat_id].remove(command_name)
return True
except ValueError:
logger.warning(f"命令 {command_name} 不在禁用列表中")
return False
return False
def disable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
"""禁用特定聊天的某个事件处理器"""
if chat_id not in self._user_disabled_event_handlers:
self._user_disabled_event_handlers[chat_id] = []
if handler_name in self._user_disabled_event_handlers[chat_id]:
logger.warning(f"事件处理器 {handler_name} 已经被禁用")
return False
self._user_disabled_event_handlers[chat_id].append(handler_name)
return True
def enable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
"""启用特定聊天的某个事件处理器"""
if chat_id in self._user_disabled_event_handlers:
try:
self._user_disabled_event_handlers[chat_id].remove(handler_name)
return True
except ValueError:
logger.warning(f"事件处理器 {handler_name} 不在禁用列表中")
return False
return False
def disable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
"""禁用特定聊天的某个工具"""
if chat_id not in self._user_disabled_tools:
self._user_disabled_tools[chat_id] = []
if tool_name in self._user_disabled_tools[chat_id]:
logger.warning(f"工具 {tool_name} 已经被禁用")
return False
self._user_disabled_tools[chat_id].append(tool_name)
return True
def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
"""启用特定聊天的某个工具"""
if chat_id in self._user_disabled_tools:
try:
self._user_disabled_tools[chat_id].remove(tool_name)
return True
except ValueError:
logger.warning(f"工具 {tool_name} 不在禁用列表中")
return False
return False
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有动作"""
return self._user_disabled_actions.get(chat_id, []).copy()
def get_disabled_chat_commands(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有命令"""
return self._user_disabled_commands.get(chat_id, []).copy()
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有事件处理器"""
return self._user_disabled_event_handlers.get(chat_id, []).copy()
def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有工具"""
return self._user_disabled_tools.get(chat_id, []).copy()
global_announcement_manager = GlobalAnnouncementManager()

View File

@@ -0,0 +1,242 @@
"""
插件热重载模块
使用 Watchdog 监听插件目录变化,自动重载插件
"""
import os
import time
from pathlib import Path
from threading import Thread
from typing import Dict, Set
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from src.common.logger import get_logger
from .plugin_manager import plugin_manager
logger = get_logger("plugin_hot_reload")
class PluginFileHandler(FileSystemEventHandler):
"""插件文件变化处理器"""
def __init__(self, hot_reload_manager):
super().__init__()
self.hot_reload_manager = hot_reload_manager
self.pending_reloads: Set[str] = set() # 待重载的插件名称
self.last_reload_time: Dict[str, float] = {} # 上次重载时间
self.debounce_delay = 1.0 # 防抖延迟(秒)
def on_modified(self, event):
"""文件修改事件"""
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
self._handle_file_change(event.src_path, "modified")
def on_created(self, event):
"""文件创建事件"""
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
self._handle_file_change(event.src_path, "created")
def on_deleted(self, event):
"""文件删除事件"""
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
self._handle_file_change(event.src_path, "deleted")
def _handle_file_change(self, file_path: str, change_type: str):
"""处理文件变化"""
try:
# 获取插件名称
plugin_name = self._get_plugin_name_from_path(file_path)
if not plugin_name:
return
current_time = time.time()
last_time = self.last_reload_time.get(plugin_name, 0)
# 防抖处理,避免频繁重载
if current_time - last_time < self.debounce_delay:
return
file_name = Path(file_path).name
logger.info(f"📁 检测到插件文件变化: {file_name} ({change_type})")
# 如果是删除事件,处理关键文件删除
if change_type == "deleted":
if file_name == "plugin.py":
if plugin_name in plugin_manager.loaded_plugins:
logger.info(f"🗑️ 插件主文件被删除,卸载插件: {plugin_name}")
self.hot_reload_manager._unload_plugin(plugin_name)
return
elif file_name == "manifest.toml":
if plugin_name in plugin_manager.loaded_plugins:
logger.info(f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name}")
self.hot_reload_manager._unload_plugin(plugin_name)
return
# 对于修改和创建事件,都进行重载
# 添加到待重载列表
self.pending_reloads.add(plugin_name)
self.last_reload_time[plugin_name] = current_time
# 延迟重载,避免文件正在写入时重载
reload_thread = Thread(
target=self._delayed_reload,
args=(plugin_name,),
daemon=True
)
reload_thread.start()
except Exception as e:
logger.error(f"❌ 处理文件变化时发生错误: {e}")
def _delayed_reload(self, plugin_name: str):
"""延迟重载插件"""
try:
time.sleep(self.debounce_delay)
if plugin_name in self.pending_reloads:
self.pending_reloads.remove(plugin_name)
self.hot_reload_manager._reload_plugin(plugin_name)
except Exception as e:
logger.error(f"❌ 延迟重载插件 {plugin_name} 时发生错误: {e}")
def _get_plugin_name_from_path(self, file_path: str) -> str:
"""从文件路径获取插件名称"""
try:
path = Path(file_path)
# 检查是否在监听的插件目录中
plugin_root = Path(self.hot_reload_manager.watch_directory)
if not path.is_relative_to(plugin_root):
return ""
# 获取插件目录名(插件名)
relative_path = path.relative_to(plugin_root)
plugin_name = relative_path.parts[0]
# 确认这是一个有效的插件目录(检查是否有 plugin.py 或 manifest.toml
plugin_dir = plugin_root / plugin_name
if plugin_dir.is_dir() and ((plugin_dir / "plugin.py").exists() or (plugin_dir / "manifest.toml").exists()):
return plugin_name
return ""
except Exception:
return ""
class PluginHotReloadManager:
"""插件热重载管理器"""
def __init__(self, watch_directory: str = None):
print("fuck")
print(os.getcwd())
self.watch_directory = os.path.join(os.getcwd(), "plugins")
self.observer = None
self.file_handler = None
self.is_running = False
# 确保监听目录存在
if not os.path.exists(self.watch_directory):
os.makedirs(self.watch_directory, exist_ok=True)
logger.info(f"创建插件监听目录: {self.watch_directory}")
def start(self):
"""启动热重载监听"""
if self.is_running:
logger.warning("插件热重载已经在运行中")
return
try:
self.observer = Observer()
self.file_handler = PluginFileHandler(self)
self.observer.schedule(
self.file_handler,
self.watch_directory,
recursive=True
)
self.observer.start()
self.is_running = True
logger.info("🚀 插件热重载已启动,监听目录: plugins")
except Exception as e:
logger.error(f"❌ 启动插件热重载失败: {e}")
self.is_running = False
def stop(self):
"""停止热重载监听"""
if not self.is_running:
return
if self.observer:
self.observer.stop()
self.observer.join()
self.is_running = False
def _reload_plugin(self, plugin_name: str):
"""重载指定插件"""
try:
logger.info(f"🔄 开始重载插件: {plugin_name}")
if plugin_manager.reload_plugin(plugin_name):
logger.info(f"✅ 插件重载成功: {plugin_name}")
else:
logger.error(f"❌ 插件重载失败: {plugin_name}")
except Exception as e:
logger.error(f"❌ 重载插件 {plugin_name} 时发生错误: {e}")
def _unload_plugin(self, plugin_name: str):
"""卸载指定插件"""
try:
logger.info(f"🗑️ 开始卸载插件: {plugin_name}")
if plugin_manager.unload_plugin(plugin_name):
logger.info(f"✅ 插件卸载成功: {plugin_name}")
else:
logger.error(f"❌ 插件卸载失败: {plugin_name}")
except Exception as e:
logger.error(f"❌ 卸载插件 {plugin_name} 时发生错误: {e}")
def reload_all_plugins(self):
"""重载所有插件"""
try:
logger.info("🔄 开始重载所有插件...")
# 获取当前已加载的插件列表
loaded_plugins = list(plugin_manager.loaded_plugins.keys())
success_count = 0
fail_count = 0
for plugin_name in loaded_plugins:
if plugin_manager.reload_plugin(plugin_name):
success_count += 1
else:
fail_count += 1
logger.info(f"✅ 插件重载完成: 成功 {success_count} 个,失败 {fail_count}")
except Exception as e:
logger.error(f"❌ 重载所有插件时发生错误: {e}")
def get_status(self) -> dict:
"""获取热重载状态"""
return {
"is_running": self.is_running,
"watch_directory": self.watch_directory,
"loaded_plugins": len(plugin_manager.loaded_plugins),
"failed_plugins": len(plugin_manager.failed_plugins),
}
# 全局热重载管理器实例
hot_reload_manager = PluginHotReloadManager()

View File

@@ -0,0 +1,593 @@
import os
import traceback
import sys
from typing import Dict, List, Optional, Tuple, Type, Any
from importlib.util import spec_from_file_location, module_from_spec
from pathlib import Path
from src.common.logger import get_logger
from src.plugin_system.base.plugin_base import PluginBase
from src.plugin_system.base.component_types import ComponentType
from src.plugin_system.utils.manifest_utils import VersionComparator
from .component_registry import component_registry
logger = get_logger("plugin_manager")
class PluginManager:
"""
插件管理器类
负责加载,重载和卸载插件,同时管理插件的所有组件
"""
def __init__(self):
self.plugin_directories: List[str] = [] # 插件根目录列表
self.plugin_classes: Dict[str, Type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类
self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径
self.loaded_plugins: Dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例
self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息
# 确保插件目录存在
self._ensure_plugin_directories()
logger.info("插件管理器初始化完成")
# === 插件目录管理 ===
def add_plugin_directory(self, directory: str) -> bool:
"""添加插件目录"""
if os.path.exists(directory):
if directory not in self.plugin_directories:
self.plugin_directories.append(directory)
logger.debug(f"已添加插件目录: {directory}")
return True
else:
logger.warning(f"插件不可重复加载: {directory}")
else:
logger.warning(f"插件目录不存在: {directory}")
return False
# === 插件加载管理 ===
def load_all_plugins(self) -> Tuple[int, int]:
"""加载所有插件
Returns:
tuple[int, int]: (插件数量, 组件数量)
"""
logger.debug("开始加载所有插件...")
# 第一阶段:加载所有插件模块(注册插件类)
total_loaded_modules = 0
total_failed_modules = 0
for directory in self.plugin_directories:
loaded, failed = self._load_plugin_modules_from_directory(directory)
total_loaded_modules += loaded
total_failed_modules += failed
logger.debug(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}")
total_registered = 0
total_failed_registration = 0
for plugin_name in self.plugin_classes.keys():
load_status, count = self.load_registered_plugin_classes(plugin_name)
if load_status:
total_registered += 1
else:
total_failed_registration += count
self._show_stats(total_registered, total_failed_registration)
return total_registered, total_failed_registration
def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
# sourcery skip: extract-duplicate-method, extract-method
"""
加载已经注册的插件类
"""
plugin_class = self.plugin_classes.get(plugin_name)
if not plugin_class:
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
return False, 1
try:
# 使用记录的插件目录路径
plugin_dir = self.plugin_paths.get(plugin_name)
# 如果没有记录,直接返回失败
if not plugin_dir:
return False, 1
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件可能因为缺少manifest而失败
if not plugin_instance:
logger.error(f"插件 {plugin_name} 实例化失败")
return False, 1
# 检查插件是否启用
if not plugin_instance.enable_plugin:
logger.info(f"插件 {plugin_name} 已禁用,跳过加载")
return False, 0
# 检查版本兼容性
is_compatible, compatibility_error = self._check_plugin_version_compatibility(
plugin_name, plugin_instance.manifest_data
)
if not is_compatible:
self.failed_plugins[plugin_name] = compatibility_error
logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}")
return False, 1
if plugin_instance.register_plugin():
self.loaded_plugins[plugin_name] = plugin_instance
self._show_plugin_components(plugin_name)
return True, 1
else:
self.failed_plugins[plugin_name] = "插件注册失败"
logger.error(f"❌ 插件注册失败: {plugin_name}")
return False, 1
except FileNotFoundError as e:
# manifest文件缺失
error_msg = f"缺少manifest文件: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
except ValueError as e:
# manifest文件格式错误或验证失败
traceback.print_exc()
error_msg = f"manifest验证失败: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
return False, 1
except Exception as e:
# 其他错误
error_msg = f"未知错误: {str(e)}"
self.failed_plugins[plugin_name] = error_msg
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
logger.debug("详细错误信息: ", exc_info=True)
return False, 1
async def remove_registered_plugin(self, plugin_name: str) -> bool:
"""
禁用插件模块
"""
if not plugin_name:
raise ValueError("插件名称不能为空")
if plugin_name not in self.loaded_plugins:
logger.warning(f"插件 {plugin_name} 未加载")
return False
plugin_instance = self.loaded_plugins[plugin_name]
plugin_info = plugin_instance.plugin_info
success = True
for component in plugin_info.components:
success &= await component_registry.remove_component(component.name, component.component_type, plugin_name)
success &= component_registry.remove_plugin_registry(plugin_name)
del self.loaded_plugins[plugin_name]
return success
async def reload_registered_plugin(self, plugin_name: str) -> bool:
"""
重载插件模块
"""
if not await self.remove_registered_plugin(plugin_name):
return False
if not self.load_registered_plugin_classes(plugin_name)[0]:
return False
logger.debug(f"插件 {plugin_name} 重载成功")
return True
def rescan_plugin_directory(self) -> Tuple[int, int]:
"""
重新扫描插件根目录
"""
total_success = 0
total_fail = 0
for directory in self.plugin_directories:
if os.path.exists(directory):
logger.debug(f"重新扫描插件根目录: {directory}")
success, fail = self._load_plugin_modules_from_directory(directory)
total_success += success
total_fail += fail
else:
logger.warning(f"插件根目录不存在: {directory}")
return total_success, total_fail
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
"""获取插件实例
Args:
plugin_name: 插件名称
Returns:
Optional[BasePlugin]: 插件实例或None
"""
return self.loaded_plugins.get(plugin_name)
# === 查询方法 ===
def list_loaded_plugins(self) -> List[str]:
"""
列出所有当前加载的插件。
Returns:
list: 当前加载的插件名称列表。
"""
return list(self.loaded_plugins.keys())
def list_registered_plugins(self) -> List[str]:
"""
列出所有已注册的插件类。
Returns:
list: 已注册的插件类名称列表。
"""
return list(self.plugin_classes.keys())
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
"""
获取指定插件的路径。
Args:
plugin_name: 插件名称
Returns:
Optional[str]: 插件目录的绝对路径如果插件不存在则返回None。
"""
return self.plugin_paths.get(plugin_name)
# === 私有方法 ===
# == 目录管理 ==
def _ensure_plugin_directories(self) -> None:
"""确保所有插件根目录存在,如果不存在则创建"""
default_directories = ["src/plugins/built_in", "plugins"]
for directory in default_directories:
if not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
logger.info(f"创建插件根目录: {directory}")
if directory not in self.plugin_directories:
self.plugin_directories.append(directory)
logger.debug(f"已添加插件根目录: {directory}")
else:
logger.warning(f"根目录不可重复加载: {directory}")
# == 插件加载 ==
def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]:
"""从指定目录加载插件模块"""
loaded_count = 0
failed_count = 0
if not os.path.exists(directory):
logger.warning(f"插件根目录不存在: {directory}")
return 0, 1
logger.debug(f"正在扫描插件根目录: {directory}")
# 遍历目录中的所有包
for item in os.listdir(directory):
item_path = os.path.join(directory, item)
if os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"):
plugin_file = os.path.join(item_path, "plugin.py")
if os.path.exists(plugin_file):
if self._load_plugin_module_file(plugin_file):
loaded_count += 1
else:
failed_count += 1
return loaded_count, failed_count
def _load_plugin_module_file(self, plugin_file: str) -> bool:
# sourcery skip: extract-method
"""加载单个插件模块文件
Args:
plugin_file: 插件文件路径
plugin_name: 插件名称
plugin_dir: 插件目录路径
"""
# 生成模块名
plugin_path = Path(plugin_file)
module_name = ".".join(plugin_path.parent.parts)
try:
# 动态导入插件模块
spec = spec_from_file_location(module_name, plugin_file)
if spec is None or spec.loader is None:
logger.error(f"无法创建模块规范: {plugin_file}")
return False
module = module_from_spec(spec)
module.__package__ = module_name # 设置模块包名
spec.loader.exec_module(module)
logger.debug(f"插件模块加载成功: {plugin_file}")
return True
except Exception as e:
error_msg = f"加载插件模块 {plugin_file} 失败: {e}"
logger.error(error_msg)
self.failed_plugins[module_name] = error_msg
return False
# == 兼容性检查 ==
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
"""检查插件版本兼容性
Args:
plugin_name: 插件名称
manifest_data: manifest数据
Returns:
Tuple[bool, str]: (是否兼容, 错误信息)
"""
if "host_application" not in manifest_data:
return True, "" # 没有版本要求,默认兼容
host_app = manifest_data["host_application"]
if not isinstance(host_app, dict):
return True, ""
min_version = host_app.get("min_version", "")
max_version = host_app.get("max_version", "")
if not min_version and not max_version:
return True, "" # 没有版本要求,默认兼容
try:
current_version = VersionComparator.get_current_host_version()
is_compatible, error_msg = VersionComparator.is_version_in_range(current_version, min_version, max_version)
if not is_compatible:
return False, f"版本不兼容: {error_msg}"
logger.debug(f"插件 {plugin_name} 版本兼容性检查通过")
return True, ""
except Exception as e:
logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}")
return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载
# == 显示统计与插件信息 ==
def _show_stats(self, total_registered: int, total_failed_registration: int):
# sourcery skip: low-code-quality
# 获取组件统计信息
stats = component_registry.get_registry_stats()
action_count = stats.get("action_components", 0)
command_count = stats.get("command_components", 0)
tool_count = stats.get("tool_components", 0)
event_handler_count = stats.get("event_handlers", 0)
total_components = stats.get("total_components", 0)
# 📋 显示插件加载总览
if total_registered > 0:
logger.info("🎉 插件系统加载完成!")
logger.info(
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, EventHandler: {event_handler_count})"
)
# 显示详细的插件列表
logger.info("📋 已加载插件详情:")
for plugin_name in self.loaded_plugins.keys():
if plugin_info := component_registry.get_plugin_info(plugin_name):
# 插件基本信息
version_info = f"v{plugin_info.version}" if plugin_info.version else ""
author_info = f"by {plugin_info.author}" if plugin_info.author else "unknown"
license_info = f"[{plugin_info.license}]" if plugin_info.license else ""
info_parts = [part for part in [version_info, author_info, license_info] if part]
extra_info = f" ({', '.join(info_parts)})" if info_parts else ""
logger.info(f" 📦 {plugin_info.display_name}{extra_info}")
# Manifest信息
if plugin_info.manifest_data:
"""
if plugin_info.keywords:
logger.info(f" 🏷️ 关键词: {', '.join(plugin_info.keywords)}")
if plugin_info.categories:
logger.info(f" 📁 分类: {', '.join(plugin_info.categories)}")
"""
if plugin_info.homepage_url:
logger.info(f" 🌐 主页: {plugin_info.homepage_url}")
# 组件列表
if plugin_info.components:
action_components = [
c for c in plugin_info.components if c.component_type == ComponentType.ACTION
]
command_components = [
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
]
tool_components = [
c for c in plugin_info.components if c.component_type == ComponentType.TOOL
]
event_handler_components = [
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
]
if action_components:
action_names = [c.name for c in action_components]
logger.info(f" 🎯 Action组件: {', '.join(action_names)}")
if command_components:
command_names = [c.name for c in command_components]
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
if tool_components:
tool_names = [c.name for c in tool_components]
logger.info(f" 🛠️ Tool组件: {', '.join(tool_names)}")
if event_handler_components:
event_handler_names = [c.name for c in event_handler_components]
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
# 依赖信息
if plugin_info.dependencies:
logger.info(f" 🔗 依赖: {', '.join(plugin_info.dependencies)}")
# 配置文件信息
if plugin_info.config_file:
config_status = "" if self.plugin_paths.get(plugin_name) else ""
logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}")
root_path = Path(__file__)
# 查找项目根目录
while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path:
root_path = root_path.parent
# 显示目录统计
logger.info("📂 加载目录统计:")
for directory in self.plugin_directories:
if os.path.exists(directory):
plugins_in_dir = []
for plugin_name in self.loaded_plugins.keys():
plugin_path = self.plugin_paths.get(plugin_name, "")
if (
Path(plugin_path)
.resolve()
.is_relative_to(Path(os.path.join(str(root_path), directory)).resolve())
):
plugins_in_dir.append(plugin_name)
if plugins_in_dir:
logger.info(f" 📁 {directory}: {len(plugins_in_dir)}个插件 ({', '.join(plugins_in_dir)})")
else:
logger.info(f" 📁 {directory}: 0个插件")
# 失败信息
if total_failed_registration > 0:
logger.info(f"⚠️ 失败统计: {total_failed_registration}个插件加载失败")
for failed_plugin, error in self.failed_plugins.items():
logger.info(f"{failed_plugin}: {error}")
else:
logger.warning("😕 没有成功加载任何插件")
def _show_plugin_components(self, plugin_name: str) -> None:
if plugin_info := component_registry.get_plugin_info(plugin_name):
component_types = {}
for comp in plugin_info.components:
comp_type = comp.component_type.name
component_types[comp_type] = component_types.get(comp_type, 0) + 1
components_str = ", ".join([f"{count}{ctype}" for ctype, count in component_types.items()])
# 显示manifest信息
manifest_info = ""
if plugin_info.license:
manifest_info += f" [{plugin_info.license}]"
if plugin_info.keywords:
manifest_info += f" 关键词: {', '.join(plugin_info.keywords[:3])}" # 只显示前3个关键词
if len(plugin_info.keywords) > 3:
manifest_info += "..."
logger.info(
f"✅ 插件加载成功: {plugin_name} v{plugin_info.version} ({components_str}){manifest_info} - {plugin_info.description}"
)
else:
logger.info(f"✅ 插件加载成功: {plugin_name}")
# === 插件卸载和重载管理 ===
def unload_plugin(self, plugin_name: str) -> bool:
"""卸载指定插件
Args:
plugin_name: 插件名称
Returns:
bool: 卸载是否成功
"""
if plugin_name not in self.loaded_plugins:
logger.warning(f"插件 {plugin_name} 未加载,无需卸载")
return False
try:
# 获取插件实例
plugin_instance = self.loaded_plugins[plugin_name]
# 调用插件的清理方法(如果有的话)
if hasattr(plugin_instance, 'on_unload'):
plugin_instance.on_unload()
# 从组件注册表中移除插件的所有组件
component_registry.unregister_plugin(plugin_name)
# 从已加载插件中移除
del self.loaded_plugins[plugin_name]
# 从失败列表中移除(如果存在)
if plugin_name in self.failed_plugins:
del self.failed_plugins[plugin_name]
logger.info(f"✅ 插件卸载成功: {plugin_name}")
return True
except Exception as e:
logger.error(f"❌ 插件卸载失败: {plugin_name} - {str(e)}")
return False
def reload_plugin(self, plugin_name: str) -> bool:
"""重载指定插件
Args:
plugin_name: 插件名称
Returns:
bool: 重载是否成功
"""
try:
# 先卸载插件
if plugin_name in self.loaded_plugins:
self.unload_plugin(plugin_name)
# 清除Python模块缓存
plugin_path = self.plugin_paths.get(plugin_name)
if plugin_path:
plugin_file = os.path.join(plugin_path, "plugin.py")
if os.path.exists(plugin_file):
# 从sys.modules中移除相关模块
modules_to_remove = []
plugin_module_prefix = ".".join(Path(plugin_file).parent.parts)
for module_name in sys.modules:
if module_name.startswith(plugin_module_prefix):
modules_to_remove.append(module_name)
for module_name in modules_to_remove:
del sys.modules[module_name]
# 从插件类注册表中移除
if plugin_name in self.plugin_classes:
del self.plugin_classes[plugin_name]
# 重新加载插件模块
if self._load_plugin_module_file(plugin_file):
# 重新加载插件实例
success, _ = self.load_registered_plugin_classes(plugin_name)
if success:
logger.info(f"🔄 插件重载成功: {plugin_name}")
return True
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 实例化失败")
return False
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 模块加载失败")
return False
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件文件不存在")
return False
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件路径未知")
return False
except Exception as e:
logger.error(f"❌ 插件重载失败: {plugin_name} - {str(e)}")
logger.debug("详细错误信息: ", exc_info=True)
return False
# 全局插件管理器实例
plugin_manager = PluginManager()

View File

@@ -0,0 +1,421 @@
import time
from typing import List, Dict, Tuple, Optional, Any
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content import ToolCall
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
logger = get_logger("tool_use")
def init_tool_executor_prompt():
"""初始化工具执行器的提示词"""
tool_executor_prompt = """
你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now}
群里正在进行的聊天内容:
{chat_history}
现在,{sender}发送了内容:{target_message},你想要回复ta。
请仔细分析聊天内容,考虑以下几点:
1. 内容中是否包含需要查询信息的问题
2. 是否有明确的工具使用指令
If you need to use a tool, please directly call the corresponding tool function. If you do not need to use any tool, simply output "No tool needed".
"""
Prompt(tool_executor_prompt, "tool_executor_prompt")
# 初始化提示词
init_tool_executor_prompt()
class ToolExecutor:
"""独立的工具执行器组件
可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。
"""
def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3):
"""初始化工具执行器
Args:
executor_id: 执行器标识符,用于日志记录
enable_cache: 是否启用缓存机制
cache_ttl: 缓存生存时间(周期数)
"""
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(self.chat_id)
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
# 缓存配置
self.enable_cache = enable_cache
self.cache_ttl = cache_ttl
self.tool_cache = {} # 格式: {cache_key: {"result": result, "ttl": ttl, "timestamp": timestamp}}
logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'}TTL={cache_ttl}")
async def execute_from_chat_message(
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
) -> Tuple[List[Dict[str, Any]], List[str], str]:
"""从聊天消息执行工具
Args:
target_message: 目标消息内容
chat_history: 聊天历史
sender: 发送者
return_details: 是否返回详细信息(使用的工具列表和提示词)
Returns:
如果return_details为False: Tuple[List[Dict], List[str], str] - (工具执行结果列表, 空, 空)
如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词)
"""
# 首先检查缓存
cache_key = self._generate_cache_key(target_message, chat_history, sender)
if cached_result := self._get_from_cache(cache_key):
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
if not return_details:
return cached_result, [], ""
# 从缓存结果中提取工具名称
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
return cached_result, used_tools, ""
# 缓存未命中,执行工具调用
# 获取可用工具
tools = self._get_tool_definitions()
# 获取当前时间
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
bot_name = global_config.bot.nickname
# 构建工具调用提示词
prompt = await global_prompt_manager.format_prompt(
"tool_executor_prompt",
target_message=target_message,
chat_history=chat_history,
sender=sender,
bot_name=bot_name,
time_now=time_now,
)
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
# 调用LLM进行工具决策
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
prompt=prompt, tools=tools, raise_when_empty=False
)
# 执行工具调用
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
# 缓存结果
if tool_results:
self._set_cache(cache_key, tool_results)
if used_tools:
logger.info(f"{self.log_prefix}工具执行完成,共执行{len(used_tools)}个工具: {used_tools}")
if return_details:
return tool_results, used_tools, prompt
else:
return tool_results, [], ""
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
all_tools = get_llm_available_tool_definitions()
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
return [definition for name, definition in all_tools if name not in user_disabled_tools]
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
"""执行工具调用
Args:
tool_calls: LLM返回的工具调用列表
Returns:
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
"""
tool_results: List[Dict[str, Any]] = []
used_tools = []
if not tool_calls:
logger.debug(f"{self.log_prefix}无需执行工具")
return [], []
# 提取tool_calls中的函数名称
func_names = [call.func_name for call in tool_calls if call.func_name]
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
# 执行每个工具调用
for tool_call in tool_calls:
try:
tool_name = tool_call.func_name
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
# 执行工具
result = await self.execute_tool_call(tool_call)
if result:
tool_info = {
"type": result.get("type", "unknown_type"),
"id": result.get("id", f"tool_exec_{time.time()}"),
"content": result.get("content", ""),
"tool_name": tool_name,
"timestamp": time.time(),
}
content = tool_info["content"]
if not isinstance(content, (str, list, tuple)):
tool_info["content"] = str(content)
tool_results.append(tool_info)
used_tools.append(tool_name)
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
preview = content[:200]
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
except Exception as e:
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
# 添加错误信息到结果中
error_info = {
"type": "tool_error",
"id": f"tool_error_{time.time()}",
"content": f"工具{tool_name}执行失败: {str(e)}",
"tool_name": tool_name,
"timestamp": time.time(),
}
tool_results.append(error_info)
return tool_results, used_tools
async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
# sourcery skip: use-assigned-variable
"""执行单个工具调用
Args:
tool_call: 工具调用对象
Returns:
Optional[Dict]: 工具调用结果如果失败则返回None
"""
try:
function_name = tool_call.func_name
function_args = tool_call.args or {}
function_args["llm_called"] = True # 标记为LLM调用
# 获取对应工具实例
tool_instance = tool_instance or get_tool_instance(function_name)
if not tool_instance:
logger.warning(f"未知工具名称: {function_name}")
return None
# 执行工具
result = await tool_instance.execute(function_args)
if result:
return {
"tool_call_id": tool_call.call_id,
"role": "tool",
"name": function_name,
"type": "function",
"content": result["content"],
}
return None
except Exception as e:
logger.error(f"执行工具调用时发生错误: {str(e)}")
raise e
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
"""生成缓存键
Args:
target_message: 目标消息内容
chat_history: 聊天历史
sender: 发送者
Returns:
str: 缓存键
"""
import hashlib
# 使用消息内容和群聊状态生成唯一缓存键
content = f"{target_message}_{chat_history}_{sender}"
return hashlib.md5(content.encode()).hexdigest()
def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]:
"""从缓存获取结果
Args:
cache_key: 缓存键
Returns:
Optional[List[Dict]]: 缓存的结果如果不存在或过期则返回None
"""
if not self.enable_cache or cache_key not in self.tool_cache:
return None
cache_item = self.tool_cache[cache_key]
if cache_item["ttl"] <= 0:
# 缓存过期,删除
del self.tool_cache[cache_key]
logger.debug(f"{self.log_prefix}缓存过期,删除缓存键: {cache_key}")
return None
# 减少TTL
cache_item["ttl"] -= 1
logger.debug(f"{self.log_prefix}使用缓存结果剩余TTL: {cache_item['ttl']}")
return cache_item["result"]
def _set_cache(self, cache_key: str, result: List[Dict]):
"""设置缓存
Args:
cache_key: 缓存键
result: 要缓存的结果
"""
if not self.enable_cache:
return
self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()}
logger.debug(f"{self.log_prefix}设置缓存TTL: {self.cache_ttl}")
def _cleanup_expired_cache(self):
"""清理过期的缓存"""
if not self.enable_cache:
return
expired_keys = []
expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0)
for key in expired_keys:
del self.tool_cache[key]
if expired_keys:
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
"""直接执行指定工具
Args:
tool_name: 工具名称
tool_args: 工具参数
validate_args: 是否验证参数
Returns:
Optional[Dict]: 工具执行结果失败时返回None
"""
try:
tool_call = ToolCall(
call_id=f"direct_tool_{time.time()}",
func_name=tool_name,
args=tool_args,
)
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
result = await self.execute_tool_call(tool_call)
if result:
tool_info = {
"type": result.get("type", "unknown_type"),
"id": result.get("id", f"direct_tool_{time.time()}"),
"content": result.get("content", ""),
"tool_name": tool_name,
"timestamp": time.time(),
}
logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}")
return tool_info
except Exception as e:
logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}")
return None
def clear_cache(self):
"""清空所有缓存"""
if self.enable_cache:
cache_count = len(self.tool_cache)
self.tool_cache.clear()
logger.info(f"{self.log_prefix}清空了{cache_count}个缓存项")
def get_cache_status(self) -> Dict:
"""获取缓存状态信息
Returns:
Dict: 包含缓存统计信息的字典
"""
if not self.enable_cache:
return {"enabled": False, "cache_count": 0}
# 清理过期缓存
self._cleanup_expired_cache()
total_count = len(self.tool_cache)
ttl_distribution = {}
for cache_item in self.tool_cache.values():
ttl = cache_item["ttl"]
ttl_distribution[ttl] = ttl_distribution.get(ttl, 0) + 1
return {
"enabled": True,
"cache_count": total_count,
"cache_ttl": self.cache_ttl,
"ttl_distribution": ttl_distribution,
}
def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1):
"""动态修改缓存配置
Args:
enable_cache: 是否启用缓存
cache_ttl: 缓存TTL
"""
if enable_cache is not None:
self.enable_cache = enable_cache
logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}")
if cache_ttl > 0:
self.cache_ttl = cache_ttl
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
"""
ToolExecutor使用示例
# 1. 基础使用 - 从聊天消息执行工具启用缓存默认TTL=3
executor = ToolExecutor(executor_id="my_executor")
results, _, _ = await executor.execute_from_chat_message(
talking_message_str="今天天气怎么样?现在几点了?",
is_group_chat=False
)
# 2. 禁用缓存的执行器
no_cache_executor = ToolExecutor(executor_id="no_cache", enable_cache=False)
# 3. 自定义缓存TTL
long_cache_executor = ToolExecutor(executor_id="long_cache", cache_ttl=10)
# 4. 获取详细信息
results, used_tools, prompt = await executor.execute_from_chat_message(
talking_message_str="帮我查询Python相关知识",
is_group_chat=False,
return_details=True
)
# 5. 直接执行特定工具
result = await executor.execute_specific_tool_simple(
tool_name="get_knowledge",
tool_args={"query": "机器学习"}
)
# 6. 缓存管理
cache_status = executor.get_cache_status() # 查看缓存状态
executor.clear_cache() # 清空缓存
executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置
"""