refactor(cache): 重构缓存系统为分层语义缓存

将原有的基于文件的 `ToolCache` 替换为全新的 `CacheManager`,引入了更复杂和高效的分层语义缓存机制。

新系统特性:
- **分层缓存**:
  - L1 缓存: 内存字典 (KV) + FAISS (向量),用于极速访问。
  - L2 缓存: SQLite (KV) + ChromaDB (向量),用于持久化存储。
- **语义缓存**: 利用嵌入模型 (Embedding) 对查询进行向量化,实现基于语义相似度的缓存命中,显著提高了缓存命中率。
- **自动失效**: 缓存键包含工具源代码的哈希值,当工具代码更新时,相关缓存会自动失效,避免了脏数据问题。
- **异步支持**: 缓存的 `get` 和 `set` 方法现在是异步的,以适应项目中异步化的工具调用流程。

`web_search_tool` 已更新以使用新的 `CacheManager`,在调用缓存时传递 `tool_class` 和 `semantic_query` 以充分利用新功能。

Co-Authored-By: tt-P607 <68868379+tt-P607@users.noreply.github.com>
This commit is contained in:
minecraft1024a
2025-08-17 22:18:26 +08:00
parent b45a992484
commit 51c0d2a1e8
4 changed files with 546 additions and 318 deletions

View File

@@ -1,344 +1,225 @@
import time
import json
import sqlite3
import chromadb
import hashlib
import re
import inspect
import numpy as np
import faiss
from typing import Any, Dict, Optional
from datetime import datetime, timedelta
from pathlib import Path
from difflib import SequenceMatcher
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.api_ada_configs import ModelTaskConfig
from src.config.config import global_config, model_config
logger = get_logger("cache_manager")
class CacheManager:
"""
一个支持分层和语义缓存的通用工具缓存管理器。
采用单例模式,确保在整个应用中只有一个缓存实例。
L1缓存: 内存字典 (KV) + FAISS (Vector)。
L2缓存: SQLite (KV) + ChromaDB (Vector)。
"""
_instance = None
class ToolCache:
"""工具缓存管理器,用于缓存工具调用结果,支持近似匹配"""
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(CacheManager, cls).__new__(cls)
return cls._instance
def __init__(
self,
cache_dir: str = "data/tool_cache",
max_age_hours: int = 24,
similarity_threshold: float = 0.65,
):
def __init__(self, default_ttl: int = 3600, db_path: str = "data/cache.db", chroma_path: str = "data/chroma_db"):
"""
初始化缓存管理器
Args:
cache_dir: 缓存目录路径
max_age_hours: 缓存最大存活时间(小时)
similarity_threshold: 近似匹配的相似度阈值 (0-1)
初始化缓存管理器
"""
self.cache_dir = Path(cache_dir)
self.max_age = timedelta(hours=max_age_hours)
self.max_age_seconds = max_age_hours * 3600
self.similarity_threshold = similarity_threshold
self.cache_dir.mkdir(parents=True, exist_ok=True)
if not hasattr(self, '_initialized'):
self.default_ttl = default_ttl
# L1 缓存 (内存)
self.l1_kv_cache: Dict[str, Dict[str, Any]] = {}
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
self.l1_vector_id_to_key: Dict[int, str] = {}
# L2 缓存 (持久化)
self.db_path = db_path
self._init_sqlite()
self.chroma_client = chromadb.PersistentClient(path=chroma_path)
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
# 嵌入模型
self.embedding_model = LLMRequest(model_config.model_task_config.embedding)
@staticmethod
def _normalize_query(query: str) -> str:
self._initialized = True
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (SQLite+ChromaDB)")
def _init_sqlite(self):
"""初始化SQLite数据库和表结构。"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS cache (
key TEXT PRIMARY KEY,
value TEXT,
expires_at REAL
)
""")
conn.commit()
def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any) -> str:
"""生成确定性的缓存键,包含代码哈希以实现自动失效。"""
try:
source_code = inspect.getsource(tool_class)
code_hash = hashlib.md5(source_code.encode()).hexdigest()
except (TypeError, OSError) as e:
code_hash = "unknown"
logger.warning(f"无法获取 {tool_class.__name__} 的源代码,代码哈希将为 'unknown'。错误: {e}")
try:
sorted_args = json.dumps(function_args, sort_keys=True)
except TypeError:
sorted_args = repr(sorted(function_args.items()))
return f"{tool_name}::{sorted_args}::{code_hash}"
async def get(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any, semantic_query: Optional[str] = None) -> Optional[Any]:
"""
标准化查询文本,用于相似度比较
Args:
query: 原始查询文本
Returns:
标准化后的查询文本
从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。
"""
if not query:
return ""
# 步骤 1: L1 精确缓存查询
key = self._generate_key(tool_name, function_args, tool_class)
logger.debug(f"生成的缓存键: {key}")
if semantic_query:
logger.debug(f"使用的语义查询: '{semantic_query}'")
# 纯 Python 实现
normalized = query.lower()
normalized = re.sub(r"[^\w\s]", " ", normalized)
normalized = " ".join(normalized.split())
return normalized
if key in self.l1_kv_cache:
entry = self.l1_kv_cache[key]
if time.time() < entry["expires_at"]:
logger.info(f"命中L1键值缓存: {key}")
return entry["data"]
else:
del self.l1_kv_cache[key]
def _calculate_similarity(self, text1: str, text2: str) -> float:
"""
计算两个文本的相似度
# 步骤 2: L1/L2 语义和L2精确缓存查询
query_embedding = None
if semantic_query and self.embedding_model:
embedding_result = await self.embedding_model.get_embedding(semantic_query)
if embedding_result:
query_embedding = np.array([embedding_result], dtype='float32')
Args:
text1: 文本1
text2: 文本2
# 步骤 2a: L1 语义缓存 (FAISS)
if query_embedding is not None and self.l1_vector_index.ntotal > 0:
faiss.normalize_L2(query_embedding)
distances, indices = self.l1_vector_index.search(query_embedding, 1)
if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似
hit_index = indices[0][0]
l1_hit_key = self.l1_vector_id_to_key.get(hit_index)
if l1_hit_key and l1_hit_key in self.l1_kv_cache:
logger.info(f"命中L1语义缓存: {l1_hit_key}")
return self.l1_kv_cache[l1_hit_key]["data"]
Returns:
相似度分数 (0-1)
"""
if not text1 or not text2:
return 0.0
# 步骤 2b: L2 精确缓存 (SQLite)
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (key,))
row = cursor.fetchone()
if row:
value, expires_at = row
if time.time() < expires_at:
logger.info(f"命中L2键值缓存: {key}")
data = json.loads(value)
# 回填 L1
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
return data
else:
cursor.execute("DELETE FROM cache WHERE key = ?", (key,))
conn.commit()
# 纯 Python 实现
norm_text1 = self._normalize_query(text1)
norm_text2 = self._normalize_query(text2)
# 步骤 2c: L2 语义缓存 (ChromaDB)
if query_embedding is not None:
results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1)
if results and results['ids'] and results['ids'][0]:
distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A'
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
if distance != 'N/A' and distance < 0.75:
l2_hit_key = results['ids'][0]
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (l2_hit_key if isinstance(l2_hit_key, str) else l2_hit_key[0],))
row = cursor.fetchone()
if row:
value, expires_at = row
if time.time() < expires_at:
data = json.loads(value)
logger.debug(f"L2语义缓存返回的数据: {data}")
# 回填 L1
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
if query_embedding is not None:
new_id = self.l1_vector_index.ntotal
faiss.normalize_L2(query_embedding)
self.l1_vector_index.add(x=query_embedding)
self.l1_vector_id_to_key[new_id] = key
return data
if norm_text1 == norm_text2:
return 1.0
return SequenceMatcher(None, norm_text1, norm_text2).ratio()
@staticmethod
def _generate_cache_key(tool_name: str, function_args: Dict[str, Any]) -> str:
"""
生成缓存键
Args:
tool_name: 工具名称
function_args: 函数参数
Returns:
缓存键字符串
"""
# 将参数排序后序列化,确保相同参数产生相同的键
sorted_args = json.dumps(function_args, sort_keys=True, ensure_ascii=False)
# 纯 Python 实现
cache_string = f"{tool_name}:{sorted_args}"
return hashlib.md5(cache_string.encode("utf-8")).hexdigest()
def _get_cache_file_path(self, cache_key: str) -> Path:
"""获取缓存文件路径"""
return self.cache_dir / f"{cache_key}.json"
def _is_cache_expired(self, cached_time: datetime) -> bool:
"""检查缓存是否过期"""
return datetime.now() - cached_time > self.max_age
def _find_similar_cache(
self, tool_name: str, function_args: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
查找相似的缓存条目
Args:
tool_name: 工具名称
function_args: 函数参数
Returns:
相似的缓存结果如果不存在则返回None
"""
query = function_args.get("query", "")
if not query:
return None
candidates = []
cache_data_list = []
# 遍历所有缓存文件,收集候选项
for cache_file in self.cache_dir.glob("*.json"):
try:
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
# 检查是否是同一个工具
if cache_data.get("tool_name") != tool_name:
continue
# 检查缓存是否过期
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
continue
# 检查其他参数是否匹配除了query
cached_args = cache_data.get("function_args", {})
args_match = True
for key, value in function_args.items():
if key != "query" and cached_args.get(key) != value:
args_match = False
break
if not args_match:
continue
# 收集候选项
cached_query = cached_args.get("query", "")
candidates.append((cached_query, len(cache_data_list)))
cache_data_list.append(cache_data)
except Exception as e:
logger.warning(f"检查缓存文件时出错: {cache_file}, 错误: {e}")
continue
if not candidates:
logger.debug(
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
)
return None
# 纯 Python 实现
best_match = None
best_similarity = 0.0
for cached_query, index in candidates:
similarity = self._calculate_similarity(query, cached_query)
if similarity > best_similarity and similarity >= self.similarity_threshold:
best_similarity = similarity
best_match = cache_data_list[index]
if best_match is not None:
cached_query = best_match["function_args"].get("query", "")
logger.info(
f"相似缓存命中,相似度: {best_similarity:.2f}, 原查询: '{cached_query}', 当前查询: '{query}'"
)
return best_match["result"]
logger.debug(
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
)
logger.debug(f"缓存未命中: {key}")
return None
def get(
self, tool_name: str, function_args: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
从缓存获取结果,支持精确匹配和近似匹配
async def set(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any, data: Any, ttl: Optional[int] = None, semantic_query: Optional[str] = None):
"""将结果存入所有缓存层。"""
if ttl is None:
ttl = self.default_ttl
if ttl <= 0:
return
Args:
tool_name: 工具名称
function_args: 函数参数
key = self._generate_key(tool_name, function_args, tool_class)
expires_at = time.time() + ttl
# 写入 L1
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
Returns:
缓存的结果如果不存在或已过期则返回None
"""
# 首先尝试精确匹配
cache_key = self._generate_cache_key(tool_name, function_args)
cache_file = self._get_cache_file_path(cache_key)
# 写入 L2
value = json.dumps(data)
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("REPLACE INTO cache (key, value, expires_at) VALUES (?, ?, ?)", (key, value, expires_at))
conn.commit()
if cache_file.exists():
try:
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
# 写入语义缓存
if semantic_query and self.embedding_model:
embedding_result = await self.embedding_model.get_embedding(semantic_query)
if embedding_result:
embedding = np.array([embedding_result], dtype='float32')
# 写入 L1 Vector
new_id = self.l1_vector_index.ntotal
faiss.normalize_L2(embedding)
self.l1_vector_index.add(x=embedding)
self.l1_vector_id_to_key[new_id] = key
# 写入 L2 Vector
self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key])
# 检查缓存是否过期
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
logger.debug(f"缓存已过期: {cache_key}")
cache_file.unlink() # 删除过期缓存
else:
logger.debug(f"精确匹配缓存: {tool_name}")
return cache_data["result"]
logger.info(f"已缓存条目: {key}, TTL: {ttl}s")
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f"读取缓存文件失败: {cache_file}, 错误: {e}")
# 删除损坏的缓存文件
if cache_file.exists():
cache_file.unlink()
def clear_l1(self):
"""清空L1缓存。"""
self.l1_kv_cache.clear()
self.l1_vector_index.reset()
self.l1_vector_id_to_key.clear()
logger.info("L1 (内存+FAISS) 缓存已清空。")
# 如果精确匹配失败,尝试近似匹配
return self._find_similar_cache(tool_name, function_args)
def clear_l2(self):
"""清空L2缓存。"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM cache")
conn.commit()
self.chroma_client.delete_collection(name="semantic_cache")
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
logger.info("L2 (SQLite & ChromaDB) 缓存已清空。")
def set(
self, tool_name: str, function_args: Dict[str, Any], result: Dict[str, Any]
) -> None:
"""
将结果保存到缓存
def clear_all(self):
"""清空所有缓存。"""
self.clear_l1()
self.clear_l2()
logger.info("所有缓存层级已清空。")
Args:
tool_name: 工具名称
function_args: 函数参数
result: 缓存结果
"""
cache_key = self._generate_cache_key(tool_name, function_args)
cache_file = self._get_cache_file_path(cache_key)
cache_data = {
"tool_name": tool_name,
"function_args": function_args,
"result": result,
"timestamp": datetime.now().isoformat(),
}
try:
with open(cache_file, "w", encoding="utf-8") as f:
json.dump(cache_data, f, ensure_ascii=False, indent=2)
logger.debug(f"缓存已保存: {tool_name} -> {cache_key}")
except Exception as e:
logger.error(f"保存缓存失败: {cache_file}, 错误: {e}")
def clear_expired(self) -> int:
"""
清理过期缓存
Returns:
删除的文件数量
"""
removed_count = 0
for cache_file in self.cache_dir.glob("*.json"):
try:
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
cache_file.unlink()
removed_count += 1
logger.debug(f"删除过期缓存: {cache_file}")
except Exception as e:
logger.warning(f"清理缓存文件时出错: {cache_file}, 错误: {e}")
# 删除损坏的文件
try:
cache_file.unlink()
removed_count += 1
except (OSError, json.JSONDecodeError, KeyError, ValueError):
logger.warning(f"删除损坏的缓存文件失败: {cache_file}, 错误: {e}")
logger.info(f"清理完成,删除了 {removed_count} 个过期缓存文件")
return removed_count
def clear_all(self) -> int:
"""
清空所有缓存
Returns:
删除的文件数量
"""
removed_count = 0
for cache_file in self.cache_dir.glob("*.json"):
try:
cache_file.unlink()
removed_count += 1
except Exception as e:
logger.warning(f"删除缓存文件失败: {cache_file}, 错误: {e}")
logger.info(f"清空缓存完成,删除了 {removed_count} 个文件")
return removed_count
def get_stats(self) -> Dict[str, Any]:
"""
获取缓存统计信息
Returns:
缓存统计信息字典
"""
total_files = 0
expired_files = 0
total_size = 0
for cache_file in self.cache_dir.glob("*.json"):
try:
total_files += 1
total_size += cache_file.stat().st_size
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
expired_files += 1
except (OSError, json.JSONDecodeError, KeyError, ValueError):
expired_files += 1 # 损坏的文件也算作过期
return {
"total_files": total_files,
"expired_files": expired_files,
"total_size_bytes": total_size,
"cache_dir": str(self.cache_dir),
"max_age_hours": self.max_age.total_seconds() / 3600,
"similarity_threshold": self.similarity_threshold,
}
tool_cache = ToolCache()
# 全局实例
tool_cache = CacheManager()

View File

@@ -0,0 +1,344 @@
import json
import hashlib
import re
from typing import Any, Dict, Optional
from datetime import datetime, timedelta
from pathlib import Path
from difflib import SequenceMatcher
from src.common.logger import get_logger
logger = get_logger("cache_manager")
class ToolCache:
"""工具缓存管理器,用于缓存工具调用结果,支持近似匹配"""
def __init__(
self,
cache_dir: str = "data/tool_cache",
max_age_hours: int = 24,
similarity_threshold: float = 0.65,
):
"""
初始化缓存管理器
Args:
cache_dir: 缓存目录路径
max_age_hours: 缓存最大存活时间(小时)
similarity_threshold: 近似匹配的相似度阈值 (0-1)
"""
self.cache_dir = Path(cache_dir)
self.max_age = timedelta(hours=max_age_hours)
self.max_age_seconds = max_age_hours * 3600
self.similarity_threshold = similarity_threshold
self.cache_dir.mkdir(parents=True, exist_ok=True)
@staticmethod
def _normalize_query(query: str) -> str:
"""
标准化查询文本,用于相似度比较
Args:
query: 原始查询文本
Returns:
标准化后的查询文本
"""
if not query:
return ""
# 纯 Python 实现
normalized = query.lower()
normalized = re.sub(r"[^\w\s]", " ", normalized)
normalized = " ".join(normalized.split())
return normalized
def _calculate_similarity(self, text1: str, text2: str) -> float:
"""
计算两个文本的相似度
Args:
text1: 文本1
text2: 文本2
Returns:
相似度分数 (0-1)
"""
if not text1 or not text2:
return 0.0
# 纯 Python 实现
norm_text1 = self._normalize_query(text1)
norm_text2 = self._normalize_query(text2)
if norm_text1 == norm_text2:
return 1.0
return SequenceMatcher(None, norm_text1, norm_text2).ratio()
@staticmethod
def _generate_cache_key(tool_name: str, function_args: Dict[str, Any]) -> str:
"""
生成缓存键
Args:
tool_name: 工具名称
function_args: 函数参数
Returns:
缓存键字符串
"""
# 将参数排序后序列化,确保相同参数产生相同的键
sorted_args = json.dumps(function_args, sort_keys=True, ensure_ascii=False)
# 纯 Python 实现
cache_string = f"{tool_name}:{sorted_args}"
return hashlib.md5(cache_string.encode("utf-8")).hexdigest()
def _get_cache_file_path(self, cache_key: str) -> Path:
"""获取缓存文件路径"""
return self.cache_dir / f"{cache_key}.json"
def _is_cache_expired(self, cached_time: datetime) -> bool:
"""检查缓存是否过期"""
return datetime.now() - cached_time > self.max_age
def _find_similar_cache(
self, tool_name: str, function_args: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
查找相似的缓存条目
Args:
tool_name: 工具名称
function_args: 函数参数
Returns:
相似的缓存结果如果不存在则返回None
"""
query = function_args.get("query", "")
if not query:
return None
candidates = []
cache_data_list = []
# 遍历所有缓存文件,收集候选项
for cache_file in self.cache_dir.glob("*.json"):
try:
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
# 检查是否是同一个工具
if cache_data.get("tool_name") != tool_name:
continue
# 检查缓存是否过期
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
continue
# 检查其他参数是否匹配除了query
cached_args = cache_data.get("function_args", {})
args_match = True
for key, value in function_args.items():
if key != "query" and cached_args.get(key) != value:
args_match = False
break
if not args_match:
continue
# 收集候选项
cached_query = cached_args.get("query", "")
candidates.append((cached_query, len(cache_data_list)))
cache_data_list.append(cache_data)
except Exception as e:
logger.warning(f"检查缓存文件时出错: {cache_file}, 错误: {e}")
continue
if not candidates:
logger.debug(
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
)
return None
# 纯 Python 实现
best_match = None
best_similarity = 0.0
for cached_query, index in candidates:
similarity = self._calculate_similarity(query, cached_query)
if similarity > best_similarity and similarity >= self.similarity_threshold:
best_similarity = similarity
best_match = cache_data_list[index]
if best_match is not None:
cached_query = best_match["function_args"].get("query", "")
logger.info(
f"相似缓存命中,相似度: {best_similarity:.2f}, 原查询: '{cached_query}', 当前查询: '{query}'"
)
return best_match["result"]
logger.debug(
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
)
return None
def get(
self, tool_name: str, function_args: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
从缓存获取结果,支持精确匹配和近似匹配
Args:
tool_name: 工具名称
function_args: 函数参数
Returns:
缓存的结果如果不存在或已过期则返回None
"""
# 首先尝试精确匹配
cache_key = self._generate_cache_key(tool_name, function_args)
cache_file = self._get_cache_file_path(cache_key)
if cache_file.exists():
try:
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
# 检查缓存是否过期
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
logger.debug(f"缓存已过期: {cache_key}")
cache_file.unlink() # 删除过期缓存
else:
logger.debug(f"精确匹配缓存: {tool_name}")
return cache_data["result"]
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f"读取缓存文件失败: {cache_file}, 错误: {e}")
# 删除损坏的缓存文件
if cache_file.exists():
cache_file.unlink()
# 如果精确匹配失败,尝试近似匹配
return self._find_similar_cache(tool_name, function_args)
def set(
self, tool_name: str, function_args: Dict[str, Any], result: Dict[str, Any]
) -> None:
"""
将结果保存到缓存
Args:
tool_name: 工具名称
function_args: 函数参数
result: 缓存结果
"""
cache_key = self._generate_cache_key(tool_name, function_args)
cache_file = self._get_cache_file_path(cache_key)
cache_data = {
"tool_name": tool_name,
"function_args": function_args,
"result": result,
"timestamp": datetime.now().isoformat(),
}
try:
with open(cache_file, "w", encoding="utf-8") as f:
json.dump(cache_data, f, ensure_ascii=False, indent=2)
logger.debug(f"缓存已保存: {tool_name} -> {cache_key}")
except Exception as e:
logger.error(f"保存缓存失败: {cache_file}, 错误: {e}")
def clear_expired(self) -> int:
"""
清理过期缓存
Returns:
删除的文件数量
"""
removed_count = 0
for cache_file in self.cache_dir.glob("*.json"):
try:
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
cache_file.unlink()
removed_count += 1
logger.debug(f"删除过期缓存: {cache_file}")
except Exception as e:
logger.warning(f"清理缓存文件时出错: {cache_file}, 错误: {e}")
# 删除损坏的文件
try:
cache_file.unlink()
removed_count += 1
except (OSError, json.JSONDecodeError, KeyError, ValueError):
logger.warning(f"删除损坏的缓存文件失败: {cache_file}, 错误: {e}")
logger.info(f"清理完成,删除了 {removed_count} 个过期缓存文件")
return removed_count
def clear_all(self) -> int:
"""
清空所有缓存
Returns:
删除的文件数量
"""
removed_count = 0
for cache_file in self.cache_dir.glob("*.json"):
try:
cache_file.unlink()
removed_count += 1
except Exception as e:
logger.warning(f"删除缓存文件失败: {cache_file}, 错误: {e}")
logger.info(f"清空缓存完成,删除了 {removed_count} 个文件")
return removed_count
def get_stats(self) -> Dict[str, Any]:
"""
获取缓存统计信息
Returns:
缓存统计信息字典
"""
total_files = 0
expired_files = 0
total_size = 0
for cache_file in self.cache_dir.glob("*.json"):
try:
total_files += 1
total_size += cache_file.stat().st_size
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
expired_files += 1
except (OSError, json.JSONDecodeError, KeyError, ValueError):
expired_files += 1 # 损坏的文件也算作过期
return {
"total_files": total_files,
"expired_files": expired_files,
"total_size_bytes": total_size,
"cache_dir": str(self.cache_dir),
"max_age_hours": self.max_age.total_seconds() / 3600,
"similarity_threshold": self.similarity_threshold,
}
tool_cache = ToolCache()

View File

@@ -17,6 +17,7 @@ from src.common.server import get_global_server, Server
from src.mood.mood_manager import mood_manager
from rich.traceback import install
from src.manager.schedule_manager import schedule_manager
from src.common.cache_manager import tool_cache
# from src.api.main import start_api_server
# 导入新的插件管理器和热重载管理器

View File

@@ -88,7 +88,8 @@ class WebSurfingTool(BaseTool):
return {"error": "搜索查询不能为空。"}
# 检查缓存
cached_result = tool_cache.get(self.name, function_args)
query = function_args.get("query")
cached_result = await tool_cache.get(self.name, function_args, tool_class=self.__class__, semantic_query=query)
if cached_result:
logger.info(f"缓存命中: {self.name} -> {function_args}")
return cached_result
@@ -109,7 +110,8 @@ class WebSurfingTool(BaseTool):
# 保存到缓存
if "error" not in result:
tool_cache.set(self.name, function_args, result)
query = function_args.get("query")
await tool_cache.set(self.name, function_args, self.__class__, result, semantic_query=query)
return result
@@ -463,7 +465,7 @@ class URLParserTool(BaseTool):
执行URL内容提取和总结。优先使用Exa失败后尝试本地解析。
"""
# 检查缓存
cached_result = tool_cache.get(self.name, function_args)
cached_result = await tool_cache.get(self.name, function_args, tool_class=self.__class__)
if cached_result:
logger.info(f"缓存命中: {self.name} -> {function_args}")
return cached_result
@@ -577,7 +579,7 @@ class URLParserTool(BaseTool):
# 保存到缓存
if "error" not in result:
tool_cache.set(self.name, function_args, result)
await tool_cache.set(self.name, function_args, self.__class__, result)
return result