tools整合彻底完成
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from typing import List, Tuple, Type
|
||||
from src.plugin_system.apis import tool_api
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
register_plugin,
|
||||
@@ -58,10 +57,7 @@ class HelloAction(BaseAction):
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行问候动作 - 这是核心功能"""
|
||||
# 发送问候消息
|
||||
hello_tool = tool_api.get_tool_instance("hello_tool")
|
||||
greeting_message = await hello_tool.execute({
|
||||
"greeting_message": self.action_data.get("greeting_message", "")
|
||||
})
|
||||
greeting_message = self.action_data.get("greeting_message", "")
|
||||
base_message = self.get_config("greeting.message", "嗨!很开心见到你!😊")
|
||||
message = base_message + greeting_message
|
||||
await self.send_text(message)
|
||||
|
||||
@@ -51,7 +51,7 @@ from .apis import (
|
||||
)
|
||||
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__version__ = "2.0.0"
|
||||
|
||||
__all__ = [
|
||||
# API 模块
|
||||
|
||||
@@ -17,6 +17,7 @@ from src.plugin_system.apis import (
|
||||
person_api,
|
||||
plugin_manage_api,
|
||||
send_api,
|
||||
tool_api,
|
||||
)
|
||||
from .logging_api import get_logger
|
||||
from .plugin_register_api import register_plugin
|
||||
@@ -36,4 +37,5 @@ __all__ = [
|
||||
"send_api",
|
||||
"get_logger",
|
||||
"register_plugin",
|
||||
"tool_api",
|
||||
]
|
||||
|
||||
@@ -5,6 +5,7 @@ from src.plugin_system.base.component_types import (
|
||||
EventHandlerInfo,
|
||||
PluginInfo,
|
||||
ComponentType,
|
||||
ToolInfo,
|
||||
)
|
||||
|
||||
|
||||
@@ -119,6 +120,21 @@ def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
|
||||
return component_registry.get_registered_command_info(command_name)
|
||||
|
||||
|
||||
def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
|
||||
"""
|
||||
获取指定 Tool 的注册信息。
|
||||
|
||||
Args:
|
||||
tool_name (str): Tool 名称。
|
||||
|
||||
Returns:
|
||||
ToolInfo: Tool 信息对象,如果 Tool 不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_registered_tool_info(tool_name)
|
||||
|
||||
|
||||
# === EventHandler 特定查询方法 ===
|
||||
def get_registered_event_handler_info(
|
||||
event_handler_name: str,
|
||||
@@ -191,6 +207,8 @@ def locally_enable_component(component_name: str, component_type: ComponentType,
|
||||
return global_announcement_manager.enable_specific_chat_action(stream_id, component_name)
|
||||
case ComponentType.COMMAND:
|
||||
return global_announcement_manager.enable_specific_chat_command(stream_id, component_name)
|
||||
case ComponentType.TOOL:
|
||||
return global_announcement_manager.enable_specific_chat_tool(stream_id, component_name)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name)
|
||||
case _:
|
||||
@@ -216,11 +234,14 @@ def locally_disable_component(component_name: str, component_type: ComponentType
|
||||
return global_announcement_manager.disable_specific_chat_action(stream_id, component_name)
|
||||
case ComponentType.COMMAND:
|
||||
return global_announcement_manager.disable_specific_chat_command(stream_id, component_name)
|
||||
case ComponentType.TOOL:
|
||||
return global_announcement_manager.disable_specific_chat_tool(stream_id, component_name)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name)
|
||||
case _:
|
||||
raise ValueError(f"未知 component type: {component_type}")
|
||||
|
||||
|
||||
def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]:
|
||||
"""
|
||||
获取指定消息流中禁用的组件列表。
|
||||
@@ -239,6 +260,8 @@ def get_locally_disabled_components(stream_id: str, component_type: ComponentTyp
|
||||
return global_announcement_manager.get_disabled_chat_actions(stream_id)
|
||||
case ComponentType.COMMAND:
|
||||
return global_announcement_manager.get_disabled_chat_commands(stream_id)
|
||||
case ComponentType.TOOL:
|
||||
return global_announcement_manager.get_disabled_chat_tools(stream_id)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
return global_announcement_manager.get_disabled_chat_event_handlers(stream_id)
|
||||
case _:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Type
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
@@ -6,20 +6,22 @@ from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("tool_api")
|
||||
|
||||
|
||||
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||
"""获取公开工具实例"""
|
||||
from src.plugin_system.core import component_registry
|
||||
|
||||
tool_class = component_registry.get_component_class(tool_name, ComponentType.TOOL)
|
||||
if not tool_class:
|
||||
return None
|
||||
tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
|
||||
return tool_class() if tool_class else None
|
||||
|
||||
return tool_class()
|
||||
|
||||
def get_llm_available_tool_definitions():
|
||||
"""获取LLM可用的工具定义列表
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, Dict[str, Any]]]: 工具定义列表,为[("tool_name", 定义)]
|
||||
"""
|
||||
from src.plugin_system.core import component_registry
|
||||
|
||||
llm_available_tools = component_registry.get_llm_available_tools()
|
||||
return [tool_class().get_tool_definition() for tool_class in llm_available_tools.values()]
|
||||
|
||||
|
||||
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]
|
||||
|
||||
@@ -1,24 +1,27 @@
|
||||
from typing import List, Any, Optional, Type
|
||||
from src.common.logger import get_logger
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ComponentType, ToolInfo
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("base_tool")
|
||||
|
||||
|
||||
|
||||
class BaseTool:
|
||||
class BaseTool(ABC):
|
||||
"""所有工具的基类"""
|
||||
|
||||
# 工具名称,子类必须重写
|
||||
name = None
|
||||
# 工具描述,子类必须重写
|
||||
description = None
|
||||
# 工具参数定义,子类必须重写
|
||||
parameters = None
|
||||
# 是否可供LLM使用,默认为False
|
||||
available_for_llm = False
|
||||
name: str = ""
|
||||
"""工具的名称"""
|
||||
description: str = ""
|
||||
"""工具的描述"""
|
||||
parameters: Dict[str, Any] = {}
|
||||
"""工具的参数定义"""
|
||||
available_for_llm: bool = False
|
||||
"""是否可供LLM使用"""
|
||||
|
||||
@classmethod
|
||||
def get_tool_definition(cls) -> dict[str, Any]:
|
||||
@@ -38,18 +41,18 @@ class BaseTool:
|
||||
@classmethod
|
||||
def get_tool_info(cls) -> ToolInfo:
|
||||
"""获取工具信息"""
|
||||
if not cls.name or not cls.description:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name 和 description 属性")
|
||||
if not cls.name or not cls.description or not cls.parameters:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||
|
||||
return ToolInfo(
|
||||
name=cls.name,
|
||||
tool_description=cls.description,
|
||||
available_for_llm=cls.available_for_llm,
|
||||
enabled=cls.available_for_llm,
|
||||
tool_parameters=cls.parameters,
|
||||
component_type=ComponentType.TOOL,
|
||||
)
|
||||
|
||||
# 工具参数定义,子类必须重写
|
||||
@abstractmethod
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行工具函数
|
||||
|
||||
|
||||
@@ -151,7 +151,6 @@ class ToolInfo(ComponentInfo):
|
||||
"""工具组件信息"""
|
||||
|
||||
tool_parameters: Dict[str, Any] = field(default_factory=dict) # 工具参数定义
|
||||
available_for_llm: bool = False # 是否可供LLM使用
|
||||
tool_description: str = "" # 工具描述
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
@@ -8,12 +8,10 @@ from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.plugin_system.core.tool_use import tool_user
|
||||
|
||||
__all__ = [
|
||||
"plugin_manager",
|
||||
"component_registry",
|
||||
"events_manager",
|
||||
"global_announcement_manager",
|
||||
"tool_user",
|
||||
]
|
||||
|
||||
@@ -85,7 +85,9 @@ class ComponentRegistry:
|
||||
return True
|
||||
|
||||
def register_component(
|
||||
self, component_info: ComponentInfo, component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler]]
|
||||
self,
|
||||
component_info: ComponentInfo,
|
||||
component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler, BaseTool]],
|
||||
) -> bool:
|
||||
"""注册组件
|
||||
|
||||
@@ -190,13 +192,13 @@ class ComponentRegistry:
|
||||
|
||||
return True
|
||||
|
||||
def _register_tool_component(self, tool_info: ToolInfo, tool_class: BaseTool):
|
||||
def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool:
|
||||
"""注册Tool组件到Tool特定注册表"""
|
||||
tool_name = tool_info.name
|
||||
self._tool_registry[tool_name] = tool_class
|
||||
|
||||
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
|
||||
if tool_info.available_for_llm and tool_info.enabled:
|
||||
if tool_info.enabled:
|
||||
self._llm_available_tools[tool_name] = tool_class
|
||||
|
||||
return True
|
||||
@@ -243,6 +245,9 @@ class ComponentRegistry:
|
||||
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
|
||||
for key in keys_to_remove:
|
||||
self._command_patterns.pop(key)
|
||||
case ComponentType.TOOL:
|
||||
self._tool_registry.pop(component_name)
|
||||
self._llm_available_tools.pop(component_name)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
@@ -255,8 +260,8 @@ class ComponentRegistry:
|
||||
self._components_classes.pop(namespaced_name)
|
||||
logger.info(f"组件 {component_name} 已移除")
|
||||
return True
|
||||
except KeyError:
|
||||
logger.warning(f"移除组件时未找到组件: {component_name}")
|
||||
except KeyError as e:
|
||||
logger.warning(f"移除组件时未找到组件: {component_name}, 发生错误: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"移除组件 {component_name} 时发生错误: {e}")
|
||||
@@ -302,6 +307,10 @@ class ComponentRegistry:
|
||||
assert isinstance(target_component_info, CommandInfo)
|
||||
pattern = target_component_info.command_pattern
|
||||
self._command_patterns[re.compile(pattern)] = component_name
|
||||
case ComponentType.TOOL:
|
||||
assert isinstance(target_component_info, ToolInfo)
|
||||
assert issubclass(target_component_class, BaseTool)
|
||||
self._llm_available_tools[component_name] = target_component_class
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
assert isinstance(target_component_info, EventHandlerInfo)
|
||||
assert issubclass(target_component_class, BaseEventHandler)
|
||||
@@ -329,13 +338,16 @@ class ComponentRegistry:
|
||||
logger.warning(f"组件 {component_name} 未注册,无法禁用")
|
||||
return False
|
||||
target_component_info.enabled = False
|
||||
try:
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
self._default_actions.pop(component_name, None)
|
||||
self._default_actions.pop(component_name)
|
||||
case ComponentType.COMMAND:
|
||||
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
|
||||
case ComponentType.TOOL:
|
||||
self._llm_available_tools.pop(component_name)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
self._enabled_event_handlers.pop(component_name, None)
|
||||
self._enabled_event_handlers.pop(component_name)
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
await events_manager.unregister_event_subscriber(component_name)
|
||||
@@ -343,6 +355,12 @@ class ComponentRegistry:
|
||||
self._components_by_type[component_type][component_name].enabled = False
|
||||
logger.info(f"组件 {component_name} 已禁用")
|
||||
return True
|
||||
except KeyError as e:
|
||||
logger.warning(f"禁用组件时未找到组件或已禁用: {component_name}, 发生错误: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"禁用组件 {component_name} 时发生错误: {e}")
|
||||
return False
|
||||
|
||||
# === 组件查询方法 ===
|
||||
def get_component_info(
|
||||
@@ -392,7 +410,7 @@ class ComponentRegistry:
|
||||
self,
|
||||
component_name: str,
|
||||
component_type: Optional[ComponentType] = None,
|
||||
) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler]]]:
|
||||
) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler], Type[BaseTool]]]:
|
||||
"""获取组件类,支持自动命名空间解析
|
||||
|
||||
Args:
|
||||
@@ -502,7 +520,7 @@ class ComponentRegistry:
|
||||
"""获取Tool注册表"""
|
||||
return self._tool_registry.copy()
|
||||
|
||||
def get_llm_available_tools(self) -> Dict[str, str]:
|
||||
def get_llm_available_tools(self) -> Dict[str, Type[BaseTool]]:
|
||||
"""获取LLM可用的Tool列表"""
|
||||
return self._llm_available_tools.copy()
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@ class GlobalAnnouncementManager:
|
||||
self._user_disabled_commands: Dict[str, List[str]] = {}
|
||||
# 用户禁用的事件处理器,chat_id -> [handler_name]
|
||||
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
|
||||
# 用户禁用的工具,chat_id -> [tool_name]
|
||||
self._user_disabled_tools: Dict[str, List[str]] = {}
|
||||
|
||||
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
|
||||
"""禁用特定聊天的某个动作"""
|
||||
@@ -77,6 +79,27 @@ class GlobalAnnouncementManager:
|
||||
return False
|
||||
return False
|
||||
|
||||
def disable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
|
||||
"""禁用特定聊天的某个工具"""
|
||||
if chat_id not in self._user_disabled_tools:
|
||||
self._user_disabled_tools[chat_id] = []
|
||||
if tool_name in self._user_disabled_tools[chat_id]:
|
||||
logger.warning(f"工具 {tool_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_tools[chat_id].append(tool_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
|
||||
"""启用特定聊天的某个工具"""
|
||||
if chat_id in self._user_disabled_tools:
|
||||
try:
|
||||
self._user_disabled_tools[chat_id].remove(tool_name)
|
||||
return True
|
||||
except ValueError:
|
||||
logger.warning(f"工具 {tool_name} 不在禁用列表中")
|
||||
return False
|
||||
return False
|
||||
|
||||
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有动作"""
|
||||
return self._user_disabled_actions.get(chat_id, []).copy()
|
||||
@@ -89,5 +112,9 @@ class GlobalAnnouncementManager:
|
||||
"""获取特定聊天禁用的所有事件处理器"""
|
||||
return self._user_disabled_event_handlers.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有工具"""
|
||||
return self._user_disabled_tools.get(chat_id, []).copy()
|
||||
|
||||
|
||||
global_announcement_manager = GlobalAnnouncementManager()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from typing import List, Dict, Tuple, Optional, Any
|
||||
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
@@ -11,6 +12,7 @@ from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("tool_use")
|
||||
|
||||
|
||||
def init_tool_executor_prompt():
|
||||
"""初始化工具执行器的提示词"""
|
||||
tool_executor_prompt = """
|
||||
@@ -27,9 +29,11 @@ If you need to use a tool, please directly call the corresponding tool function.
|
||||
"""
|
||||
Prompt(tool_executor_prompt, "tool_executor_prompt")
|
||||
|
||||
|
||||
# 初始化提示词
|
||||
init_tool_executor_prompt()
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""独立的工具执行器组件
|
||||
|
||||
@@ -53,9 +57,6 @@ class ToolExecutor:
|
||||
request_type="tool_executor",
|
||||
)
|
||||
|
||||
# 初始化工具实例
|
||||
self.tool_instance = ToolUser()
|
||||
|
||||
# 缓存配置
|
||||
self.enable_cache = enable_cache
|
||||
self.cache_ttl = cache_ttl
|
||||
@@ -75,7 +76,7 @@ class ToolExecutor:
|
||||
return_details: 是否返回详细信息(使用的工具列表和提示词)
|
||||
|
||||
Returns:
|
||||
如果return_details为False: List[Dict] - 工具执行结果列表
|
||||
如果return_details为False: Tuple[List[Dict], List[str], str] - (工具执行结果列表, 空, 空)
|
||||
如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词)
|
||||
"""
|
||||
|
||||
@@ -84,15 +85,15 @@ class ToolExecutor:
|
||||
if cached_result := self._get_from_cache(cache_key):
|
||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
|
||||
if not return_details:
|
||||
return cached_result, [], "使用缓存结果"
|
||||
return cached_result, [], ""
|
||||
|
||||
# 从缓存结果中提取工具名称
|
||||
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
|
||||
return cached_result, used_tools, "使用缓存结果"
|
||||
return cached_result, used_tools, ""
|
||||
|
||||
# 缓存未命中,执行工具调用
|
||||
# 获取可用工具
|
||||
tools = self.tool_instance._define_tools()
|
||||
tools = self._get_tool_definitions()
|
||||
|
||||
# 获取当前时间
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
@@ -114,6 +115,7 @@ class ToolExecutor:
|
||||
# 调用LLM进行工具决策
|
||||
response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools)
|
||||
|
||||
# TODO: 在APIADA加入后完全修复这里!
|
||||
# 解析LLM响应
|
||||
if len(other_info) == 3:
|
||||
reasoning_content, model_name, tool_calls = other_info
|
||||
@@ -136,6 +138,11 @@ class ToolExecutor:
|
||||
else:
|
||||
return tool_results, [], ""
|
||||
|
||||
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
all_tools = get_llm_available_tool_definitions()
|
||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||
return [parameters for name, parameters in all_tools if name not in user_disabled_tools]
|
||||
|
||||
async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]:
|
||||
"""执行工具调用
|
||||
|
||||
@@ -174,7 +181,7 @@ class ToolExecutor:
|
||||
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
|
||||
|
||||
# 执行工具
|
||||
result = await self.tool_instance.execute_tool_call(tool_call)
|
||||
result = await self._execute_tool_call(tool_call)
|
||||
|
||||
if result:
|
||||
tool_info = {
|
||||
@@ -207,6 +214,45 @@ class ToolExecutor:
|
||||
|
||||
return tool_results, used_tools
|
||||
|
||||
async def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Optional[Dict]:
|
||||
# sourcery skip: use-assigned-variable
|
||||
"""执行单个工具调用
|
||||
|
||||
Args:
|
||||
tool_call: 工具调用对象
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 工具调用结果,如果失败则返回None
|
||||
"""
|
||||
try:
|
||||
function_name = tool_call["function"]["name"]
|
||||
function_args = json.loads(tool_call["function"]["arguments"])
|
||||
function_args["llm_called"] = True # 标记为LLM调用
|
||||
|
||||
# 获取对应工具实例
|
||||
tool_instance = get_tool_instance(function_name)
|
||||
if not tool_instance:
|
||||
logger.warning(f"未知工具名称: {function_name}")
|
||||
return None
|
||||
|
||||
# 执行工具
|
||||
result = await tool_instance.execute(function_args)
|
||||
if result:
|
||||
# 直接使用 function_name 作为 tool_type
|
||||
tool_type = function_name
|
||||
|
||||
return {
|
||||
"tool_call_id": tool_call["id"],
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"type": tool_type,
|
||||
"content": result["content"],
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
|
||||
"""生成缓存键
|
||||
|
||||
@@ -274,15 +320,6 @@ class ToolExecutor:
|
||||
if expired_keys:
|
||||
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
|
||||
|
||||
def get_available_tools(self) -> List[str]:
|
||||
"""获取可用工具列表
|
||||
|
||||
Returns:
|
||||
List[str]: 可用工具名称列表
|
||||
"""
|
||||
tools = self.tool_instance._define_tools()
|
||||
return [tool.get("function", {}).get("name", "unknown") for tool in tools]
|
||||
|
||||
async def execute_specific_tool(
|
||||
self, tool_name: str, tool_args: Dict, validate_args: bool = True
|
||||
) -> Optional[Dict]:
|
||||
@@ -301,7 +338,7 @@ class ToolExecutor:
|
||||
|
||||
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
|
||||
|
||||
result = await self.tool_instance.execute_tool_call(tool_call)
|
||||
result = await self._execute_tool_call(tool_call)
|
||||
|
||||
if result:
|
||||
tool_info = {
|
||||
@@ -367,6 +404,7 @@ class ToolExecutor:
|
||||
self.cache_ttl = cache_ttl
|
||||
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
|
||||
|
||||
|
||||
"""
|
||||
ToolExecutor使用示例:
|
||||
|
||||
@@ -397,62 +435,7 @@ result = await executor.execute_specific_tool(
|
||||
)
|
||||
|
||||
# 6. 缓存管理
|
||||
available_tools = executor.get_available_tools()
|
||||
cache_status = executor.get_cache_status() # 查看缓存状态
|
||||
executor.clear_cache() # 清空缓存
|
||||
executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置
|
||||
"""
|
||||
|
||||
|
||||
class ToolUser:
|
||||
@staticmethod
|
||||
def _define_tools():
|
||||
"""获取所有已注册工具的定义
|
||||
|
||||
Returns:
|
||||
list: 工具定义列表
|
||||
"""
|
||||
return get_llm_available_tool_definitions()
|
||||
|
||||
@staticmethod
|
||||
async def execute_tool_call(tool_call):
|
||||
# sourcery skip: use-assigned-variable
|
||||
"""执行特定的工具调用
|
||||
|
||||
Args:
|
||||
tool_call: 工具调用对象
|
||||
message_txt: 原始消息文本
|
||||
|
||||
Returns:
|
||||
dict: 工具调用结果
|
||||
"""
|
||||
try:
|
||||
function_name = tool_call["function"]["name"]
|
||||
function_args = json.loads(tool_call["function"]["arguments"])
|
||||
function_args["llm_called"] = True # 标记为LLM调用
|
||||
|
||||
# 获取对应工具实例
|
||||
tool_instance = get_tool_instance(function_name)
|
||||
if not tool_instance:
|
||||
logger.warning(f"未知工具名称: {function_name}")
|
||||
return None
|
||||
|
||||
# 执行工具
|
||||
result = await tool_instance.execute(function_args)
|
||||
if result:
|
||||
# 直接使用 function_name 作为 tool_type
|
||||
tool_type = function_name
|
||||
|
||||
return {
|
||||
"tool_call_id": tool_call["id"],
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"type": tool_type,
|
||||
"content": result["content"],
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
tool_user = ToolUser()
|
||||
@@ -1,407 +0,0 @@
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.tools.tool_use import ToolUser
|
||||
from src.chat.utils.json_utils import process_llm_tool_calls
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
logger = get_logger("tool_executor")
|
||||
|
||||
|
||||
def init_tool_executor_prompt():
|
||||
"""初始化工具执行器的提示词"""
|
||||
tool_executor_prompt = """
|
||||
你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now}。
|
||||
群里正在进行的聊天内容:
|
||||
{chat_history}
|
||||
|
||||
现在,{sender}发送了内容:{target_message},你想要回复ta。
|
||||
请仔细分析聊天内容,考虑以下几点:
|
||||
1. 内容中是否包含需要查询信息的问题
|
||||
2. 是否有明确的工具使用指令
|
||||
|
||||
If you need to use a tool, please directly call the corresponding tool function. If you do not need to use any tool, simply output "No tool needed".
|
||||
"""
|
||||
Prompt(tool_executor_prompt, "tool_executor_prompt")
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""独立的工具执行器组件
|
||||
|
||||
可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3):
|
||||
"""初始化工具执行器
|
||||
|
||||
Args:
|
||||
executor_id: 执行器标识符,用于日志记录
|
||||
enable_cache: 是否启用缓存机制
|
||||
cache_ttl: 缓存生存时间(周期数)
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(self.chat_id)
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
|
||||
|
||||
self.llm_model = LLMRequest(
|
||||
model=global_config.model.tool_use,
|
||||
request_type="tool_executor",
|
||||
)
|
||||
|
||||
# 初始化工具实例
|
||||
self.tool_instance = ToolUser()
|
||||
|
||||
# 缓存配置
|
||||
self.enable_cache = enable_cache
|
||||
self.cache_ttl = cache_ttl
|
||||
self.tool_cache = {} # 格式: {cache_key: {"result": result, "ttl": ttl, "timestamp": timestamp}}
|
||||
|
||||
logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'},TTL={cache_ttl}")
|
||||
|
||||
async def execute_from_chat_message(
|
||||
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
|
||||
) -> Tuple[List[Dict], List[str], str]:
|
||||
"""从聊天消息执行工具
|
||||
|
||||
Args:
|
||||
target_message: 目标消息内容
|
||||
chat_history: 聊天历史
|
||||
sender: 发送者
|
||||
return_details: 是否返回详细信息(使用的工具列表和提示词)
|
||||
|
||||
Returns:
|
||||
如果return_details为False: List[Dict] - 工具执行结果列表
|
||||
如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词)
|
||||
"""
|
||||
|
||||
# 首先检查缓存
|
||||
cache_key = self._generate_cache_key(target_message, chat_history, sender)
|
||||
if cached_result := self._get_from_cache(cache_key):
|
||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
|
||||
if not return_details:
|
||||
return cached_result, [], "使用缓存结果"
|
||||
|
||||
# 从缓存结果中提取工具名称
|
||||
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
|
||||
return cached_result, used_tools, "使用缓存结果"
|
||||
|
||||
# 缓存未命中,执行工具调用
|
||||
# 获取可用工具
|
||||
tools = self.tool_instance._define_tools()
|
||||
|
||||
# 获取当前时间
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
|
||||
# 构建工具调用提示词
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"tool_executor_prompt",
|
||||
target_message=target_message,
|
||||
chat_history=chat_history,
|
||||
sender=sender,
|
||||
bot_name=bot_name,
|
||||
time_now=time_now,
|
||||
)
|
||||
|
||||
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
|
||||
|
||||
# 调用LLM进行工具决策
|
||||
response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools)
|
||||
|
||||
# 解析LLM响应
|
||||
if len(other_info) == 3:
|
||||
reasoning_content, model_name, tool_calls = other_info
|
||||
else:
|
||||
reasoning_content, model_name = other_info
|
||||
tool_calls = None
|
||||
|
||||
# 执行工具调用
|
||||
tool_results, used_tools = await self._execute_tool_calls(tool_calls)
|
||||
|
||||
# 缓存结果
|
||||
if tool_results:
|
||||
self._set_cache(cache_key, tool_results)
|
||||
|
||||
if used_tools:
|
||||
logger.info(f"{self.log_prefix}工具执行完成,共执行{len(used_tools)}个工具: {used_tools}")
|
||||
|
||||
if return_details:
|
||||
return tool_results, used_tools, prompt
|
||||
else:
|
||||
return tool_results, [], ""
|
||||
|
||||
async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]:
|
||||
"""执行工具调用
|
||||
|
||||
Args:
|
||||
tool_calls: LLM返回的工具调用列表
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
|
||||
"""
|
||||
tool_results = []
|
||||
used_tools = []
|
||||
|
||||
if not tool_calls:
|
||||
logger.debug(f"{self.log_prefix}无需执行工具")
|
||||
return tool_results, used_tools
|
||||
|
||||
logger.info(f"{self.log_prefix}开始执行工具调用: {tool_calls}")
|
||||
|
||||
# 处理工具调用
|
||||
success, valid_tool_calls, error_msg = process_llm_tool_calls(tool_calls)
|
||||
|
||||
if not success:
|
||||
logger.error(f"{self.log_prefix}工具调用解析失败: {error_msg}")
|
||||
return tool_results, used_tools
|
||||
|
||||
if not valid_tool_calls:
|
||||
logger.debug(f"{self.log_prefix}无有效工具调用")
|
||||
return tool_results, used_tools
|
||||
|
||||
# 执行每个工具调用
|
||||
for tool_call in valid_tool_calls:
|
||||
try:
|
||||
tool_name = tool_call.get("name", "unknown_tool")
|
||||
used_tools.append(tool_name)
|
||||
|
||||
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
|
||||
|
||||
# 执行工具
|
||||
result = await self.tool_instance.execute_tool_call(tool_call)
|
||||
|
||||
if result:
|
||||
tool_info = {
|
||||
"type": result.get("type", "unknown_type"),
|
||||
"id": result.get("id", f"tool_exec_{time.time()}"),
|
||||
"content": result.get("content", ""),
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
tool_results.append(tool_info)
|
||||
|
||||
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
|
||||
content = tool_info["content"]
|
||||
if not isinstance(content, (str, list, tuple)):
|
||||
content = str(content)
|
||||
preview = content[:200]
|
||||
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
|
||||
# 添加错误信息到结果中
|
||||
error_info = {
|
||||
"type": "tool_error",
|
||||
"id": f"tool_error_{time.time()}",
|
||||
"content": f"工具{tool_name}执行失败: {str(e)}",
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
tool_results.append(error_info)
|
||||
|
||||
return tool_results, used_tools
|
||||
|
||||
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
|
||||
"""生成缓存键
|
||||
|
||||
Args:
|
||||
target_message: 目标消息内容
|
||||
chat_history: 聊天历史
|
||||
sender: 发送者
|
||||
|
||||
Returns:
|
||||
str: 缓存键
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
# 使用消息内容和群聊状态生成唯一缓存键
|
||||
content = f"{target_message}_{chat_history}_{sender}"
|
||||
return hashlib.md5(content.encode()).hexdigest()
|
||||
|
||||
def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]:
|
||||
"""从缓存获取结果
|
||||
|
||||
Args:
|
||||
cache_key: 缓存键
|
||||
|
||||
Returns:
|
||||
Optional[List[Dict]]: 缓存的结果,如果不存在或过期则返回None
|
||||
"""
|
||||
if not self.enable_cache or cache_key not in self.tool_cache:
|
||||
return None
|
||||
|
||||
cache_item = self.tool_cache[cache_key]
|
||||
if cache_item["ttl"] <= 0:
|
||||
# 缓存过期,删除
|
||||
del self.tool_cache[cache_key]
|
||||
logger.debug(f"{self.log_prefix}缓存过期,删除缓存键: {cache_key}")
|
||||
return None
|
||||
|
||||
# 减少TTL
|
||||
cache_item["ttl"] -= 1
|
||||
logger.debug(f"{self.log_prefix}使用缓存结果,剩余TTL: {cache_item['ttl']}")
|
||||
return cache_item["result"]
|
||||
|
||||
def _set_cache(self, cache_key: str, result: List[Dict]):
|
||||
"""设置缓存
|
||||
|
||||
Args:
|
||||
cache_key: 缓存键
|
||||
result: 要缓存的结果
|
||||
"""
|
||||
if not self.enable_cache:
|
||||
return
|
||||
|
||||
self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()}
|
||||
logger.debug(f"{self.log_prefix}设置缓存,TTL: {self.cache_ttl}")
|
||||
|
||||
def _cleanup_expired_cache(self):
|
||||
"""清理过期的缓存"""
|
||||
if not self.enable_cache:
|
||||
return
|
||||
|
||||
expired_keys = []
|
||||
expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0)
|
||||
for key in expired_keys:
|
||||
del self.tool_cache[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
|
||||
|
||||
def get_available_tools(self) -> List[str]:
|
||||
"""获取可用工具列表
|
||||
|
||||
Returns:
|
||||
List[str]: 可用工具名称列表
|
||||
"""
|
||||
tools = self.tool_instance._define_tools()
|
||||
return [tool.get("function", {}).get("name", "unknown") for tool in tools]
|
||||
|
||||
async def execute_specific_tool(
|
||||
self, tool_name: str, tool_args: Dict, validate_args: bool = True
|
||||
) -> Optional[Dict]:
|
||||
"""直接执行指定工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_args: 工具参数
|
||||
validate_args: 是否验证参数
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 工具执行结果,失败时返回None
|
||||
"""
|
||||
try:
|
||||
tool_call = {"name": tool_name, "arguments": tool_args}
|
||||
|
||||
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
|
||||
|
||||
result = await self.tool_instance.execute_tool_call(tool_call)
|
||||
|
||||
if result:
|
||||
tool_info = {
|
||||
"type": result.get("type", "unknown_type"),
|
||||
"id": result.get("id", f"direct_tool_{time.time()}"),
|
||||
"content": result.get("content", ""),
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}")
|
||||
return tool_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空所有缓存"""
|
||||
if self.enable_cache:
|
||||
cache_count = len(self.tool_cache)
|
||||
self.tool_cache.clear()
|
||||
logger.info(f"{self.log_prefix}清空了{cache_count}个缓存项")
|
||||
|
||||
def get_cache_status(self) -> Dict:
|
||||
"""获取缓存状态信息
|
||||
|
||||
Returns:
|
||||
Dict: 包含缓存统计信息的字典
|
||||
"""
|
||||
if not self.enable_cache:
|
||||
return {"enabled": False, "cache_count": 0}
|
||||
|
||||
# 清理过期缓存
|
||||
self._cleanup_expired_cache()
|
||||
|
||||
total_count = len(self.tool_cache)
|
||||
ttl_distribution = {}
|
||||
|
||||
for cache_item in self.tool_cache.values():
|
||||
ttl = cache_item["ttl"]
|
||||
ttl_distribution[ttl] = ttl_distribution.get(ttl, 0) + 1
|
||||
|
||||
return {
|
||||
"enabled": True,
|
||||
"cache_count": total_count,
|
||||
"cache_ttl": self.cache_ttl,
|
||||
"ttl_distribution": ttl_distribution,
|
||||
}
|
||||
|
||||
def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1):
|
||||
"""动态修改缓存配置
|
||||
|
||||
Args:
|
||||
enable_cache: 是否启用缓存
|
||||
cache_ttl: 缓存TTL
|
||||
"""
|
||||
if enable_cache is not None:
|
||||
self.enable_cache = enable_cache
|
||||
logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}")
|
||||
|
||||
if cache_ttl > 0:
|
||||
self.cache_ttl = cache_ttl
|
||||
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
|
||||
|
||||
|
||||
# 初始化提示词
|
||||
init_tool_executor_prompt()
|
||||
|
||||
|
||||
"""
|
||||
使用示例:
|
||||
|
||||
# 1. 基础使用 - 从聊天消息执行工具(启用缓存,默认TTL=3)
|
||||
executor = ToolExecutor(executor_id="my_executor")
|
||||
results, _, _ = await executor.execute_from_chat_message(
|
||||
talking_message_str="今天天气怎么样?现在几点了?",
|
||||
is_group_chat=False
|
||||
)
|
||||
|
||||
# 2. 禁用缓存的执行器
|
||||
no_cache_executor = ToolExecutor(executor_id="no_cache", enable_cache=False)
|
||||
|
||||
# 3. 自定义缓存TTL
|
||||
long_cache_executor = ToolExecutor(executor_id="long_cache", cache_ttl=10)
|
||||
|
||||
# 4. 获取详细信息
|
||||
results, used_tools, prompt = await executor.execute_from_chat_message(
|
||||
talking_message_str="帮我查询Python相关知识",
|
||||
is_group_chat=False,
|
||||
return_details=True
|
||||
)
|
||||
|
||||
# 5. 直接执行特定工具
|
||||
result = await executor.execute_specific_tool(
|
||||
tool_name="get_knowledge",
|
||||
tool_args={"query": "机器学习"}
|
||||
)
|
||||
|
||||
# 6. 缓存管理
|
||||
available_tools = executor.get_available_tools()
|
||||
cache_status = executor.get_cache_status() # 查看缓存状态
|
||||
executor.clear_cache() # 清空缓存
|
||||
executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置
|
||||
"""
|
||||
@@ -1,56 +0,0 @@
|
||||
import json
|
||||
from src.common.logger import get_logger
|
||||
from src.tools.tool_can_use import get_all_tool_definitions, get_tool_instance
|
||||
|
||||
logger = get_logger("tool_use")
|
||||
|
||||
|
||||
class ToolUser:
|
||||
@staticmethod
|
||||
def _define_tools():
|
||||
"""获取所有已注册工具的定义
|
||||
|
||||
Returns:
|
||||
list: 工具定义列表
|
||||
"""
|
||||
return get_all_tool_definitions()
|
||||
|
||||
@staticmethod
|
||||
async def execute_tool_call(tool_call):
|
||||
# sourcery skip: use-assigned-variable
|
||||
"""执行特定的工具调用
|
||||
|
||||
Args:
|
||||
tool_call: 工具调用对象
|
||||
message_txt: 原始消息文本
|
||||
|
||||
Returns:
|
||||
dict: 工具调用结果
|
||||
"""
|
||||
try:
|
||||
function_name = tool_call["function"]["name"]
|
||||
function_args = json.loads(tool_call["function"]["arguments"])
|
||||
|
||||
# 获取对应工具实例
|
||||
tool_instance = get_tool_instance(function_name)
|
||||
if not tool_instance:
|
||||
logger.warning(f"未知工具名称: {function_name}")
|
||||
return None
|
||||
|
||||
# 执行工具
|
||||
result = await tool_instance.execute(function_args)
|
||||
if result:
|
||||
# 直接使用 function_name 作为 tool_type
|
||||
tool_type = function_name
|
||||
|
||||
return {
|
||||
"tool_call_id": tool_call["id"],
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"type": tool_type,
|
||||
"content": result["content"],
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
||||
return None
|
||||
Reference in New Issue
Block a user