From 51c0d2a1e86db39ac46e74288d3ceed665458514 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sun, 17 Aug 2025 22:18:26 +0800 Subject: [PATCH] =?UTF-8?q?refactor(cache):=20=E9=87=8D=E6=9E=84=E7=BC=93?= =?UTF-8?q?=E5=AD=98=E7=B3=BB=E7=BB=9F=E4=B8=BA=E5=88=86=E5=B1=82=E8=AF=AD?= =?UTF-8?q?=E4=B9=89=E7=BC=93=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将原有的基于文件的 `ToolCache` 替换为全新的 `CacheManager`,引入了更复杂和高效的分层语义缓存机制。 新系统特性: - **分层缓存**: - L1 缓存: 内存字典 (KV) + FAISS (向量),用于极速访问。 - L2 缓存: SQLite (KV) + ChromaDB (向量),用于持久化存储。 - **语义缓存**: 利用嵌入模型 (Embedding) 对查询进行向量化,实现基于语义相似度的缓存命中,显著提高了缓存命中率。 - **自动失效**: 缓存键包含工具源代码的哈希值,当工具代码更新时,相关缓存会自动失效,避免了脏数据问题。 - **异步支持**: 缓存的 `get` 和 `set` 方法现在是异步的,以适应项目中异步化的工具调用流程。 `web_search_tool` 已更新以使用新的 `CacheManager`,在调用缓存时传递 `tool_class` 和 `semantic_query` 以充分利用新功能。 Co-Authored-By: tt-P607 <68868379+tt-P607@users.noreply.github.com> --- src/common/cache_manager.py | 509 +++++++----------- src/common/cache_manager_backup.py | 344 ++++++++++++ src/main.py | 1 + .../built_in/web_search_tool/plugin.py | 10 +- 4 files changed, 546 insertions(+), 318 deletions(-) create mode 100644 src/common/cache_manager_backup.py diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index ecaff3458..c1141aa66 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -1,344 +1,225 @@ +import time import json +import sqlite3 +import chromadb import hashlib -import re +import inspect +import numpy as np +import faiss from typing import Any, Dict, Optional -from datetime import datetime, timedelta -from pathlib import Path -from difflib import SequenceMatcher - from src.common.logger import get_logger +from src.llm_models.utils_model import LLMRequest +from src.config.api_ada_configs import ModelTaskConfig +from src.config.config import global_config, model_config logger = get_logger("cache_manager") +class CacheManager: + """ + 一个支持分层和语义缓存的通用工具缓存管理器。 + 采用单例模式,确保在整个应用中只有一个缓存实例。 + L1缓存: 内存字典 (KV) + FAISS (Vector)。 + L2缓存: SQLite (KV) + ChromaDB (Vector)。 + """ + _instance = None -class ToolCache: - """工具缓存管理器,用于缓存工具调用结果,支持近似匹配""" + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(CacheManager, cls).__new__(cls) + return cls._instance - def __init__( - self, - cache_dir: str = "data/tool_cache", - max_age_hours: int = 24, - similarity_threshold: float = 0.65, - ): + def __init__(self, default_ttl: int = 3600, db_path: str = "data/cache.db", chroma_path: str = "data/chroma_db"): """ - 初始化缓存管理器 - - Args: - cache_dir: 缓存目录路径 - max_age_hours: 缓存最大存活时间(小时) - similarity_threshold: 近似匹配的相似度阈值 (0-1) + 初始化缓存管理器。 """ - self.cache_dir = Path(cache_dir) - self.max_age = timedelta(hours=max_age_hours) - self.max_age_seconds = max_age_hours * 3600 - self.similarity_threshold = similarity_threshold - self.cache_dir.mkdir(parents=True, exist_ok=True) + if not hasattr(self, '_initialized'): + self.default_ttl = default_ttl + + # L1 缓存 (内存) + self.l1_kv_cache: Dict[str, Dict[str, Any]] = {} + embedding_dim = global_config.lpmm_knowledge.embedding_dimension + self.l1_vector_index = faiss.IndexFlatIP(embedding_dim) + self.l1_vector_id_to_key: Dict[int, str] = {} + + # L2 缓存 (持久化) + self.db_path = db_path + self._init_sqlite() + self.chroma_client = chromadb.PersistentClient(path=chroma_path) + self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache") + + # 嵌入模型 + self.embedding_model = LLMRequest(model_config.model_task_config.embedding) - @staticmethod - def _normalize_query(query: str) -> str: + self._initialized = True + logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (SQLite+ChromaDB)") + + def _init_sqlite(self): + """初始化SQLite数据库和表结构。""" + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS cache ( + key TEXT PRIMARY KEY, + value TEXT, + expires_at REAL + ) + """) + conn.commit() + + def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any) -> str: + """生成确定性的缓存键,包含代码哈希以实现自动失效。""" + try: + source_code = inspect.getsource(tool_class) + code_hash = hashlib.md5(source_code.encode()).hexdigest() + except (TypeError, OSError) as e: + code_hash = "unknown" + logger.warning(f"无法获取 {tool_class.__name__} 的源代码,代码哈希将为 'unknown'。错误: {e}") + try: + sorted_args = json.dumps(function_args, sort_keys=True) + except TypeError: + sorted_args = repr(sorted(function_args.items())) + return f"{tool_name}::{sorted_args}::{code_hash}" + + async def get(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any, semantic_query: Optional[str] = None) -> Optional[Any]: """ - 标准化查询文本,用于相似度比较 - - Args: - query: 原始查询文本 - - Returns: - 标准化后的查询文本 + 从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。 """ - if not query: - return "" + # 步骤 1: L1 精确缓存查询 + key = self._generate_key(tool_name, function_args, tool_class) + logger.debug(f"生成的缓存键: {key}") + if semantic_query: + logger.debug(f"使用的语义查询: '{semantic_query}'") - # 纯 Python 实现 - normalized = query.lower() - normalized = re.sub(r"[^\w\s]", " ", normalized) - normalized = " ".join(normalized.split()) - return normalized + if key in self.l1_kv_cache: + entry = self.l1_kv_cache[key] + if time.time() < entry["expires_at"]: + logger.info(f"命中L1键值缓存: {key}") + return entry["data"] + else: + del self.l1_kv_cache[key] - def _calculate_similarity(self, text1: str, text2: str) -> float: - """ - 计算两个文本的相似度 + # 步骤 2: L1/L2 语义和L2精确缓存查询 + query_embedding = None + if semantic_query and self.embedding_model: + embedding_result = await self.embedding_model.get_embedding(semantic_query) + if embedding_result: + query_embedding = np.array([embedding_result], dtype='float32') - Args: - text1: 文本1 - text2: 文本2 + # 步骤 2a: L1 语义缓存 (FAISS) + if query_embedding is not None and self.l1_vector_index.ntotal > 0: + faiss.normalize_L2(query_embedding) + distances, indices = self.l1_vector_index.search(query_embedding, 1) + if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似 + hit_index = indices[0][0] + l1_hit_key = self.l1_vector_id_to_key.get(hit_index) + if l1_hit_key and l1_hit_key in self.l1_kv_cache: + logger.info(f"命中L1语义缓存: {l1_hit_key}") + return self.l1_kv_cache[l1_hit_key]["data"] - Returns: - 相似度分数 (0-1) - """ - if not text1 or not text2: - return 0.0 + # 步骤 2b: L2 精确缓存 (SQLite) + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (key,)) + row = cursor.fetchone() + if row: + value, expires_at = row + if time.time() < expires_at: + logger.info(f"命中L2键值缓存: {key}") + data = json.loads(value) + # 回填 L1 + self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} + return data + else: + cursor.execute("DELETE FROM cache WHERE key = ?", (key,)) + conn.commit() - # 纯 Python 实现 - norm_text1 = self._normalize_query(text1) - norm_text2 = self._normalize_query(text2) + # 步骤 2c: L2 语义缓存 (ChromaDB) + if query_embedding is not None: + results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1) + if results and results['ids'] and results['ids'][0]: + distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A' + logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}") + if distance != 'N/A' and distance < 0.75: + l2_hit_key = results['ids'][0] + logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}") + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (l2_hit_key if isinstance(l2_hit_key, str) else l2_hit_key[0],)) + row = cursor.fetchone() + if row: + value, expires_at = row + if time.time() < expires_at: + data = json.loads(value) + logger.debug(f"L2语义缓存返回的数据: {data}") + # 回填 L1 + self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} + if query_embedding is not None: + new_id = self.l1_vector_index.ntotal + faiss.normalize_L2(query_embedding) + self.l1_vector_index.add(x=query_embedding) + self.l1_vector_id_to_key[new_id] = key + return data - if norm_text1 == norm_text2: - return 1.0 - - return SequenceMatcher(None, norm_text1, norm_text2).ratio() - - @staticmethod - def _generate_cache_key(tool_name: str, function_args: Dict[str, Any]) -> str: - """ - 生成缓存键 - - Args: - tool_name: 工具名称 - function_args: 函数参数 - - Returns: - 缓存键字符串 - """ - # 将参数排序后序列化,确保相同参数产生相同的键 - sorted_args = json.dumps(function_args, sort_keys=True, ensure_ascii=False) - - # 纯 Python 实现 - cache_string = f"{tool_name}:{sorted_args}" - return hashlib.md5(cache_string.encode("utf-8")).hexdigest() - - def _get_cache_file_path(self, cache_key: str) -> Path: - """获取缓存文件路径""" - return self.cache_dir / f"{cache_key}.json" - - def _is_cache_expired(self, cached_time: datetime) -> bool: - """检查缓存是否过期""" - return datetime.now() - cached_time > self.max_age - - def _find_similar_cache( - self, tool_name: str, function_args: Dict[str, Any] - ) -> Optional[Dict[str, Any]]: - """ - 查找相似的缓存条目 - - Args: - tool_name: 工具名称 - function_args: 函数参数 - - Returns: - 相似的缓存结果,如果不存在则返回None - """ - query = function_args.get("query", "") - if not query: - return None - - candidates = [] - cache_data_list = [] - - # 遍历所有缓存文件,收集候选项 - for cache_file in self.cache_dir.glob("*.json"): - try: - with open(cache_file, "r", encoding="utf-8") as f: - cache_data = json.load(f) - - # 检查是否是同一个工具 - if cache_data.get("tool_name") != tool_name: - continue - - # 检查缓存是否过期 - cached_time = datetime.fromisoformat(cache_data["timestamp"]) - if self._is_cache_expired(cached_time): - continue - - # 检查其他参数是否匹配(除了query) - cached_args = cache_data.get("function_args", {}) - args_match = True - for key, value in function_args.items(): - if key != "query" and cached_args.get(key) != value: - args_match = False - break - - if not args_match: - continue - - # 收集候选项 - cached_query = cached_args.get("query", "") - candidates.append((cached_query, len(cache_data_list))) - cache_data_list.append(cache_data) - - except Exception as e: - logger.warning(f"检查缓存文件时出错: {cache_file}, 错误: {e}") - continue - - if not candidates: - logger.debug( - f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}" - ) - return None - - # 纯 Python 实现 - best_match = None - best_similarity = 0.0 - - for cached_query, index in candidates: - similarity = self._calculate_similarity(query, cached_query) - if similarity > best_similarity and similarity >= self.similarity_threshold: - best_similarity = similarity - best_match = cache_data_list[index] - - if best_match is not None: - cached_query = best_match["function_args"].get("query", "") - logger.info( - f"相似缓存命中,相似度: {best_similarity:.2f}, 原查询: '{cached_query}', 当前查询: '{query}'" - ) - return best_match["result"] - - logger.debug( - f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}" - ) + logger.debug(f"缓存未命中: {key}") return None - def get( - self, tool_name: str, function_args: Dict[str, Any] - ) -> Optional[Dict[str, Any]]: - """ - 从缓存获取结果,支持精确匹配和近似匹配 + async def set(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any, data: Any, ttl: Optional[int] = None, semantic_query: Optional[str] = None): + """将结果存入所有缓存层。""" + if ttl is None: + ttl = self.default_ttl + if ttl <= 0: + return - Args: - tool_name: 工具名称 - function_args: 函数参数 + key = self._generate_key(tool_name, function_args, tool_class) + expires_at = time.time() + ttl + + # 写入 L1 + self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} - Returns: - 缓存的结果,如果不存在或已过期则返回None - """ - # 首先尝试精确匹配 - cache_key = self._generate_cache_key(tool_name, function_args) - cache_file = self._get_cache_file_path(cache_key) + # 写入 L2 + value = json.dumps(data) + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute("REPLACE INTO cache (key, value, expires_at) VALUES (?, ?, ?)", (key, value, expires_at)) + conn.commit() - if cache_file.exists(): - try: - with open(cache_file, "r", encoding="utf-8") as f: - cache_data = json.load(f) + # 写入语义缓存 + if semantic_query and self.embedding_model: + embedding_result = await self.embedding_model.get_embedding(semantic_query) + if embedding_result: + embedding = np.array([embedding_result], dtype='float32') + # 写入 L1 Vector + new_id = self.l1_vector_index.ntotal + faiss.normalize_L2(embedding) + self.l1_vector_index.add(x=embedding) + self.l1_vector_id_to_key[new_id] = key + # 写入 L2 Vector + self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key]) - # 检查缓存是否过期 - cached_time = datetime.fromisoformat(cache_data["timestamp"]) - if self._is_cache_expired(cached_time): - logger.debug(f"缓存已过期: {cache_key}") - cache_file.unlink() # 删除过期缓存 - else: - logger.debug(f"精确匹配缓存: {tool_name}") - return cache_data["result"] + logger.info(f"已缓存条目: {key}, TTL: {ttl}s") - except (json.JSONDecodeError, KeyError, ValueError) as e: - logger.warning(f"读取缓存文件失败: {cache_file}, 错误: {e}") - # 删除损坏的缓存文件 - if cache_file.exists(): - cache_file.unlink() + def clear_l1(self): + """清空L1缓存。""" + self.l1_kv_cache.clear() + self.l1_vector_index.reset() + self.l1_vector_id_to_key.clear() + logger.info("L1 (内存+FAISS) 缓存已清空。") - # 如果精确匹配失败,尝试近似匹配 - return self._find_similar_cache(tool_name, function_args) + def clear_l2(self): + """清空L2缓存。""" + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM cache") + conn.commit() + self.chroma_client.delete_collection(name="semantic_cache") + self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache") + logger.info("L2 (SQLite & ChromaDB) 缓存已清空。") - def set( - self, tool_name: str, function_args: Dict[str, Any], result: Dict[str, Any] - ) -> None: - """ - 将结果保存到缓存 + def clear_all(self): + """清空所有缓存。""" + self.clear_l1() + self.clear_l2() + logger.info("所有缓存层级已清空。") - Args: - tool_name: 工具名称 - function_args: 函数参数 - result: 缓存结果 - """ - cache_key = self._generate_cache_key(tool_name, function_args) - cache_file = self._get_cache_file_path(cache_key) - - cache_data = { - "tool_name": tool_name, - "function_args": function_args, - "result": result, - "timestamp": datetime.now().isoformat(), - } - - try: - with open(cache_file, "w", encoding="utf-8") as f: - json.dump(cache_data, f, ensure_ascii=False, indent=2) - logger.debug(f"缓存已保存: {tool_name} -> {cache_key}") - except Exception as e: - logger.error(f"保存缓存失败: {cache_file}, 错误: {e}") - - def clear_expired(self) -> int: - """ - 清理过期缓存 - - Returns: - 删除的文件数量 - """ - removed_count = 0 - - for cache_file in self.cache_dir.glob("*.json"): - try: - with open(cache_file, "r", encoding="utf-8") as f: - cache_data = json.load(f) - - cached_time = datetime.fromisoformat(cache_data["timestamp"]) - if self._is_cache_expired(cached_time): - cache_file.unlink() - removed_count += 1 - logger.debug(f"删除过期缓存: {cache_file}") - - except Exception as e: - logger.warning(f"清理缓存文件时出错: {cache_file}, 错误: {e}") - # 删除损坏的文件 - try: - cache_file.unlink() - removed_count += 1 - except (OSError, json.JSONDecodeError, KeyError, ValueError): - logger.warning(f"删除损坏的缓存文件失败: {cache_file}, 错误: {e}") - - logger.info(f"清理完成,删除了 {removed_count} 个过期缓存文件") - return removed_count - - def clear_all(self) -> int: - """ - 清空所有缓存 - - Returns: - 删除的文件数量 - """ - removed_count = 0 - - for cache_file in self.cache_dir.glob("*.json"): - try: - cache_file.unlink() - removed_count += 1 - except Exception as e: - logger.warning(f"删除缓存文件失败: {cache_file}, 错误: {e}") - - logger.info(f"清空缓存完成,删除了 {removed_count} 个文件") - return removed_count - - def get_stats(self) -> Dict[str, Any]: - """ - 获取缓存统计信息 - - Returns: - 缓存统计信息字典 - """ - total_files = 0 - expired_files = 0 - total_size = 0 - - for cache_file in self.cache_dir.glob("*.json"): - try: - total_files += 1 - total_size += cache_file.stat().st_size - - with open(cache_file, "r", encoding="utf-8") as f: - cache_data = json.load(f) - - cached_time = datetime.fromisoformat(cache_data["timestamp"]) - if self._is_cache_expired(cached_time): - expired_files += 1 - - except (OSError, json.JSONDecodeError, KeyError, ValueError): - expired_files += 1 # 损坏的文件也算作过期 - - return { - "total_files": total_files, - "expired_files": expired_files, - "total_size_bytes": total_size, - "cache_dir": str(self.cache_dir), - "max_age_hours": self.max_age.total_seconds() / 3600, - "similarity_threshold": self.similarity_threshold, - } - -tool_cache = ToolCache() \ No newline at end of file +# 全局实例 +tool_cache = CacheManager() \ No newline at end of file diff --git a/src/common/cache_manager_backup.py b/src/common/cache_manager_backup.py new file mode 100644 index 000000000..ecaff3458 --- /dev/null +++ b/src/common/cache_manager_backup.py @@ -0,0 +1,344 @@ +import json +import hashlib +import re +from typing import Any, Dict, Optional +from datetime import datetime, timedelta +from pathlib import Path +from difflib import SequenceMatcher + +from src.common.logger import get_logger + +logger = get_logger("cache_manager") + + +class ToolCache: + """工具缓存管理器,用于缓存工具调用结果,支持近似匹配""" + + def __init__( + self, + cache_dir: str = "data/tool_cache", + max_age_hours: int = 24, + similarity_threshold: float = 0.65, + ): + """ + 初始化缓存管理器 + + Args: + cache_dir: 缓存目录路径 + max_age_hours: 缓存最大存活时间(小时) + similarity_threshold: 近似匹配的相似度阈值 (0-1) + """ + self.cache_dir = Path(cache_dir) + self.max_age = timedelta(hours=max_age_hours) + self.max_age_seconds = max_age_hours * 3600 + self.similarity_threshold = similarity_threshold + self.cache_dir.mkdir(parents=True, exist_ok=True) + + @staticmethod + def _normalize_query(query: str) -> str: + """ + 标准化查询文本,用于相似度比较 + + Args: + query: 原始查询文本 + + Returns: + 标准化后的查询文本 + """ + if not query: + return "" + + # 纯 Python 实现 + normalized = query.lower() + normalized = re.sub(r"[^\w\s]", " ", normalized) + normalized = " ".join(normalized.split()) + return normalized + + def _calculate_similarity(self, text1: str, text2: str) -> float: + """ + 计算两个文本的相似度 + + Args: + text1: 文本1 + text2: 文本2 + + Returns: + 相似度分数 (0-1) + """ + if not text1 or not text2: + return 0.0 + + # 纯 Python 实现 + norm_text1 = self._normalize_query(text1) + norm_text2 = self._normalize_query(text2) + + if norm_text1 == norm_text2: + return 1.0 + + return SequenceMatcher(None, norm_text1, norm_text2).ratio() + + @staticmethod + def _generate_cache_key(tool_name: str, function_args: Dict[str, Any]) -> str: + """ + 生成缓存键 + + Args: + tool_name: 工具名称 + function_args: 函数参数 + + Returns: + 缓存键字符串 + """ + # 将参数排序后序列化,确保相同参数产生相同的键 + sorted_args = json.dumps(function_args, sort_keys=True, ensure_ascii=False) + + # 纯 Python 实现 + cache_string = f"{tool_name}:{sorted_args}" + return hashlib.md5(cache_string.encode("utf-8")).hexdigest() + + def _get_cache_file_path(self, cache_key: str) -> Path: + """获取缓存文件路径""" + return self.cache_dir / f"{cache_key}.json" + + def _is_cache_expired(self, cached_time: datetime) -> bool: + """检查缓存是否过期""" + return datetime.now() - cached_time > self.max_age + + def _find_similar_cache( + self, tool_name: str, function_args: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + """ + 查找相似的缓存条目 + + Args: + tool_name: 工具名称 + function_args: 函数参数 + + Returns: + 相似的缓存结果,如果不存在则返回None + """ + query = function_args.get("query", "") + if not query: + return None + + candidates = [] + cache_data_list = [] + + # 遍历所有缓存文件,收集候选项 + for cache_file in self.cache_dir.glob("*.json"): + try: + with open(cache_file, "r", encoding="utf-8") as f: + cache_data = json.load(f) + + # 检查是否是同一个工具 + if cache_data.get("tool_name") != tool_name: + continue + + # 检查缓存是否过期 + cached_time = datetime.fromisoformat(cache_data["timestamp"]) + if self._is_cache_expired(cached_time): + continue + + # 检查其他参数是否匹配(除了query) + cached_args = cache_data.get("function_args", {}) + args_match = True + for key, value in function_args.items(): + if key != "query" and cached_args.get(key) != value: + args_match = False + break + + if not args_match: + continue + + # 收集候选项 + cached_query = cached_args.get("query", "") + candidates.append((cached_query, len(cache_data_list))) + cache_data_list.append(cache_data) + + except Exception as e: + logger.warning(f"检查缓存文件时出错: {cache_file}, 错误: {e}") + continue + + if not candidates: + logger.debug( + f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}" + ) + return None + + # 纯 Python 实现 + best_match = None + best_similarity = 0.0 + + for cached_query, index in candidates: + similarity = self._calculate_similarity(query, cached_query) + if similarity > best_similarity and similarity >= self.similarity_threshold: + best_similarity = similarity + best_match = cache_data_list[index] + + if best_match is not None: + cached_query = best_match["function_args"].get("query", "") + logger.info( + f"相似缓存命中,相似度: {best_similarity:.2f}, 原查询: '{cached_query}', 当前查询: '{query}'" + ) + return best_match["result"] + + logger.debug( + f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}" + ) + return None + + def get( + self, tool_name: str, function_args: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + """ + 从缓存获取结果,支持精确匹配和近似匹配 + + Args: + tool_name: 工具名称 + function_args: 函数参数 + + Returns: + 缓存的结果,如果不存在或已过期则返回None + """ + # 首先尝试精确匹配 + cache_key = self._generate_cache_key(tool_name, function_args) + cache_file = self._get_cache_file_path(cache_key) + + if cache_file.exists(): + try: + with open(cache_file, "r", encoding="utf-8") as f: + cache_data = json.load(f) + + # 检查缓存是否过期 + cached_time = datetime.fromisoformat(cache_data["timestamp"]) + if self._is_cache_expired(cached_time): + logger.debug(f"缓存已过期: {cache_key}") + cache_file.unlink() # 删除过期缓存 + else: + logger.debug(f"精确匹配缓存: {tool_name}") + return cache_data["result"] + + except (json.JSONDecodeError, KeyError, ValueError) as e: + logger.warning(f"读取缓存文件失败: {cache_file}, 错误: {e}") + # 删除损坏的缓存文件 + if cache_file.exists(): + cache_file.unlink() + + # 如果精确匹配失败,尝试近似匹配 + return self._find_similar_cache(tool_name, function_args) + + def set( + self, tool_name: str, function_args: Dict[str, Any], result: Dict[str, Any] + ) -> None: + """ + 将结果保存到缓存 + + Args: + tool_name: 工具名称 + function_args: 函数参数 + result: 缓存结果 + """ + cache_key = self._generate_cache_key(tool_name, function_args) + cache_file = self._get_cache_file_path(cache_key) + + cache_data = { + "tool_name": tool_name, + "function_args": function_args, + "result": result, + "timestamp": datetime.now().isoformat(), + } + + try: + with open(cache_file, "w", encoding="utf-8") as f: + json.dump(cache_data, f, ensure_ascii=False, indent=2) + logger.debug(f"缓存已保存: {tool_name} -> {cache_key}") + except Exception as e: + logger.error(f"保存缓存失败: {cache_file}, 错误: {e}") + + def clear_expired(self) -> int: + """ + 清理过期缓存 + + Returns: + 删除的文件数量 + """ + removed_count = 0 + + for cache_file in self.cache_dir.glob("*.json"): + try: + with open(cache_file, "r", encoding="utf-8") as f: + cache_data = json.load(f) + + cached_time = datetime.fromisoformat(cache_data["timestamp"]) + if self._is_cache_expired(cached_time): + cache_file.unlink() + removed_count += 1 + logger.debug(f"删除过期缓存: {cache_file}") + + except Exception as e: + logger.warning(f"清理缓存文件时出错: {cache_file}, 错误: {e}") + # 删除损坏的文件 + try: + cache_file.unlink() + removed_count += 1 + except (OSError, json.JSONDecodeError, KeyError, ValueError): + logger.warning(f"删除损坏的缓存文件失败: {cache_file}, 错误: {e}") + + logger.info(f"清理完成,删除了 {removed_count} 个过期缓存文件") + return removed_count + + def clear_all(self) -> int: + """ + 清空所有缓存 + + Returns: + 删除的文件数量 + """ + removed_count = 0 + + for cache_file in self.cache_dir.glob("*.json"): + try: + cache_file.unlink() + removed_count += 1 + except Exception as e: + logger.warning(f"删除缓存文件失败: {cache_file}, 错误: {e}") + + logger.info(f"清空缓存完成,删除了 {removed_count} 个文件") + return removed_count + + def get_stats(self) -> Dict[str, Any]: + """ + 获取缓存统计信息 + + Returns: + 缓存统计信息字典 + """ + total_files = 0 + expired_files = 0 + total_size = 0 + + for cache_file in self.cache_dir.glob("*.json"): + try: + total_files += 1 + total_size += cache_file.stat().st_size + + with open(cache_file, "r", encoding="utf-8") as f: + cache_data = json.load(f) + + cached_time = datetime.fromisoformat(cache_data["timestamp"]) + if self._is_cache_expired(cached_time): + expired_files += 1 + + except (OSError, json.JSONDecodeError, KeyError, ValueError): + expired_files += 1 # 损坏的文件也算作过期 + + return { + "total_files": total_files, + "expired_files": expired_files, + "total_size_bytes": total_size, + "cache_dir": str(self.cache_dir), + "max_age_hours": self.max_age.total_seconds() / 3600, + "similarity_threshold": self.similarity_threshold, + } + +tool_cache = ToolCache() \ No newline at end of file diff --git a/src/main.py b/src/main.py index 6e07ee692..475b92699 100644 --- a/src/main.py +++ b/src/main.py @@ -17,6 +17,7 @@ from src.common.server import get_global_server, Server from src.mood.mood_manager import mood_manager from rich.traceback import install from src.manager.schedule_manager import schedule_manager +from src.common.cache_manager import tool_cache # from src.api.main import start_api_server # 导入新的插件管理器和热重载管理器 diff --git a/src/plugins/built_in/web_search_tool/plugin.py b/src/plugins/built_in/web_search_tool/plugin.py index b2212f836..c4415fd99 100644 --- a/src/plugins/built_in/web_search_tool/plugin.py +++ b/src/plugins/built_in/web_search_tool/plugin.py @@ -88,7 +88,8 @@ class WebSurfingTool(BaseTool): return {"error": "搜索查询不能为空。"} # 检查缓存 - cached_result = tool_cache.get(self.name, function_args) + query = function_args.get("query") + cached_result = await tool_cache.get(self.name, function_args, tool_class=self.__class__, semantic_query=query) if cached_result: logger.info(f"缓存命中: {self.name} -> {function_args}") return cached_result @@ -109,7 +110,8 @@ class WebSurfingTool(BaseTool): # 保存到缓存 if "error" not in result: - tool_cache.set(self.name, function_args, result) + query = function_args.get("query") + await tool_cache.set(self.name, function_args, self.__class__, result, semantic_query=query) return result @@ -463,7 +465,7 @@ class URLParserTool(BaseTool): 执行URL内容提取和总结。优先使用Exa,失败后尝试本地解析。 """ # 检查缓存 - cached_result = tool_cache.get(self.name, function_args) + cached_result = await tool_cache.get(self.name, function_args, tool_class=self.__class__) if cached_result: logger.info(f"缓存命中: {self.name} -> {function_args}") return cached_result @@ -577,7 +579,7 @@ class URLParserTool(BaseTool): # 保存到缓存 if "error" not in result: - tool_cache.set(self.name, function_args, result) + await tool_cache.set(self.name, function_args, self.__class__, result) return result