From 6b53560a7ed2e3ae9f80e4763f15959a3015675e Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Wed, 27 Aug 2025 18:45:59 +0800 Subject: [PATCH] feat(tool_system): implement declarative caching for tools This commit refactors the tool caching system to be more robust, configurable, and easier to use. The caching logic is centralized within the `wrap_tool_executor`, removing the need for boilerplate code within individual tool implementations. Key changes: - Adds `enable_cache`, `cache_ttl`, and `semantic_cache_query_key` attributes to `BaseTool` for declarative cache configuration. - Moves caching logic from a simple history-based lookup and individual tools into a unified handling process in `wrap_tool_executor`. - The new system leverages the central `tool_cache` manager for both exact and semantic caching based on tool configuration. - Refactors `WebSurfingTool` and `URLParserTool` to utilize the new declarative caching mechanism, simplifying their code. --- docs/plugins/tool_caching_guide.md | 124 ++++++++++++++++++ src/common/tool_history.py | 102 ++++++++------ src/plugin_system/base/base_tool.py | 7 + .../web_search_tool/tools/url_parser.py | 25 ++-- .../web_search_tool/tools/web_search.py | 20 +-- 5 files changed, 208 insertions(+), 70 deletions(-) create mode 100644 docs/plugins/tool_caching_guide.md diff --git a/docs/plugins/tool_caching_guide.md b/docs/plugins/tool_caching_guide.md new file mode 100644 index 000000000..d670a9f1a --- /dev/null +++ b/docs/plugins/tool_caching_guide.md @@ -0,0 +1,124 @@ +# 自动化工具缓存系统使用指南 + +为了提升性能并减少不必要的重复计算或API调用,MMC内置了一套强大且易于使用的自动化工具缓存系统。该系统同时支持传统的**精确缓存**和先进的**语义缓存**。工具开发者无需编写任何手动缓存逻辑,只需在工具类中设置几个属性,即可轻松启用和配置缓存行为。 + +## 核心概念 + +- **精确缓存 (KV Cache)**: 当一个工具被调用时,系统会根据工具名称和所有参数生成一个唯一的键。只有当**下一次调用的工具名和所有参数与之前完全一致**时,才会命中缓存。 +- **语义缓存 (Vector Cache)**: 它不要求参数完全一致,而是理解参数的**语义和意图**。例如,`"查询深圳今天的天气"` 和 `"今天深圳天气怎么样"` 这两个不同的查询,在语义上是高度相似的。如果启用了语义缓存,第二个查询就能成功命中由第一个查询产生的缓存结果。 + +## 如何为你的工具启用缓存 + +为你的工具(必须继承自 `BaseTool`)启用缓存非常简单,只需在你的工具类定义中添加以下一个或多个属性即可: + +### 1. `enable_cache: bool` + +这是启用缓存的总开关。 + +- **类型**: `bool` +- **默认值**: `False` +- **作用**: 设置为 `True` 即可为该工具启用缓存功能。如果为 `False`,后续的所有缓存配置都将无效。 + +**示例**: +```python +class MyAwesomeTool(BaseTool): + # ... 其他定义 ... + enable_cache: bool = True +``` + +### 2. `cache_ttl: int` + +设置缓存的生存时间(Time-To-Live)。 + +- **类型**: `int` +- **单位**: 秒 +- **默认值**: `3600` (1小时) +- **作用**: 定义缓存条目在被视为过期之前可以存活多长时间。 + +**示例**: +```python +class MyLongTermCacheTool(BaseTool): + # ... 其他定义 ... + enable_cache: bool = True + cache_ttl: int = 86400 # 缓存24小时 +``` + +### 3. `semantic_cache_query_key: Optional[str]` + +启用语义缓存的关键。 + +- **类型**: `Optional[str]` +- **默认值**: `None` +- **作用**: + - 将此属性的值设置为你工具的某个**参数的名称**(字符串)。 + - 自动化缓存系统在工作时,会提取该参数的值,将其转换为向量,并进行语义相似度搜索。 + - 如果该值为 `None`,则此工具**仅使用精确缓存**。 + +**示例**: +```python +class WebSurfingTool(BaseTool): + name: str = "web_search" + parameters = [ + ("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None), + # ... 其他参数 ... + ] + + # --- 缓存配置 --- + enable_cache: bool = True + cache_ttl: int = 7200 # 缓存2小时 + semantic_cache_query_key: str = "query" # <-- 关键! +``` +在上面的例子中,`web_search` 工具的 `"query"` 参数值(例如,用户输入的搜索词)将被用于语义缓存搜索。 + +## 完整示例 + +假设我们有一个调用外部API来获取股票价格的工具。由于股价在短时间内相对稳定,且查询意图可能相似(如 "苹果股价" vs "AAPL股价"),因此非常适合使用缓存。 + +```python +# in your_plugin/tools/stock_checker.py + +from src.plugin_system import BaseTool, ToolParamType + +class StockCheckerTool(BaseTool): + """ + 一个用于查询股票价格的工具。 + """ + name: str = "get_stock_price" + description: str = "获取指定公司或股票代码的最新价格。" + available_for_llm: bool = True + parameters = [ + ("symbol", ToolParamType.STRING, "公司名称或股票代码 (e.g., 'AAPL', '苹果')", True, None), + ] + + # --- 缓存配置 --- + # 1. 开启缓存 + enable_cache: bool = True + # 2. 股价信息缓存10分钟 + cache_ttl: int = 600 + # 3. 使用 "symbol" 参数进行语义搜索 + semantic_cache_query_key: str = "symbol" + # -------------------- + + async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: + symbol = function_args.get("symbol") + + # ... 这里是你调用外部API获取股票价格的逻辑 ... + # price = await some_stock_api.get_price(symbol) + price = 123.45 # 示例价格 + + return { + "type": "stock_price_result", + "content": f"{symbol} 的当前价格是 ${price}" + } + +``` + +通过以上简单的三行配置,`StockCheckerTool` 现在就拥有了强大的自动化缓存能力: + +- 当用户查询 `"苹果"` 时,工具会执行并缓存结果。 +- 在接下来的10分钟内,如果再次查询 `"苹果"`,将直接从精确缓存返回结果。 +- 更智能的是,如果另一个用户查询 `"AAPL"`,语义缓存系统会识别出 `"AAPL"` 和 `"苹果"` 在语义上高度相关,大概率也会直接返回缓存的结果,而无需再次调用API。 + +--- + +现在,你可以专注于实现工具的核心逻辑,把缓存的复杂性交给MMC的自动化系统来处理。 \ No newline at end of file diff --git a/src/common/tool_history.py b/src/common/tool_history.py index 0f76f1a68..b3edb12ce 100644 --- a/src/common/tool_history.py +++ b/src/common/tool_history.py @@ -4,9 +4,11 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Union import json from pathlib import Path +import inspect from .logger import get_logger from src.config.config import global_config +from src.common.cache_manager import tool_cache logger = get_logger("tool_history") @@ -113,34 +115,6 @@ class ToolHistoryManager: except Exception as e: logger.error(f"记录工具调用时发生错误: {e}") - def find_cached_result(self, tool_name: str, args: Dict[str, Any]) -> Optional[Dict[str, Any]]: - """查找匹配的缓存记录 - - Args: - tool_name: 工具名称 - args: 工具调用参数 - - Returns: - Optional[Dict[str, Any]]: 如果找到匹配的缓存记录则返回结果,否则返回None - """ - # 检查是否启用历史记录 - if not global_config.tool.history.enable_history: - return None - - # 清理输入参数中的敏感信息以便比较 - sanitized_input_args = self._sanitize_args(args) - - # 按时间倒序遍历历史记录 - for record in reversed(self._history): - if (record["tool_name"] == tool_name and - record["status"] == "completed" and - record["ttl_count"] < record.get("ttl", 5)): - # 比较参数是否匹配 - if self._sanitize_args(record["arguments"]) == sanitized_input_args: - logger.info(f"工具 {tool_name} 命中缓存记录") - return record["result"] - return None - def _sanitize_args(self, args: Dict[str, Any]) -> Dict[str, Any]: """清理参数中的敏感信息""" sensitive_keys = ['api_key', 'token', 'password', 'secret'] @@ -327,27 +301,78 @@ class ToolHistoryManager: def wrap_tool_executor(): """ - 包装工具执行器以添加历史记录功能 + 包装工具执行器以添加历史记录和缓存功能 这个函数应该在系统启动时被调用一次 """ from src.plugin_system.core.tool_use import ToolExecutor + from src.plugin_system.apis.tool_api import get_tool_instance original_execute = ToolExecutor.execute_tool_call history_manager = ToolHistoryManager() async def wrapped_execute_tool_call(self, tool_call, tool_instance=None): start_time = time.time() + + # 确保我们有 tool_instance + if not tool_instance: + tool_instance = get_tool_instance(tool_call.func_name) - # 首先检查缓存 - if cached_result := history_manager.find_cached_result(tool_call.func_name, tool_call.args): - logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行") - return cached_result + # 如果没有 tool_instance,就无法进行缓存检查,直接执行 + if not tool_instance: + result = await original_execute(self, tool_call, None) + execution_time = time.time() - start_time + history_manager.record_tool_call( + tool_name=tool_call.func_name, + args=tool_call.args, + result=result, + execution_time=execution_time, + status="completed", + chat_id=getattr(self, 'chat_id', None), + ttl=5 # Default TTL + ) + return result + + # 新的缓存逻辑 + if tool_instance.enable_cache: + try: + tool_file_path = inspect.getfile(tool_instance.__class__) + semantic_query = None + if tool_instance.semantic_cache_query_key: + semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key) + + cached_result = await tool_cache.get( + tool_name=tool_call.func_name, + function_args=tool_call.args, + tool_file_path=tool_file_path, + semantic_query=semantic_query + ) + if cached_result: + logger.info(f"{self.log_prefix}使用缓存结果,跳过工具 {tool_call.func_name} 执行") + return cached_result + except Exception as e: + logger.error(f"{self.log_prefix}检查工具缓存时出错: {e}") try: result = await original_execute(self, tool_call, tool_instance) execution_time = time.time() - start_time - # 获取工具的ttl值 - ttl = getattr(tool_instance, 'history_ttl', 5) if tool_instance else 5 + # 缓存结果 + if tool_instance.enable_cache: + try: + tool_file_path = inspect.getfile(tool_instance.__class__) + semantic_query = None + if tool_instance.semantic_cache_query_key: + semantic_query = tool_call.args.get(tool_instance.semantic_cache_query_key) + + await tool_cache.set( + tool_name=tool_call.func_name, + function_args=tool_call.args, + tool_file_path=tool_file_path, + data=result, + ttl=tool_instance.cache_ttl, + semantic_query=semantic_query + ) + except Exception as e: + logger.error(f"{self.log_prefix}设置工具缓存时出错: {e}") # 记录成功的调用 history_manager.record_tool_call( @@ -357,16 +382,13 @@ def wrap_tool_executor(): execution_time=execution_time, status="completed", chat_id=getattr(self, 'chat_id', None), - ttl=ttl + ttl=tool_instance.history_ttl ) return result except Exception as e: execution_time = time.time() - start_time - # 获取工具的ttl值 - ttl = getattr(tool_instance, 'history_ttl', 5) if tool_instance else 5 - # 记录失败的调用 history_manager.record_tool_call( tool_name=tool_call.func_name, @@ -375,7 +397,7 @@ def wrap_tool_executor(): execution_time=execution_time, status="error", chat_id=getattr(self, 'chat_id', None), - ttl=ttl + ttl=tool_instance.history_ttl ) raise diff --git a/src/plugin_system/base/base_tool.py b/src/plugin_system/base/base_tool.py index 4d7e6280d..b5022ea2a 100644 --- a/src/plugin_system/base/base_tool.py +++ b/src/plugin_system/base/base_tool.py @@ -31,6 +31,13 @@ class BaseTool(ABC): history_ttl: int = 5 """工具调用历史记录的TTL值,默认为5。设为0表示不记录历史""" + enable_cache: bool = False + """是否为该工具启用缓存""" + cache_ttl: int = 3600 + """缓存的TTL值(秒),默认为3600秒(1小时)""" + semantic_cache_query_key: Optional[str] = None + """用于语义缓存的查询参数键名。如果设置,将使用此参数的值进行语义相似度搜索""" + def __init__(self, plugin_config: Optional[dict] = None): self.plugin_config = plugin_config or {} # 直接存储插件配置字典 diff --git a/src/plugins/built_in/web_search_tool/tools/url_parser.py b/src/plugins/built_in/web_search_tool/tools/url_parser.py index 315e06271..c9a7e9fd9 100644 --- a/src/plugins/built_in/web_search_tool/tools/url_parser.py +++ b/src/plugins/built_in/web_search_tool/tools/url_parser.py @@ -30,6 +30,12 @@ class URLParserTool(BaseTool): parameters = [ ("urls", ToolParamType.STRING, "要理解的网站", True, None), ] + + # --- 新的缓存配置 --- + enable_cache: bool = True + cache_ttl: int = 86400 # 缓存24小时 + semantic_cache_query_key: str = "urls" + # -------------------- def __init__(self, plugin_config=None): super().__init__(plugin_config) @@ -42,10 +48,11 @@ class URLParserTool(BaseTool): if exa_api_keys is None: # 从插件配置文件读取 exa_api_keys = self.get_config("exa.api_keys", []) - + # 创建API密钥管理器 + from typing import cast, List self.api_manager = create_api_key_manager_from_config( - exa_api_keys, + cast(List[str], exa_api_keys), lambda key: Exa(api_key=key), "Exa URL Parser" ) @@ -135,16 +142,6 @@ class URLParserTool(BaseTool): """ 执行URL内容提取和总结。优先使用Exa,失败后尝试本地解析。 """ - # 获取当前文件路径用于缓存键 - import os - current_file_path = os.path.abspath(__file__) - - # 检查缓存 - cached_result = await tool_cache.get(self.name, function_args, current_file_path) - if cached_result: - logger.info(f"缓存命中: {self.name} -> {function_args}") - return cached_result - urls_input = function_args.get("urls") if not urls_input: return {"error": "URL列表不能为空。"} @@ -235,8 +232,4 @@ class URLParserTool(BaseTool): "errors": error_messages } - # 保存到缓存 - if "error" not in result: - await tool_cache.set(self.name, function_args, current_file_path, result) - return result diff --git a/src/plugins/built_in/web_search_tool/tools/web_search.py b/src/plugins/built_in/web_search_tool/tools/web_search.py index c09ad5e92..149965d06 100644 --- a/src/plugins/built_in/web_search_tool/tools/web_search.py +++ b/src/plugins/built_in/web_search_tool/tools/web_search.py @@ -31,6 +31,12 @@ class WebSurfingTool(BaseTool): ("time_range", ToolParamType.STRING, "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'。", False, ["any", "week", "month"]) ] # type: ignore + # --- 新的缓存配置 --- + enable_cache: bool = True + cache_ttl: int = 7200 # 缓存2小时 + semantic_cache_query_key: str = "query" + # -------------------- + def __init__(self, plugin_config=None): super().__init__(plugin_config) # 初始化搜索引擎 @@ -46,16 +52,6 @@ class WebSurfingTool(BaseTool): if not query: return {"error": "搜索查询不能为空。"} - # 获取当前文件路径用于缓存键 - import os - current_file_path = os.path.abspath(__file__) - - # 检查缓存 - cached_result = await tool_cache.get(self.name, function_args, current_file_path, semantic_query=query) - if cached_result: - logger.info(f"缓存命中: {self.name} -> {function_args}") - return cached_result - # 读取搜索配置 enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"]) search_strategy = config_api.get_global_config("web_search.search_strategy", "single") @@ -69,10 +65,6 @@ class WebSurfingTool(BaseTool): result = await self._execute_fallback_search(function_args, enabled_engines) else: # single result = await self._execute_single_search(function_args, enabled_engines) - - # 保存到缓存 - if "error" not in result: - await tool_cache.set(self.name, function_args, current_file_path, result, semantic_query=query) return result