将原有的基于文件的 `ToolCache` 替换为全新的 `CacheManager`,引入了更复杂和高效的分层语义缓存机制。 新系统特性: - **分层缓存**: - L1 缓存: 内存字典 (KV) + FAISS (向量),用于极速访问。 - L2 缓存: SQLite (KV) + ChromaDB (向量),用于持久化存储。 - **语义缓存**: 利用嵌入模型 (Embedding) 对查询进行向量化,实现基于语义相似度的缓存命中,显著提高了缓存命中率。 - **自动失效**: 缓存键包含工具源代码的哈希值,当工具代码更新时,相关缓存会自动失效,避免了脏数据问题。 - **异步支持**: 缓存的 `get` 和 `set` 方法现在是异步的,以适应项目中异步化的工具调用流程。 `web_search_tool` 已更新以使用新的 `CacheManager`,在调用缓存时传递 `tool_class` 和 `semantic_query` 以充分利用新功能。 Co-Authored-By: tt-P607 <68868379+tt-P607@users.noreply.github.com>
344 lines
11 KiB
Python
344 lines
11 KiB
Python
import json
|
||
import hashlib
|
||
import re
|
||
from typing import Any, Dict, Optional
|
||
from datetime import datetime, timedelta
|
||
from pathlib import Path
|
||
from difflib import SequenceMatcher
|
||
|
||
from src.common.logger import get_logger
|
||
|
||
logger = get_logger("cache_manager")
|
||
|
||
|
||
class ToolCache:
|
||
"""工具缓存管理器,用于缓存工具调用结果,支持近似匹配"""
|
||
|
||
def __init__(
|
||
self,
|
||
cache_dir: str = "data/tool_cache",
|
||
max_age_hours: int = 24,
|
||
similarity_threshold: float = 0.65,
|
||
):
|
||
"""
|
||
初始化缓存管理器
|
||
|
||
Args:
|
||
cache_dir: 缓存目录路径
|
||
max_age_hours: 缓存最大存活时间(小时)
|
||
similarity_threshold: 近似匹配的相似度阈值 (0-1)
|
||
"""
|
||
self.cache_dir = Path(cache_dir)
|
||
self.max_age = timedelta(hours=max_age_hours)
|
||
self.max_age_seconds = max_age_hours * 3600
|
||
self.similarity_threshold = similarity_threshold
|
||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
@staticmethod
|
||
def _normalize_query(query: str) -> str:
|
||
"""
|
||
标准化查询文本,用于相似度比较
|
||
|
||
Args:
|
||
query: 原始查询文本
|
||
|
||
Returns:
|
||
标准化后的查询文本
|
||
"""
|
||
if not query:
|
||
return ""
|
||
|
||
# 纯 Python 实现
|
||
normalized = query.lower()
|
||
normalized = re.sub(r"[^\w\s]", " ", normalized)
|
||
normalized = " ".join(normalized.split())
|
||
return normalized
|
||
|
||
def _calculate_similarity(self, text1: str, text2: str) -> float:
|
||
"""
|
||
计算两个文本的相似度
|
||
|
||
Args:
|
||
text1: 文本1
|
||
text2: 文本2
|
||
|
||
Returns:
|
||
相似度分数 (0-1)
|
||
"""
|
||
if not text1 or not text2:
|
||
return 0.0
|
||
|
||
# 纯 Python 实现
|
||
norm_text1 = self._normalize_query(text1)
|
||
norm_text2 = self._normalize_query(text2)
|
||
|
||
if norm_text1 == norm_text2:
|
||
return 1.0
|
||
|
||
return SequenceMatcher(None, norm_text1, norm_text2).ratio()
|
||
|
||
@staticmethod
|
||
def _generate_cache_key(tool_name: str, function_args: Dict[str, Any]) -> str:
|
||
"""
|
||
生成缓存键
|
||
|
||
Args:
|
||
tool_name: 工具名称
|
||
function_args: 函数参数
|
||
|
||
Returns:
|
||
缓存键字符串
|
||
"""
|
||
# 将参数排序后序列化,确保相同参数产生相同的键
|
||
sorted_args = json.dumps(function_args, sort_keys=True, ensure_ascii=False)
|
||
|
||
# 纯 Python 实现
|
||
cache_string = f"{tool_name}:{sorted_args}"
|
||
return hashlib.md5(cache_string.encode("utf-8")).hexdigest()
|
||
|
||
def _get_cache_file_path(self, cache_key: str) -> Path:
|
||
"""获取缓存文件路径"""
|
||
return self.cache_dir / f"{cache_key}.json"
|
||
|
||
def _is_cache_expired(self, cached_time: datetime) -> bool:
|
||
"""检查缓存是否过期"""
|
||
return datetime.now() - cached_time > self.max_age
|
||
|
||
def _find_similar_cache(
|
||
self, tool_name: str, function_args: Dict[str, Any]
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
查找相似的缓存条目
|
||
|
||
Args:
|
||
tool_name: 工具名称
|
||
function_args: 函数参数
|
||
|
||
Returns:
|
||
相似的缓存结果,如果不存在则返回None
|
||
"""
|
||
query = function_args.get("query", "")
|
||
if not query:
|
||
return None
|
||
|
||
candidates = []
|
||
cache_data_list = []
|
||
|
||
# 遍历所有缓存文件,收集候选项
|
||
for cache_file in self.cache_dir.glob("*.json"):
|
||
try:
|
||
with open(cache_file, "r", encoding="utf-8") as f:
|
||
cache_data = json.load(f)
|
||
|
||
# 检查是否是同一个工具
|
||
if cache_data.get("tool_name") != tool_name:
|
||
continue
|
||
|
||
# 检查缓存是否过期
|
||
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
||
if self._is_cache_expired(cached_time):
|
||
continue
|
||
|
||
# 检查其他参数是否匹配(除了query)
|
||
cached_args = cache_data.get("function_args", {})
|
||
args_match = True
|
||
for key, value in function_args.items():
|
||
if key != "query" and cached_args.get(key) != value:
|
||
args_match = False
|
||
break
|
||
|
||
if not args_match:
|
||
continue
|
||
|
||
# 收集候选项
|
||
cached_query = cached_args.get("query", "")
|
||
candidates.append((cached_query, len(cache_data_list)))
|
||
cache_data_list.append(cache_data)
|
||
|
||
except Exception as e:
|
||
logger.warning(f"检查缓存文件时出错: {cache_file}, 错误: {e}")
|
||
continue
|
||
|
||
if not candidates:
|
||
logger.debug(
|
||
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
|
||
)
|
||
return None
|
||
|
||
# 纯 Python 实现
|
||
best_match = None
|
||
best_similarity = 0.0
|
||
|
||
for cached_query, index in candidates:
|
||
similarity = self._calculate_similarity(query, cached_query)
|
||
if similarity > best_similarity and similarity >= self.similarity_threshold:
|
||
best_similarity = similarity
|
||
best_match = cache_data_list[index]
|
||
|
||
if best_match is not None:
|
||
cached_query = best_match["function_args"].get("query", "")
|
||
logger.info(
|
||
f"相似缓存命中,相似度: {best_similarity:.2f}, 原查询: '{cached_query}', 当前查询: '{query}'"
|
||
)
|
||
return best_match["result"]
|
||
|
||
logger.debug(
|
||
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
|
||
)
|
||
return None
|
||
|
||
def get(
|
||
self, tool_name: str, function_args: Dict[str, Any]
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
从缓存获取结果,支持精确匹配和近似匹配
|
||
|
||
Args:
|
||
tool_name: 工具名称
|
||
function_args: 函数参数
|
||
|
||
Returns:
|
||
缓存的结果,如果不存在或已过期则返回None
|
||
"""
|
||
# 首先尝试精确匹配
|
||
cache_key = self._generate_cache_key(tool_name, function_args)
|
||
cache_file = self._get_cache_file_path(cache_key)
|
||
|
||
if cache_file.exists():
|
||
try:
|
||
with open(cache_file, "r", encoding="utf-8") as f:
|
||
cache_data = json.load(f)
|
||
|
||
# 检查缓存是否过期
|
||
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
||
if self._is_cache_expired(cached_time):
|
||
logger.debug(f"缓存已过期: {cache_key}")
|
||
cache_file.unlink() # 删除过期缓存
|
||
else:
|
||
logger.debug(f"精确匹配缓存: {tool_name}")
|
||
return cache_data["result"]
|
||
|
||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||
logger.warning(f"读取缓存文件失败: {cache_file}, 错误: {e}")
|
||
# 删除损坏的缓存文件
|
||
if cache_file.exists():
|
||
cache_file.unlink()
|
||
|
||
# 如果精确匹配失败,尝试近似匹配
|
||
return self._find_similar_cache(tool_name, function_args)
|
||
|
||
def set(
|
||
self, tool_name: str, function_args: Dict[str, Any], result: Dict[str, Any]
|
||
) -> None:
|
||
"""
|
||
将结果保存到缓存
|
||
|
||
Args:
|
||
tool_name: 工具名称
|
||
function_args: 函数参数
|
||
result: 缓存结果
|
||
"""
|
||
cache_key = self._generate_cache_key(tool_name, function_args)
|
||
cache_file = self._get_cache_file_path(cache_key)
|
||
|
||
cache_data = {
|
||
"tool_name": tool_name,
|
||
"function_args": function_args,
|
||
"result": result,
|
||
"timestamp": datetime.now().isoformat(),
|
||
}
|
||
|
||
try:
|
||
with open(cache_file, "w", encoding="utf-8") as f:
|
||
json.dump(cache_data, f, ensure_ascii=False, indent=2)
|
||
logger.debug(f"缓存已保存: {tool_name} -> {cache_key}")
|
||
except Exception as e:
|
||
logger.error(f"保存缓存失败: {cache_file}, 错误: {e}")
|
||
|
||
def clear_expired(self) -> int:
|
||
"""
|
||
清理过期缓存
|
||
|
||
Returns:
|
||
删除的文件数量
|
||
"""
|
||
removed_count = 0
|
||
|
||
for cache_file in self.cache_dir.glob("*.json"):
|
||
try:
|
||
with open(cache_file, "r", encoding="utf-8") as f:
|
||
cache_data = json.load(f)
|
||
|
||
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
||
if self._is_cache_expired(cached_time):
|
||
cache_file.unlink()
|
||
removed_count += 1
|
||
logger.debug(f"删除过期缓存: {cache_file}")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"清理缓存文件时出错: {cache_file}, 错误: {e}")
|
||
# 删除损坏的文件
|
||
try:
|
||
cache_file.unlink()
|
||
removed_count += 1
|
||
except (OSError, json.JSONDecodeError, KeyError, ValueError):
|
||
logger.warning(f"删除损坏的缓存文件失败: {cache_file}, 错误: {e}")
|
||
|
||
logger.info(f"清理完成,删除了 {removed_count} 个过期缓存文件")
|
||
return removed_count
|
||
|
||
def clear_all(self) -> int:
|
||
"""
|
||
清空所有缓存
|
||
|
||
Returns:
|
||
删除的文件数量
|
||
"""
|
||
removed_count = 0
|
||
|
||
for cache_file in self.cache_dir.glob("*.json"):
|
||
try:
|
||
cache_file.unlink()
|
||
removed_count += 1
|
||
except Exception as e:
|
||
logger.warning(f"删除缓存文件失败: {cache_file}, 错误: {e}")
|
||
|
||
logger.info(f"清空缓存完成,删除了 {removed_count} 个文件")
|
||
return removed_count
|
||
|
||
def get_stats(self) -> Dict[str, Any]:
|
||
"""
|
||
获取缓存统计信息
|
||
|
||
Returns:
|
||
缓存统计信息字典
|
||
"""
|
||
total_files = 0
|
||
expired_files = 0
|
||
total_size = 0
|
||
|
||
for cache_file in self.cache_dir.glob("*.json"):
|
||
try:
|
||
total_files += 1
|
||
total_size += cache_file.stat().st_size
|
||
|
||
with open(cache_file, "r", encoding="utf-8") as f:
|
||
cache_data = json.load(f)
|
||
|
||
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
||
if self._is_cache_expired(cached_time):
|
||
expired_files += 1
|
||
|
||
except (OSError, json.JSONDecodeError, KeyError, ValueError):
|
||
expired_files += 1 # 损坏的文件也算作过期
|
||
|
||
return {
|
||
"total_files": total_files,
|
||
"expired_files": expired_files,
|
||
"total_size_bytes": total_size,
|
||
"cache_dir": str(self.cache_dir),
|
||
"max_age_hours": self.max_age.total_seconds() / 3600,
|
||
"similarity_threshold": self.similarity_threshold,
|
||
}
|
||
|
||
tool_cache = ToolCache() |