tools系统

This commit is contained in:
UnCLAS-Prommer
2025-07-31 11:41:15 +08:00
parent 483c8fb547
commit 37e52a1566
10 changed files with 95 additions and 323 deletions

View File

@@ -412,7 +412,7 @@ class DefaultReplyer:
for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "")
result_type = tool_result.get("type", "info")
result_type = tool_result.get("type", "tool_result")
tool_info_str += f"- 【{tool_name}{result_type}: {content}\n"
@@ -848,7 +848,7 @@ class DefaultReplyer:
raw_reply: str,
reason: str,
reply_to: str,
) -> str: # sourcery skip: remove-redundant-if
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
chat_stream = self.chat_stream
chat_id = chat_stream.stream_id
is_group_chat = bool(chat_stream.group_info)

View File

@@ -1,223 +0,0 @@
import ast
import json
import logging
from typing import Any, Dict, TypeVar, List, Union, Tuple, Optional
# 定义类型变量用于泛型类型提示
T = TypeVar("T")
# 获取logger
logger = logging.getLogger("json_utils")
def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
"""
安全地解析JSON字符串出错时返回默认值
现在尝试处理单引号和标准JSON
参数:
json_str: 要解析的JSON字符串
default_value: 解析失败时返回的默认值
返回:
解析后的Python对象或在解析失败时返回default_value
"""
if not json_str or not isinstance(json_str, str):
logger.warning(f"safe_json_loads 接收到非字符串输入: {type(json_str)}, 值: {json_str}")
return default_value
try:
# 尝试标准的 JSON 解析
return json.loads(json_str)
except json.JSONDecodeError:
# 如果标准解析失败,尝试用 ast.literal_eval 解析
try:
# logger.debug(f"标准JSON解析失败尝试用 ast.literal_eval 解析: {json_str[:100]}...")
result = ast.literal_eval(json_str)
if isinstance(result, dict):
return result
logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
return default_value
except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e:
logger.error(f"使用 ast.literal_eval 解析失败: {ast_e}, 字符串: {json_str[:100]}...")
return default_value
except Exception as e:
logger.error(f"使用 ast.literal_eval 解析时发生意外错误: {e}, 字符串: {json_str[:100]}...")
return default_value
except Exception as e:
logger.error(f"JSON解析过程中发生意外错误: {e}, 字符串: {json_str[:100]}...")
return default_value
def extract_tool_call_arguments(
tool_call: Dict[str, Any], default_value: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
从LLM工具调用对象中提取参数
参数:
tool_call: 工具调用对象字典
default_value: 解析失败时返回的默认值
返回:
解析后的参数字典或在解析失败时返回default_value
"""
default_result = default_value or {}
if not tool_call or not isinstance(tool_call, dict):
logger.error(f"无效的工具调用对象: {tool_call}")
return default_result
try:
# 提取function参数
function_data = tool_call.get("function", {})
if not function_data or not isinstance(function_data, dict):
logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}")
return default_result
if arguments_str := function_data.get("arguments", "{}"):
# 解析JSON
return safe_json_loads(arguments_str, default_result)
else:
return default_result
except Exception as e:
logger.error(f"提取工具调用参数时出错: {e}")
return default_result
def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False, pretty: bool = False) -> str:
"""
安全地将Python对象序列化为JSON字符串
参数:
obj: 要序列化的Python对象
default_value: 序列化失败时返回的默认值
ensure_ascii: 是否确保ASCII编码(默认False允许中文等非ASCII字符)
pretty: 是否美化输出JSON
返回:
序列化后的JSON字符串或在序列化失败时返回default_value
"""
try:
indent = 2 if pretty else None
return json.dumps(obj, ensure_ascii=ensure_ascii, indent=indent)
except TypeError as e:
logger.error(f"JSON序列化失败(类型错误): {e}")
return default_value
except Exception as e:
logger.error(f"JSON序列化过程中发生意外错误: {e}")
return default_value
def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, List[Any], str]:
"""
标准化LLM响应格式将各种格式如元组转换为统一的列表格式
参数:
response: 原始LLM响应
log_prefix: 日志前缀
返回:
元组 (成功标志, 标准化后的响应列表, 错误消息)
"""
logger.debug(f"{log_prefix}原始人 LLM响应: {response}")
# 检查是否为None
if response is None:
return False, [], "LLM响应为None"
# 记录原始类型
logger.debug(f"{log_prefix}LLM响应原始类型: {type(response).__name__}")
# 将元组转换为列表
if isinstance(response, tuple):
logger.debug(f"{log_prefix}将元组响应转换为列表")
response = list(response)
# 确保是列表类型
if not isinstance(response, list):
return False, [], f"无法处理的LLM响应类型: {type(response).__name__}"
# 处理工具调用部分(如果存在)
if len(response) == 3:
content, reasoning, tool_calls = response
# 将工具调用部分转换为列表(如果是元组)
if isinstance(tool_calls, tuple):
logger.debug(f"{log_prefix}将工具调用元组转换为列表")
tool_calls = list(tool_calls)
response[2] = tool_calls
return True, response, ""
def process_llm_tool_calls(
tool_calls: List[Dict[str, Any]], log_prefix: str = ""
) -> Tuple[bool, List[Dict[str, Any]], str]:
"""
处理并验证LLM响应中的工具调用列表
参数:
tool_calls: 从LLM响应中直接获取的工具调用列表
log_prefix: 日志前缀
返回:
元组 (成功标志, 验证后的工具调用列表, 错误消息)
"""
# 如果列表为空,表示没有工具调用,这不是错误
if not tool_calls:
return True, [], "工具调用列表为空"
# 验证每个工具调用的格式
valid_tool_calls = []
for i, tool_call in enumerate(tool_calls):
if not isinstance(tool_call, dict):
logger.warning(f"{log_prefix}工具调用[{i}]不是字典: {type(tool_call).__name__}, 内容: {tool_call}")
continue
# 检查基本结构
if tool_call.get("type") != "function":
logger.warning(
f"{log_prefix}工具调用[{i}]不是function类型: type={tool_call.get('type', '未定义')}, 内容: {tool_call}"
)
continue
if "function" not in tool_call or not isinstance(tool_call.get("function"), dict):
logger.warning(f"{log_prefix}工具调用[{i}]缺少'function'字段或其类型不正确: {tool_call}")
continue
func_details = tool_call["function"]
if "name" not in func_details or not isinstance(func_details.get("name"), str):
logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'name'或类型不正确: {func_details}")
continue
# 验证参数 'arguments'
args_value = func_details.get("arguments")
# 1. 检查 arguments 是否存在且是字符串
if args_value is None or not isinstance(args_value, str):
logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'arguments'字符串: {func_details}")
continue
# 2. 尝试安全地解析 arguments 字符串
parsed_args = safe_json_loads(args_value, None)
# 3. 检查解析结果是否为字典
if parsed_args is None or not isinstance(parsed_args, dict):
logger.warning(
f"{log_prefix}工具调用[{i}]的'arguments'无法解析为有效的JSON字典, "
f"原始字符串: {args_value[:100]}..., 解析结果类型: {type(parsed_args).__name__}"
)
continue
# 如果检查通过,将原始的 tool_call 加入有效列表
valid_tool_calls.append(tool_call)
if not valid_tool_calls and tool_calls: # 如果原始列表不为空,但验证后为空
return False, [], "所有工具调用格式均无效"
return True, valid_tool_calls, ""

View File

@@ -1,4 +1,4 @@
raise DeprecationWarning("Genimi Client is not fully available yet.")
raise DeprecationWarning("Genimi Client is not fully available yet. Please remove your Gemini API Provider")
import asyncio
import io
from collections.abc import Iterable

View File

@@ -0,0 +1,3 @@
from .tool_option import ToolCall
__all__ = ["ToolCall"]

View File

@@ -11,7 +11,7 @@ from src.config.config import model_config
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
from .payload_content.message import MessageBuilder, Message
from .payload_content.resp_format import RespFormat
from .payload_content.tool_option import ToolOption, ToolCall
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType
from .model_client.base_client import BaseClient, APIResponse, client_registry
from .utils import compress_messages, llm_usage_recorder
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
@@ -60,7 +60,7 @@ class LLMRequest:
image_format: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]:
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
"""
为图像生成响应
Args:
@@ -68,7 +68,7 @@ class LLMRequest:
image_base64 (str): 图像的Base64编码字符串
image_format (str): 图像格式(如 'png', 'jpeg' 等)
Returns:
(Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
# 请求体构建
message_builder = MessageBuilder()
@@ -104,18 +104,18 @@ class LLMRequest:
request_type=self.request_type,
endpoint="/chat/completions",
)
return content, (
reasoning_content,
model_info.name,
self._convert_tool_calls(tool_calls) if tool_calls else None,
)
return content, (reasoning_content, model_info.name, tool_calls)
async def generate_response_for_voice(self):
pass
async def generate_response_async(
self, prompt: str, temperature: Optional[float] = None, max_tokens: Optional[int] = None
) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]:
self,
prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
tools: Optional[List[Dict[str, Any]]] = None,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
"""
异步生成响应
Args:
@@ -123,13 +123,13 @@ class LLMRequest:
temperature (float, optional): 温度参数
max_tokens (int, optional): 最大token数
Returns:
(Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
# 请求体构建
message_builder = MessageBuilder()
message_builder.add_text_content(prompt)
messages = [message_builder.build()]
tool_built = self._build_tool_options(tools)
# 模型选择
model_info, api_provider, client = self._select_model()
@@ -142,6 +142,7 @@ class LLMRequest:
message_list=messages,
temperature=temperature,
max_tokens=max_tokens,
tool_options=tool_built,
)
content = response.content
reasoning_content = response.reasoning_content or ""
@@ -161,11 +162,7 @@ class LLMRequest:
if not content:
raise RuntimeError("获取LLM生成内容失败")
return content, (
reasoning_content,
model_info.name,
self._convert_tool_calls(tool_calls) if tool_calls else None,
)
return content, (reasoning_content, model_info.name, tool_calls)
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
"""获取嵌入向量
@@ -214,10 +211,6 @@ class LLMRequest:
client = client_registry.get_client_class(api_provider.client_type)(copy.deepcopy(api_provider))
return model_info, api_provider, client
def _convert_tool_calls(self, tool_calls: List[ToolCall]) -> List[Dict[str, Any]]:
"""将ToolCall对象转换为Dict列表"""
pass
async def _execute_request(
self,
api_provider: APIProvider,
@@ -435,6 +428,35 @@ class LLMRequest:
)
return -1, None
def _build_tool_options(self, tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
"""构建工具选项列表"""
if not tools:
return None
tool_options: List[ToolOption] = []
for tool in tools:
tool_legal = True
tool_options_builder = ToolOptionBuilder()
tool_options_builder.set_name(tool.get("name", ""))
tool_options_builder.set_description(tool.get("description", ""))
parameters: List[Tuple[str, str, str, bool]] = tool.get("parameters", [])
for param in parameters:
try:
tool_options_builder.add_param(
name=param[0],
param_type=ToolParamType(param[1]),
description=param[2],
required=param[3],
)
except ValueError as ve:
tool_legal = False
logger.error(f"{param[1]} 参数类型错误: {str(ve)}")
except Exception as e:
tool_legal = False
logger.error(f"构建工具参数失败: {str(e)}")
if tool_legal:
tool_options.append(tool_options_builder.build())
return tool_options or None
@staticmethod
def _extract_reasoning(content: str) -> Tuple[str, str]:
"""CoT思维链提取向后兼容"""

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict
from typing import Any, List, Tuple
from rich.traceback import install
from src.common.logger import get_logger
@@ -17,8 +17,8 @@ class BaseTool(ABC):
"""工具的名称"""
description: str = ""
"""工具的描述"""
parameters: Dict[str, Any] = {}
"""工具的参数定义"""
parameters: List[Tuple[str, str, str, bool]] = []
"""工具的参数定义,为[("param_name", "param_type", "description", required)]"""
available_for_llm: bool = False
"""是否可供LLM使用"""
@@ -32,10 +32,7 @@ class BaseTool(ABC):
if not cls.name or not cls.description or not cls.parameters:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return {
"type": "function",
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
}
return {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
@classmethod
def get_tool_info(cls) -> ToolInfo:
@@ -79,7 +76,9 @@ class BaseTool(ABC):
Returns:
dict: 工具执行结果
"""
if self.parameters and (missing := [p for p in self.parameters.get("required", []) if p not in function_args]):
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {', '.join(missing)}")
parameter_required = [param[0] for param in self.parameters if param[3]] # 获取所有必填参数名
for param_name in parameter_required:
if param_name not in function_args:
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {param_name}")
return await self.execute(function_args)

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Dict, Any, List, Optional
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass, field
from maim_message import Seg
@@ -150,7 +150,7 @@ class CommandInfo(ComponentInfo):
class ToolInfo(ComponentInfo):
"""工具组件信息"""
tool_parameters: Dict[str, Any] = field(default_factory=dict) # 工具参数定义
tool_parameters: List[Tuple[str, str, str, bool]] = field(default_factory=list) # 工具参数定义
tool_description: str = "" # 工具描述
def __post_init__(self):

View File

@@ -1,12 +1,11 @@
import json
import time
from typing import List, Dict, Tuple, Optional, Any
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content import ToolCall
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.json_utils import process_llm_tool_calls
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
@@ -63,7 +62,7 @@ class ToolExecutor:
async def execute_from_chat_message(
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
) -> Tuple[List[Dict], List[str], str]:
) -> Tuple[List[Dict[str, Any]], List[str], str]:
"""从聊天消息执行工具
Args:
@@ -110,15 +109,9 @@ class ToolExecutor:
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
# 调用LLM进行工具决策
response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools)
# TODO: 在APIADA加入后完全修复这里
# 解析LLM响应
if len(other_info) == 3:
reasoning_content, model_name, tool_calls = other_info
else:
reasoning_content, model_name = other_info
tool_calls = None
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
prompt=prompt, tools=tools
)
# 执行工具调用
tool_results, used_tools = await self._execute_tool_calls(tool_calls)
@@ -138,9 +131,9 @@ class ToolExecutor:
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
all_tools = get_llm_available_tool_definitions()
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
return [parameters for name, parameters in all_tools if name not in user_disabled_tools]
return [definition for name, definition in all_tools if name not in user_disabled_tools]
async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]:
async def _execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
"""执行工具调用
Args:
@@ -149,32 +142,19 @@ class ToolExecutor:
Returns:
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
"""
tool_results = []
tool_results: List[Dict[str, Any]] = []
used_tools = []
if not tool_calls:
logger.debug(f"{self.log_prefix}无需执行工具")
return tool_results, used_tools
return [], []
logger.info(f"{self.log_prefix}开始执行工具调用: {tool_calls}")
# 处理工具调用
success, valid_tool_calls, error_msg = process_llm_tool_calls(tool_calls)
if not success:
logger.error(f"{self.log_prefix}工具调用解析失败: {error_msg}")
return tool_results, used_tools
if not valid_tool_calls:
logger.debug(f"{self.log_prefix}无有效工具调用")
return tool_results, used_tools
# 执行每个工具调用
for tool_call in valid_tool_calls:
for tool_call in tool_calls:
try:
tool_name = tool_call.get("name", "unknown_tool")
used_tools.append(tool_name)
tool_name = tool_call.func_name
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
# 执行工具
@@ -188,15 +168,15 @@ class ToolExecutor:
"tool_name": tool_name,
"timestamp": time.time(),
}
tool_results.append(tool_info)
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
content = tool_info["content"]
if not isinstance(content, (str, list, tuple)):
content = str(content)
tool_info["content"] = str(content)
tool_results.append(tool_info)
used_tools.append(tool_name)
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
preview = content[:200]
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
except Exception as e:
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
# 添加错误信息到结果中
@@ -211,7 +191,7 @@ class ToolExecutor:
return tool_results, used_tools
async def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Optional[Dict]:
async def _execute_tool_call(self, tool_call: ToolCall) -> Optional[Dict[str, Any]]:
# sourcery skip: use-assigned-variable
"""执行单个工具调用
@@ -222,8 +202,8 @@ class ToolExecutor:
Optional[Dict]: 工具调用结果如果失败则返回None
"""
try:
function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"])
function_name = tool_call.func_name
function_args = tool_call.args or {}
function_args["llm_called"] = True # 标记为LLM调用
# 获取对应工具实例
@@ -235,20 +215,17 @@ class ToolExecutor:
# 执行工具
result = await tool_instance.execute(function_args)
if result:
# 直接使用 function_name 作为 tool_type
tool_type = function_name
return {
"tool_call_id": tool_call["id"],
"tool_call_id": tool_call.call_id,
"role": "tool",
"name": function_name,
"type": tool_type,
"type": "function",
"content": result["content"],
}
return None
except Exception as e:
logger.error(f"执行工具调用时发生错误: {str(e)}")
return None
raise e
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
"""生成缓存键
@@ -317,9 +294,7 @@ class ToolExecutor:
if expired_keys:
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
async def execute_specific_tool(
self, tool_name: str, tool_args: Dict, validate_args: bool = True
) -> Optional[Dict]:
async def execute_specific_tool(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
"""直接执行指定工具
Args:
@@ -331,7 +306,11 @@ class ToolExecutor:
Optional[Dict]: 工具执行结果失败时返回None
"""
try:
tool_call = {"name": tool_name, "arguments": tool_args}
tool_call = ToolCall(
call_id=f"direct_tool_{time.time()}",
func_name=tool_name,
args=tool_args,
)
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")

View File

@@ -14,14 +14,10 @@ class SearchKnowledgeTool(BaseTool):
name = "search_knowledge"
description = "使用工具从知识库中搜索相关信息"
parameters = {
"type": "object",
"properties": {
"query": {"type": "string", "description": "搜索查询关键词"},
"threshold": {"type": "number", "description": "相似度阈值0.0到1.0之间"},
},
"required": ["query"],
}
parameters = [
("query", "string", "搜索查询关键词", True),
("threshold", "float", "相似度阈值0.0到1.0之间", False),
]
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行知识库搜索

View File

@@ -14,14 +14,10 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
name = "lpmm_search_knowledge"
description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具"
parameters = {
"type": "object",
"properties": {
"query": {"type": "string", "description": "搜索查询关键词"},
"threshold": {"type": "number", "description": "相似度阈值0.0到1.0之间"},
},
"required": ["query"],
}
parameters = [
("query", "string", "搜索查询关键词", True),
("threshold", "float", "相似度阈值0.0到1.0之间", False),
]
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
"""执行知识库搜索