feat(tool_system): implement declarative caching for tools

This commit refactors the tool caching system to be more robust, configurable, and easier to use. The caching logic is centralized within the `wrap_tool_executor`, removing the need for boilerplate code within individual tool implementations.

Key changes:
- Adds `enable_cache`, `cache_ttl`, and `semantic_cache_query_key` attributes to `BaseTool` for declarative cache configuration.
- Moves caching logic from a simple history-based lookup and individual tools into a unified handling process in `wrap_tool_executor`.
- The new system leverages the central `tool_cache` manager for both exact and semantic caching based on tool configuration.
- Refactors `WebSurfingTool` and `URLParserTool` to utilize the new declarative caching mechanism, simplifying their code.
This commit is contained in:
minecraft1024a
2025-08-27 18:45:59 +08:00
committed by Windpicker-owo
parent 12bcde800e
commit 6b53560a7e
5 changed files with 208 additions and 70 deletions

View File

@@ -0,0 +1,124 @@
# 自动化工具缓存系统使用指南
为了提升性能并减少不必要的重复计算或API调用MMC内置了一套强大且易于使用的自动化工具缓存系统。该系统同时支持传统的**精确缓存**和先进的**语义缓存**。工具开发者无需编写任何手动缓存逻辑,只需在工具类中设置几个属性,即可轻松启用和配置缓存行为。
## 核心概念
- **精确缓存 (KV Cache)**: 当一个工具被调用时,系统会根据工具名称和所有参数生成一个唯一的键。只有当**下一次调用的工具名和所有参数与之前完全一致**时,才会命中缓存。
- **语义缓存 (Vector Cache)**: 它不要求参数完全一致,而是理解参数的**语义和意图**。例如,`"查询深圳今天的天气"``"今天深圳天气怎么样"` 这两个不同的查询,在语义上是高度相似的。如果启用了语义缓存,第二个查询就能成功命中由第一个查询产生的缓存结果。
## 如何为你的工具启用缓存
为你的工具(必须继承自 `BaseTool`)启用缓存非常简单,只需在你的工具类定义中添加以下一个或多个属性即可:
### 1. `enable_cache: bool`
这是启用缓存的总开关。
- **类型**: `bool`
- **默认值**: `False`
- **作用**: 设置为 `True` 即可为该工具启用缓存功能。如果为 `False`,后续的所有缓存配置都将无效。
**示例**:
```python
class MyAwesomeTool(BaseTool):
# ... 其他定义 ...
enable_cache: bool = True
```
### 2. `cache_ttl: int`
设置缓存的生存时间Time-To-Live
- **类型**: `int`
- **单位**: 秒
- **默认值**: `3600` (1小时)
- **作用**: 定义缓存条目在被视为过期之前可以存活多长时间。
**示例**:
```python
class MyLongTermCacheTool(BaseTool):
# ... 其他定义 ...
enable_cache: bool = True
cache_ttl: int = 86400 # 缓存24小时
```
### 3. `semantic_cache_query_key: Optional[str]`
启用语义缓存的关键。
- **类型**: `Optional[str]`
- **默认值**: `None`
- **作用**:
- 将此属性的值设置为你工具的某个**参数的名称**(字符串)。
- 自动化缓存系统在工作时,会提取该参数的值,将其转换为向量,并进行语义相似度搜索。
- 如果该值为 `None`,则此工具**仅使用精确缓存**。
**示例**:
```python
class WebSurfingTool(BaseTool):
name: str = "web_search"
parameters = [
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
# ... 其他参数 ...
]
# --- 缓存配置 ---
enable_cache: bool = True
cache_ttl: int = 7200 # 缓存2小时
semantic_cache_query_key: str = "query" # <-- 关键!
```
在上面的例子中,`web_search` 工具的 `"query"` 参数值(例如,用户输入的搜索词)将被用于语义缓存搜索。
## 完整示例
假设我们有一个调用外部API来获取股票价格的工具。由于股价在短时间内相对稳定且查询意图可能相似如 "苹果股价" vs "AAPL股价"),因此非常适合使用缓存。
```python
# in your_plugin/tools/stock_checker.py
from src.plugin_system import BaseTool, ToolParamType
class StockCheckerTool(BaseTool):
"""
一个用于查询股票价格的工具。
"""
name: str = "get_stock_price"
description: str = "获取指定公司或股票代码的最新价格。"
available_for_llm: bool = True
parameters = [
("symbol", ToolParamType.STRING, "公司名称或股票代码 (e.g., 'AAPL', '苹果')", True, None),
]
# --- 缓存配置 ---
# 1. 开启缓存
enable_cache: bool = True
# 2. 股价信息缓存10分钟
cache_ttl: int = 600
# 3. 使用 "symbol" 参数进行语义搜索
semantic_cache_query_key: str = "symbol"
# --------------------
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
symbol = function_args.get("symbol")
# ... 这里是你调用外部API获取股票价格的逻辑 ...
# price = await some_stock_api.get_price(symbol)
price = 123.45 # 示例价格
return {
"type": "stock_price_result",
"content": f"{symbol} 的当前价格是 ${price}"
}
```
通过以上简单的三行配置,`StockCheckerTool` 现在就拥有了强大的自动化缓存能力:
- 当用户查询 `"苹果"` 时,工具会执行并缓存结果。
- 在接下来的10分钟内如果再次查询 `"苹果"`,将直接从精确缓存返回结果。
- 更智能的是,如果另一个用户查询 `"AAPL"`,语义缓存系统会识别出 `"AAPL"``"苹果"` 在语义上高度相关大概率也会直接返回缓存的结果而无需再次调用API。
---
现在你可以专注于实现工具的核心逻辑把缓存的复杂性交给MMC的自动化系统来处理。

View File

@@ -4,9 +4,11 @@ from datetime import datetime
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import json import json
from pathlib import Path from pathlib import Path
import inspect
from .logger import get_logger from .logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.common.cache_manager import tool_cache
logger = get_logger("tool_history") logger = get_logger("tool_history")
@@ -113,34 +115,6 @@ class ToolHistoryManager:
except Exception as e: except Exception as e:
logger.error(f"记录工具调用时发生错误: {e}") logger.error(f"记录工具调用时发生错误: {e}")
def find_cached_result(self, tool_name: str, args: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""查找匹配的缓存记录
Args:
tool_name: 工具名称
args: 工具调用参数
Returns:
Optional[Dict[str, Any]]: 如果找到匹配的缓存记录则返回结果否则返回None
"""
# 检查是否启用历史记录
if not global_config.tool.history.enable_history:
return None
# 清理输入参数中的敏感信息以便比较
sanitized_input_args = self._sanitize_args(args)
# 按时间倒序遍历历史记录
for record in reversed(self._history):
if (record["tool_name"] == tool_name and
record["status"] == "completed" and
record["ttl_count"] < record.get("ttl", 5)):
# 比较参数是否匹配
if self._sanitize_args(record["arguments"]) == sanitized_input_args:
logger.info(f"工具 {tool_name} 命中缓存记录")
return record["result"]
return None
def _sanitize_args(self, args: Dict[str, Any]) -> Dict[str, Any]: def _sanitize_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""清理参数中的敏感信息""" """清理参数中的敏感信息"""
sensitive_keys = ['api_key', 'token', 'password', 'secret'] sensitive_keys = ['api_key', 'token', 'password', 'secret']
@@ -327,27 +301,78 @@ class ToolHistoryManager:
def wrap_tool_executor(): def wrap_tool_executor():
""" """
包装工具执行器以添加历史记录功能 包装工具执行器以添加历史记录和缓存功能
这个函数应该在系统启动时被调用一次 这个函数应该在系统启动时被调用一次
""" """
from src.plugin_system.core.tool_use import ToolExecutor from src.plugin_system.core.tool_use import ToolExecutor
from src.plugin_system.apis.tool_api import get_tool_instance
original_execute = ToolExecutor.execute_tool_call original_execute = ToolExecutor.execute_tool_call
history_manager = ToolHistoryManager() history_manager = ToolHistoryManager()
async def wrapped_execute_tool_call(self, tool_call, tool_instance=None): async def wrapped_execute_tool_call(self, tool_call, tool_instance=None):
start_time = time.time() start_time = time.time()
# 首先检查缓存 # 确保我们有 tool_instance
if cached_result := history_manager.find_cached_result(tool_call.func_name, tool_call.args): if not tool_instance:
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行") tool_instance = get_tool_instance(tool_call.func_name)
return cached_result
# 如果没有 tool_instance就无法进行缓存检查直接执行
if not tool_instance:
result = await original_execute(self, tool_call, None)
execution_time = time.time() - start_time
history_manager.record_tool_call(
tool_name=tool_call.func_name,
args=tool_call.args,
result=result,
execution_time=execution_time,
status="completed",
chat_id=getattr(self, 'chat_id', None),
ttl=5 # Default TTL
)
return result
# 新的缓存逻辑
if tool_instance.enable_cache:
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
cached_result = await tool_cache.get(
tool_name=tool_call.func_name,
function_args=tool_call.args,
tool_file_path=tool_file_path,
semantic_query=semantic_query
)
if cached_result:
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行")
return cached_result
except Exception as e:
logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}")
try: try:
result = await original_execute(self, tool_call, tool_instance) result = await original_execute(self, tool_call, tool_instance)
execution_time = time.time() - start_time execution_time = time.time() - start_time
# 获取工具的ttl值 # 缓存结果
ttl = getattr(tool_instance, 'history_ttl', 5) if tool_instance else 5 if tool_instance.enable_cache:
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key)
await tool_cache.set(
tool_name=tool_call.func_name,
function_args=tool_call.args,
tool_file_path=tool_file_path,
data=result,
ttl=tool_instance.cache_ttl,
semantic_query=semantic_query
)
except Exception as e:
logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}")
# 记录成功的调用 # 记录成功的调用
history_manager.record_tool_call( history_manager.record_tool_call(
@@ -357,16 +382,13 @@ def wrap_tool_executor():
execution_time=execution_time, execution_time=execution_time,
status="completed", status="completed",
chat_id=getattr(self, 'chat_id', None), chat_id=getattr(self, 'chat_id', None),
ttl=ttl ttl=tool_instance.history_ttl
) )
return result return result
except Exception as e: except Exception as e:
execution_time = time.time() - start_time execution_time = time.time() - start_time
# 获取工具的ttl值
ttl = getattr(tool_instance, 'history_ttl', 5) if tool_instance else 5
# 记录失败的调用 # 记录失败的调用
history_manager.record_tool_call( history_manager.record_tool_call(
tool_name=tool_call.func_name, tool_name=tool_call.func_name,
@@ -375,7 +397,7 @@ def wrap_tool_executor():
execution_time=execution_time, execution_time=execution_time,
status="error", status="error",
chat_id=getattr(self, 'chat_id', None), chat_id=getattr(self, 'chat_id', None),
ttl=ttl ttl=tool_instance.history_ttl
) )
raise raise

View File

@@ -31,6 +31,13 @@ class BaseTool(ABC):
history_ttl: int = 5 history_ttl: int = 5
"""工具调用历史记录的TTL值默认为5。设为0表示不记录历史""" """工具调用历史记录的TTL值默认为5。设为0表示不记录历史"""
enable_cache: bool = False
"""是否为该工具启用缓存"""
cache_ttl: int = 3600
"""缓存的TTL值默认为3600秒1小时"""
semantic_cache_query_key: Optional[str] = None
"""用于语义缓存的查询参数键名。如果设置,将使用此参数的值进行语义相似度搜索"""
def __init__(self, plugin_config: Optional[dict] = None): def __init__(self, plugin_config: Optional[dict] = None):
self.plugin_config = plugin_config or {} # 直接存储插件配置字典 self.plugin_config = plugin_config or {} # 直接存储插件配置字典

View File

@@ -31,6 +31,12 @@ class URLParserTool(BaseTool):
("urls", ToolParamType.STRING, "要理解的网站", True, None), ("urls", ToolParamType.STRING, "要理解的网站", True, None),
] ]
# --- 新的缓存配置 ---
enable_cache: bool = True
cache_ttl: int = 86400 # 缓存24小时
semantic_cache_query_key: str = "urls"
# --------------------
def __init__(self, plugin_config=None): def __init__(self, plugin_config=None):
super().__init__(plugin_config) super().__init__(plugin_config)
self._initialize_exa_clients() self._initialize_exa_clients()
@@ -44,8 +50,9 @@ class URLParserTool(BaseTool):
exa_api_keys = self.get_config("exa.api_keys", []) exa_api_keys = self.get_config("exa.api_keys", [])
# 创建API密钥管理器 # 创建API密钥管理器
from typing import cast, List
self.api_manager = create_api_key_manager_from_config( self.api_manager = create_api_key_manager_from_config(
exa_api_keys, cast(List[str], exa_api_keys),
lambda key: Exa(api_key=key), lambda key: Exa(api_key=key),
"Exa URL Parser" "Exa URL Parser"
) )
@@ -135,16 +142,6 @@ class URLParserTool(BaseTool):
""" """
执行URL内容提取和总结。优先使用Exa失败后尝试本地解析。 执行URL内容提取和总结。优先使用Exa失败后尝试本地解析。
""" """
# 获取当前文件路径用于缓存键
import os
current_file_path = os.path.abspath(__file__)
# 检查缓存
cached_result = await tool_cache.get(self.name, function_args, current_file_path)
if cached_result:
logger.info(f"缓存命中: {self.name} -> {function_args}")
return cached_result
urls_input = function_args.get("urls") urls_input = function_args.get("urls")
if not urls_input: if not urls_input:
return {"error": "URL列表不能为空。"} return {"error": "URL列表不能为空。"}
@@ -235,8 +232,4 @@ class URLParserTool(BaseTool):
"errors": error_messages "errors": error_messages
} }
# 保存到缓存
if "error" not in result:
await tool_cache.set(self.name, function_args, current_file_path, result)
return result return result

View File

@@ -31,6 +31,12 @@ class WebSurfingTool(BaseTool):
("time_range", ToolParamType.STRING, "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'", False, ["any", "week", "month"]) ("time_range", ToolParamType.STRING, "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'", False, ["any", "week", "month"])
] # type: ignore ] # type: ignore
# --- 新的缓存配置 ---
enable_cache: bool = True
cache_ttl: int = 7200 # 缓存2小时
semantic_cache_query_key: str = "query"
# --------------------
def __init__(self, plugin_config=None): def __init__(self, plugin_config=None):
super().__init__(plugin_config) super().__init__(plugin_config)
# 初始化搜索引擎 # 初始化搜索引擎
@@ -46,16 +52,6 @@ class WebSurfingTool(BaseTool):
if not query: if not query:
return {"error": "搜索查询不能为空。"} return {"error": "搜索查询不能为空。"}
# 获取当前文件路径用于缓存键
import os
current_file_path = os.path.abspath(__file__)
# 检查缓存
cached_result = await tool_cache.get(self.name, function_args, current_file_path, semantic_query=query)
if cached_result:
logger.info(f"缓存命中: {self.name} -> {function_args}")
return cached_result
# 读取搜索配置 # 读取搜索配置
enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"]) enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"])
search_strategy = config_api.get_global_config("web_search.search_strategy", "single") search_strategy = config_api.get_global_config("web_search.search_strategy", "single")
@@ -70,10 +66,6 @@ class WebSurfingTool(BaseTool):
else: # single else: # single
result = await self._execute_single_search(function_args, enabled_engines) result = await self._execute_single_search(function_args, enabled_engines)
# 保存到缓存
if "error" not in result:
await tool_cache.set(self.name, function_args, current_file_path, result, semantic_query=query)
return result return result
async def _execute_parallel_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]: async def _execute_parallel_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]: