tools系统
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
from typing import Any, List, Tuple
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -17,8 +17,8 @@ class BaseTool(ABC):
|
||||
"""工具的名称"""
|
||||
description: str = ""
|
||||
"""工具的描述"""
|
||||
parameters: Dict[str, Any] = {}
|
||||
"""工具的参数定义"""
|
||||
parameters: List[Tuple[str, str, str, bool]] = []
|
||||
"""工具的参数定义,为[("param_name", "param_type", "description", required)]"""
|
||||
available_for_llm: bool = False
|
||||
"""是否可供LLM使用"""
|
||||
|
||||
@@ -32,10 +32,7 @@ class BaseTool(ABC):
|
||||
if not cls.name or not cls.description or not cls.parameters:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
|
||||
}
|
||||
return {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
|
||||
|
||||
@classmethod
|
||||
def get_tool_info(cls) -> ToolInfo:
|
||||
@@ -79,7 +76,9 @@ class BaseTool(ABC):
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
if self.parameters and (missing := [p for p in self.parameters.get("required", []) if p not in function_args]):
|
||||
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {', '.join(missing)}")
|
||||
parameter_required = [param[0] for param in self.parameters if param[3]] # 获取所有必填参数名
|
||||
for param_name in parameter_required:
|
||||
if param_name not in function_args:
|
||||
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {param_name}")
|
||||
|
||||
return await self.execute(function_args)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List, Optional
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from maim_message import Seg
|
||||
|
||||
@@ -150,7 +150,7 @@ class CommandInfo(ComponentInfo):
|
||||
class ToolInfo(ComponentInfo):
|
||||
"""工具组件信息"""
|
||||
|
||||
tool_parameters: Dict[str, Any] = field(default_factory=dict) # 工具参数定义
|
||||
tool_parameters: List[Tuple[str, str, str, bool]] = field(default_factory=list) # 工具参数定义
|
||||
tool_description: str = "" # 工具描述
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Tuple, Optional, Any
|
||||
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.llm_models.payload_content import ToolCall
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.json_utils import process_llm_tool_calls
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -63,7 +62,7 @@ class ToolExecutor:
|
||||
|
||||
async def execute_from_chat_message(
|
||||
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
|
||||
) -> Tuple[List[Dict], List[str], str]:
|
||||
) -> Tuple[List[Dict[str, Any]], List[str], str]:
|
||||
"""从聊天消息执行工具
|
||||
|
||||
Args:
|
||||
@@ -110,15 +109,9 @@ class ToolExecutor:
|
||||
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
|
||||
|
||||
# 调用LLM进行工具决策
|
||||
response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools)
|
||||
|
||||
# TODO: 在APIADA加入后完全修复这里!
|
||||
# 解析LLM响应
|
||||
if len(other_info) == 3:
|
||||
reasoning_content, model_name, tool_calls = other_info
|
||||
else:
|
||||
reasoning_content, model_name = other_info
|
||||
tool_calls = None
|
||||
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
|
||||
prompt=prompt, tools=tools
|
||||
)
|
||||
|
||||
# 执行工具调用
|
||||
tool_results, used_tools = await self._execute_tool_calls(tool_calls)
|
||||
@@ -138,9 +131,9 @@ class ToolExecutor:
|
||||
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
all_tools = get_llm_available_tool_definitions()
|
||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||
return [parameters for name, parameters in all_tools if name not in user_disabled_tools]
|
||||
return [definition for name, definition in all_tools if name not in user_disabled_tools]
|
||||
|
||||
async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]:
|
||||
async def _execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
"""执行工具调用
|
||||
|
||||
Args:
|
||||
@@ -149,32 +142,19 @@ class ToolExecutor:
|
||||
Returns:
|
||||
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
|
||||
"""
|
||||
tool_results = []
|
||||
tool_results: List[Dict[str, Any]] = []
|
||||
used_tools = []
|
||||
|
||||
if not tool_calls:
|
||||
logger.debug(f"{self.log_prefix}无需执行工具")
|
||||
return tool_results, used_tools
|
||||
return [], []
|
||||
|
||||
logger.info(f"{self.log_prefix}开始执行工具调用: {tool_calls}")
|
||||
|
||||
# 处理工具调用
|
||||
success, valid_tool_calls, error_msg = process_llm_tool_calls(tool_calls)
|
||||
|
||||
if not success:
|
||||
logger.error(f"{self.log_prefix}工具调用解析失败: {error_msg}")
|
||||
return tool_results, used_tools
|
||||
|
||||
if not valid_tool_calls:
|
||||
logger.debug(f"{self.log_prefix}无有效工具调用")
|
||||
return tool_results, used_tools
|
||||
|
||||
# 执行每个工具调用
|
||||
for tool_call in valid_tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
try:
|
||||
tool_name = tool_call.get("name", "unknown_tool")
|
||||
used_tools.append(tool_name)
|
||||
|
||||
tool_name = tool_call.func_name
|
||||
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
|
||||
|
||||
# 执行工具
|
||||
@@ -188,15 +168,15 @@ class ToolExecutor:
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
tool_results.append(tool_info)
|
||||
|
||||
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
|
||||
content = tool_info["content"]
|
||||
if not isinstance(content, (str, list, tuple)):
|
||||
content = str(content)
|
||||
tool_info["content"] = str(content)
|
||||
|
||||
tool_results.append(tool_info)
|
||||
used_tools.append(tool_name)
|
||||
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
|
||||
preview = content[:200]
|
||||
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
|
||||
# 添加错误信息到结果中
|
||||
@@ -211,7 +191,7 @@ class ToolExecutor:
|
||||
|
||||
return tool_results, used_tools
|
||||
|
||||
async def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Optional[Dict]:
|
||||
async def _execute_tool_call(self, tool_call: ToolCall) -> Optional[Dict[str, Any]]:
|
||||
# sourcery skip: use-assigned-variable
|
||||
"""执行单个工具调用
|
||||
|
||||
@@ -222,8 +202,8 @@ class ToolExecutor:
|
||||
Optional[Dict]: 工具调用结果,如果失败则返回None
|
||||
"""
|
||||
try:
|
||||
function_name = tool_call["function"]["name"]
|
||||
function_args = json.loads(tool_call["function"]["arguments"])
|
||||
function_name = tool_call.func_name
|
||||
function_args = tool_call.args or {}
|
||||
function_args["llm_called"] = True # 标记为LLM调用
|
||||
|
||||
# 获取对应工具实例
|
||||
@@ -235,20 +215,17 @@ class ToolExecutor:
|
||||
# 执行工具
|
||||
result = await tool_instance.execute(function_args)
|
||||
if result:
|
||||
# 直接使用 function_name 作为 tool_type
|
||||
tool_type = function_name
|
||||
|
||||
return {
|
||||
"tool_call_id": tool_call["id"],
|
||||
"tool_call_id": tool_call.call_id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"type": tool_type,
|
||||
"type": "function",
|
||||
"content": result["content"],
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
||||
return None
|
||||
raise e
|
||||
|
||||
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
|
||||
"""生成缓存键
|
||||
@@ -317,9 +294,7 @@ class ToolExecutor:
|
||||
if expired_keys:
|
||||
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
|
||||
|
||||
async def execute_specific_tool(
|
||||
self, tool_name: str, tool_args: Dict, validate_args: bool = True
|
||||
) -> Optional[Dict]:
|
||||
async def execute_specific_tool(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
|
||||
"""直接执行指定工具
|
||||
|
||||
Args:
|
||||
@@ -331,7 +306,11 @@ class ToolExecutor:
|
||||
Optional[Dict]: 工具执行结果,失败时返回None
|
||||
"""
|
||||
try:
|
||||
tool_call = {"name": tool_name, "arguments": tool_args}
|
||||
tool_call = ToolCall(
|
||||
call_id=f"direct_tool_{time.time()}",
|
||||
func_name=tool_name,
|
||||
args=tool_args,
|
||||
)
|
||||
|
||||
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user