From 37e52a1566437cad9366adcf4ac16156554d4488 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Thu, 31 Jul 2025 11:41:15 +0800 Subject: [PATCH] =?UTF-8?q?tools=E7=B3=BB=E7=BB=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/replyer/default_generator.py | 4 +- src/chat/utils/json_utils.py | 223 ------------------ src/llm_models/model_client/gemini_client.py | 2 +- src/llm_models/payload_content/__init__.py | 3 + src/llm_models/utils_model.py | 64 +++-- src/plugin_system/base/base_tool.py | 17 +- src/plugin_system/base/component_types.py | 4 +- src/plugin_system/core/tool_use.py | 77 +++--- .../built_in/knowledge/get_knowledge.py | 12 +- .../built_in/knowledge/lpmm_get_knowledge.py | 12 +- 10 files changed, 95 insertions(+), 323 deletions(-) delete mode 100644 src/chat/utils/json_utils.py diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 9aacb1ae1..3c8a54922 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -412,7 +412,7 @@ class DefaultReplyer: for tool_result in tool_results: tool_name = tool_result.get("tool_name", "unknown") content = tool_result.get("content", "") - result_type = tool_result.get("type", "info") + result_type = tool_result.get("type", "tool_result") tool_info_str += f"- 【{tool_name}】{result_type}: {content}\n" @@ -848,7 +848,7 @@ class DefaultReplyer: raw_reply: str, reason: str, reply_to: str, - ) -> str: # sourcery skip: remove-redundant-if + ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if chat_stream = self.chat_stream chat_id = chat_stream.stream_id is_group_chat = bool(chat_stream.group_info) diff --git a/src/chat/utils/json_utils.py b/src/chat/utils/json_utils.py deleted file mode 100644 index 892deac4f..000000000 --- a/src/chat/utils/json_utils.py +++ /dev/null @@ -1,223 +0,0 @@ -import ast -import json -import logging - -from typing import Any, Dict, TypeVar, List, Union, Tuple, Optional - -# 定义类型变量用于泛型类型提示 -T = TypeVar("T") - -# 获取logger -logger = logging.getLogger("json_utils") - - -def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]: - """ - 安全地解析JSON字符串,出错时返回默认值 - 现在尝试处理单引号和标准JSON - - 参数: - json_str: 要解析的JSON字符串 - default_value: 解析失败时返回的默认值 - - 返回: - 解析后的Python对象,或在解析失败时返回default_value - """ - if not json_str or not isinstance(json_str, str): - logger.warning(f"safe_json_loads 接收到非字符串输入: {type(json_str)}, 值: {json_str}") - return default_value - - try: - # 尝试标准的 JSON 解析 - return json.loads(json_str) - except json.JSONDecodeError: - # 如果标准解析失败,尝试用 ast.literal_eval 解析 - try: - # logger.debug(f"标准JSON解析失败,尝试用 ast.literal_eval 解析: {json_str[:100]}...") - result = ast.literal_eval(json_str) - if isinstance(result, dict): - return result - logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}") - return default_value - except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e: - logger.error(f"使用 ast.literal_eval 解析失败: {ast_e}, 字符串: {json_str[:100]}...") - return default_value - except Exception as e: - logger.error(f"使用 ast.literal_eval 解析时发生意外错误: {e}, 字符串: {json_str[:100]}...") - return default_value - except Exception as e: - logger.error(f"JSON解析过程中发生意外错误: {e}, 字符串: {json_str[:100]}...") - return default_value - - -def extract_tool_call_arguments( - tool_call: Dict[str, Any], default_value: Optional[Dict[str, Any]] = None -) -> Dict[str, Any]: - """ - 从LLM工具调用对象中提取参数 - - 参数: - tool_call: 工具调用对象字典 - default_value: 解析失败时返回的默认值 - - 返回: - 解析后的参数字典,或在解析失败时返回default_value - """ - default_result = default_value or {} - - if not tool_call or not isinstance(tool_call, dict): - logger.error(f"无效的工具调用对象: {tool_call}") - return default_result - - try: - # 提取function参数 - function_data = tool_call.get("function", {}) - if not function_data or not isinstance(function_data, dict): - logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}") - return default_result - - if arguments_str := function_data.get("arguments", "{}"): - # 解析JSON - return safe_json_loads(arguments_str, default_result) - else: - return default_result - - except Exception as e: - logger.error(f"提取工具调用参数时出错: {e}") - return default_result - - -def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False, pretty: bool = False) -> str: - """ - 安全地将Python对象序列化为JSON字符串 - - 参数: - obj: 要序列化的Python对象 - default_value: 序列化失败时返回的默认值 - ensure_ascii: 是否确保ASCII编码(默认False,允许中文等非ASCII字符) - pretty: 是否美化输出JSON - - 返回: - 序列化后的JSON字符串,或在序列化失败时返回default_value - """ - try: - indent = 2 if pretty else None - return json.dumps(obj, ensure_ascii=ensure_ascii, indent=indent) - except TypeError as e: - logger.error(f"JSON序列化失败(类型错误): {e}") - return default_value - except Exception as e: - logger.error(f"JSON序列化过程中发生意外错误: {e}") - return default_value - - -def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, List[Any], str]: - """ - 标准化LLM响应格式,将各种格式(如元组)转换为统一的列表格式 - - 参数: - response: 原始LLM响应 - log_prefix: 日志前缀 - - 返回: - 元组 (成功标志, 标准化后的响应列表, 错误消息) - """ - - logger.debug(f"{log_prefix}原始人 LLM响应: {response}") - - # 检查是否为None - if response is None: - return False, [], "LLM响应为None" - - # 记录原始类型 - logger.debug(f"{log_prefix}LLM响应原始类型: {type(response).__name__}") - - # 将元组转换为列表 - if isinstance(response, tuple): - logger.debug(f"{log_prefix}将元组响应转换为列表") - response = list(response) - - # 确保是列表类型 - if not isinstance(response, list): - return False, [], f"无法处理的LLM响应类型: {type(response).__name__}" - - # 处理工具调用部分(如果存在) - if len(response) == 3: - content, reasoning, tool_calls = response - - # 将工具调用部分转换为列表(如果是元组) - if isinstance(tool_calls, tuple): - logger.debug(f"{log_prefix}将工具调用元组转换为列表") - tool_calls = list(tool_calls) - response[2] = tool_calls - - return True, response, "" - - -def process_llm_tool_calls( - tool_calls: List[Dict[str, Any]], log_prefix: str = "" -) -> Tuple[bool, List[Dict[str, Any]], str]: - """ - 处理并验证LLM响应中的工具调用列表 - - 参数: - tool_calls: 从LLM响应中直接获取的工具调用列表 - log_prefix: 日志前缀 - - 返回: - 元组 (成功标志, 验证后的工具调用列表, 错误消息) - """ - - # 如果列表为空,表示没有工具调用,这不是错误 - if not tool_calls: - return True, [], "工具调用列表为空" - - # 验证每个工具调用的格式 - valid_tool_calls = [] - for i, tool_call in enumerate(tool_calls): - if not isinstance(tool_call, dict): - logger.warning(f"{log_prefix}工具调用[{i}]不是字典: {type(tool_call).__name__}, 内容: {tool_call}") - continue - - # 检查基本结构 - if tool_call.get("type") != "function": - logger.warning( - f"{log_prefix}工具调用[{i}]不是function类型: type={tool_call.get('type', '未定义')}, 内容: {tool_call}" - ) - continue - - if "function" not in tool_call or not isinstance(tool_call.get("function"), dict): - logger.warning(f"{log_prefix}工具调用[{i}]缺少'function'字段或其类型不正确: {tool_call}") - continue - - func_details = tool_call["function"] - if "name" not in func_details or not isinstance(func_details.get("name"), str): - logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'name'或类型不正确: {func_details}") - continue - - # 验证参数 'arguments' - args_value = func_details.get("arguments") - - # 1. 检查 arguments 是否存在且是字符串 - if args_value is None or not isinstance(args_value, str): - logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'arguments'字符串: {func_details}") - continue - - # 2. 尝试安全地解析 arguments 字符串 - parsed_args = safe_json_loads(args_value, None) - - # 3. 检查解析结果是否为字典 - if parsed_args is None or not isinstance(parsed_args, dict): - logger.warning( - f"{log_prefix}工具调用[{i}]的'arguments'无法解析为有效的JSON字典, " - f"原始字符串: {args_value[:100]}..., 解析结果类型: {type(parsed_args).__name__}" - ) - continue - - # 如果检查通过,将原始的 tool_call 加入有效列表 - valid_tool_calls.append(tool_call) - - if not valid_tool_calls and tool_calls: # 如果原始列表不为空,但验证后为空 - return False, [], "所有工具调用格式均无效" - - return True, valid_tool_calls, "" diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index 0377fb118..e04a327df 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -1,4 +1,4 @@ -raise DeprecationWarning("Genimi Client is not fully available yet.") +raise DeprecationWarning("Genimi Client is not fully available yet. Please remove your Gemini API Provider") import asyncio import io from collections.abc import Iterable diff --git a/src/llm_models/payload_content/__init__.py b/src/llm_models/payload_content/__init__.py index e69de29bb..33e43c5ee 100644 --- a/src/llm_models/payload_content/__init__.py +++ b/src/llm_models/payload_content/__init__.py @@ -0,0 +1,3 @@ +from .tool_option import ToolCall + +__all__ = ["ToolCall"] \ No newline at end of file diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index ab3251509..679d1149f 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -11,7 +11,7 @@ from src.config.config import model_config from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig from .payload_content.message import MessageBuilder, Message from .payload_content.resp_format import RespFormat -from .payload_content.tool_option import ToolOption, ToolCall +from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType from .model_client.base_client import BaseClient, APIResponse, client_registry from .utils import compress_messages, llm_usage_recorder from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException @@ -60,7 +60,7 @@ class LLMRequest: image_format: str, temperature: Optional[float] = None, max_tokens: Optional[int] = None, - ) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]: + ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ 为图像生成响应 Args: @@ -68,7 +68,7 @@ class LLMRequest: image_base64 (str): 图像的Base64编码字符串 image_format (str): 图像格式(如 'png', 'jpeg' 等) Returns: - (Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表 + (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ # 请求体构建 message_builder = MessageBuilder() @@ -104,18 +104,18 @@ class LLMRequest: request_type=self.request_type, endpoint="/chat/completions", ) - return content, ( - reasoning_content, - model_info.name, - self._convert_tool_calls(tool_calls) if tool_calls else None, - ) + return content, (reasoning_content, model_info.name, tool_calls) async def generate_response_for_voice(self): pass async def generate_response_async( - self, prompt: str, temperature: Optional[float] = None, max_tokens: Optional[int] = None - ) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]: + self, + prompt: str, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + tools: Optional[List[Dict[str, Any]]] = None, + ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ 异步生成响应 Args: @@ -123,13 +123,13 @@ class LLMRequest: temperature (float, optional): 温度参数 max_tokens (int, optional): 最大token数 Returns: - (Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表 + (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ # 请求体构建 message_builder = MessageBuilder() message_builder.add_text_content(prompt) messages = [message_builder.build()] - + tool_built = self._build_tool_options(tools) # 模型选择 model_info, api_provider, client = self._select_model() @@ -142,6 +142,7 @@ class LLMRequest: message_list=messages, temperature=temperature, max_tokens=max_tokens, + tool_options=tool_built, ) content = response.content reasoning_content = response.reasoning_content or "" @@ -161,11 +162,7 @@ class LLMRequest: if not content: raise RuntimeError("获取LLM生成内容失败") - return content, ( - reasoning_content, - model_info.name, - self._convert_tool_calls(tool_calls) if tool_calls else None, - ) + return content, (reasoning_content, model_info.name, tool_calls) async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: """获取嵌入向量 @@ -214,10 +211,6 @@ class LLMRequest: client = client_registry.get_client_class(api_provider.client_type)(copy.deepcopy(api_provider)) return model_info, api_provider, client - def _convert_tool_calls(self, tool_calls: List[ToolCall]) -> List[Dict[str, Any]]: - """将ToolCall对象转换为Dict列表""" - pass - async def _execute_request( self, api_provider: APIProvider, @@ -435,6 +428,35 @@ class LLMRequest: ) return -1, None + def _build_tool_options(self, tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]: + """构建工具选项列表""" + if not tools: + return None + tool_options: List[ToolOption] = [] + for tool in tools: + tool_legal = True + tool_options_builder = ToolOptionBuilder() + tool_options_builder.set_name(tool.get("name", "")) + tool_options_builder.set_description(tool.get("description", "")) + parameters: List[Tuple[str, str, str, bool]] = tool.get("parameters", []) + for param in parameters: + try: + tool_options_builder.add_param( + name=param[0], + param_type=ToolParamType(param[1]), + description=param[2], + required=param[3], + ) + except ValueError as ve: + tool_legal = False + logger.error(f"{param[1]} 参数类型错误: {str(ve)}") + except Exception as e: + tool_legal = False + logger.error(f"构建工具参数失败: {str(e)}") + if tool_legal: + tool_options.append(tool_options_builder.build()) + return tool_options or None + @staticmethod def _extract_reasoning(content: str) -> Tuple[str, str]: """CoT思维链提取,向后兼容""" diff --git a/src/plugin_system/base/base_tool.py b/src/plugin_system/base/base_tool.py index 3e21e25a6..5b996d375 100644 --- a/src/plugin_system/base/base_tool.py +++ b/src/plugin_system/base/base_tool.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict +from typing import Any, List, Tuple from rich.traceback import install from src.common.logger import get_logger @@ -17,8 +17,8 @@ class BaseTool(ABC): """工具的名称""" description: str = "" """工具的描述""" - parameters: Dict[str, Any] = {} - """工具的参数定义""" + parameters: List[Tuple[str, str, str, bool]] = [] + """工具的参数定义,为[("param_name", "param_type", "description", required)]""" available_for_llm: bool = False """是否可供LLM使用""" @@ -32,10 +32,7 @@ class BaseTool(ABC): if not cls.name or not cls.description or not cls.parameters: raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性") - return { - "type": "function", - "function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters}, - } + return {"name": cls.name, "description": cls.description, "parameters": cls.parameters} @classmethod def get_tool_info(cls) -> ToolInfo: @@ -79,7 +76,9 @@ class BaseTool(ABC): Returns: dict: 工具执行结果 """ - if self.parameters and (missing := [p for p in self.parameters.get("required", []) if p not in function_args]): - raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {', '.join(missing)}") + parameter_required = [param[0] for param in self.parameters if param[3]] # 获取所有必填参数名 + for param_name in parameter_required: + if param_name not in function_args: + raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {param_name}") return await self.execute(function_args) diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index aeeccde5a..5ed75a7bb 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List, Optional, Tuple from dataclasses import dataclass, field from maim_message import Seg @@ -150,7 +150,7 @@ class CommandInfo(ComponentInfo): class ToolInfo(ComponentInfo): """工具组件信息""" - tool_parameters: Dict[str, Any] = field(default_factory=dict) # 工具参数定义 + tool_parameters: List[Tuple[str, str, str, bool]] = field(default_factory=list) # 工具参数定义 tool_description: str = "" # 工具描述 def __post_init__(self): diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index a220161db..65cceb006 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -1,12 +1,11 @@ -import json import time from typing import List, Dict, Tuple, Optional, Any from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance from src.plugin_system.core.global_announcement_manager import global_announcement_manager from src.llm_models.utils_model import LLMRequest +from src.llm_models.payload_content import ToolCall from src.config.config import global_config, model_config from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.chat.utils.json_utils import process_llm_tool_calls from src.chat.message_receive.chat_stream import get_chat_manager from src.common.logger import get_logger @@ -63,7 +62,7 @@ class ToolExecutor: async def execute_from_chat_message( self, target_message: str, chat_history: str, sender: str, return_details: bool = False - ) -> Tuple[List[Dict], List[str], str]: + ) -> Tuple[List[Dict[str, Any]], List[str], str]: """从聊天消息执行工具 Args: @@ -110,15 +109,9 @@ class ToolExecutor: logger.debug(f"{self.log_prefix}开始LLM工具调用分析") # 调用LLM进行工具决策 - response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools) - - # TODO: 在APIADA加入后完全修复这里! - # 解析LLM响应 - if len(other_info) == 3: - reasoning_content, model_name, tool_calls = other_info - else: - reasoning_content, model_name = other_info - tool_calls = None + response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async( + prompt=prompt, tools=tools + ) # 执行工具调用 tool_results, used_tools = await self._execute_tool_calls(tool_calls) @@ -138,9 +131,9 @@ class ToolExecutor: def _get_tool_definitions(self) -> List[Dict[str, Any]]: all_tools = get_llm_available_tool_definitions() user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id) - return [parameters for name, parameters in all_tools if name not in user_disabled_tools] + return [definition for name, definition in all_tools if name not in user_disabled_tools] - async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]: + async def _execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]: """执行工具调用 Args: @@ -149,32 +142,19 @@ class ToolExecutor: Returns: Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表) """ - tool_results = [] + tool_results: List[Dict[str, Any]] = [] used_tools = [] if not tool_calls: logger.debug(f"{self.log_prefix}无需执行工具") - return tool_results, used_tools + return [], [] logger.info(f"{self.log_prefix}开始执行工具调用: {tool_calls}") - # 处理工具调用 - success, valid_tool_calls, error_msg = process_llm_tool_calls(tool_calls) - - if not success: - logger.error(f"{self.log_prefix}工具调用解析失败: {error_msg}") - return tool_results, used_tools - - if not valid_tool_calls: - logger.debug(f"{self.log_prefix}无有效工具调用") - return tool_results, used_tools - # 执行每个工具调用 - for tool_call in valid_tool_calls: + for tool_call in tool_calls: try: - tool_name = tool_call.get("name", "unknown_tool") - used_tools.append(tool_name) - + tool_name = tool_call.func_name logger.debug(f"{self.log_prefix}执行工具: {tool_name}") # 执行工具 @@ -188,15 +168,15 @@ class ToolExecutor: "tool_name": tool_name, "timestamp": time.time(), } - tool_results.append(tool_info) - - logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}") content = tool_info["content"] if not isinstance(content, (str, list, tuple)): - content = str(content) + tool_info["content"] = str(content) + + tool_results.append(tool_info) + used_tools.append(tool_name) + logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}") preview = content[:200] logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...") - except Exception as e: logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}") # 添加错误信息到结果中 @@ -211,7 +191,7 @@ class ToolExecutor: return tool_results, used_tools - async def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Optional[Dict]: + async def _execute_tool_call(self, tool_call: ToolCall) -> Optional[Dict[str, Any]]: # sourcery skip: use-assigned-variable """执行单个工具调用 @@ -222,8 +202,8 @@ class ToolExecutor: Optional[Dict]: 工具调用结果,如果失败则返回None """ try: - function_name = tool_call["function"]["name"] - function_args = json.loads(tool_call["function"]["arguments"]) + function_name = tool_call.func_name + function_args = tool_call.args or {} function_args["llm_called"] = True # 标记为LLM调用 # 获取对应工具实例 @@ -235,20 +215,17 @@ class ToolExecutor: # 执行工具 result = await tool_instance.execute(function_args) if result: - # 直接使用 function_name 作为 tool_type - tool_type = function_name - return { - "tool_call_id": tool_call["id"], + "tool_call_id": tool_call.call_id, "role": "tool", "name": function_name, - "type": tool_type, + "type": "function", "content": result["content"], } return None except Exception as e: logger.error(f"执行工具调用时发生错误: {str(e)}") - return None + raise e def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str: """生成缓存键 @@ -317,9 +294,7 @@ class ToolExecutor: if expired_keys: logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存") - async def execute_specific_tool( - self, tool_name: str, tool_args: Dict, validate_args: bool = True - ) -> Optional[Dict]: + async def execute_specific_tool(self, tool_name: str, tool_args: Dict) -> Optional[Dict]: """直接执行指定工具 Args: @@ -331,7 +306,11 @@ class ToolExecutor: Optional[Dict]: 工具执行结果,失败时返回None """ try: - tool_call = {"name": tool_name, "arguments": tool_args} + tool_call = ToolCall( + call_id=f"direct_tool_{time.time()}", + func_name=tool_name, + args=tool_args, + ) logger.info(f"{self.log_prefix}直接执行工具: {tool_name}") diff --git a/src/plugins/built_in/knowledge/get_knowledge.py b/src/plugins/built_in/knowledge/get_knowledge.py index 4e662235a..54f93cddf 100644 --- a/src/plugins/built_in/knowledge/get_knowledge.py +++ b/src/plugins/built_in/knowledge/get_knowledge.py @@ -14,14 +14,10 @@ class SearchKnowledgeTool(BaseTool): name = "search_knowledge" description = "使用工具从知识库中搜索相关信息" - parameters = { - "type": "object", - "properties": { - "query": {"type": "string", "description": "搜索查询关键词"}, - "threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"}, - }, - "required": ["query"], - } + parameters = [ + ("query", "string", "搜索查询关键词", True), + ("threshold", "float", "相似度阈值,0.0到1.0之间", False), + ] async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: """执行知识库搜索 diff --git a/src/plugins/built_in/knowledge/lpmm_get_knowledge.py b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py index 0c8a32d78..ef74add92 100644 --- a/src/plugins/built_in/knowledge/lpmm_get_knowledge.py +++ b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py @@ -14,14 +14,10 @@ class SearchKnowledgeFromLPMMTool(BaseTool): name = "lpmm_search_knowledge" description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具" - parameters = { - "type": "object", - "properties": { - "query": {"type": "string", "description": "搜索查询关键词"}, - "threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"}, - }, - "required": ["query"], - } + parameters = [ + ("query", "string", "搜索查询关键词", True), + ("threshold", "float", "相似度阈值,0.0到1.0之间", False), + ] async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: """执行知识库搜索