tools系统
This commit is contained in:
@@ -412,7 +412,7 @@ class DefaultReplyer:
|
||||
for tool_result in tool_results:
|
||||
tool_name = tool_result.get("tool_name", "unknown")
|
||||
content = tool_result.get("content", "")
|
||||
result_type = tool_result.get("type", "info")
|
||||
result_type = tool_result.get("type", "tool_result")
|
||||
|
||||
tool_info_str += f"- 【{tool_name}】{result_type}: {content}\n"
|
||||
|
||||
@@ -848,7 +848,7 @@ class DefaultReplyer:
|
||||
raw_reply: str,
|
||||
reason: str,
|
||||
reply_to: str,
|
||||
) -> str: # sourcery skip: remove-redundant-if
|
||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
|
||||
@@ -1,223 +0,0 @@
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import Any, Dict, TypeVar, List, Union, Tuple, Optional
|
||||
|
||||
# 定义类型变量用于泛型类型提示
|
||||
T = TypeVar("T")
|
||||
|
||||
# 获取logger
|
||||
logger = logging.getLogger("json_utils")
|
||||
|
||||
|
||||
def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
|
||||
"""
|
||||
安全地解析JSON字符串,出错时返回默认值
|
||||
现在尝试处理单引号和标准JSON
|
||||
|
||||
参数:
|
||||
json_str: 要解析的JSON字符串
|
||||
default_value: 解析失败时返回的默认值
|
||||
|
||||
返回:
|
||||
解析后的Python对象,或在解析失败时返回default_value
|
||||
"""
|
||||
if not json_str or not isinstance(json_str, str):
|
||||
logger.warning(f"safe_json_loads 接收到非字符串输入: {type(json_str)}, 值: {json_str}")
|
||||
return default_value
|
||||
|
||||
try:
|
||||
# 尝试标准的 JSON 解析
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果标准解析失败,尝试用 ast.literal_eval 解析
|
||||
try:
|
||||
# logger.debug(f"标准JSON解析失败,尝试用 ast.literal_eval 解析: {json_str[:100]}...")
|
||||
result = ast.literal_eval(json_str)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
|
||||
return default_value
|
||||
except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e:
|
||||
logger.error(f"使用 ast.literal_eval 解析失败: {ast_e}, 字符串: {json_str[:100]}...")
|
||||
return default_value
|
||||
except Exception as e:
|
||||
logger.error(f"使用 ast.literal_eval 解析时发生意外错误: {e}, 字符串: {json_str[:100]}...")
|
||||
return default_value
|
||||
except Exception as e:
|
||||
logger.error(f"JSON解析过程中发生意外错误: {e}, 字符串: {json_str[:100]}...")
|
||||
return default_value
|
||||
|
||||
|
||||
def extract_tool_call_arguments(
|
||||
tool_call: Dict[str, Any], default_value: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从LLM工具调用对象中提取参数
|
||||
|
||||
参数:
|
||||
tool_call: 工具调用对象字典
|
||||
default_value: 解析失败时返回的默认值
|
||||
|
||||
返回:
|
||||
解析后的参数字典,或在解析失败时返回default_value
|
||||
"""
|
||||
default_result = default_value or {}
|
||||
|
||||
if not tool_call or not isinstance(tool_call, dict):
|
||||
logger.error(f"无效的工具调用对象: {tool_call}")
|
||||
return default_result
|
||||
|
||||
try:
|
||||
# 提取function参数
|
||||
function_data = tool_call.get("function", {})
|
||||
if not function_data or not isinstance(function_data, dict):
|
||||
logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}")
|
||||
return default_result
|
||||
|
||||
if arguments_str := function_data.get("arguments", "{}"):
|
||||
# 解析JSON
|
||||
return safe_json_loads(arguments_str, default_result)
|
||||
else:
|
||||
return default_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"提取工具调用参数时出错: {e}")
|
||||
return default_result
|
||||
|
||||
|
||||
def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False, pretty: bool = False) -> str:
|
||||
"""
|
||||
安全地将Python对象序列化为JSON字符串
|
||||
|
||||
参数:
|
||||
obj: 要序列化的Python对象
|
||||
default_value: 序列化失败时返回的默认值
|
||||
ensure_ascii: 是否确保ASCII编码(默认False,允许中文等非ASCII字符)
|
||||
pretty: 是否美化输出JSON
|
||||
|
||||
返回:
|
||||
序列化后的JSON字符串,或在序列化失败时返回default_value
|
||||
"""
|
||||
try:
|
||||
indent = 2 if pretty else None
|
||||
return json.dumps(obj, ensure_ascii=ensure_ascii, indent=indent)
|
||||
except TypeError as e:
|
||||
logger.error(f"JSON序列化失败(类型错误): {e}")
|
||||
return default_value
|
||||
except Exception as e:
|
||||
logger.error(f"JSON序列化过程中发生意外错误: {e}")
|
||||
return default_value
|
||||
|
||||
|
||||
def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, List[Any], str]:
|
||||
"""
|
||||
标准化LLM响应格式,将各种格式(如元组)转换为统一的列表格式
|
||||
|
||||
参数:
|
||||
response: 原始LLM响应
|
||||
log_prefix: 日志前缀
|
||||
|
||||
返回:
|
||||
元组 (成功标志, 标准化后的响应列表, 错误消息)
|
||||
"""
|
||||
|
||||
logger.debug(f"{log_prefix}原始人 LLM响应: {response}")
|
||||
|
||||
# 检查是否为None
|
||||
if response is None:
|
||||
return False, [], "LLM响应为None"
|
||||
|
||||
# 记录原始类型
|
||||
logger.debug(f"{log_prefix}LLM响应原始类型: {type(response).__name__}")
|
||||
|
||||
# 将元组转换为列表
|
||||
if isinstance(response, tuple):
|
||||
logger.debug(f"{log_prefix}将元组响应转换为列表")
|
||||
response = list(response)
|
||||
|
||||
# 确保是列表类型
|
||||
if not isinstance(response, list):
|
||||
return False, [], f"无法处理的LLM响应类型: {type(response).__name__}"
|
||||
|
||||
# 处理工具调用部分(如果存在)
|
||||
if len(response) == 3:
|
||||
content, reasoning, tool_calls = response
|
||||
|
||||
# 将工具调用部分转换为列表(如果是元组)
|
||||
if isinstance(tool_calls, tuple):
|
||||
logger.debug(f"{log_prefix}将工具调用元组转换为列表")
|
||||
tool_calls = list(tool_calls)
|
||||
response[2] = tool_calls
|
||||
|
||||
return True, response, ""
|
||||
|
||||
|
||||
def process_llm_tool_calls(
|
||||
tool_calls: List[Dict[str, Any]], log_prefix: str = ""
|
||||
) -> Tuple[bool, List[Dict[str, Any]], str]:
|
||||
"""
|
||||
处理并验证LLM响应中的工具调用列表
|
||||
|
||||
参数:
|
||||
tool_calls: 从LLM响应中直接获取的工具调用列表
|
||||
log_prefix: 日志前缀
|
||||
|
||||
返回:
|
||||
元组 (成功标志, 验证后的工具调用列表, 错误消息)
|
||||
"""
|
||||
|
||||
# 如果列表为空,表示没有工具调用,这不是错误
|
||||
if not tool_calls:
|
||||
return True, [], "工具调用列表为空"
|
||||
|
||||
# 验证每个工具调用的格式
|
||||
valid_tool_calls = []
|
||||
for i, tool_call in enumerate(tool_calls):
|
||||
if not isinstance(tool_call, dict):
|
||||
logger.warning(f"{log_prefix}工具调用[{i}]不是字典: {type(tool_call).__name__}, 内容: {tool_call}")
|
||||
continue
|
||||
|
||||
# 检查基本结构
|
||||
if tool_call.get("type") != "function":
|
||||
logger.warning(
|
||||
f"{log_prefix}工具调用[{i}]不是function类型: type={tool_call.get('type', '未定义')}, 内容: {tool_call}"
|
||||
)
|
||||
continue
|
||||
|
||||
if "function" not in tool_call or not isinstance(tool_call.get("function"), dict):
|
||||
logger.warning(f"{log_prefix}工具调用[{i}]缺少'function'字段或其类型不正确: {tool_call}")
|
||||
continue
|
||||
|
||||
func_details = tool_call["function"]
|
||||
if "name" not in func_details or not isinstance(func_details.get("name"), str):
|
||||
logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'name'或类型不正确: {func_details}")
|
||||
continue
|
||||
|
||||
# 验证参数 'arguments'
|
||||
args_value = func_details.get("arguments")
|
||||
|
||||
# 1. 检查 arguments 是否存在且是字符串
|
||||
if args_value is None or not isinstance(args_value, str):
|
||||
logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'arguments'字符串: {func_details}")
|
||||
continue
|
||||
|
||||
# 2. 尝试安全地解析 arguments 字符串
|
||||
parsed_args = safe_json_loads(args_value, None)
|
||||
|
||||
# 3. 检查解析结果是否为字典
|
||||
if parsed_args is None or not isinstance(parsed_args, dict):
|
||||
logger.warning(
|
||||
f"{log_prefix}工具调用[{i}]的'arguments'无法解析为有效的JSON字典, "
|
||||
f"原始字符串: {args_value[:100]}..., 解析结果类型: {type(parsed_args).__name__}"
|
||||
)
|
||||
continue
|
||||
|
||||
# 如果检查通过,将原始的 tool_call 加入有效列表
|
||||
valid_tool_calls.append(tool_call)
|
||||
|
||||
if not valid_tool_calls and tool_calls: # 如果原始列表不为空,但验证后为空
|
||||
return False, [], "所有工具调用格式均无效"
|
||||
|
||||
return True, valid_tool_calls, ""
|
||||
@@ -1,4 +1,4 @@
|
||||
raise DeprecationWarning("Genimi Client is not fully available yet.")
|
||||
raise DeprecationWarning("Genimi Client is not fully available yet. Please remove your Gemini API Provider")
|
||||
import asyncio
|
||||
import io
|
||||
from collections.abc import Iterable
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .tool_option import ToolCall
|
||||
|
||||
__all__ = ["ToolCall"]
|
||||
@@ -11,7 +11,7 @@ from src.config.config import model_config
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
|
||||
from .payload_content.message import MessageBuilder, Message
|
||||
from .payload_content.resp_format import RespFormat
|
||||
from .payload_content.tool_option import ToolOption, ToolCall
|
||||
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType
|
||||
from .model_client.base_client import BaseClient, APIResponse, client_registry
|
||||
from .utils import compress_messages, llm_usage_recorder
|
||||
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
|
||||
@@ -60,7 +60,7 @@ class LLMRequest:
|
||||
image_format: str,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]:
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||
"""
|
||||
为图像生成响应
|
||||
Args:
|
||||
@@ -68,7 +68,7 @@ class LLMRequest:
|
||||
image_base64 (str): 图像的Base64编码字符串
|
||||
image_format (str): 图像格式(如 'png', 'jpeg' 等)
|
||||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
# 请求体构建
|
||||
message_builder = MessageBuilder()
|
||||
@@ -104,18 +104,18 @@ class LLMRequest:
|
||||
request_type=self.request_type,
|
||||
endpoint="/chat/completions",
|
||||
)
|
||||
return content, (
|
||||
reasoning_content,
|
||||
model_info.name,
|
||||
self._convert_tool_calls(tool_calls) if tool_calls else None,
|
||||
)
|
||||
return content, (reasoning_content, model_info.name, tool_calls)
|
||||
|
||||
async def generate_response_for_voice(self):
|
||||
pass
|
||||
|
||||
async def generate_response_async(
|
||||
self, prompt: str, temperature: Optional[float] = None, max_tokens: Optional[int] = None
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]:
|
||||
self,
|
||||
prompt: str,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||
"""
|
||||
异步生成响应
|
||||
Args:
|
||||
@@ -123,13 +123,13 @@ class LLMRequest:
|
||||
temperature (float, optional): 温度参数
|
||||
max_tokens (int, optional): 最大token数
|
||||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
# 请求体构建
|
||||
message_builder = MessageBuilder()
|
||||
message_builder.add_text_content(prompt)
|
||||
messages = [message_builder.build()]
|
||||
|
||||
tool_built = self._build_tool_options(tools)
|
||||
# 模型选择
|
||||
model_info, api_provider, client = self._select_model()
|
||||
|
||||
@@ -142,6 +142,7 @@ class LLMRequest:
|
||||
message_list=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
tool_options=tool_built,
|
||||
)
|
||||
content = response.content
|
||||
reasoning_content = response.reasoning_content or ""
|
||||
@@ -161,11 +162,7 @@ class LLMRequest:
|
||||
if not content:
|
||||
raise RuntimeError("获取LLM生成内容失败")
|
||||
|
||||
return content, (
|
||||
reasoning_content,
|
||||
model_info.name,
|
||||
self._convert_tool_calls(tool_calls) if tool_calls else None,
|
||||
)
|
||||
return content, (reasoning_content, model_info.name, tool_calls)
|
||||
|
||||
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
||||
"""获取嵌入向量
|
||||
@@ -214,10 +211,6 @@ class LLMRequest:
|
||||
client = client_registry.get_client_class(api_provider.client_type)(copy.deepcopy(api_provider))
|
||||
return model_info, api_provider, client
|
||||
|
||||
def _convert_tool_calls(self, tool_calls: List[ToolCall]) -> List[Dict[str, Any]]:
|
||||
"""将ToolCall对象转换为Dict列表"""
|
||||
pass
|
||||
|
||||
async def _execute_request(
|
||||
self,
|
||||
api_provider: APIProvider,
|
||||
@@ -435,6 +428,35 @@ class LLMRequest:
|
||||
)
|
||||
return -1, None
|
||||
|
||||
def _build_tool_options(self, tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
|
||||
"""构建工具选项列表"""
|
||||
if not tools:
|
||||
return None
|
||||
tool_options: List[ToolOption] = []
|
||||
for tool in tools:
|
||||
tool_legal = True
|
||||
tool_options_builder = ToolOptionBuilder()
|
||||
tool_options_builder.set_name(tool.get("name", ""))
|
||||
tool_options_builder.set_description(tool.get("description", ""))
|
||||
parameters: List[Tuple[str, str, str, bool]] = tool.get("parameters", [])
|
||||
for param in parameters:
|
||||
try:
|
||||
tool_options_builder.add_param(
|
||||
name=param[0],
|
||||
param_type=ToolParamType(param[1]),
|
||||
description=param[2],
|
||||
required=param[3],
|
||||
)
|
||||
except ValueError as ve:
|
||||
tool_legal = False
|
||||
logger.error(f"{param[1]} 参数类型错误: {str(ve)}")
|
||||
except Exception as e:
|
||||
tool_legal = False
|
||||
logger.error(f"构建工具参数失败: {str(e)}")
|
||||
if tool_legal:
|
||||
tool_options.append(tool_options_builder.build())
|
||||
return tool_options or None
|
||||
|
||||
@staticmethod
|
||||
def _extract_reasoning(content: str) -> Tuple[str, str]:
|
||||
"""CoT思维链提取,向后兼容"""
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
from typing import Any, List, Tuple
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -17,8 +17,8 @@ class BaseTool(ABC):
|
||||
"""工具的名称"""
|
||||
description: str = ""
|
||||
"""工具的描述"""
|
||||
parameters: Dict[str, Any] = {}
|
||||
"""工具的参数定义"""
|
||||
parameters: List[Tuple[str, str, str, bool]] = []
|
||||
"""工具的参数定义,为[("param_name", "param_type", "description", required)]"""
|
||||
available_for_llm: bool = False
|
||||
"""是否可供LLM使用"""
|
||||
|
||||
@@ -32,10 +32,7 @@ class BaseTool(ABC):
|
||||
if not cls.name or not cls.description or not cls.parameters:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
|
||||
}
|
||||
return {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
|
||||
|
||||
@classmethod
|
||||
def get_tool_info(cls) -> ToolInfo:
|
||||
@@ -79,7 +76,9 @@ class BaseTool(ABC):
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
if self.parameters and (missing := [p for p in self.parameters.get("required", []) if p not in function_args]):
|
||||
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {', '.join(missing)}")
|
||||
parameter_required = [param[0] for param in self.parameters if param[3]] # 获取所有必填参数名
|
||||
for param_name in parameter_required:
|
||||
if param_name not in function_args:
|
||||
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {param_name}")
|
||||
|
||||
return await self.execute(function_args)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List, Optional
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from maim_message import Seg
|
||||
|
||||
@@ -150,7 +150,7 @@ class CommandInfo(ComponentInfo):
|
||||
class ToolInfo(ComponentInfo):
|
||||
"""工具组件信息"""
|
||||
|
||||
tool_parameters: Dict[str, Any] = field(default_factory=dict) # 工具参数定义
|
||||
tool_parameters: List[Tuple[str, str, str, bool]] = field(default_factory=list) # 工具参数定义
|
||||
tool_description: str = "" # 工具描述
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Tuple, Optional, Any
|
||||
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.llm_models.payload_content import ToolCall
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.json_utils import process_llm_tool_calls
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -63,7 +62,7 @@ class ToolExecutor:
|
||||
|
||||
async def execute_from_chat_message(
|
||||
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
|
||||
) -> Tuple[List[Dict], List[str], str]:
|
||||
) -> Tuple[List[Dict[str, Any]], List[str], str]:
|
||||
"""从聊天消息执行工具
|
||||
|
||||
Args:
|
||||
@@ -110,15 +109,9 @@ class ToolExecutor:
|
||||
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
|
||||
|
||||
# 调用LLM进行工具决策
|
||||
response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools)
|
||||
|
||||
# TODO: 在APIADA加入后完全修复这里!
|
||||
# 解析LLM响应
|
||||
if len(other_info) == 3:
|
||||
reasoning_content, model_name, tool_calls = other_info
|
||||
else:
|
||||
reasoning_content, model_name = other_info
|
||||
tool_calls = None
|
||||
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
|
||||
prompt=prompt, tools=tools
|
||||
)
|
||||
|
||||
# 执行工具调用
|
||||
tool_results, used_tools = await self._execute_tool_calls(tool_calls)
|
||||
@@ -138,9 +131,9 @@ class ToolExecutor:
|
||||
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
all_tools = get_llm_available_tool_definitions()
|
||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||
return [parameters for name, parameters in all_tools if name not in user_disabled_tools]
|
||||
return [definition for name, definition in all_tools if name not in user_disabled_tools]
|
||||
|
||||
async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]:
|
||||
async def _execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
"""执行工具调用
|
||||
|
||||
Args:
|
||||
@@ -149,32 +142,19 @@ class ToolExecutor:
|
||||
Returns:
|
||||
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
|
||||
"""
|
||||
tool_results = []
|
||||
tool_results: List[Dict[str, Any]] = []
|
||||
used_tools = []
|
||||
|
||||
if not tool_calls:
|
||||
logger.debug(f"{self.log_prefix}无需执行工具")
|
||||
return tool_results, used_tools
|
||||
return [], []
|
||||
|
||||
logger.info(f"{self.log_prefix}开始执行工具调用: {tool_calls}")
|
||||
|
||||
# 处理工具调用
|
||||
success, valid_tool_calls, error_msg = process_llm_tool_calls(tool_calls)
|
||||
|
||||
if not success:
|
||||
logger.error(f"{self.log_prefix}工具调用解析失败: {error_msg}")
|
||||
return tool_results, used_tools
|
||||
|
||||
if not valid_tool_calls:
|
||||
logger.debug(f"{self.log_prefix}无有效工具调用")
|
||||
return tool_results, used_tools
|
||||
|
||||
# 执行每个工具调用
|
||||
for tool_call in valid_tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
try:
|
||||
tool_name = tool_call.get("name", "unknown_tool")
|
||||
used_tools.append(tool_name)
|
||||
|
||||
tool_name = tool_call.func_name
|
||||
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
|
||||
|
||||
# 执行工具
|
||||
@@ -188,15 +168,15 @@ class ToolExecutor:
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
tool_results.append(tool_info)
|
||||
|
||||
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
|
||||
content = tool_info["content"]
|
||||
if not isinstance(content, (str, list, tuple)):
|
||||
content = str(content)
|
||||
tool_info["content"] = str(content)
|
||||
|
||||
tool_results.append(tool_info)
|
||||
used_tools.append(tool_name)
|
||||
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
|
||||
preview = content[:200]
|
||||
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
|
||||
# 添加错误信息到结果中
|
||||
@@ -211,7 +191,7 @@ class ToolExecutor:
|
||||
|
||||
return tool_results, used_tools
|
||||
|
||||
async def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Optional[Dict]:
|
||||
async def _execute_tool_call(self, tool_call: ToolCall) -> Optional[Dict[str, Any]]:
|
||||
# sourcery skip: use-assigned-variable
|
||||
"""执行单个工具调用
|
||||
|
||||
@@ -222,8 +202,8 @@ class ToolExecutor:
|
||||
Optional[Dict]: 工具调用结果,如果失败则返回None
|
||||
"""
|
||||
try:
|
||||
function_name = tool_call["function"]["name"]
|
||||
function_args = json.loads(tool_call["function"]["arguments"])
|
||||
function_name = tool_call.func_name
|
||||
function_args = tool_call.args or {}
|
||||
function_args["llm_called"] = True # 标记为LLM调用
|
||||
|
||||
# 获取对应工具实例
|
||||
@@ -235,20 +215,17 @@ class ToolExecutor:
|
||||
# 执行工具
|
||||
result = await tool_instance.execute(function_args)
|
||||
if result:
|
||||
# 直接使用 function_name 作为 tool_type
|
||||
tool_type = function_name
|
||||
|
||||
return {
|
||||
"tool_call_id": tool_call["id"],
|
||||
"tool_call_id": tool_call.call_id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"type": tool_type,
|
||||
"type": "function",
|
||||
"content": result["content"],
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
||||
return None
|
||||
raise e
|
||||
|
||||
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
|
||||
"""生成缓存键
|
||||
@@ -317,9 +294,7 @@ class ToolExecutor:
|
||||
if expired_keys:
|
||||
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
|
||||
|
||||
async def execute_specific_tool(
|
||||
self, tool_name: str, tool_args: Dict, validate_args: bool = True
|
||||
) -> Optional[Dict]:
|
||||
async def execute_specific_tool(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
|
||||
"""直接执行指定工具
|
||||
|
||||
Args:
|
||||
@@ -331,7 +306,11 @@ class ToolExecutor:
|
||||
Optional[Dict]: 工具执行结果,失败时返回None
|
||||
"""
|
||||
try:
|
||||
tool_call = {"name": tool_name, "arguments": tool_args}
|
||||
tool_call = ToolCall(
|
||||
call_id=f"direct_tool_{time.time()}",
|
||||
func_name=tool_name,
|
||||
args=tool_args,
|
||||
)
|
||||
|
||||
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
|
||||
|
||||
|
||||
@@ -14,14 +14,10 @@ class SearchKnowledgeTool(BaseTool):
|
||||
|
||||
name = "search_knowledge"
|
||||
description = "使用工具从知识库中搜索相关信息"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "搜索查询关键词"},
|
||||
"threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
parameters = [
|
||||
("query", "string", "搜索查询关键词", True),
|
||||
("threshold", "float", "相似度阈值,0.0到1.0之间", False),
|
||||
]
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行知识库搜索
|
||||
|
||||
@@ -14,14 +14,10 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
|
||||
|
||||
name = "lpmm_search_knowledge"
|
||||
description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "搜索查询关键词"},
|
||||
"threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
parameters = [
|
||||
("query", "string", "搜索查询关键词", True),
|
||||
("threshold", "float", "相似度阈值,0.0到1.0之间", False),
|
||||
]
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行知识库搜索
|
||||
|
||||
Reference in New Issue
Block a user