tools系统

This commit is contained in:
UnCLAS-Prommer
2025-07-31 11:41:15 +08:00
parent 483c8fb547
commit 37e52a1566
10 changed files with 95 additions and 323 deletions

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict
from typing import Any, List, Tuple
from rich.traceback import install
from src.common.logger import get_logger
@@ -17,8 +17,8 @@ class BaseTool(ABC):
"""工具的名称"""
description: str = ""
"""工具的描述"""
parameters: Dict[str, Any] = {}
"""工具的参数定义"""
parameters: List[Tuple[str, str, str, bool]] = []
"""工具的参数定义,为[("param_name", "param_type", "description", required)]"""
available_for_llm: bool = False
"""是否可供LLM使用"""
@@ -32,10 +32,7 @@ class BaseTool(ABC):
if not cls.name or not cls.description or not cls.parameters:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return {
"type": "function",
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
}
return {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
@classmethod
def get_tool_info(cls) -> ToolInfo:
@@ -79,7 +76,9 @@ class BaseTool(ABC):
Returns:
dict: 工具执行结果
"""
if self.parameters and (missing := [p for p in self.parameters.get("required", []) if p not in function_args]):
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {', '.join(missing)}")
parameter_required = [param[0] for param in self.parameters if param[3]] # 获取所有必填参数名
for param_name in parameter_required:
if param_name not in function_args:
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {param_name}")
return await self.execute(function_args)

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Dict, Any, List, Optional
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass, field
from maim_message import Seg
@@ -150,7 +150,7 @@ class CommandInfo(ComponentInfo):
class ToolInfo(ComponentInfo):
"""工具组件信息"""
tool_parameters: Dict[str, Any] = field(default_factory=dict) # 工具参数定义
tool_parameters: List[Tuple[str, str, str, bool]] = field(default_factory=list) # 工具参数定义
tool_description: str = "" # 工具描述
def __post_init__(self):

View File

@@ -1,12 +1,11 @@
import json
import time
from typing import List, Dict, Tuple, Optional, Any
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content import ToolCall
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.json_utils import process_llm_tool_calls
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
@@ -63,7 +62,7 @@ class ToolExecutor:
async def execute_from_chat_message(
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
) -> Tuple[List[Dict], List[str], str]:
) -> Tuple[List[Dict[str, Any]], List[str], str]:
"""从聊天消息执行工具
Args:
@@ -110,15 +109,9 @@ class ToolExecutor:
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
# 调用LLM进行工具决策
response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools)
# TODO: 在APIADA加入后完全修复这里
# 解析LLM响应
if len(other_info) == 3:
reasoning_content, model_name, tool_calls = other_info
else:
reasoning_content, model_name = other_info
tool_calls = None
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
prompt=prompt, tools=tools
)
# 执行工具调用
tool_results, used_tools = await self._execute_tool_calls(tool_calls)
@@ -138,9 +131,9 @@ class ToolExecutor:
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
all_tools = get_llm_available_tool_definitions()
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
return [parameters for name, parameters in all_tools if name not in user_disabled_tools]
return [definition for name, definition in all_tools if name not in user_disabled_tools]
async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]:
async def _execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
"""执行工具调用
Args:
@@ -149,32 +142,19 @@ class ToolExecutor:
Returns:
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
"""
tool_results = []
tool_results: List[Dict[str, Any]] = []
used_tools = []
if not tool_calls:
logger.debug(f"{self.log_prefix}无需执行工具")
return tool_results, used_tools
return [], []
logger.info(f"{self.log_prefix}开始执行工具调用: {tool_calls}")
# 处理工具调用
success, valid_tool_calls, error_msg = process_llm_tool_calls(tool_calls)
if not success:
logger.error(f"{self.log_prefix}工具调用解析失败: {error_msg}")
return tool_results, used_tools
if not valid_tool_calls:
logger.debug(f"{self.log_prefix}无有效工具调用")
return tool_results, used_tools
# 执行每个工具调用
for tool_call in valid_tool_calls:
for tool_call in tool_calls:
try:
tool_name = tool_call.get("name", "unknown_tool")
used_tools.append(tool_name)
tool_name = tool_call.func_name
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
# 执行工具
@@ -188,15 +168,15 @@ class ToolExecutor:
"tool_name": tool_name,
"timestamp": time.time(),
}
tool_results.append(tool_info)
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
content = tool_info["content"]
if not isinstance(content, (str, list, tuple)):
content = str(content)
tool_info["content"] = str(content)
tool_results.append(tool_info)
used_tools.append(tool_name)
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
preview = content[:200]
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
except Exception as e:
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
# 添加错误信息到结果中
@@ -211,7 +191,7 @@ class ToolExecutor:
return tool_results, used_tools
async def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Optional[Dict]:
async def _execute_tool_call(self, tool_call: ToolCall) -> Optional[Dict[str, Any]]:
# sourcery skip: use-assigned-variable
"""执行单个工具调用
@@ -222,8 +202,8 @@ class ToolExecutor:
Optional[Dict]: 工具调用结果如果失败则返回None
"""
try:
function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"])
function_name = tool_call.func_name
function_args = tool_call.args or {}
function_args["llm_called"] = True # 标记为LLM调用
# 获取对应工具实例
@@ -235,20 +215,17 @@ class ToolExecutor:
# 执行工具
result = await tool_instance.execute(function_args)
if result:
# 直接使用 function_name 作为 tool_type
tool_type = function_name
return {
"tool_call_id": tool_call["id"],
"tool_call_id": tool_call.call_id,
"role": "tool",
"name": function_name,
"type": tool_type,
"type": "function",
"content": result["content"],
}
return None
except Exception as e:
logger.error(f"执行工具调用时发生错误: {str(e)}")
return None
raise e
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
"""生成缓存键
@@ -317,9 +294,7 @@ class ToolExecutor:
if expired_keys:
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
async def execute_specific_tool(
self, tool_name: str, tool_args: Dict, validate_args: bool = True
) -> Optional[Dict]:
async def execute_specific_tool(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
"""直接执行指定工具
Args:
@@ -331,7 +306,11 @@ class ToolExecutor:
Optional[Dict]: 工具执行结果失败时返回None
"""
try:
tool_call = {"name": tool_name, "arguments": tool_args}
tool_call = ToolCall(
call_id=f"direct_tool_{time.time()}",
func_name=tool_name,
args=tool_args,
)
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")