From 74ae472005745abb222802f8cd8dd8b0a2e393b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sun, 17 Aug 2025 21:06:25 +0800 Subject: [PATCH] Add ToolCache class for tool result caching Introduces a ToolCache class to manage caching of tool invocation results with support for both exact and approximate (similarity-based) query matching. Includes methods for cache retrieval, storage, expiration, cleanup, and statistics. This helps improve efficiency by reusing previous results and reducing redundant tool executions. Co-Authored-By: tt-P607 <68868379+tt-P607@users.noreply.github.com> --- src/common/cache_manager.py | 344 ++++++++++++++++++++++++++++++++++++ 1 file changed, 344 insertions(+) create mode 100644 src/common/cache_manager.py diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py new file mode 100644 index 000000000..ecaff3458 --- /dev/null +++ b/src/common/cache_manager.py @@ -0,0 +1,344 @@ +import json +import hashlib +import re +from typing import Any, Dict, Optional +from datetime import datetime, timedelta +from pathlib import Path +from difflib import SequenceMatcher + +from src.common.logger import get_logger + +logger = get_logger("cache_manager") + + +class ToolCache: + """工具缓存管理器,用于缓存工具调用结果,支持近似匹配""" + + def __init__( + self, + cache_dir: str = "data/tool_cache", + max_age_hours: int = 24, + similarity_threshold: float = 0.65, + ): + """ + 初始化缓存管理器 + + Args: + cache_dir: 缓存目录路径 + max_age_hours: 缓存最大存活时间(小时) + similarity_threshold: 近似匹配的相似度阈值 (0-1) + """ + self.cache_dir = Path(cache_dir) + self.max_age = timedelta(hours=max_age_hours) + self.max_age_seconds = max_age_hours * 3600 + self.similarity_threshold = similarity_threshold + self.cache_dir.mkdir(parents=True, exist_ok=True) + + @staticmethod + def _normalize_query(query: str) -> str: + """ + 标准化查询文本,用于相似度比较 + + Args: + query: 原始查询文本 + + Returns: + 标准化后的查询文本 + """ + if not query: + return "" + + # 纯 Python 实现 + normalized = query.lower() + normalized = re.sub(r"[^\w\s]", " ", normalized) + normalized = " ".join(normalized.split()) + return normalized + + def _calculate_similarity(self, text1: str, text2: str) -> float: + """ + 计算两个文本的相似度 + + Args: + text1: 文本1 + text2: 文本2 + + Returns: + 相似度分数 (0-1) + """ + if not text1 or not text2: + return 0.0 + + # 纯 Python 实现 + norm_text1 = self._normalize_query(text1) + norm_text2 = self._normalize_query(text2) + + if norm_text1 == norm_text2: + return 1.0 + + return SequenceMatcher(None, norm_text1, norm_text2).ratio() + + @staticmethod + def _generate_cache_key(tool_name: str, function_args: Dict[str, Any]) -> str: + """ + 生成缓存键 + + Args: + tool_name: 工具名称 + function_args: 函数参数 + + Returns: + 缓存键字符串 + """ + # 将参数排序后序列化,确保相同参数产生相同的键 + sorted_args = json.dumps(function_args, sort_keys=True, ensure_ascii=False) + + # 纯 Python 实现 + cache_string = f"{tool_name}:{sorted_args}" + return hashlib.md5(cache_string.encode("utf-8")).hexdigest() + + def _get_cache_file_path(self, cache_key: str) -> Path: + """获取缓存文件路径""" + return self.cache_dir / f"{cache_key}.json" + + def _is_cache_expired(self, cached_time: datetime) -> bool: + """检查缓存是否过期""" + return datetime.now() - cached_time > self.max_age + + def _find_similar_cache( + self, tool_name: str, function_args: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + """ + 查找相似的缓存条目 + + Args: + tool_name: 工具名称 + function_args: 函数参数 + + Returns: + 相似的缓存结果,如果不存在则返回None + """ + query = function_args.get("query", "") + if not query: + return None + + candidates = [] + cache_data_list = [] + + # 遍历所有缓存文件,收集候选项 + for cache_file in self.cache_dir.glob("*.json"): + try: + with open(cache_file, "r", encoding="utf-8") as f: + cache_data = json.load(f) + + # 检查是否是同一个工具 + if cache_data.get("tool_name") != tool_name: + continue + + # 检查缓存是否过期 + cached_time = datetime.fromisoformat(cache_data["timestamp"]) + if self._is_cache_expired(cached_time): + continue + + # 检查其他参数是否匹配(除了query) + cached_args = cache_data.get("function_args", {}) + args_match = True + for key, value in function_args.items(): + if key != "query" and cached_args.get(key) != value: + args_match = False + break + + if not args_match: + continue + + # 收集候选项 + cached_query = cached_args.get("query", "") + candidates.append((cached_query, len(cache_data_list))) + cache_data_list.append(cache_data) + + except Exception as e: + logger.warning(f"检查缓存文件时出错: {cache_file}, 错误: {e}") + continue + + if not candidates: + logger.debug( + f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}" + ) + return None + + # 纯 Python 实现 + best_match = None + best_similarity = 0.0 + + for cached_query, index in candidates: + similarity = self._calculate_similarity(query, cached_query) + if similarity > best_similarity and similarity >= self.similarity_threshold: + best_similarity = similarity + best_match = cache_data_list[index] + + if best_match is not None: + cached_query = best_match["function_args"].get("query", "") + logger.info( + f"相似缓存命中,相似度: {best_similarity:.2f}, 原查询: '{cached_query}', 当前查询: '{query}'" + ) + return best_match["result"] + + logger.debug( + f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}" + ) + return None + + def get( + self, tool_name: str, function_args: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + """ + 从缓存获取结果,支持精确匹配和近似匹配 + + Args: + tool_name: 工具名称 + function_args: 函数参数 + + Returns: + 缓存的结果,如果不存在或已过期则返回None + """ + # 首先尝试精确匹配 + cache_key = self._generate_cache_key(tool_name, function_args) + cache_file = self._get_cache_file_path(cache_key) + + if cache_file.exists(): + try: + with open(cache_file, "r", encoding="utf-8") as f: + cache_data = json.load(f) + + # 检查缓存是否过期 + cached_time = datetime.fromisoformat(cache_data["timestamp"]) + if self._is_cache_expired(cached_time): + logger.debug(f"缓存已过期: {cache_key}") + cache_file.unlink() # 删除过期缓存 + else: + logger.debug(f"精确匹配缓存: {tool_name}") + return cache_data["result"] + + except (json.JSONDecodeError, KeyError, ValueError) as e: + logger.warning(f"读取缓存文件失败: {cache_file}, 错误: {e}") + # 删除损坏的缓存文件 + if cache_file.exists(): + cache_file.unlink() + + # 如果精确匹配失败,尝试近似匹配 + return self._find_similar_cache(tool_name, function_args) + + def set( + self, tool_name: str, function_args: Dict[str, Any], result: Dict[str, Any] + ) -> None: + """ + 将结果保存到缓存 + + Args: + tool_name: 工具名称 + function_args: 函数参数 + result: 缓存结果 + """ + cache_key = self._generate_cache_key(tool_name, function_args) + cache_file = self._get_cache_file_path(cache_key) + + cache_data = { + "tool_name": tool_name, + "function_args": function_args, + "result": result, + "timestamp": datetime.now().isoformat(), + } + + try: + with open(cache_file, "w", encoding="utf-8") as f: + json.dump(cache_data, f, ensure_ascii=False, indent=2) + logger.debug(f"缓存已保存: {tool_name} -> {cache_key}") + except Exception as e: + logger.error(f"保存缓存失败: {cache_file}, 错误: {e}") + + def clear_expired(self) -> int: + """ + 清理过期缓存 + + Returns: + 删除的文件数量 + """ + removed_count = 0 + + for cache_file in self.cache_dir.glob("*.json"): + try: + with open(cache_file, "r", encoding="utf-8") as f: + cache_data = json.load(f) + + cached_time = datetime.fromisoformat(cache_data["timestamp"]) + if self._is_cache_expired(cached_time): + cache_file.unlink() + removed_count += 1 + logger.debug(f"删除过期缓存: {cache_file}") + + except Exception as e: + logger.warning(f"清理缓存文件时出错: {cache_file}, 错误: {e}") + # 删除损坏的文件 + try: + cache_file.unlink() + removed_count += 1 + except (OSError, json.JSONDecodeError, KeyError, ValueError): + logger.warning(f"删除损坏的缓存文件失败: {cache_file}, 错误: {e}") + + logger.info(f"清理完成,删除了 {removed_count} 个过期缓存文件") + return removed_count + + def clear_all(self) -> int: + """ + 清空所有缓存 + + Returns: + 删除的文件数量 + """ + removed_count = 0 + + for cache_file in self.cache_dir.glob("*.json"): + try: + cache_file.unlink() + removed_count += 1 + except Exception as e: + logger.warning(f"删除缓存文件失败: {cache_file}, 错误: {e}") + + logger.info(f"清空缓存完成,删除了 {removed_count} 个文件") + return removed_count + + def get_stats(self) -> Dict[str, Any]: + """ + 获取缓存统计信息 + + Returns: + 缓存统计信息字典 + """ + total_files = 0 + expired_files = 0 + total_size = 0 + + for cache_file in self.cache_dir.glob("*.json"): + try: + total_files += 1 + total_size += cache_file.stat().st_size + + with open(cache_file, "r", encoding="utf-8") as f: + cache_data = json.load(f) + + cached_time = datetime.fromisoformat(cache_data["timestamp"]) + if self._is_cache_expired(cached_time): + expired_files += 1 + + except (OSError, json.JSONDecodeError, KeyError, ValueError): + expired_files += 1 # 损坏的文件也算作过期 + + return { + "total_files": total_files, + "expired_files": expired_files, + "total_size_bytes": total_size, + "cache_dir": str(self.cache_dir), + "max_age_hours": self.max_age.total_seconds() / 3600, + "similarity_threshold": self.similarity_threshold, + } + +tool_cache = ToolCache() \ No newline at end of file