diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index 8ede9616a..8093bc885 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -1,9 +1,11 @@ from typing import List, Tuple, Type +from src.plugin_system.apis import tool_api from src.plugin_system import ( BasePlugin, register_plugin, BaseAction, BaseCommand, + BaseTool, ComponentInfo, ActionActivationType, ConfigField, @@ -12,6 +14,32 @@ from src.plugin_system import ( MaiMessages, ) +class HelloTool(BaseTool): + """问候工具 - 用于发送问候消息""" + + name = "hello_tool" + description = "发送问候消息" + parameters = { + "type": "object", + "properties": { + "greeting_message": { + "type": "string", + "description": "要发送的问候消息" + }, + }, + "required": ["greeting_message"] + } + available_for_llm = True + + + async def execute(self, function_args): + """执行问候工具""" + import random + greeting_message = random.choice(function_args.get("greeting_message", ["嗨!很高兴见到你!😊"])) + return { + "name": self.name, + "content": greeting_message + } # ===== Action组件 ===== class HelloAction(BaseAction): @@ -30,7 +58,10 @@ class HelloAction(BaseAction): async def execute(self) -> Tuple[bool, str]: """执行问候动作 - 这是核心功能""" # 发送问候消息 - greeting_message = self.action_data.get("greeting_message", "") + hello_tool = tool_api.get_tool_instance("hello_tool") + greeting_message = await hello_tool.execute({ + "greeting_message": self.action_data.get("greeting_message", "") + }) base_message = self.get_config("greeting.message", "嗨!很开心见到你!😊") message = base_message + greeting_message await self.send_text(message) @@ -132,7 +163,7 @@ class HelloWorldPlugin(BasePlugin): "enabled": ConfigField(type=bool, default=False, description="是否启用插件"), }, "greeting": { - "message": ConfigField(type=str, default="嗨!很开心见到你!😊", description="默认问候消息"), + "message": ConfigField(type=list, default=["嗨!很开心见到你!😊","Ciallo~(∠・ω< )⌒★"], description="默认问候消息"), "enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"), }, "time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")}, @@ -142,6 +173,7 @@ class HelloWorldPlugin(BasePlugin): def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: return [ (HelloAction.get_action_info(), HelloAction), + (HelloTool.get_tool_info(), HelloTool), # 添加问候工具 (ByeAction.get_action_info(), ByeAction), # 添加告别Action (TimeCommand.get_command_info(), TimeCommand), (PrintMessage.get_handler_info(), PrintMessage), diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 0e99b6b3a..51313d4e1 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -29,7 +29,6 @@ from src.chat.memory_system.instant_memory import InstantMemory from src.mood.mood_manager import mood_manager from src.person_info.relationship_fetcher import relationship_fetcher_manager from src.person_info.person_info import get_person_info_manager -from src.tools.tool_executor import ToolExecutor from src.plugin_system.base.component_types import ActionInfo logger = get_logger("replyer") @@ -139,6 +138,8 @@ class DefaultReplyer: self.heart_fc_sender = HeartFCSender() self.memory_activator = MemoryActivator() self.instant_memory = InstantMemory(chat_id=self.chat_stream.stream_id) + + from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖 self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3) def _select_weighted_model_config(self) -> Dict[str, Any]: diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index cb73d8e6c..cd13bdbab 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -9,6 +9,7 @@ from .base import ( BasePlugin, BaseAction, BaseCommand, + BaseTool, ConfigField, ComponentType, ActionActivationType, @@ -34,6 +35,7 @@ from .utils import ( from .apis import ( chat_api, + tool_api, component_manage_api, config_api, database_api, @@ -54,6 +56,7 @@ __version__ = "1.0.0" __all__ = [ # API 模块 "chat_api", + "tool_api", "component_manage_api", "config_api", "database_api", @@ -70,6 +73,7 @@ __all__ = [ "BasePlugin", "BaseAction", "BaseCommand", + "BaseTool", "BaseEventHandler", # 类型定义 "ComponentType", diff --git a/src/plugin_system/apis/tool_api.py b/src/plugin_system/apis/tool_api.py new file mode 100644 index 000000000..09fee548e --- /dev/null +++ b/src/plugin_system/apis/tool_api.py @@ -0,0 +1,25 @@ +from typing import Optional +from src.plugin_system.base.base_tool import BaseTool +from src.plugin_system.base.component_types import ComponentType + +from src.common.logger import get_logger + +logger = get_logger("tool_api") + +def get_tool_instance(tool_name: str) -> Optional[BaseTool]: + """获取公开工具实例""" + from src.plugin_system.core import component_registry + + tool_class = component_registry.get_component_class(tool_name, ComponentType.TOOL) + if not tool_class: + return None + + return tool_class() + +def get_llm_available_tool_definitions(): + from src.plugin_system.core import component_registry + + llm_available_tools = component_registry.get_llm_available_tools() + return [tool_class().get_tool_definition() for tool_class in llm_available_tools.values()] + + diff --git a/src/plugin_system/base/__init__.py b/src/plugin_system/base/__init__.py index a95e05aed..b9a2893e4 100644 --- a/src/plugin_system/base/__init__.py +++ b/src/plugin_system/base/__init__.py @@ -6,6 +6,7 @@ from .base_plugin import BasePlugin from .base_action import BaseAction +from .base_tool import BaseTool from .base_command import BaseCommand from .base_events_handler import BaseEventHandler from .component_types import ( @@ -15,6 +16,7 @@ from .component_types import ( ComponentInfo, ActionInfo, CommandInfo, + ToolInfo, PluginInfo, PythonDependency, EventHandlerInfo, @@ -27,12 +29,14 @@ __all__ = [ "BasePlugin", "BaseAction", "BaseCommand", + "BaseTool", "ComponentType", "ActionActivationType", "ChatMode", "ComponentInfo", "ActionInfo", "CommandInfo", + "ToolInfo", "PluginInfo", "PythonDependency", "ConfigField", diff --git a/src/plugin_system/base/base_tool.py b/src/plugin_system/base/base_tool.py new file mode 100644 index 000000000..b2f219629 --- /dev/null +++ b/src/plugin_system/base/base_tool.py @@ -0,0 +1,62 @@ +from typing import List, Any, Optional, Type +from src.common.logger import get_logger +from rich.traceback import install +from src.plugin_system.base.component_types import ComponentType, ToolInfo +install(extra_lines=3) + +logger = get_logger("base_tool") + + + +class BaseTool: + """所有工具的基类""" + + # 工具名称,子类必须重写 + name = None + # 工具描述,子类必须重写 + description = None + # 工具参数定义,子类必须重写 + parameters = None + # 是否可供LLM使用,默认为False + available_for_llm = False + + @classmethod + def get_tool_definition(cls) -> dict[str, Any]: + """获取工具定义,用于LLM工具调用 + + Returns: + dict: 工具定义字典 + """ + if not cls.name or not cls.description or not cls.parameters: + raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性") + + return { + "type": "function", + "function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters}, + } + + @classmethod + def get_tool_info(cls) -> ToolInfo: + """获取工具信息""" + if not cls.name or not cls.description: + raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name 和 description 属性") + + return ToolInfo( + name=cls.name, + tool_description=cls.description, + available_for_llm=cls.available_for_llm, + tool_parameters=cls.parameters, + component_type=ComponentType.TOOL, + ) + + # 工具参数定义,子类必须重写 + async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: + """执行工具函数 + + Args: + function_args: 工具调用参数 + + Returns: + dict: 工具执行结果 + """ + raise NotImplementedError("子类必须实现execute方法") diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index eeb2a5a08..3ecb15a0a 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -10,6 +10,7 @@ class ComponentType(Enum): ACTION = "action" # 动作组件 COMMAND = "command" # 命令组件 + TOOL = "tool" # 服务组件(预留) SCHEDULER = "scheduler" # 定时任务组件(预留) EVENT_HANDLER = "event_handler" # 事件处理组件(预留) @@ -144,7 +145,18 @@ class CommandInfo(ComponentInfo): def __post_init__(self): super().__post_init__() self.component_type = ComponentType.COMMAND + +@dataclass +class ToolInfo(ComponentInfo): + """工具组件信息""" + tool_parameters: Dict[str, Any] = field(default_factory=dict) # 工具参数定义 + available_for_llm: bool = False # 是否可供LLM使用 + tool_description: str = "" # 工具描述 + + def __post_init__(self): + super().__post_init__() + self.component_type = ComponentType.TOOL @dataclass class EventHandlerInfo(ComponentInfo): diff --git a/src/plugin_system/core/__init__.py b/src/plugin_system/core/__init__.py index eb794a30b..3eecad418 100644 --- a/src/plugin_system/core/__init__.py +++ b/src/plugin_system/core/__init__.py @@ -8,10 +8,12 @@ from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.events_manager import events_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager +from src.plugin_system.core.tool_use import tool_user __all__ = [ "plugin_manager", "component_registry", "events_manager", "global_announcement_manager", + "tool_user", ] diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 2ea89b880..832739f1d 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -6,6 +6,7 @@ from src.common.logger import get_logger from src.plugin_system.base.component_types import ( ComponentInfo, ActionInfo, + ToolInfo, CommandInfo, EventHandlerInfo, PluginInfo, @@ -13,6 +14,7 @@ from src.plugin_system.base.component_types import ( ) from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.base_action import BaseAction +from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.base_events_handler import BaseEventHandler logger = get_logger("component_registry") @@ -30,7 +32,7 @@ class ComponentRegistry: """组件注册表 命名空间式组件名 -> 组件信息""" self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType} """类型 -> 组件原名称 -> 组件信息""" - self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {} + self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler]]] = {} """命名空间式组件名 -> 组件类""" # 插件注册表 @@ -49,6 +51,10 @@ class ComponentRegistry: self._command_patterns: Dict[Pattern, str] = {} """编译后的正则 -> command名""" + # 工具特定注册表 + self._tool_registry: Dict[str, Type[BaseTool]] = {} # 工具名 -> 工具类 + self._llm_available_tools: Dict[str, Type[BaseTool]] = {} # llm可用的工具名 -> 工具类 + # EventHandler特定注册表 self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {} """event_handler名 -> event_handler类""" @@ -125,6 +131,10 @@ class ComponentRegistry: assert isinstance(component_info, CommandInfo) assert issubclass(component_class, BaseCommand) ret = self._register_command_component(component_info, component_class) + case ComponentType.TOOL: + assert isinstance(component_info, ToolInfo) + assert issubclass(component_class, BaseTool) + ret = self._register_tool_component(component_info, component_class) case ComponentType.EVENT_HANDLER: assert isinstance(component_info, EventHandlerInfo) assert issubclass(component_class, BaseEventHandler) @@ -180,6 +190,17 @@ class ComponentRegistry: return True + def _register_tool_component(self, tool_info: ToolInfo, tool_class: BaseTool): + """注册Tool组件到Tool特定注册表""" + tool_name = tool_info.name + self._tool_registry[tool_name] = tool_class + + # 如果是llm可用的且启用的工具,添加到 llm可用工具列表 + if tool_info.available_for_llm and tool_info.enabled: + self._llm_available_tools[tool_name] = tool_class + + return True + def _register_event_handler_component( self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler] ) -> bool: @@ -475,7 +496,28 @@ class ComponentRegistry: candidates[0].match(text).groupdict(), # type: ignore command_info, ) + + # === Tool 特定查询方法 === + def get_tool_registry(self) -> Dict[str, Type[BaseTool]]: + """获取Tool注册表""" + return self._tool_registry.copy() + + def get_llm_available_tools(self) -> Dict[str, str]: + """获取LLM可用的Tool列表""" + return self._llm_available_tools.copy() + def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]: + """获取Tool信息 + + Args: + tool_name: 工具名称 + + Returns: + ToolInfo: 工具信息对象,如果工具不存在则返回 None + """ + info = self.get_component_info(tool_name, ComponentType.TOOL) + return info if isinstance(info, ToolInfo) else None + # === EventHandler 特定查询方法 === def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]: @@ -529,17 +571,21 @@ class ComponentRegistry: """获取注册中心统计信息""" action_components: int = 0 command_components: int = 0 - events_handlers: int = 0 + tool_components: int = 0 + events_handlers: int = 0 for component in self._components.values(): if component.component_type == ComponentType.ACTION: action_components += 1 elif component.component_type == ComponentType.COMMAND: command_components += 1 + elif component.component_type == ComponentType.TOOL: + tool_components += 1 elif component.component_type == ComponentType.EVENT_HANDLER: events_handlers += 1 return { "action_components": action_components, "command_components": command_components, + "tool_components": tool_components, "event_handlers": events_handlers, "total_components": len(self._components), "total_plugins": len(self._plugins), diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index ded03a18a..014b7a0cc 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -358,6 +358,7 @@ class PluginManager: stats = component_registry.get_registry_stats() action_count = stats.get("action_components", 0) command_count = stats.get("command_components", 0) + tool_count = stats.get("tool_components", 0) event_handler_count = stats.get("event_handlers", 0) total_components = stats.get("total_components", 0) @@ -365,7 +366,7 @@ class PluginManager: if total_registered > 0: logger.info("🎉 插件系统加载完成!") logger.info( - f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, EventHandler: {event_handler_count})" + f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, EventHandler: {event_handler_count})" ) # 显示详细的插件列表 @@ -400,6 +401,9 @@ class PluginManager: command_components = [ c for c in plugin_info.components if c.component_type == ComponentType.COMMAND ] + tool_components = [ + c for c in plugin_info.components if c.component_type == ComponentType.TOOL + ] event_handler_components = [ c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER ] @@ -411,7 +415,9 @@ class PluginManager: if command_components: command_names = [c.name for c in command_components] logger.info(f" ⚡ Command组件: {', '.join(command_names)}") - + if tool_components: + tool_names = [c.name for c in tool_components] + logger.info(f" 🛠️ Tool组件: {', '.join(tool_names)}") if event_handler_components: event_handler_names = [c.name for c in event_handler_components] logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}") diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py new file mode 100644 index 000000000..bec600190 --- /dev/null +++ b/src/plugin_system/core/tool_use.py @@ -0,0 +1,458 @@ +import json +import time +from typing import List, Dict, Tuple, Optional +from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions,get_tool_instance +from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.json_utils import process_llm_tool_calls +from src.chat.message_receive.chat_stream import get_chat_manager +from src.common.logger import get_logger + +logger = get_logger("tool_use") + +def init_tool_executor_prompt(): + """初始化工具执行器的提示词""" + tool_executor_prompt = """ +你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now}。 +群里正在进行的聊天内容: +{chat_history} + +现在,{sender}发送了内容:{target_message},你想要回复ta。 +请仔细分析聊天内容,考虑以下几点: +1. 内容中是否包含需要查询信息的问题 +2. 是否有明确的工具使用指令 + +If you need to use a tool, please directly call the corresponding tool function. If you do not need to use any tool, simply output "No tool needed". +""" + Prompt(tool_executor_prompt, "tool_executor_prompt") + +# 初始化提示词 +init_tool_executor_prompt() + +class ToolExecutor: + """独立的工具执行器组件 + + 可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。 + """ + + def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3): + """初始化工具执行器 + + Args: + executor_id: 执行器标识符,用于日志记录 + enable_cache: 是否启用缓存机制 + cache_ttl: 缓存生存时间(周期数) + """ + self.chat_id = chat_id + self.chat_stream = get_chat_manager().get_stream(self.chat_id) + self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]" + + self.llm_model = LLMRequest( + model=global_config.model.tool_use, + request_type="tool_executor", + ) + + # 初始化工具实例 + self.tool_instance = ToolUser() + + # 缓存配置 + self.enable_cache = enable_cache + self.cache_ttl = cache_ttl + self.tool_cache = {} # 格式: {cache_key: {"result": result, "ttl": ttl, "timestamp": timestamp}} + + logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'},TTL={cache_ttl}") + + async def execute_from_chat_message( + self, target_message: str, chat_history: str, sender: str, return_details: bool = False + ) -> Tuple[List[Dict], List[str], str]: + """从聊天消息执行工具 + + Args: + target_message: 目标消息内容 + chat_history: 聊天历史 + sender: 发送者 + return_details: 是否返回详细信息(使用的工具列表和提示词) + + Returns: + 如果return_details为False: List[Dict] - 工具执行结果列表 + 如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词) + """ + + # 首先检查缓存 + cache_key = self._generate_cache_key(target_message, chat_history, sender) + if cached_result := self._get_from_cache(cache_key): + logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行") + if not return_details: + return cached_result, [], "使用缓存结果" + + # 从缓存结果中提取工具名称 + used_tools = [result.get("tool_name", "unknown") for result in cached_result] + return cached_result, used_tools, "使用缓存结果" + + # 缓存未命中,执行工具调用 + # 获取可用工具 + tools = self.tool_instance._define_tools() + + # 获取当前时间 + time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + + bot_name = global_config.bot.nickname + + # 构建工具调用提示词 + prompt = await global_prompt_manager.format_prompt( + "tool_executor_prompt", + target_message=target_message, + chat_history=chat_history, + sender=sender, + bot_name=bot_name, + time_now=time_now, + ) + + logger.debug(f"{self.log_prefix}开始LLM工具调用分析") + + # 调用LLM进行工具决策 + response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools) + + # 解析LLM响应 + if len(other_info) == 3: + reasoning_content, model_name, tool_calls = other_info + else: + reasoning_content, model_name = other_info + tool_calls = None + + # 执行工具调用 + tool_results, used_tools = await self._execute_tool_calls(tool_calls) + + # 缓存结果 + if tool_results: + self._set_cache(cache_key, tool_results) + + if used_tools: + logger.info(f"{self.log_prefix}工具执行完成,共执行{len(used_tools)}个工具: {used_tools}") + + if return_details: + return tool_results, used_tools, prompt + else: + return tool_results, [], "" + + async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]: + """执行工具调用 + + Args: + tool_calls: LLM返回的工具调用列表 + + Returns: + Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表) + """ + tool_results = [] + used_tools = [] + + if not tool_calls: + logger.debug(f"{self.log_prefix}无需执行工具") + return tool_results, used_tools + + logger.info(f"{self.log_prefix}开始执行工具调用: {tool_calls}") + + # 处理工具调用 + success, valid_tool_calls, error_msg = process_llm_tool_calls(tool_calls) + + if not success: + logger.error(f"{self.log_prefix}工具调用解析失败: {error_msg}") + return tool_results, used_tools + + if not valid_tool_calls: + logger.debug(f"{self.log_prefix}无有效工具调用") + return tool_results, used_tools + + # 执行每个工具调用 + for tool_call in valid_tool_calls: + try: + tool_name = tool_call.get("name", "unknown_tool") + used_tools.append(tool_name) + + logger.debug(f"{self.log_prefix}执行工具: {tool_name}") + + # 执行工具 + result = await self.tool_instance.execute_tool_call(tool_call) + + if result: + tool_info = { + "type": result.get("type", "unknown_type"), + "id": result.get("id", f"tool_exec_{time.time()}"), + "content": result.get("content", ""), + "tool_name": tool_name, + "timestamp": time.time(), + } + tool_results.append(tool_info) + + logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}") + content = tool_info["content"] + if not isinstance(content, (str, list, tuple)): + content = str(content) + preview = content[:200] + logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...") + + except Exception as e: + logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}") + # 添加错误信息到结果中 + error_info = { + "type": "tool_error", + "id": f"tool_error_{time.time()}", + "content": f"工具{tool_name}执行失败: {str(e)}", + "tool_name": tool_name, + "timestamp": time.time(), + } + tool_results.append(error_info) + + return tool_results, used_tools + + def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str: + """生成缓存键 + + Args: + target_message: 目标消息内容 + chat_history: 聊天历史 + sender: 发送者 + + Returns: + str: 缓存键 + """ + import hashlib + + # 使用消息内容和群聊状态生成唯一缓存键 + content = f"{target_message}_{chat_history}_{sender}" + return hashlib.md5(content.encode()).hexdigest() + + def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]: + """从缓存获取结果 + + Args: + cache_key: 缓存键 + + Returns: + Optional[List[Dict]]: 缓存的结果,如果不存在或过期则返回None + """ + if not self.enable_cache or cache_key not in self.tool_cache: + return None + + cache_item = self.tool_cache[cache_key] + if cache_item["ttl"] <= 0: + # 缓存过期,删除 + del self.tool_cache[cache_key] + logger.debug(f"{self.log_prefix}缓存过期,删除缓存键: {cache_key}") + return None + + # 减少TTL + cache_item["ttl"] -= 1 + logger.debug(f"{self.log_prefix}使用缓存结果,剩余TTL: {cache_item['ttl']}") + return cache_item["result"] + + def _set_cache(self, cache_key: str, result: List[Dict]): + """设置缓存 + + Args: + cache_key: 缓存键 + result: 要缓存的结果 + """ + if not self.enable_cache: + return + + self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()} + logger.debug(f"{self.log_prefix}设置缓存,TTL: {self.cache_ttl}") + + def _cleanup_expired_cache(self): + """清理过期的缓存""" + if not self.enable_cache: + return + + expired_keys = [] + expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0) + for key in expired_keys: + del self.tool_cache[key] + + if expired_keys: + logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存") + + def get_available_tools(self) -> List[str]: + """获取可用工具列表 + + Returns: + List[str]: 可用工具名称列表 + """ + tools = self.tool_instance._define_tools() + return [tool.get("function", {}).get("name", "unknown") for tool in tools] + + async def execute_specific_tool( + self, tool_name: str, tool_args: Dict, validate_args: bool = True + ) -> Optional[Dict]: + """直接执行指定工具 + + Args: + tool_name: 工具名称 + tool_args: 工具参数 + validate_args: 是否验证参数 + + Returns: + Optional[Dict]: 工具执行结果,失败时返回None + """ + try: + tool_call = {"name": tool_name, "arguments": tool_args} + + logger.info(f"{self.log_prefix}直接执行工具: {tool_name}") + + result = await self.tool_instance.execute_tool_call(tool_call) + + if result: + tool_info = { + "type": result.get("type", "unknown_type"), + "id": result.get("id", f"direct_tool_{time.time()}"), + "content": result.get("content", ""), + "tool_name": tool_name, + "timestamp": time.time(), + } + logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}") + return tool_info + + except Exception as e: + logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}") + + return None + + def clear_cache(self): + """清空所有缓存""" + if self.enable_cache: + cache_count = len(self.tool_cache) + self.tool_cache.clear() + logger.info(f"{self.log_prefix}清空了{cache_count}个缓存项") + + def get_cache_status(self) -> Dict: + """获取缓存状态信息 + + Returns: + Dict: 包含缓存统计信息的字典 + """ + if not self.enable_cache: + return {"enabled": False, "cache_count": 0} + + # 清理过期缓存 + self._cleanup_expired_cache() + + total_count = len(self.tool_cache) + ttl_distribution = {} + + for cache_item in self.tool_cache.values(): + ttl = cache_item["ttl"] + ttl_distribution[ttl] = ttl_distribution.get(ttl, 0) + 1 + + return { + "enabled": True, + "cache_count": total_count, + "cache_ttl": self.cache_ttl, + "ttl_distribution": ttl_distribution, + } + + def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1): + """动态修改缓存配置 + + Args: + enable_cache: 是否启用缓存 + cache_ttl: 缓存TTL + """ + if enable_cache is not None: + self.enable_cache = enable_cache + logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}") + + if cache_ttl > 0: + self.cache_ttl = cache_ttl + logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}") + +""" +ToolExecutor使用示例: + +# 1. 基础使用 - 从聊天消息执行工具(启用缓存,默认TTL=3) +executor = ToolExecutor(executor_id="my_executor") +results, _, _ = await executor.execute_from_chat_message( + talking_message_str="今天天气怎么样?现在几点了?", + is_group_chat=False +) + +# 2. 禁用缓存的执行器 +no_cache_executor = ToolExecutor(executor_id="no_cache", enable_cache=False) + +# 3. 自定义缓存TTL +long_cache_executor = ToolExecutor(executor_id="long_cache", cache_ttl=10) + +# 4. 获取详细信息 +results, used_tools, prompt = await executor.execute_from_chat_message( + talking_message_str="帮我查询Python相关知识", + is_group_chat=False, + return_details=True +) + +# 5. 直接执行特定工具 +result = await executor.execute_specific_tool( + tool_name="get_knowledge", + tool_args={"query": "机器学习"} +) + +# 6. 缓存管理 +available_tools = executor.get_available_tools() +cache_status = executor.get_cache_status() # 查看缓存状态 +executor.clear_cache() # 清空缓存 +executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置 +""" + + +class ToolUser: + @staticmethod + def _define_tools(): + """获取所有已注册工具的定义 + + Returns: + list: 工具定义列表 + """ + return get_llm_available_tool_definitions() + + @staticmethod + async def execute_tool_call(tool_call): + # sourcery skip: use-assigned-variable + """执行特定的工具调用 + + Args: + tool_call: 工具调用对象 + message_txt: 原始消息文本 + + Returns: + dict: 工具调用结果 + """ + try: + function_name = tool_call["function"]["name"] + function_args = json.loads(tool_call["function"]["arguments"]) + function_args["llm_called"] = True # 标记为LLM调用 + + # 获取对应工具实例 + tool_instance = get_tool_instance(function_name) + if not tool_instance: + logger.warning(f"未知工具名称: {function_name}") + return None + + # 执行工具 + result = await tool_instance.execute(function_args) + if result: + # 直接使用 function_name 作为 tool_type + tool_type = function_name + + return { + "tool_call_id": tool_call["id"], + "role": "tool", + "name": function_name, + "type": tool_type, + "content": result["content"], + } + return None + except Exception as e: + logger.error(f"执行工具调用时发生错误: {str(e)}") + return None + +tool_user = ToolUser() \ No newline at end of file