diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index 8093bc885..2f278036b 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -1,5 +1,4 @@ from typing import List, Tuple, Type -from src.plugin_system.apis import tool_api from src.plugin_system import ( BasePlugin, register_plugin, @@ -58,10 +57,7 @@ class HelloAction(BaseAction): async def execute(self) -> Tuple[bool, str]: """执行问候动作 - 这是核心功能""" # 发送问候消息 - hello_tool = tool_api.get_tool_instance("hello_tool") - greeting_message = await hello_tool.execute({ - "greeting_message": self.action_data.get("greeting_message", "") - }) + greeting_message = self.action_data.get("greeting_message", "") base_message = self.get_config("greeting.message", "嗨!很开心见到你!😊") message = base_message + greeting_message await self.send_text(message) diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index cd13bdbab..f8c71af42 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -51,7 +51,7 @@ from .apis import ( ) -__version__ = "1.0.0" +__version__ = "2.0.0" __all__ = [ # API 模块 diff --git a/src/plugin_system/apis/__init__.py b/src/plugin_system/apis/__init__.py index c9705c451..362c98581 100644 --- a/src/plugin_system/apis/__init__.py +++ b/src/plugin_system/apis/__init__.py @@ -17,6 +17,7 @@ from src.plugin_system.apis import ( person_api, plugin_manage_api, send_api, + tool_api, ) from .logging_api import get_logger from .plugin_register_api import register_plugin @@ -36,4 +37,5 @@ __all__ = [ "send_api", "get_logger", "register_plugin", + "tool_api", ] diff --git a/src/plugin_system/apis/component_manage_api.py b/src/plugin_system/apis/component_manage_api.py index d9ea051d9..1ffa0833e 100644 --- a/src/plugin_system/apis/component_manage_api.py +++ b/src/plugin_system/apis/component_manage_api.py @@ -5,6 +5,7 @@ from src.plugin_system.base.component_types import ( EventHandlerInfo, PluginInfo, ComponentType, + ToolInfo, ) @@ -119,6 +120,21 @@ def get_registered_command_info(command_name: str) -> Optional[CommandInfo]: return component_registry.get_registered_command_info(command_name) +def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]: + """ + 获取指定 Tool 的注册信息。 + + Args: + tool_name (str): Tool 名称。 + + Returns: + ToolInfo: Tool 信息对象,如果 Tool 不存在则返回 None。 + """ + from src.plugin_system.core.component_registry import component_registry + + return component_registry.get_registered_tool_info(tool_name) + + # === EventHandler 特定查询方法 === def get_registered_event_handler_info( event_handler_name: str, @@ -191,6 +207,8 @@ def locally_enable_component(component_name: str, component_type: ComponentType, return global_announcement_manager.enable_specific_chat_action(stream_id, component_name) case ComponentType.COMMAND: return global_announcement_manager.enable_specific_chat_command(stream_id, component_name) + case ComponentType.TOOL: + return global_announcement_manager.enable_specific_chat_tool(stream_id, component_name) case ComponentType.EVENT_HANDLER: return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name) case _: @@ -216,11 +234,14 @@ def locally_disable_component(component_name: str, component_type: ComponentType return global_announcement_manager.disable_specific_chat_action(stream_id, component_name) case ComponentType.COMMAND: return global_announcement_manager.disable_specific_chat_command(stream_id, component_name) + case ComponentType.TOOL: + return global_announcement_manager.disable_specific_chat_tool(stream_id, component_name) case ComponentType.EVENT_HANDLER: return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name) case _: raise ValueError(f"未知 component type: {component_type}") + def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]: """ 获取指定消息流中禁用的组件列表。 @@ -239,7 +260,9 @@ def get_locally_disabled_components(stream_id: str, component_type: ComponentTyp return global_announcement_manager.get_disabled_chat_actions(stream_id) case ComponentType.COMMAND: return global_announcement_manager.get_disabled_chat_commands(stream_id) + case ComponentType.TOOL: + return global_announcement_manager.get_disabled_chat_tools(stream_id) case ComponentType.EVENT_HANDLER: return global_announcement_manager.get_disabled_chat_event_handlers(stream_id) case _: - raise ValueError(f"未知 component type: {component_type}") \ No newline at end of file + raise ValueError(f"未知 component type: {component_type}") diff --git a/src/plugin_system/apis/tool_api.py b/src/plugin_system/apis/tool_api.py index 09fee548e..a6704126d 100644 --- a/src/plugin_system/apis/tool_api.py +++ b/src/plugin_system/apis/tool_api.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Type from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.component_types import ComponentType @@ -6,20 +6,22 @@ from src.common.logger import get_logger logger = get_logger("tool_api") + def get_tool_instance(tool_name: str) -> Optional[BaseTool]: """获取公开工具实例""" from src.plugin_system.core import component_registry - tool_class = component_registry.get_component_class(tool_name, ComponentType.TOOL) - if not tool_class: - return None - - return tool_class() + tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore + return tool_class() if tool_class else None + def get_llm_available_tool_definitions(): - from src.plugin_system.core import component_registry + """获取LLM可用的工具定义列表 + Returns: + List[Tuple[str, Dict[str, Any]]]: 工具定义列表,为[("tool_name", 定义)] + """ + from src.plugin_system.core import component_registry + llm_available_tools = component_registry.get_llm_available_tools() - return [tool_class().get_tool_definition() for tool_class in llm_available_tools.values()] - - + return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()] diff --git a/src/plugin_system/base/base_tool.py b/src/plugin_system/base/base_tool.py index b2f219629..1c757180b 100644 --- a/src/plugin_system/base/base_tool.py +++ b/src/plugin_system/base/base_tool.py @@ -1,24 +1,27 @@ -from typing import List, Any, Optional, Type -from src.common.logger import get_logger +from abc import ABC, abstractmethod +from typing import Any, Dict from rich.traceback import install + +from src.common.logger import get_logger from src.plugin_system.base.component_types import ComponentType, ToolInfo + install(extra_lines=3) logger = get_logger("base_tool") -class BaseTool: +class BaseTool(ABC): """所有工具的基类""" - # 工具名称,子类必须重写 - name = None - # 工具描述,子类必须重写 - description = None - # 工具参数定义,子类必须重写 - parameters = None - # 是否可供LLM使用,默认为False - available_for_llm = False + name: str = "" + """工具的名称""" + description: str = "" + """工具的描述""" + parameters: Dict[str, Any] = {} + """工具的参数定义""" + available_for_llm: bool = False + """是否可供LLM使用""" @classmethod def get_tool_definition(cls) -> dict[str, Any]: @@ -38,18 +41,18 @@ class BaseTool: @classmethod def get_tool_info(cls) -> ToolInfo: """获取工具信息""" - if not cls.name or not cls.description: - raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name 和 description 属性") - + if not cls.name or not cls.description or not cls.parameters: + raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性") + return ToolInfo( name=cls.name, tool_description=cls.description, - available_for_llm=cls.available_for_llm, + enabled=cls.available_for_llm, tool_parameters=cls.parameters, component_type=ComponentType.TOOL, ) - # 工具参数定义,子类必须重写 + @abstractmethod async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: """执行工具函数 diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index 3ecb15a0a..aeeccde5a 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -151,7 +151,6 @@ class ToolInfo(ComponentInfo): """工具组件信息""" tool_parameters: Dict[str, Any] = field(default_factory=dict) # 工具参数定义 - available_for_llm: bool = False # 是否可供LLM使用 tool_description: str = "" # 工具描述 def __post_init__(self): diff --git a/src/plugin_system/core/__init__.py b/src/plugin_system/core/__init__.py index 3eecad418..eb794a30b 100644 --- a/src/plugin_system/core/__init__.py +++ b/src/plugin_system/core/__init__.py @@ -8,12 +8,10 @@ from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.events_manager import events_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager -from src.plugin_system.core.tool_use import tool_user __all__ = [ "plugin_manager", "component_registry", "events_manager", "global_announcement_manager", - "tool_user", ] diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 832739f1d..616e5e463 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -85,7 +85,9 @@ class ComponentRegistry: return True def register_component( - self, component_info: ComponentInfo, component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler]] + self, + component_info: ComponentInfo, + component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler, BaseTool]], ) -> bool: """注册组件 @@ -190,17 +192,17 @@ class ComponentRegistry: return True - def _register_tool_component(self, tool_info: ToolInfo, tool_class: BaseTool): + def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool: """注册Tool组件到Tool特定注册表""" tool_name = tool_info.name self._tool_registry[tool_name] = tool_class - + # 如果是llm可用的且启用的工具,添加到 llm可用工具列表 - if tool_info.available_for_llm and tool_info.enabled: + if tool_info.enabled: self._llm_available_tools[tool_name] = tool_class return True - + def _register_event_handler_component( self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler] ) -> bool: @@ -243,6 +245,9 @@ class ComponentRegistry: keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name] for key in keys_to_remove: self._command_patterns.pop(key) + case ComponentType.TOOL: + self._tool_registry.pop(component_name) + self._llm_available_tools.pop(component_name) case ComponentType.EVENT_HANDLER: from .events_manager import events_manager # 延迟导入防止循环导入问题 @@ -255,13 +260,13 @@ class ComponentRegistry: self._components_classes.pop(namespaced_name) logger.info(f"组件 {component_name} 已移除") return True - except KeyError: - logger.warning(f"移除组件时未找到组件: {component_name}") + except KeyError as e: + logger.warning(f"移除组件时未找到组件: {component_name}, 发生错误: {e}") return False except Exception as e: logger.error(f"移除组件 {component_name} 时发生错误: {e}") return False - + def remove_plugin_registry(self, plugin_name: str) -> bool: """移除插件注册信息 @@ -302,6 +307,10 @@ class ComponentRegistry: assert isinstance(target_component_info, CommandInfo) pattern = target_component_info.command_pattern self._command_patterns[re.compile(pattern)] = component_name + case ComponentType.TOOL: + assert isinstance(target_component_info, ToolInfo) + assert issubclass(target_component_class, BaseTool) + self._llm_available_tools[component_name] = target_component_class case ComponentType.EVENT_HANDLER: assert isinstance(target_component_info, EventHandlerInfo) assert issubclass(target_component_class, BaseEventHandler) @@ -329,20 +338,29 @@ class ComponentRegistry: logger.warning(f"组件 {component_name} 未注册,无法禁用") return False target_component_info.enabled = False - match component_type: - case ComponentType.ACTION: - self._default_actions.pop(component_name, None) - case ComponentType.COMMAND: - self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name} - case ComponentType.EVENT_HANDLER: - self._enabled_event_handlers.pop(component_name, None) - from .events_manager import events_manager # 延迟导入防止循环导入问题 + try: + match component_type: + case ComponentType.ACTION: + self._default_actions.pop(component_name) + case ComponentType.COMMAND: + self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name} + case ComponentType.TOOL: + self._llm_available_tools.pop(component_name) + case ComponentType.EVENT_HANDLER: + self._enabled_event_handlers.pop(component_name) + from .events_manager import events_manager # 延迟导入防止循环导入问题 - await events_manager.unregister_event_subscriber(component_name) - self._components[component_name].enabled = False - self._components_by_type[component_type][component_name].enabled = False - logger.info(f"组件 {component_name} 已禁用") - return True + await events_manager.unregister_event_subscriber(component_name) + self._components[component_name].enabled = False + self._components_by_type[component_type][component_name].enabled = False + logger.info(f"组件 {component_name} 已禁用") + return True + except KeyError as e: + logger.warning(f"禁用组件时未找到组件或已禁用: {component_name}, 发生错误: {e}") + return False + except Exception as e: + logger.error(f"禁用组件 {component_name} 时发生错误: {e}") + return False # === 组件查询方法 === def get_component_info( @@ -392,7 +410,7 @@ class ComponentRegistry: self, component_name: str, component_type: Optional[ComponentType] = None, - ) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler]]]: + ) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler], Type[BaseTool]]]: """获取组件类,支持自动命名空间解析 Args: @@ -496,13 +514,13 @@ class ComponentRegistry: candidates[0].match(text).groupdict(), # type: ignore command_info, ) - + # === Tool 特定查询方法 === def get_tool_registry(self) -> Dict[str, Type[BaseTool]]: """获取Tool注册表""" return self._tool_registry.copy() - - def get_llm_available_tools(self) -> Dict[str, str]: + + def get_llm_available_tools(self) -> Dict[str, Type[BaseTool]]: """获取LLM可用的Tool列表""" return self._llm_available_tools.copy() @@ -517,7 +535,7 @@ class ComponentRegistry: """ info = self.get_component_info(tool_name, ComponentType.TOOL) return info if isinstance(info, ToolInfo) else None - + # === EventHandler 特定查询方法 === def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]: @@ -572,7 +590,7 @@ class ComponentRegistry: action_components: int = 0 command_components: int = 0 tool_components: int = 0 - events_handlers: int = 0 + events_handlers: int = 0 for component in self._components.values(): if component.component_type == ComponentType.ACTION: action_components += 1 diff --git a/src/plugin_system/core/global_announcement_manager.py b/src/plugin_system/core/global_announcement_manager.py index 9f7052f5d..bb6f06b4f 100644 --- a/src/plugin_system/core/global_announcement_manager.py +++ b/src/plugin_system/core/global_announcement_manager.py @@ -13,6 +13,8 @@ class GlobalAnnouncementManager: self._user_disabled_commands: Dict[str, List[str]] = {} # 用户禁用的事件处理器,chat_id -> [handler_name] self._user_disabled_event_handlers: Dict[str, List[str]] = {} + # 用户禁用的工具,chat_id -> [tool_name] + self._user_disabled_tools: Dict[str, List[str]] = {} def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool: """禁用特定聊天的某个动作""" @@ -77,6 +79,27 @@ class GlobalAnnouncementManager: return False return False + def disable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool: + """禁用特定聊天的某个工具""" + if chat_id not in self._user_disabled_tools: + self._user_disabled_tools[chat_id] = [] + if tool_name in self._user_disabled_tools[chat_id]: + logger.warning(f"工具 {tool_name} 已经被禁用") + return False + self._user_disabled_tools[chat_id].append(tool_name) + return True + + def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool: + """启用特定聊天的某个工具""" + if chat_id in self._user_disabled_tools: + try: + self._user_disabled_tools[chat_id].remove(tool_name) + return True + except ValueError: + logger.warning(f"工具 {tool_name} 不在禁用列表中") + return False + return False + def get_disabled_chat_actions(self, chat_id: str) -> List[str]: """获取特定聊天禁用的所有动作""" return self._user_disabled_actions.get(chat_id, []).copy() @@ -88,6 +111,10 @@ class GlobalAnnouncementManager: def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]: """获取特定聊天禁用的所有事件处理器""" return self._user_disabled_event_handlers.get(chat_id, []).copy() + + def get_disabled_chat_tools(self, chat_id: str) -> List[str]: + """获取特定聊天禁用的所有工具""" + return self._user_disabled_tools.get(chat_id, []).copy() global_announcement_manager = GlobalAnnouncementManager() diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index bec600190..d7b86b8d6 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -1,7 +1,8 @@ import json import time -from typing import List, Dict, Tuple, Optional -from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions,get_tool_instance +from typing import List, Dict, Tuple, Optional, Any +from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance +from src.plugin_system.core.global_announcement_manager import global_announcement_manager from src.llm_models.utils_model import LLMRequest from src.config.config import global_config from src.chat.utils.prompt_builder import Prompt, global_prompt_manager @@ -11,6 +12,7 @@ from src.common.logger import get_logger logger = get_logger("tool_use") + def init_tool_executor_prompt(): """初始化工具执行器的提示词""" tool_executor_prompt = """ @@ -27,9 +29,11 @@ If you need to use a tool, please directly call the corresponding tool function. """ Prompt(tool_executor_prompt, "tool_executor_prompt") + # 初始化提示词 init_tool_executor_prompt() + class ToolExecutor: """独立的工具执行器组件 @@ -53,9 +57,6 @@ class ToolExecutor: request_type="tool_executor", ) - # 初始化工具实例 - self.tool_instance = ToolUser() - # 缓存配置 self.enable_cache = enable_cache self.cache_ttl = cache_ttl @@ -75,7 +76,7 @@ class ToolExecutor: return_details: 是否返回详细信息(使用的工具列表和提示词) Returns: - 如果return_details为False: List[Dict] - 工具执行结果列表 + 如果return_details为False: Tuple[List[Dict], List[str], str] - (工具执行结果列表, 空, 空) 如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词) """ @@ -84,15 +85,15 @@ class ToolExecutor: if cached_result := self._get_from_cache(cache_key): logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行") if not return_details: - return cached_result, [], "使用缓存结果" + return cached_result, [], "" # 从缓存结果中提取工具名称 used_tools = [result.get("tool_name", "unknown") for result in cached_result] - return cached_result, used_tools, "使用缓存结果" + return cached_result, used_tools, "" # 缓存未命中,执行工具调用 # 获取可用工具 - tools = self.tool_instance._define_tools() + tools = self._get_tool_definitions() # 获取当前时间 time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) @@ -114,6 +115,7 @@ class ToolExecutor: # 调用LLM进行工具决策 response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools) + # TODO: 在APIADA加入后完全修复这里! # 解析LLM响应 if len(other_info) == 3: reasoning_content, model_name, tool_calls = other_info @@ -135,6 +137,11 @@ class ToolExecutor: return tool_results, used_tools, prompt else: return tool_results, [], "" + + def _get_tool_definitions(self) -> List[Dict[str, Any]]: + all_tools = get_llm_available_tool_definitions() + user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id) + return [parameters for name, parameters in all_tools if name not in user_disabled_tools] async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]: """执行工具调用 @@ -174,7 +181,7 @@ class ToolExecutor: logger.debug(f"{self.log_prefix}执行工具: {tool_name}") # 执行工具 - result = await self.tool_instance.execute_tool_call(tool_call) + result = await self._execute_tool_call(tool_call) if result: tool_info = { @@ -207,6 +214,45 @@ class ToolExecutor: return tool_results, used_tools + async def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Optional[Dict]: + # sourcery skip: use-assigned-variable + """执行单个工具调用 + + Args: + tool_call: 工具调用对象 + + Returns: + Optional[Dict]: 工具调用结果,如果失败则返回None + """ + try: + function_name = tool_call["function"]["name"] + function_args = json.loads(tool_call["function"]["arguments"]) + function_args["llm_called"] = True # 标记为LLM调用 + + # 获取对应工具实例 + tool_instance = get_tool_instance(function_name) + if not tool_instance: + logger.warning(f"未知工具名称: {function_name}") + return None + + # 执行工具 + result = await tool_instance.execute(function_args) + if result: + # 直接使用 function_name 作为 tool_type + tool_type = function_name + + return { + "tool_call_id": tool_call["id"], + "role": "tool", + "name": function_name, + "type": tool_type, + "content": result["content"], + } + return None + except Exception as e: + logger.error(f"执行工具调用时发生错误: {str(e)}") + return None + def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str: """生成缓存键 @@ -274,15 +320,6 @@ class ToolExecutor: if expired_keys: logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存") - def get_available_tools(self) -> List[str]: - """获取可用工具列表 - - Returns: - List[str]: 可用工具名称列表 - """ - tools = self.tool_instance._define_tools() - return [tool.get("function", {}).get("name", "unknown") for tool in tools] - async def execute_specific_tool( self, tool_name: str, tool_args: Dict, validate_args: bool = True ) -> Optional[Dict]: @@ -301,7 +338,7 @@ class ToolExecutor: logger.info(f"{self.log_prefix}直接执行工具: {tool_name}") - result = await self.tool_instance.execute_tool_call(tool_call) + result = await self._execute_tool_call(tool_call) if result: tool_info = { @@ -367,6 +404,7 @@ class ToolExecutor: self.cache_ttl = cache_ttl logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}") + """ ToolExecutor使用示例: @@ -397,62 +435,7 @@ result = await executor.execute_specific_tool( ) # 6. 缓存管理 -available_tools = executor.get_available_tools() cache_status = executor.get_cache_status() # 查看缓存状态 executor.clear_cache() # 清空缓存 executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置 """ - - -class ToolUser: - @staticmethod - def _define_tools(): - """获取所有已注册工具的定义 - - Returns: - list: 工具定义列表 - """ - return get_llm_available_tool_definitions() - - @staticmethod - async def execute_tool_call(tool_call): - # sourcery skip: use-assigned-variable - """执行特定的工具调用 - - Args: - tool_call: 工具调用对象 - message_txt: 原始消息文本 - - Returns: - dict: 工具调用结果 - """ - try: - function_name = tool_call["function"]["name"] - function_args = json.loads(tool_call["function"]["arguments"]) - function_args["llm_called"] = True # 标记为LLM调用 - - # 获取对应工具实例 - tool_instance = get_tool_instance(function_name) - if not tool_instance: - logger.warning(f"未知工具名称: {function_name}") - return None - - # 执行工具 - result = await tool_instance.execute(function_args) - if result: - # 直接使用 function_name 作为 tool_type - tool_type = function_name - - return { - "tool_call_id": tool_call["id"], - "role": "tool", - "name": function_name, - "type": tool_type, - "content": result["content"], - } - return None - except Exception as e: - logger.error(f"执行工具调用时发生错误: {str(e)}") - return None - -tool_user = ToolUser() \ No newline at end of file diff --git a/src/tools/tool_executor.py b/src/tools/tool_executor.py deleted file mode 100644 index 0f50ca2ab..000000000 --- a/src/tools/tool_executor.py +++ /dev/null @@ -1,407 +0,0 @@ -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config -import time -from src.common.logger import get_logger -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.tools.tool_use import ToolUser -from src.chat.utils.json_utils import process_llm_tool_calls -from typing import List, Dict, Tuple, Optional -from src.chat.message_receive.chat_stream import get_chat_manager - -logger = get_logger("tool_executor") - - -def init_tool_executor_prompt(): - """初始化工具执行器的提示词""" - tool_executor_prompt = """ -你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now}。 -群里正在进行的聊天内容: -{chat_history} - -现在,{sender}发送了内容:{target_message},你想要回复ta。 -请仔细分析聊天内容,考虑以下几点: -1. 内容中是否包含需要查询信息的问题 -2. 是否有明确的工具使用指令 - -If you need to use a tool, please directly call the corresponding tool function. If you do not need to use any tool, simply output "No tool needed". -""" - Prompt(tool_executor_prompt, "tool_executor_prompt") - - -class ToolExecutor: - """独立的工具执行器组件 - - 可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。 - """ - - def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3): - """初始化工具执行器 - - Args: - executor_id: 执行器标识符,用于日志记录 - enable_cache: 是否启用缓存机制 - cache_ttl: 缓存生存时间(周期数) - """ - self.chat_id = chat_id - self.chat_stream = get_chat_manager().get_stream(self.chat_id) - self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]" - - self.llm_model = LLMRequest( - model=global_config.model.tool_use, - request_type="tool_executor", - ) - - # 初始化工具实例 - self.tool_instance = ToolUser() - - # 缓存配置 - self.enable_cache = enable_cache - self.cache_ttl = cache_ttl - self.tool_cache = {} # 格式: {cache_key: {"result": result, "ttl": ttl, "timestamp": timestamp}} - - logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'},TTL={cache_ttl}") - - async def execute_from_chat_message( - self, target_message: str, chat_history: str, sender: str, return_details: bool = False - ) -> Tuple[List[Dict], List[str], str]: - """从聊天消息执行工具 - - Args: - target_message: 目标消息内容 - chat_history: 聊天历史 - sender: 发送者 - return_details: 是否返回详细信息(使用的工具列表和提示词) - - Returns: - 如果return_details为False: List[Dict] - 工具执行结果列表 - 如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词) - """ - - # 首先检查缓存 - cache_key = self._generate_cache_key(target_message, chat_history, sender) - if cached_result := self._get_from_cache(cache_key): - logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行") - if not return_details: - return cached_result, [], "使用缓存结果" - - # 从缓存结果中提取工具名称 - used_tools = [result.get("tool_name", "unknown") for result in cached_result] - return cached_result, used_tools, "使用缓存结果" - - # 缓存未命中,执行工具调用 - # 获取可用工具 - tools = self.tool_instance._define_tools() - - # 获取当前时间 - time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - - bot_name = global_config.bot.nickname - - # 构建工具调用提示词 - prompt = await global_prompt_manager.format_prompt( - "tool_executor_prompt", - target_message=target_message, - chat_history=chat_history, - sender=sender, - bot_name=bot_name, - time_now=time_now, - ) - - logger.debug(f"{self.log_prefix}开始LLM工具调用分析") - - # 调用LLM进行工具决策 - response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools) - - # 解析LLM响应 - if len(other_info) == 3: - reasoning_content, model_name, tool_calls = other_info - else: - reasoning_content, model_name = other_info - tool_calls = None - - # 执行工具调用 - tool_results, used_tools = await self._execute_tool_calls(tool_calls) - - # 缓存结果 - if tool_results: - self._set_cache(cache_key, tool_results) - - if used_tools: - logger.info(f"{self.log_prefix}工具执行完成,共执行{len(used_tools)}个工具: {used_tools}") - - if return_details: - return tool_results, used_tools, prompt - else: - return tool_results, [], "" - - async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]: - """执行工具调用 - - Args: - tool_calls: LLM返回的工具调用列表 - - Returns: - Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表) - """ - tool_results = [] - used_tools = [] - - if not tool_calls: - logger.debug(f"{self.log_prefix}无需执行工具") - return tool_results, used_tools - - logger.info(f"{self.log_prefix}开始执行工具调用: {tool_calls}") - - # 处理工具调用 - success, valid_tool_calls, error_msg = process_llm_tool_calls(tool_calls) - - if not success: - logger.error(f"{self.log_prefix}工具调用解析失败: {error_msg}") - return tool_results, used_tools - - if not valid_tool_calls: - logger.debug(f"{self.log_prefix}无有效工具调用") - return tool_results, used_tools - - # 执行每个工具调用 - for tool_call in valid_tool_calls: - try: - tool_name = tool_call.get("name", "unknown_tool") - used_tools.append(tool_name) - - logger.debug(f"{self.log_prefix}执行工具: {tool_name}") - - # 执行工具 - result = await self.tool_instance.execute_tool_call(tool_call) - - if result: - tool_info = { - "type": result.get("type", "unknown_type"), - "id": result.get("id", f"tool_exec_{time.time()}"), - "content": result.get("content", ""), - "tool_name": tool_name, - "timestamp": time.time(), - } - tool_results.append(tool_info) - - logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}") - content = tool_info["content"] - if not isinstance(content, (str, list, tuple)): - content = str(content) - preview = content[:200] - logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...") - - except Exception as e: - logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}") - # 添加错误信息到结果中 - error_info = { - "type": "tool_error", - "id": f"tool_error_{time.time()}", - "content": f"工具{tool_name}执行失败: {str(e)}", - "tool_name": tool_name, - "timestamp": time.time(), - } - tool_results.append(error_info) - - return tool_results, used_tools - - def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str: - """生成缓存键 - - Args: - target_message: 目标消息内容 - chat_history: 聊天历史 - sender: 发送者 - - Returns: - str: 缓存键 - """ - import hashlib - - # 使用消息内容和群聊状态生成唯一缓存键 - content = f"{target_message}_{chat_history}_{sender}" - return hashlib.md5(content.encode()).hexdigest() - - def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]: - """从缓存获取结果 - - Args: - cache_key: 缓存键 - - Returns: - Optional[List[Dict]]: 缓存的结果,如果不存在或过期则返回None - """ - if not self.enable_cache or cache_key not in self.tool_cache: - return None - - cache_item = self.tool_cache[cache_key] - if cache_item["ttl"] <= 0: - # 缓存过期,删除 - del self.tool_cache[cache_key] - logger.debug(f"{self.log_prefix}缓存过期,删除缓存键: {cache_key}") - return None - - # 减少TTL - cache_item["ttl"] -= 1 - logger.debug(f"{self.log_prefix}使用缓存结果,剩余TTL: {cache_item['ttl']}") - return cache_item["result"] - - def _set_cache(self, cache_key: str, result: List[Dict]): - """设置缓存 - - Args: - cache_key: 缓存键 - result: 要缓存的结果 - """ - if not self.enable_cache: - return - - self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()} - logger.debug(f"{self.log_prefix}设置缓存,TTL: {self.cache_ttl}") - - def _cleanup_expired_cache(self): - """清理过期的缓存""" - if not self.enable_cache: - return - - expired_keys = [] - expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0) - for key in expired_keys: - del self.tool_cache[key] - - if expired_keys: - logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存") - - def get_available_tools(self) -> List[str]: - """获取可用工具列表 - - Returns: - List[str]: 可用工具名称列表 - """ - tools = self.tool_instance._define_tools() - return [tool.get("function", {}).get("name", "unknown") for tool in tools] - - async def execute_specific_tool( - self, tool_name: str, tool_args: Dict, validate_args: bool = True - ) -> Optional[Dict]: - """直接执行指定工具 - - Args: - tool_name: 工具名称 - tool_args: 工具参数 - validate_args: 是否验证参数 - - Returns: - Optional[Dict]: 工具执行结果,失败时返回None - """ - try: - tool_call = {"name": tool_name, "arguments": tool_args} - - logger.info(f"{self.log_prefix}直接执行工具: {tool_name}") - - result = await self.tool_instance.execute_tool_call(tool_call) - - if result: - tool_info = { - "type": result.get("type", "unknown_type"), - "id": result.get("id", f"direct_tool_{time.time()}"), - "content": result.get("content", ""), - "tool_name": tool_name, - "timestamp": time.time(), - } - logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}") - return tool_info - - except Exception as e: - logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}") - - return None - - def clear_cache(self): - """清空所有缓存""" - if self.enable_cache: - cache_count = len(self.tool_cache) - self.tool_cache.clear() - logger.info(f"{self.log_prefix}清空了{cache_count}个缓存项") - - def get_cache_status(self) -> Dict: - """获取缓存状态信息 - - Returns: - Dict: 包含缓存统计信息的字典 - """ - if not self.enable_cache: - return {"enabled": False, "cache_count": 0} - - # 清理过期缓存 - self._cleanup_expired_cache() - - total_count = len(self.tool_cache) - ttl_distribution = {} - - for cache_item in self.tool_cache.values(): - ttl = cache_item["ttl"] - ttl_distribution[ttl] = ttl_distribution.get(ttl, 0) + 1 - - return { - "enabled": True, - "cache_count": total_count, - "cache_ttl": self.cache_ttl, - "ttl_distribution": ttl_distribution, - } - - def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1): - """动态修改缓存配置 - - Args: - enable_cache: 是否启用缓存 - cache_ttl: 缓存TTL - """ - if enable_cache is not None: - self.enable_cache = enable_cache - logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}") - - if cache_ttl > 0: - self.cache_ttl = cache_ttl - logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}") - - -# 初始化提示词 -init_tool_executor_prompt() - - -""" -使用示例: - -# 1. 基础使用 - 从聊天消息执行工具(启用缓存,默认TTL=3) -executor = ToolExecutor(executor_id="my_executor") -results, _, _ = await executor.execute_from_chat_message( - talking_message_str="今天天气怎么样?现在几点了?", - is_group_chat=False -) - -# 2. 禁用缓存的执行器 -no_cache_executor = ToolExecutor(executor_id="no_cache", enable_cache=False) - -# 3. 自定义缓存TTL -long_cache_executor = ToolExecutor(executor_id="long_cache", cache_ttl=10) - -# 4. 获取详细信息 -results, used_tools, prompt = await executor.execute_from_chat_message( - talking_message_str="帮我查询Python相关知识", - is_group_chat=False, - return_details=True -) - -# 5. 直接执行特定工具 -result = await executor.execute_specific_tool( - tool_name="get_knowledge", - tool_args={"query": "机器学习"} -) - -# 6. 缓存管理 -available_tools = executor.get_available_tools() -cache_status = executor.get_cache_status() # 查看缓存状态 -executor.clear_cache() # 清空缓存 -executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置 -""" diff --git a/src/tools/tool_use.py b/src/tools/tool_use.py deleted file mode 100644 index 6a8cd48a6..000000000 --- a/src/tools/tool_use.py +++ /dev/null @@ -1,56 +0,0 @@ -import json -from src.common.logger import get_logger -from src.tools.tool_can_use import get_all_tool_definitions, get_tool_instance - -logger = get_logger("tool_use") - - -class ToolUser: - @staticmethod - def _define_tools(): - """获取所有已注册工具的定义 - - Returns: - list: 工具定义列表 - """ - return get_all_tool_definitions() - - @staticmethod - async def execute_tool_call(tool_call): - # sourcery skip: use-assigned-variable - """执行特定的工具调用 - - Args: - tool_call: 工具调用对象 - message_txt: 原始消息文本 - - Returns: - dict: 工具调用结果 - """ - try: - function_name = tool_call["function"]["name"] - function_args = json.loads(tool_call["function"]["arguments"]) - - # 获取对应工具实例 - tool_instance = get_tool_instance(function_name) - if not tool_instance: - logger.warning(f"未知工具名称: {function_name}") - return None - - # 执行工具 - result = await tool_instance.execute(function_args) - if result: - # 直接使用 function_name 作为 tool_type - tool_type = function_name - - return { - "tool_call_id": tool_call["id"], - "role": "tool", - "name": function_name, - "type": tool_type, - "content": result["content"], - } - return None - except Exception as e: - logger.error(f"执行工具调用时发生错误: {str(e)}") - return None