refactor(cache): 重构缓存系统为分层语义缓存
将原有的基于文件的 `ToolCache` 替换为全新的 `CacheManager`,引入了更复杂和高效的分层语义缓存机制。 新系统特性: - **分层缓存**: - L1 缓存: 内存字典 (KV) + FAISS (向量),用于极速访问。 - L2 缓存: SQLite (KV) + ChromaDB (向量),用于持久化存储。 - **语义缓存**: 利用嵌入模型 (Embedding) 对查询进行向量化,实现基于语义相似度的缓存命中,显著提高了缓存命中率。 - **自动失效**: 缓存键包含工具源代码的哈希值,当工具代码更新时,相关缓存会自动失效,避免了脏数据问题。 - **异步支持**: 缓存的 `get` 和 `set` 方法现在是异步的,以适应项目中异步化的工具调用流程。 `web_search_tool` 已更新以使用新的 `CacheManager`,在调用缓存时传递 `tool_class` 和 `semantic_query` 以充分利用新功能。 Co-Authored-By: tt-P607 <68868379+tt-P607@users.noreply.github.com>
This commit is contained in:
@@ -1,344 +1,225 @@
|
|||||||
|
import time
|
||||||
import json
|
import json
|
||||||
|
import sqlite3
|
||||||
|
import chromadb
|
||||||
import hashlib
|
import hashlib
|
||||||
import re
|
import inspect
|
||||||
|
import numpy as np
|
||||||
|
import faiss
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from pathlib import Path
|
|
||||||
from difflib import SequenceMatcher
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.config.api_ada_configs import ModelTaskConfig
|
||||||
|
from src.config.config import global_config, model_config
|
||||||
|
|
||||||
logger = get_logger("cache_manager")
|
logger = get_logger("cache_manager")
|
||||||
|
|
||||||
|
class CacheManager:
|
||||||
|
"""
|
||||||
|
一个支持分层和语义缓存的通用工具缓存管理器。
|
||||||
|
采用单例模式,确保在整个应用中只有一个缓存实例。
|
||||||
|
L1缓存: 内存字典 (KV) + FAISS (Vector)。
|
||||||
|
L2缓存: SQLite (KV) + ChromaDB (Vector)。
|
||||||
|
"""
|
||||||
|
_instance = None
|
||||||
|
|
||||||
class ToolCache:
|
def __new__(cls, *args, **kwargs):
|
||||||
"""工具缓存管理器,用于缓存工具调用结果,支持近似匹配"""
|
if not cls._instance:
|
||||||
|
cls._instance = super(CacheManager, cls).__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, default_ttl: int = 3600, db_path: str = "data/cache.db", chroma_path: str = "data/chroma_db"):
|
||||||
self,
|
|
||||||
cache_dir: str = "data/tool_cache",
|
|
||||||
max_age_hours: int = 24,
|
|
||||||
similarity_threshold: float = 0.65,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
初始化缓存管理器
|
初始化缓存管理器。
|
||||||
|
|
||||||
Args:
|
|
||||||
cache_dir: 缓存目录路径
|
|
||||||
max_age_hours: 缓存最大存活时间(小时)
|
|
||||||
similarity_threshold: 近似匹配的相似度阈值 (0-1)
|
|
||||||
"""
|
"""
|
||||||
self.cache_dir = Path(cache_dir)
|
if not hasattr(self, '_initialized'):
|
||||||
self.max_age = timedelta(hours=max_age_hours)
|
self.default_ttl = default_ttl
|
||||||
self.max_age_seconds = max_age_hours * 3600
|
|
||||||
self.similarity_threshold = similarity_threshold
|
# L1 缓存 (内存)
|
||||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
self.l1_kv_cache: Dict[str, Dict[str, Any]] = {}
|
||||||
|
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||||
|
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
|
||||||
|
self.l1_vector_id_to_key: Dict[int, str] = {}
|
||||||
|
|
||||||
|
# L2 缓存 (持久化)
|
||||||
|
self.db_path = db_path
|
||||||
|
self._init_sqlite()
|
||||||
|
self.chroma_client = chromadb.PersistentClient(path=chroma_path)
|
||||||
|
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
|
||||||
|
|
||||||
|
# 嵌入模型
|
||||||
|
self.embedding_model = LLMRequest(model_config.model_task_config.embedding)
|
||||||
|
|
||||||
@staticmethod
|
self._initialized = True
|
||||||
def _normalize_query(query: str) -> str:
|
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (SQLite+ChromaDB)")
|
||||||
|
|
||||||
|
def _init_sqlite(self):
|
||||||
|
"""初始化SQLite数据库和表结构。"""
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS cache (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value TEXT,
|
||||||
|
expires_at REAL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any) -> str:
|
||||||
|
"""生成确定性的缓存键,包含代码哈希以实现自动失效。"""
|
||||||
|
try:
|
||||||
|
source_code = inspect.getsource(tool_class)
|
||||||
|
code_hash = hashlib.md5(source_code.encode()).hexdigest()
|
||||||
|
except (TypeError, OSError) as e:
|
||||||
|
code_hash = "unknown"
|
||||||
|
logger.warning(f"无法获取 {tool_class.__name__} 的源代码,代码哈希将为 'unknown'。错误: {e}")
|
||||||
|
try:
|
||||||
|
sorted_args = json.dumps(function_args, sort_keys=True)
|
||||||
|
except TypeError:
|
||||||
|
sorted_args = repr(sorted(function_args.items()))
|
||||||
|
return f"{tool_name}::{sorted_args}::{code_hash}"
|
||||||
|
|
||||||
|
async def get(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any, semantic_query: Optional[str] = None) -> Optional[Any]:
|
||||||
"""
|
"""
|
||||||
标准化查询文本,用于相似度比较
|
从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 原始查询文本
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
标准化后的查询文本
|
|
||||||
"""
|
"""
|
||||||
if not query:
|
# 步骤 1: L1 精确缓存查询
|
||||||
return ""
|
key = self._generate_key(tool_name, function_args, tool_class)
|
||||||
|
logger.debug(f"生成的缓存键: {key}")
|
||||||
|
if semantic_query:
|
||||||
|
logger.debug(f"使用的语义查询: '{semantic_query}'")
|
||||||
|
|
||||||
# 纯 Python 实现
|
if key in self.l1_kv_cache:
|
||||||
normalized = query.lower()
|
entry = self.l1_kv_cache[key]
|
||||||
normalized = re.sub(r"[^\w\s]", " ", normalized)
|
if time.time() < entry["expires_at"]:
|
||||||
normalized = " ".join(normalized.split())
|
logger.info(f"命中L1键值缓存: {key}")
|
||||||
return normalized
|
return entry["data"]
|
||||||
|
else:
|
||||||
|
del self.l1_kv_cache[key]
|
||||||
|
|
||||||
def _calculate_similarity(self, text1: str, text2: str) -> float:
|
# 步骤 2: L1/L2 语义和L2精确缓存查询
|
||||||
"""
|
query_embedding = None
|
||||||
计算两个文本的相似度
|
if semantic_query and self.embedding_model:
|
||||||
|
embedding_result = await self.embedding_model.get_embedding(semantic_query)
|
||||||
|
if embedding_result:
|
||||||
|
query_embedding = np.array([embedding_result], dtype='float32')
|
||||||
|
|
||||||
Args:
|
# 步骤 2a: L1 语义缓存 (FAISS)
|
||||||
text1: 文本1
|
if query_embedding is not None and self.l1_vector_index.ntotal > 0:
|
||||||
text2: 文本2
|
faiss.normalize_L2(query_embedding)
|
||||||
|
distances, indices = self.l1_vector_index.search(query_embedding, 1)
|
||||||
|
if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似
|
||||||
|
hit_index = indices[0][0]
|
||||||
|
l1_hit_key = self.l1_vector_id_to_key.get(hit_index)
|
||||||
|
if l1_hit_key and l1_hit_key in self.l1_kv_cache:
|
||||||
|
logger.info(f"命中L1语义缓存: {l1_hit_key}")
|
||||||
|
return self.l1_kv_cache[l1_hit_key]["data"]
|
||||||
|
|
||||||
Returns:
|
# 步骤 2b: L2 精确缓存 (SQLite)
|
||||||
相似度分数 (0-1)
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
"""
|
cursor = conn.cursor()
|
||||||
if not text1 or not text2:
|
cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (key,))
|
||||||
return 0.0
|
row = cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
value, expires_at = row
|
||||||
|
if time.time() < expires_at:
|
||||||
|
logger.info(f"命中L2键值缓存: {key}")
|
||||||
|
data = json.loads(value)
|
||||||
|
# 回填 L1
|
||||||
|
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
||||||
|
return data
|
||||||
|
else:
|
||||||
|
cursor.execute("DELETE FROM cache WHERE key = ?", (key,))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
# 纯 Python 实现
|
# 步骤 2c: L2 语义缓存 (ChromaDB)
|
||||||
norm_text1 = self._normalize_query(text1)
|
if query_embedding is not None:
|
||||||
norm_text2 = self._normalize_query(text2)
|
results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1)
|
||||||
|
if results and results['ids'] and results['ids'][0]:
|
||||||
|
distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A'
|
||||||
|
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
|
||||||
|
if distance != 'N/A' and distance < 0.75:
|
||||||
|
l2_hit_key = results['ids'][0]
|
||||||
|
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (l2_hit_key if isinstance(l2_hit_key, str) else l2_hit_key[0],))
|
||||||
|
row = cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
value, expires_at = row
|
||||||
|
if time.time() < expires_at:
|
||||||
|
data = json.loads(value)
|
||||||
|
logger.debug(f"L2语义缓存返回的数据: {data}")
|
||||||
|
# 回填 L1
|
||||||
|
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
||||||
|
if query_embedding is not None:
|
||||||
|
new_id = self.l1_vector_index.ntotal
|
||||||
|
faiss.normalize_L2(query_embedding)
|
||||||
|
self.l1_vector_index.add(x=query_embedding)
|
||||||
|
self.l1_vector_id_to_key[new_id] = key
|
||||||
|
return data
|
||||||
|
|
||||||
if norm_text1 == norm_text2:
|
logger.debug(f"缓存未命中: {key}")
|
||||||
return 1.0
|
|
||||||
|
|
||||||
return SequenceMatcher(None, norm_text1, norm_text2).ratio()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _generate_cache_key(tool_name: str, function_args: Dict[str, Any]) -> str:
|
|
||||||
"""
|
|
||||||
生成缓存键
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: 工具名称
|
|
||||||
function_args: 函数参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
缓存键字符串
|
|
||||||
"""
|
|
||||||
# 将参数排序后序列化,确保相同参数产生相同的键
|
|
||||||
sorted_args = json.dumps(function_args, sort_keys=True, ensure_ascii=False)
|
|
||||||
|
|
||||||
# 纯 Python 实现
|
|
||||||
cache_string = f"{tool_name}:{sorted_args}"
|
|
||||||
return hashlib.md5(cache_string.encode("utf-8")).hexdigest()
|
|
||||||
|
|
||||||
def _get_cache_file_path(self, cache_key: str) -> Path:
|
|
||||||
"""获取缓存文件路径"""
|
|
||||||
return self.cache_dir / f"{cache_key}.json"
|
|
||||||
|
|
||||||
def _is_cache_expired(self, cached_time: datetime) -> bool:
|
|
||||||
"""检查缓存是否过期"""
|
|
||||||
return datetime.now() - cached_time > self.max_age
|
|
||||||
|
|
||||||
def _find_similar_cache(
|
|
||||||
self, tool_name: str, function_args: Dict[str, Any]
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
查找相似的缓存条目
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: 工具名称
|
|
||||||
function_args: 函数参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
相似的缓存结果,如果不存在则返回None
|
|
||||||
"""
|
|
||||||
query = function_args.get("query", "")
|
|
||||||
if not query:
|
|
||||||
return None
|
|
||||||
|
|
||||||
candidates = []
|
|
||||||
cache_data_list = []
|
|
||||||
|
|
||||||
# 遍历所有缓存文件,收集候选项
|
|
||||||
for cache_file in self.cache_dir.glob("*.json"):
|
|
||||||
try:
|
|
||||||
with open(cache_file, "r", encoding="utf-8") as f:
|
|
||||||
cache_data = json.load(f)
|
|
||||||
|
|
||||||
# 检查是否是同一个工具
|
|
||||||
if cache_data.get("tool_name") != tool_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查缓存是否过期
|
|
||||||
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
|
||||||
if self._is_cache_expired(cached_time):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查其他参数是否匹配(除了query)
|
|
||||||
cached_args = cache_data.get("function_args", {})
|
|
||||||
args_match = True
|
|
||||||
for key, value in function_args.items():
|
|
||||||
if key != "query" and cached_args.get(key) != value:
|
|
||||||
args_match = False
|
|
||||||
break
|
|
||||||
|
|
||||||
if not args_match:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 收集候选项
|
|
||||||
cached_query = cached_args.get("query", "")
|
|
||||||
candidates.append((cached_query, len(cache_data_list)))
|
|
||||||
cache_data_list.append(cache_data)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"检查缓存文件时出错: {cache_file}, 错误: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not candidates:
|
|
||||||
logger.debug(
|
|
||||||
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 纯 Python 实现
|
|
||||||
best_match = None
|
|
||||||
best_similarity = 0.0
|
|
||||||
|
|
||||||
for cached_query, index in candidates:
|
|
||||||
similarity = self._calculate_similarity(query, cached_query)
|
|
||||||
if similarity > best_similarity and similarity >= self.similarity_threshold:
|
|
||||||
best_similarity = similarity
|
|
||||||
best_match = cache_data_list[index]
|
|
||||||
|
|
||||||
if best_match is not None:
|
|
||||||
cached_query = best_match["function_args"].get("query", "")
|
|
||||||
logger.info(
|
|
||||||
f"相似缓存命中,相似度: {best_similarity:.2f}, 原查询: '{cached_query}', 当前查询: '{query}'"
|
|
||||||
)
|
|
||||||
return best_match["result"]
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get(
|
async def set(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any, data: Any, ttl: Optional[int] = None, semantic_query: Optional[str] = None):
|
||||||
self, tool_name: str, function_args: Dict[str, Any]
|
"""将结果存入所有缓存层。"""
|
||||||
) -> Optional[Dict[str, Any]]:
|
if ttl is None:
|
||||||
"""
|
ttl = self.default_ttl
|
||||||
从缓存获取结果,支持精确匹配和近似匹配
|
if ttl <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
Args:
|
key = self._generate_key(tool_name, function_args, tool_class)
|
||||||
tool_name: 工具名称
|
expires_at = time.time() + ttl
|
||||||
function_args: 函数参数
|
|
||||||
|
# 写入 L1
|
||||||
|
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
||||||
|
|
||||||
Returns:
|
# 写入 L2
|
||||||
缓存的结果,如果不存在或已过期则返回None
|
value = json.dumps(data)
|
||||||
"""
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
# 首先尝试精确匹配
|
cursor = conn.cursor()
|
||||||
cache_key = self._generate_cache_key(tool_name, function_args)
|
cursor.execute("REPLACE INTO cache (key, value, expires_at) VALUES (?, ?, ?)", (key, value, expires_at))
|
||||||
cache_file = self._get_cache_file_path(cache_key)
|
conn.commit()
|
||||||
|
|
||||||
if cache_file.exists():
|
# 写入语义缓存
|
||||||
try:
|
if semantic_query and self.embedding_model:
|
||||||
with open(cache_file, "r", encoding="utf-8") as f:
|
embedding_result = await self.embedding_model.get_embedding(semantic_query)
|
||||||
cache_data = json.load(f)
|
if embedding_result:
|
||||||
|
embedding = np.array([embedding_result], dtype='float32')
|
||||||
|
# 写入 L1 Vector
|
||||||
|
new_id = self.l1_vector_index.ntotal
|
||||||
|
faiss.normalize_L2(embedding)
|
||||||
|
self.l1_vector_index.add(x=embedding)
|
||||||
|
self.l1_vector_id_to_key[new_id] = key
|
||||||
|
# 写入 L2 Vector
|
||||||
|
self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key])
|
||||||
|
|
||||||
# 检查缓存是否过期
|
logger.info(f"已缓存条目: {key}, TTL: {ttl}s")
|
||||||
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
|
||||||
if self._is_cache_expired(cached_time):
|
|
||||||
logger.debug(f"缓存已过期: {cache_key}")
|
|
||||||
cache_file.unlink() # 删除过期缓存
|
|
||||||
else:
|
|
||||||
logger.debug(f"精确匹配缓存: {tool_name}")
|
|
||||||
return cache_data["result"]
|
|
||||||
|
|
||||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
def clear_l1(self):
|
||||||
logger.warning(f"读取缓存文件失败: {cache_file}, 错误: {e}")
|
"""清空L1缓存。"""
|
||||||
# 删除损坏的缓存文件
|
self.l1_kv_cache.clear()
|
||||||
if cache_file.exists():
|
self.l1_vector_index.reset()
|
||||||
cache_file.unlink()
|
self.l1_vector_id_to_key.clear()
|
||||||
|
logger.info("L1 (内存+FAISS) 缓存已清空。")
|
||||||
|
|
||||||
# 如果精确匹配失败,尝试近似匹配
|
def clear_l2(self):
|
||||||
return self._find_similar_cache(tool_name, function_args)
|
"""清空L2缓存。"""
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM cache")
|
||||||
|
conn.commit()
|
||||||
|
self.chroma_client.delete_collection(name="semantic_cache")
|
||||||
|
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
|
||||||
|
logger.info("L2 (SQLite & ChromaDB) 缓存已清空。")
|
||||||
|
|
||||||
def set(
|
def clear_all(self):
|
||||||
self, tool_name: str, function_args: Dict[str, Any], result: Dict[str, Any]
|
"""清空所有缓存。"""
|
||||||
) -> None:
|
self.clear_l1()
|
||||||
"""
|
self.clear_l2()
|
||||||
将结果保存到缓存
|
logger.info("所有缓存层级已清空。")
|
||||||
|
|
||||||
Args:
|
# 全局实例
|
||||||
tool_name: 工具名称
|
tool_cache = CacheManager()
|
||||||
function_args: 函数参数
|
|
||||||
result: 缓存结果
|
|
||||||
"""
|
|
||||||
cache_key = self._generate_cache_key(tool_name, function_args)
|
|
||||||
cache_file = self._get_cache_file_path(cache_key)
|
|
||||||
|
|
||||||
cache_data = {
|
|
||||||
"tool_name": tool_name,
|
|
||||||
"function_args": function_args,
|
|
||||||
"result": result,
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(cache_file, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(cache_data, f, ensure_ascii=False, indent=2)
|
|
||||||
logger.debug(f"缓存已保存: {tool_name} -> {cache_key}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"保存缓存失败: {cache_file}, 错误: {e}")
|
|
||||||
|
|
||||||
def clear_expired(self) -> int:
|
|
||||||
"""
|
|
||||||
清理过期缓存
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
删除的文件数量
|
|
||||||
"""
|
|
||||||
removed_count = 0
|
|
||||||
|
|
||||||
for cache_file in self.cache_dir.glob("*.json"):
|
|
||||||
try:
|
|
||||||
with open(cache_file, "r", encoding="utf-8") as f:
|
|
||||||
cache_data = json.load(f)
|
|
||||||
|
|
||||||
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
|
||||||
if self._is_cache_expired(cached_time):
|
|
||||||
cache_file.unlink()
|
|
||||||
removed_count += 1
|
|
||||||
logger.debug(f"删除过期缓存: {cache_file}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"清理缓存文件时出错: {cache_file}, 错误: {e}")
|
|
||||||
# 删除损坏的文件
|
|
||||||
try:
|
|
||||||
cache_file.unlink()
|
|
||||||
removed_count += 1
|
|
||||||
except (OSError, json.JSONDecodeError, KeyError, ValueError):
|
|
||||||
logger.warning(f"删除损坏的缓存文件失败: {cache_file}, 错误: {e}")
|
|
||||||
|
|
||||||
logger.info(f"清理完成,删除了 {removed_count} 个过期缓存文件")
|
|
||||||
return removed_count
|
|
||||||
|
|
||||||
def clear_all(self) -> int:
|
|
||||||
"""
|
|
||||||
清空所有缓存
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
删除的文件数量
|
|
||||||
"""
|
|
||||||
removed_count = 0
|
|
||||||
|
|
||||||
for cache_file in self.cache_dir.glob("*.json"):
|
|
||||||
try:
|
|
||||||
cache_file.unlink()
|
|
||||||
removed_count += 1
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"删除缓存文件失败: {cache_file}, 错误: {e}")
|
|
||||||
|
|
||||||
logger.info(f"清空缓存完成,删除了 {removed_count} 个文件")
|
|
||||||
return removed_count
|
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
获取缓存统计信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
缓存统计信息字典
|
|
||||||
"""
|
|
||||||
total_files = 0
|
|
||||||
expired_files = 0
|
|
||||||
total_size = 0
|
|
||||||
|
|
||||||
for cache_file in self.cache_dir.glob("*.json"):
|
|
||||||
try:
|
|
||||||
total_files += 1
|
|
||||||
total_size += cache_file.stat().st_size
|
|
||||||
|
|
||||||
with open(cache_file, "r", encoding="utf-8") as f:
|
|
||||||
cache_data = json.load(f)
|
|
||||||
|
|
||||||
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
|
||||||
if self._is_cache_expired(cached_time):
|
|
||||||
expired_files += 1
|
|
||||||
|
|
||||||
except (OSError, json.JSONDecodeError, KeyError, ValueError):
|
|
||||||
expired_files += 1 # 损坏的文件也算作过期
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_files": total_files,
|
|
||||||
"expired_files": expired_files,
|
|
||||||
"total_size_bytes": total_size,
|
|
||||||
"cache_dir": str(self.cache_dir),
|
|
||||||
"max_age_hours": self.max_age.total_seconds() / 3600,
|
|
||||||
"similarity_threshold": self.similarity_threshold,
|
|
||||||
}
|
|
||||||
|
|
||||||
tool_cache = ToolCache()
|
|
||||||
344
src/common/cache_manager_backup.py
Normal file
344
src/common/cache_manager_backup.py
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
import json
|
||||||
|
import hashlib
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from difflib import SequenceMatcher
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("cache_manager")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCache:
|
||||||
|
"""工具缓存管理器,用于缓存工具调用结果,支持近似匹配"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cache_dir: str = "data/tool_cache",
|
||||||
|
max_age_hours: int = 24,
|
||||||
|
similarity_threshold: float = 0.65,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化缓存管理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_dir: 缓存目录路径
|
||||||
|
max_age_hours: 缓存最大存活时间(小时)
|
||||||
|
similarity_threshold: 近似匹配的相似度阈值 (0-1)
|
||||||
|
"""
|
||||||
|
self.cache_dir = Path(cache_dir)
|
||||||
|
self.max_age = timedelta(hours=max_age_hours)
|
||||||
|
self.max_age_seconds = max_age_hours * 3600
|
||||||
|
self.similarity_threshold = similarity_threshold
|
||||||
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_query(query: str) -> str:
|
||||||
|
"""
|
||||||
|
标准化查询文本,用于相似度比较
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 原始查询文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
标准化后的查询文本
|
||||||
|
"""
|
||||||
|
if not query:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 纯 Python 实现
|
||||||
|
normalized = query.lower()
|
||||||
|
normalized = re.sub(r"[^\w\s]", " ", normalized)
|
||||||
|
normalized = " ".join(normalized.split())
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
def _calculate_similarity(self, text1: str, text2: str) -> float:
|
||||||
|
"""
|
||||||
|
计算两个文本的相似度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text1: 文本1
|
||||||
|
text2: 文本2
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相似度分数 (0-1)
|
||||||
|
"""
|
||||||
|
if not text1 or not text2:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# 纯 Python 实现
|
||||||
|
norm_text1 = self._normalize_query(text1)
|
||||||
|
norm_text2 = self._normalize_query(text2)
|
||||||
|
|
||||||
|
if norm_text1 == norm_text2:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
return SequenceMatcher(None, norm_text1, norm_text2).ratio()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_cache_key(tool_name: str, function_args: Dict[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
生成缓存键
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
function_args: 函数参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存键字符串
|
||||||
|
"""
|
||||||
|
# 将参数排序后序列化,确保相同参数产生相同的键
|
||||||
|
sorted_args = json.dumps(function_args, sort_keys=True, ensure_ascii=False)
|
||||||
|
|
||||||
|
# 纯 Python 实现
|
||||||
|
cache_string = f"{tool_name}:{sorted_args}"
|
||||||
|
return hashlib.md5(cache_string.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
def _get_cache_file_path(self, cache_key: str) -> Path:
|
||||||
|
"""获取缓存文件路径"""
|
||||||
|
return self.cache_dir / f"{cache_key}.json"
|
||||||
|
|
||||||
|
def _is_cache_expired(self, cached_time: datetime) -> bool:
|
||||||
|
"""检查缓存是否过期"""
|
||||||
|
return datetime.now() - cached_time > self.max_age
|
||||||
|
|
||||||
|
def _find_similar_cache(
|
||||||
|
self, tool_name: str, function_args: Dict[str, Any]
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
查找相似的缓存条目
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
function_args: 函数参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相似的缓存结果,如果不存在则返回None
|
||||||
|
"""
|
||||||
|
query = function_args.get("query", "")
|
||||||
|
if not query:
|
||||||
|
return None
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
cache_data_list = []
|
||||||
|
|
||||||
|
# 遍历所有缓存文件,收集候选项
|
||||||
|
for cache_file in self.cache_dir.glob("*.json"):
|
||||||
|
try:
|
||||||
|
with open(cache_file, "r", encoding="utf-8") as f:
|
||||||
|
cache_data = json.load(f)
|
||||||
|
|
||||||
|
# 检查是否是同一个工具
|
||||||
|
if cache_data.get("tool_name") != tool_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查缓存是否过期
|
||||||
|
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
||||||
|
if self._is_cache_expired(cached_time):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查其他参数是否匹配(除了query)
|
||||||
|
cached_args = cache_data.get("function_args", {})
|
||||||
|
args_match = True
|
||||||
|
for key, value in function_args.items():
|
||||||
|
if key != "query" and cached_args.get(key) != value:
|
||||||
|
args_match = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if not args_match:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 收集候选项
|
||||||
|
cached_query = cached_args.get("query", "")
|
||||||
|
candidates.append((cached_query, len(cache_data_list)))
|
||||||
|
cache_data_list.append(cache_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"检查缓存文件时出错: {cache_file}, 错误: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
logger.debug(
|
||||||
|
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 纯 Python 实现
|
||||||
|
best_match = None
|
||||||
|
best_similarity = 0.0
|
||||||
|
|
||||||
|
for cached_query, index in candidates:
|
||||||
|
similarity = self._calculate_similarity(query, cached_query)
|
||||||
|
if similarity > best_similarity and similarity >= self.similarity_threshold:
|
||||||
|
best_similarity = similarity
|
||||||
|
best_match = cache_data_list[index]
|
||||||
|
|
||||||
|
if best_match is not None:
|
||||||
|
cached_query = best_match["function_args"].get("query", "")
|
||||||
|
logger.info(
|
||||||
|
f"相似缓存命中,相似度: {best_similarity:.2f}, 原查询: '{cached_query}', 当前查询: '{query}'"
|
||||||
|
)
|
||||||
|
return best_match["result"]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self, tool_name: str, function_args: Dict[str, Any]
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
从缓存获取结果,支持精确匹配和近似匹配
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
function_args: 函数参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存的结果,如果不存在或已过期则返回None
|
||||||
|
"""
|
||||||
|
# 首先尝试精确匹配
|
||||||
|
cache_key = self._generate_cache_key(tool_name, function_args)
|
||||||
|
cache_file = self._get_cache_file_path(cache_key)
|
||||||
|
|
||||||
|
if cache_file.exists():
|
||||||
|
try:
|
||||||
|
with open(cache_file, "r", encoding="utf-8") as f:
|
||||||
|
cache_data = json.load(f)
|
||||||
|
|
||||||
|
# 检查缓存是否过期
|
||||||
|
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
||||||
|
if self._is_cache_expired(cached_time):
|
||||||
|
logger.debug(f"缓存已过期: {cache_key}")
|
||||||
|
cache_file.unlink() # 删除过期缓存
|
||||||
|
else:
|
||||||
|
logger.debug(f"精确匹配缓存: {tool_name}")
|
||||||
|
return cache_data["result"]
|
||||||
|
|
||||||
|
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||||
|
logger.warning(f"读取缓存文件失败: {cache_file}, 错误: {e}")
|
||||||
|
# 删除损坏的缓存文件
|
||||||
|
if cache_file.exists():
|
||||||
|
cache_file.unlink()
|
||||||
|
|
||||||
|
# 如果精确匹配失败,尝试近似匹配
|
||||||
|
return self._find_similar_cache(tool_name, function_args)
|
||||||
|
|
||||||
|
def set(
|
||||||
|
self, tool_name: str, function_args: Dict[str, Any], result: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
将结果保存到缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
function_args: 函数参数
|
||||||
|
result: 缓存结果
|
||||||
|
"""
|
||||||
|
cache_key = self._generate_cache_key(tool_name, function_args)
|
||||||
|
cache_file = self._get_cache_file_path(cache_key)
|
||||||
|
|
||||||
|
cache_data = {
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"function_args": function_args,
|
||||||
|
"result": result,
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(cache_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(cache_data, f, ensure_ascii=False, indent=2)
|
||||||
|
logger.debug(f"缓存已保存: {tool_name} -> {cache_key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存缓存失败: {cache_file}, 错误: {e}")
|
||||||
|
|
||||||
|
def clear_expired(self) -> int:
|
||||||
|
"""
|
||||||
|
清理过期缓存
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
删除的文件数量
|
||||||
|
"""
|
||||||
|
removed_count = 0
|
||||||
|
|
||||||
|
for cache_file in self.cache_dir.glob("*.json"):
|
||||||
|
try:
|
||||||
|
with open(cache_file, "r", encoding="utf-8") as f:
|
||||||
|
cache_data = json.load(f)
|
||||||
|
|
||||||
|
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
||||||
|
if self._is_cache_expired(cached_time):
|
||||||
|
cache_file.unlink()
|
||||||
|
removed_count += 1
|
||||||
|
logger.debug(f"删除过期缓存: {cache_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"清理缓存文件时出错: {cache_file}, 错误: {e}")
|
||||||
|
# 删除损坏的文件
|
||||||
|
try:
|
||||||
|
cache_file.unlink()
|
||||||
|
removed_count += 1
|
||||||
|
except (OSError, json.JSONDecodeError, KeyError, ValueError):
|
||||||
|
logger.warning(f"删除损坏的缓存文件失败: {cache_file}, 错误: {e}")
|
||||||
|
|
||||||
|
logger.info(f"清理完成,删除了 {removed_count} 个过期缓存文件")
|
||||||
|
return removed_count
|
||||||
|
|
||||||
|
def clear_all(self) -> int:
|
||||||
|
"""
|
||||||
|
清空所有缓存
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
删除的文件数量
|
||||||
|
"""
|
||||||
|
removed_count = 0
|
||||||
|
|
||||||
|
for cache_file in self.cache_dir.glob("*.json"):
|
||||||
|
try:
|
||||||
|
cache_file.unlink()
|
||||||
|
removed_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"删除缓存文件失败: {cache_file}, 错误: {e}")
|
||||||
|
|
||||||
|
logger.info(f"清空缓存完成,删除了 {removed_count} 个文件")
|
||||||
|
return removed_count
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取缓存统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存统计信息字典
|
||||||
|
"""
|
||||||
|
total_files = 0
|
||||||
|
expired_files = 0
|
||||||
|
total_size = 0
|
||||||
|
|
||||||
|
for cache_file in self.cache_dir.glob("*.json"):
|
||||||
|
try:
|
||||||
|
total_files += 1
|
||||||
|
total_size += cache_file.stat().st_size
|
||||||
|
|
||||||
|
with open(cache_file, "r", encoding="utf-8") as f:
|
||||||
|
cache_data = json.load(f)
|
||||||
|
|
||||||
|
cached_time = datetime.fromisoformat(cache_data["timestamp"])
|
||||||
|
if self._is_cache_expired(cached_time):
|
||||||
|
expired_files += 1
|
||||||
|
|
||||||
|
except (OSError, json.JSONDecodeError, KeyError, ValueError):
|
||||||
|
expired_files += 1 # 损坏的文件也算作过期
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_files": total_files,
|
||||||
|
"expired_files": expired_files,
|
||||||
|
"total_size_bytes": total_size,
|
||||||
|
"cache_dir": str(self.cache_dir),
|
||||||
|
"max_age_hours": self.max_age.total_seconds() / 3600,
|
||||||
|
"similarity_threshold": self.similarity_threshold,
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_cache = ToolCache()
|
||||||
@@ -17,6 +17,7 @@ from src.common.server import get_global_server, Server
|
|||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from src.manager.schedule_manager import schedule_manager
|
from src.manager.schedule_manager import schedule_manager
|
||||||
|
from src.common.cache_manager import tool_cache
|
||||||
# from src.api.main import start_api_server
|
# from src.api.main import start_api_server
|
||||||
|
|
||||||
# 导入新的插件管理器和热重载管理器
|
# 导入新的插件管理器和热重载管理器
|
||||||
|
|||||||
@@ -88,7 +88,8 @@ class WebSurfingTool(BaseTool):
|
|||||||
return {"error": "搜索查询不能为空。"}
|
return {"error": "搜索查询不能为空。"}
|
||||||
|
|
||||||
# 检查缓存
|
# 检查缓存
|
||||||
cached_result = tool_cache.get(self.name, function_args)
|
query = function_args.get("query")
|
||||||
|
cached_result = await tool_cache.get(self.name, function_args, tool_class=self.__class__, semantic_query=query)
|
||||||
if cached_result:
|
if cached_result:
|
||||||
logger.info(f"缓存命中: {self.name} -> {function_args}")
|
logger.info(f"缓存命中: {self.name} -> {function_args}")
|
||||||
return cached_result
|
return cached_result
|
||||||
@@ -109,7 +110,8 @@ class WebSurfingTool(BaseTool):
|
|||||||
|
|
||||||
# 保存到缓存
|
# 保存到缓存
|
||||||
if "error" not in result:
|
if "error" not in result:
|
||||||
tool_cache.set(self.name, function_args, result)
|
query = function_args.get("query")
|
||||||
|
await tool_cache.set(self.name, function_args, self.__class__, result, semantic_query=query)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -463,7 +465,7 @@ class URLParserTool(BaseTool):
|
|||||||
执行URL内容提取和总结。优先使用Exa,失败后尝试本地解析。
|
执行URL内容提取和总结。优先使用Exa,失败后尝试本地解析。
|
||||||
"""
|
"""
|
||||||
# 检查缓存
|
# 检查缓存
|
||||||
cached_result = tool_cache.get(self.name, function_args)
|
cached_result = await tool_cache.get(self.name, function_args, tool_class=self.__class__)
|
||||||
if cached_result:
|
if cached_result:
|
||||||
logger.info(f"缓存命中: {self.name} -> {function_args}")
|
logger.info(f"缓存命中: {self.name} -> {function_args}")
|
||||||
return cached_result
|
return cached_result
|
||||||
@@ -577,7 +579,7 @@ class URLParserTool(BaseTool):
|
|||||||
|
|
||||||
# 保存到缓存
|
# 保存到缓存
|
||||||
if "error" not in result:
|
if "error" not in result:
|
||||||
tool_cache.set(self.name, function_args, result)
|
await tool_cache.set(self.name, function_args, self.__class__, result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user