feat(tool_system): implement declarative caching for tools

This commit refactors the tool caching system to be more robust, configurable, and easier to use. The caching logic is centralized within the `wrap_tool_executor`, removing the need for boilerplate code within individual tool implementations.

Key changes:
- Adds `enable_cache`, `cache_ttl`, and `semantic_cache_query_key` attributes to `BaseTool` for declarative cache configuration.
- Moves caching logic from a simple history-based lookup and individual tools into a unified handling process in `wrap_tool_executor`.
- The new system leverages the central `tool_cache` manager for both exact and semantic caching based on tool configuration.
- Refactors `WebSurfingTool` and `URLParserTool` to utilize the new declarative caching mechanism, simplifying their code.
This commit is contained in:
minecraft1024a
2025-08-27 18:45:59 +08:00
parent 1fd1c76a84
commit af17290595
5 changed files with 208 additions and 70 deletions

View File

@@ -4,9 +4,11 @@ from datetime import datetime
from typing import Any, Dict, List, Optional, Union
import json
from pathlib import Path
import inspect
from .logger import get_logger
from src.config.config import global_config
from src.common.cache_manager import tool_cache
logger = get_logger("tool_history")
@@ -113,34 +115,6 @@ class ToolHistoryManager:
except Exception as e:
logger.error(f"记录工具调用时发生错误: {e}")
def find_cached_result(self, tool_name: str, args: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""查找匹配的缓存记录
Args:
tool_name: 工具名称
args: 工具调用参数
Returns:
Optional[Dict[str, Any]]: 如果找到匹配的缓存记录则返回结果否则返回None
"""
# 检查是否启用历史记录
if not global_config.tool.history.enable_history:
return None
# 清理输入参数中的敏感信息以便比较
sanitized_input_args = self._sanitize_args(args)
# 按时间倒序遍历历史记录
for record in reversed(self._history):
if (record["tool_name"] == tool_name and
record["status"] == "completed" and
record["ttl_count"] < record.get("ttl", 5)):
# 比较参数是否匹配
if self._sanitize_args(record["arguments"]) == sanitized_input_args:
logger.info(f"工具 {tool_name} 命中缓存记录")
return record["result"]
return None
def _sanitize_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""清理参数中的敏感信息"""
sensitive_keys = ['api_key', 'token', 'password', 'secret']
@@ -327,27 +301,78 @@ class ToolHistoryManager:
def wrap_tool_executor():
"""
包装工具执行器以添加历史记录功能
包装工具执行器以添加历史记录和缓存功能
这个函数应该在系统启动时被调用一次
"""
from src.plugin_system.core.tool_use import ToolExecutor
from src.plugin_system.apis.tool_api import get_tool_instance
original_execute = ToolExecutor.execute_tool_call
history_manager = ToolHistoryManager()
async def wrapped_execute_tool_call(self, tool_call, tool_instance=None):
start_time = time.time()
# 确保我们有 tool_instance
if not tool_instance:
tool_instance = get_tool_instance(tool_call.func_name)
# 首先检查缓存
if cached_result := history_manager.find_cached_result(tool_call.func_name, tool_call.args):
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
return cached_result
# 如果没有 tool_instance就无法进行缓存检查直接执行
if not tool_instance:
result = await original_execute(self, tool_call, None)
execution_time = time.time() - start_time
history_manager.record_tool_call(
tool_name=tool_call.func_name,
args=tool_call.args,
result=result,
execution_time=execution_time,
status="completed",
chat_id=getattr(self, 'chat_id', None),
ttl=5 # Default TTL
)
return result
# 新的缓存逻辑
if tool_instance.enable_cache:
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
cached_result = await tool_cache.get(
tool_name=tool_call.func_name,
function_args=tool_call.args,
tool_file_path=tool_file_path,
semantic_query=semantic_query
)
if cached_result:
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
return cached_result
except Exception as e:
logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}")
try:
result = await original_execute(self, tool_call, tool_instance)
execution_time = time.time() - start_time
# 获取工具的ttl值
ttl = getattr(tool_instance, 'history_ttl', 5) if tool_instance else 5
# 缓存结果
if tool_instance.enable_cache:
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
await tool_cache.set(
tool_name=tool_call.func_name,
function_args=tool_call.args,
tool_file_path=tool_file_path,
data=result,
ttl=tool_instance.cache_ttl,
semantic_query=semantic_query
)
except Exception as e:
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
# 记录成功的调用
history_manager.record_tool_call(
@@ -357,16 +382,13 @@ def wrap_tool_executor():
execution_time=execution_time,
status="completed",
chat_id=getattr(self, 'chat_id', None),
ttl=ttl
ttl=tool_instance.history_ttl
)
return result
except Exception as e:
execution_time = time.time() - start_time
# 获取工具的ttl值
ttl = getattr(tool_instance, 'history_ttl', 5) if tool_instance else 5
# 记录失败的调用
history_manager.record_tool_call(
tool_name=tool_call.func_name,
@@ -375,7 +397,7 @@ def wrap_tool_executor():
execution_time=execution_time,
status="error",
chat_id=getattr(self, 'chat_id', None),
ttl=ttl
ttl=tool_instance.history_ttl
)
raise