refactor(cache): 重构缓存系统为分层语义缓存

将原有的基于文件的 `ToolCache` 替换为全新的 `CacheManager`,引入了更复杂和高效的分层语义缓存机制。

新系统特性:
- **分层缓存**:
  - L1 缓存: 内存字典 (KV) + FAISS (向量),用于极速访问。
  - L2 缓存: SQLite (KV) + ChromaDB (向量),用于持久化存储。
- **语义缓存**: 利用嵌入模型 (Embedding) 对查询进行向量化,实现基于语义相似度的缓存命中,显著提高了缓存命中率。
- **自动失效**: 缓存键包含工具源代码的哈希值,当工具代码更新时,相关缓存会自动失效,避免了脏数据问题。
- **异步支持**: 缓存的 `get` 和 `set` 方法现在是异步的,以适应项目中异步化的工具调用流程。

`web_search_tool` 已更新以使用新的 `CacheManager`,在调用缓存时传递 `tool_class` 和 `semantic_query` 以充分利用新功能。

Co-Authored-By: tt-P607 <68868379+tt-P607@users.noreply.github.com>
This commit is contained in:
minecraft1024a
2025-08-17 22:18:26 +08:00
parent b45a992484
commit 51c0d2a1e8
4 changed files with 546 additions and 318 deletions

View File

@@ -1,344 +1,225 @@
import time
import json import json
import sqlite3
import chromadb
import hashlib import hashlib
import re import inspect
import numpy as np
import faiss
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from datetime import datetime, timedelta
from pathlib import Path
from difflib import SequenceMatcher
from src.common.logger import get_logger from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.api_ada_configs import ModelTaskConfig
from src.config.config import global_config, model_config
logger = get_logger("cache_manager") logger = get_logger("cache_manager")
class CacheManager:
"""
一个支持分层和语义缓存的通用工具缓存管理器。
采用单例模式,确保在整个应用中只有一个缓存实例。
L1缓存: 内存字典 (KV) + FAISS (Vector)。
L2缓存: SQLite (KV) + ChromaDB (Vector)。
"""
_instance = None
class ToolCache: def __new__(cls, *args, **kwargs):
"""工具缓存管理器,用于缓存工具调用结果,支持近似匹配""" if not cls._instance:
cls._instance = super(CacheManager, cls).__new__(cls)
return cls._instance
def __init__( def __init__(self, default_ttl: int = 3600, db_path: str = "data/cache.db", chroma_path: str = "data/chroma_db"):
self,
cache_dir: str = "data/tool_cache",
max_age_hours: int = 24,
similarity_threshold: float = 0.65,
):
""" """
初始化缓存管理器 初始化缓存管理器
Args:
cache_dir: 缓存目录路径
max_age_hours: 缓存最大存活时间(小时)
similarity_threshold: 近似匹配的相似度阈值 (0-1)
""" """
self.cache_dir = Path(cache_dir) if not hasattr(self, '_initialized'):
self.max_age = timedelta(hours=max_age_hours) self.default_ttl = default_ttl
self.max_age_seconds = max_age_hours * 3600
self.similarity_threshold = similarity_threshold
self.cache_dir.mkdir(parents=True, exist_ok=True)
@staticmethod # L1 缓存 (内存)
def _normalize_query(query: str) -> str: self.l1_kv_cache: Dict[str, Dict[str, Any]] = {}
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
self.l1_vector_id_to_key: Dict[int, str] = {}
# L2 缓存 (持久化)
self.db_path = db_path
self._init_sqlite()
self.chroma_client = chromadb.PersistentClient(path=chroma_path)
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
# 嵌入模型
self.embedding_model = LLMRequest(model_config.model_task_config.embedding)
self._initialized = True
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (SQLite+ChromaDB)")
def _init_sqlite(self):
"""初始化SQLite数据库和表结构。"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS cache (
key TEXT PRIMARY KEY,
value TEXT,
expires_at REAL
)
""")
conn.commit()
def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any) -> str:
"""生成确定性的缓存键,包含代码哈希以实现自动失效。"""
try:
source_code = inspect.getsource(tool_class)
code_hash = hashlib.md5(source_code.encode()).hexdigest()
except (TypeError, OSError) as e:
code_hash = "unknown"
logger.warning(f"无法获取 {tool_class.__name__} 的源代码,代码哈希将为 'unknown'。错误: {e}")
try:
sorted_args = json.dumps(function_args, sort_keys=True)
except TypeError:
sorted_args = repr(sorted(function_args.items()))
return f"{tool_name}::{sorted_args}::{code_hash}"
async def get(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any, semantic_query: Optional[str] = None) -> Optional[Any]:
""" """
标准化查询文本,用于相似度比较 从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。
Args:
query: 原始查询文本
Returns:
标准化后的查询文本
""" """
if not query: # 步骤 1: L1 精确缓存查询
return "" key = self._generate_key(tool_name, function_args, tool_class)
logger.debug(f"生成的缓存键: {key}")
if semantic_query:
logger.debug(f"使用的语义查询: '{semantic_query}'")
# 纯 Python 实现 if key in self.l1_kv_cache:
normalized = query.lower() entry = self.l1_kv_cache[key]
normalized = re.sub(r"[^\w\s]", " ", normalized) if time.time() < entry["expires_at"]:
normalized = " ".join(normalized.split()) logger.info(f"命中L1键值缓存: {key}")
return normalized return entry["data"]
else:
del self.l1_kv_cache[key]
def _calculate_similarity(self, text1: str, text2: str) -> float: # 步骤 2: L1/L2 语义和L2精确缓存查询
""" query_embedding = None
计算两个文本的相似度 if semantic_query and self.embedding_model:
embedding_result = await self.embedding_model.get_embedding(semantic_query)
if embedding_result:
query_embedding = np.array([embedding_result], dtype='float32')
Args: # 步骤 2a: L1 语义缓存 (FAISS)
text1: 文本1 if query_embedding is not None and self.l1_vector_index.ntotal > 0:
text2: 文本2 faiss.normalize_L2(query_embedding)
distances, indices = self.l1_vector_index.search(query_embedding, 1)
if indices.size > 0 and distances[0][0] > 0.75: # IP 越大越相似
hit_index = indices[0][0]
l1_hit_key = self.l1_vector_id_to_key.get(hit_index)
if l1_hit_key and l1_hit_key in self.l1_kv_cache:
logger.info(f"命中L1语义缓存: {l1_hit_key}")
return self.l1_kv_cache[l1_hit_key]["data"]
Returns: # 步骤 2b: L2 精确缓存 (SQLite)
相似度分数 (0-1) with sqlite3.connect(self.db_path) as conn:
""" cursor = conn.cursor()
if not text1 or not text2: cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (key,))
return 0.0 row = cursor.fetchone()
if row:
value, expires_at = row
if time.time() < expires_at:
logger.info(f"命中L2键值缓存: {key}")
data = json.loads(value)
# 回填 L1
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
return data
else:
cursor.execute("DELETE FROM cache WHERE key = ?", (key,))
conn.commit()
# 纯 Python 实现 # 步骤 2c: L2 语义缓存 (ChromaDB)
norm_text1 = self._normalize_query(text1) if query_embedding is not None:
norm_text2 = self._normalize_query(text2) results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1)
if results and results['ids'] and results['ids'][0]:
distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A'
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
if distance != 'N/A' and distance < 0.75:
l2_hit_key = results['ids'][0]
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (l2_hit_key if isinstance(l2_hit_key, str) else l2_hit_key[0],))
row = cursor.fetchone()
if row:
value, expires_at = row
if time.time() < expires_at:
data = json.loads(value)
logger.debug(f"L2语义缓存返回的数据: {data}")
# 回填 L1
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
if query_embedding is not None:
new_id = self.l1_vector_index.ntotal
faiss.normalize_L2(query_embedding)
self.l1_vector_index.add(x=query_embedding)
self.l1_vector_id_to_key[new_id] = key
return data
if norm_text1 == norm_text2: logger.debug(f"缓存未命中: {key}")
return 1.0
return SequenceMatcher(None, norm_text1, norm_text2).ratio()
@staticmethod
def _generate_cache_key(tool_name: str, function_args: Dict[str, Any]) -> str:
"""
生成缓存键
Args:
tool_name: 工具名称
function_args: 函数参数
Returns:
缓存键字符串
"""
# 将参数排序后序列化,确保相同参数产生相同的键
sorted_args = json.dumps(function_args, sort_keys=True, ensure_ascii=False)
# 纯 Python 实现
cache_string = f"{tool_name}:{sorted_args}"
return hashlib.md5(cache_string.encode("utf-8")).hexdigest()
def _get_cache_file_path(self, cache_key: str) -> Path:
"""获取缓存文件路径"""
return self.cache_dir / f"{cache_key}.json"
def _is_cache_expired(self, cached_time: datetime) -> bool:
"""检查缓存是否过期"""
return datetime.now() - cached_time > self.max_age
def _find_similar_cache(
self, tool_name: str, function_args: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
查找相似的缓存条目
Args:
tool_name: 工具名称
function_args: 函数参数
Returns:
相似的缓存结果如果不存在则返回None
"""
query = function_args.get("query", "")
if not query:
return None
candidates = []
cache_data_list = []
# 遍历所有缓存文件,收集候选项
for cache_file in self.cache_dir.glob("*.json"):
try:
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
# 检查是否是同一个工具
if cache_data.get("tool_name") != tool_name:
continue
# 检查缓存是否过期
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
continue
# 检查其他参数是否匹配除了query
cached_args = cache_data.get("function_args", {})
args_match = True
for key, value in function_args.items():
if key != "query" and cached_args.get(key) != value:
args_match = False
break
if not args_match:
continue
# 收集候选项
cached_query = cached_args.get("query", "")
candidates.append((cached_query, len(cache_data_list)))
cache_data_list.append(cache_data)
except Exception as e:
logger.warning(f"检查缓存文件时出错: {cache_file}, 错误: {e}")
continue
if not candidates:
logger.debug(
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
)
return None
# 纯 Python 实现
best_match = None
best_similarity = 0.0
for cached_query, index in candidates:
similarity = self._calculate_similarity(query, cached_query)
if similarity > best_similarity and similarity >= self.similarity_threshold:
best_similarity = similarity
best_match = cache_data_list[index]
if best_match is not None:
cached_query = best_match["function_args"].get("query", "")
logger.info(
f"相似缓存命中,相似度: {best_similarity:.2f}, 原查询: '{cached_query}', 当前查询: '{query}'"
)
return best_match["result"]
logger.debug(
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
)
return None return None
def get( async def set(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any, data: Any, ttl: Optional[int] = None, semantic_query: Optional[str] = None):
self, tool_name: str, function_args: Dict[str, Any] """将结果存入所有缓存层。"""
) -> Optional[Dict[str, Any]]: if ttl is None:
""" ttl = self.default_ttl
从缓存获取结果,支持精确匹配和近似匹配 if ttl <= 0:
return
Args: key = self._generate_key(tool_name, function_args, tool_class)
tool_name: 工具名称 expires_at = time.time() + ttl
function_args: 函数参数
Returns: # 写入 L1
缓存的结果如果不存在或已过期则返回None self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
"""
# 首先尝试精确匹配
cache_key = self._generate_cache_key(tool_name, function_args)
cache_file = self._get_cache_file_path(cache_key)
if cache_file.exists(): # 写入 L2
try: value = json.dumps(data)
with open(cache_file, "r", encoding="utf-8") as f: with sqlite3.connect(self.db_path) as conn:
cache_data = json.load(f) cursor = conn.cursor()
cursor.execute("REPLACE INTO cache (key, value, expires_at) VALUES (?, ?, ?)", (key, value, expires_at))
conn.commit()
# 检查缓存是否过期 # 写入语义缓存
cached_time = datetime.fromisoformat(cache_data["timestamp"]) if semantic_query and self.embedding_model:
if self._is_cache_expired(cached_time): embedding_result = await self.embedding_model.get_embedding(semantic_query)
logger.debug(f"缓存已过期: {cache_key}") if embedding_result:
cache_file.unlink() # 删除过期缓存 embedding = np.array([embedding_result], dtype='float32')
else: # 写入 L1 Vector
logger.debug(f"精确匹配缓存: {tool_name}") new_id = self.l1_vector_index.ntotal
return cache_data["result"] faiss.normalize_L2(embedding)
self.l1_vector_index.add(x=embedding)
self.l1_vector_id_to_key[new_id] = key
# 写入 L2 Vector
self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key])
except (json.JSONDecodeError, KeyError, ValueError) as e: logger.info(f"已缓存条目: {key}, TTL: {ttl}s")
logger.warning(f"读取缓存文件失败: {cache_file}, 错误: {e}")
# 删除损坏的缓存文件
if cache_file.exists():
cache_file.unlink()
# 如果精确匹配失败,尝试近似匹配 def clear_l1(self):
return self._find_similar_cache(tool_name, function_args) """清空L1缓存。"""
self.l1_kv_cache.clear()
self.l1_vector_index.reset()
self.l1_vector_id_to_key.clear()
logger.info("L1 (内存+FAISS) 缓存已清空。")
def set( def clear_l2(self):
self, tool_name: str, function_args: Dict[str, Any], result: Dict[str, Any] """清空L2缓存。"""
) -> None: with sqlite3.connect(self.db_path) as conn:
""" cursor = conn.cursor()
将结果保存到缓存 cursor.execute("DELETE FROM cache")
conn.commit()
self.chroma_client.delete_collection(name="semantic_cache")
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
logger.info("L2 (SQLite & ChromaDB) 缓存已清空。")
Args: def clear_all(self):
tool_name: 工具名称 """清空所有缓存。"""
function_args: 函数参数 self.clear_l1()
result: 缓存结果 self.clear_l2()
""" logger.info("所有缓存层级已清空。")
cache_key = self._generate_cache_key(tool_name, function_args)
cache_file = self._get_cache_file_path(cache_key)
cache_data = { # 全局实例
"tool_name": tool_name, tool_cache = CacheManager()
"function_args": function_args,
"result": result,
"timestamp": datetime.now().isoformat(),
}
try:
with open(cache_file, "w", encoding="utf-8") as f:
json.dump(cache_data, f, ensure_ascii=False, indent=2)
logger.debug(f"缓存已保存: {tool_name} -> {cache_key}")
except Exception as e:
logger.error(f"保存缓存失败: {cache_file}, 错误: {e}")
def clear_expired(self) -> int:
"""
清理过期缓存
Returns:
删除的文件数量
"""
removed_count = 0
for cache_file in self.cache_dir.glob("*.json"):
try:
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
cache_file.unlink()
removed_count += 1
logger.debug(f"删除过期缓存: {cache_file}")
except Exception as e:
logger.warning(f"清理缓存文件时出错: {cache_file}, 错误: {e}")
# 删除损坏的文件
try:
cache_file.unlink()
removed_count += 1
except (OSError, json.JSONDecodeError, KeyError, ValueError):
logger.warning(f"删除损坏的缓存文件失败: {cache_file}, 错误: {e}")
logger.info(f"清理完成,删除了 {removed_count} 个过期缓存文件")
return removed_count
def clear_all(self) -> int:
"""
清空所有缓存
Returns:
删除的文件数量
"""
removed_count = 0
for cache_file in self.cache_dir.glob("*.json"):
try:
cache_file.unlink()
removed_count += 1
except Exception as e:
logger.warning(f"删除缓存文件失败: {cache_file}, 错误: {e}")
logger.info(f"清空缓存完成,删除了 {removed_count} 个文件")
return removed_count
def get_stats(self) -> Dict[str, Any]:
"""
获取缓存统计信息
Returns:
缓存统计信息字典
"""
total_files = 0
expired_files = 0
total_size = 0
for cache_file in self.cache_dir.glob("*.json"):
try:
total_files += 1
total_size += cache_file.stat().st_size
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
expired_files += 1
except (OSError, json.JSONDecodeError, KeyError, ValueError):
expired_files += 1 # 损坏的文件也算作过期
return {
"total_files": total_files,
"expired_files": expired_files,
"total_size_bytes": total_size,
"cache_dir": str(self.cache_dir),
"max_age_hours": self.max_age.total_seconds() / 3600,
"similarity_threshold": self.similarity_threshold,
}
tool_cache = ToolCache()

View File

@@ -0,0 +1,344 @@
import json
import hashlib
import re
from typing import Any, Dict, Optional
from datetime import datetime, timedelta
from pathlib import Path
from difflib import SequenceMatcher
from src.common.logger import get_logger
logger = get_logger("cache_manager")
class ToolCache:
"""工具缓存管理器,用于缓存工具调用结果,支持近似匹配"""
def __init__(
self,
cache_dir: str = "data/tool_cache",
max_age_hours: int = 24,
similarity_threshold: float = 0.65,
):
"""
初始化缓存管理器
Args:
cache_dir: 缓存目录路径
max_age_hours: 缓存最大存活时间(小时)
similarity_threshold: 近似匹配的相似度阈值 (0-1)
"""
self.cache_dir = Path(cache_dir)
self.max_age = timedelta(hours=max_age_hours)
self.max_age_seconds = max_age_hours * 3600
self.similarity_threshold = similarity_threshold
self.cache_dir.mkdir(parents=True, exist_ok=True)
@staticmethod
def _normalize_query(query: str) -> str:
"""
标准化查询文本,用于相似度比较
Args:
query: 原始查询文本
Returns:
标准化后的查询文本
"""
if not query:
return ""
# 纯 Python 实现
normalized = query.lower()
normalized = re.sub(r"[^\w\s]", " ", normalized)
normalized = " ".join(normalized.split())
return normalized
def _calculate_similarity(self, text1: str, text2: str) -> float:
"""
计算两个文本的相似度
Args:
text1: 文本1
text2: 文本2
Returns:
相似度分数 (0-1)
"""
if not text1 or not text2:
return 0.0
# 纯 Python 实现
norm_text1 = self._normalize_query(text1)
norm_text2 = self._normalize_query(text2)
if norm_text1 == norm_text2:
return 1.0
return SequenceMatcher(None, norm_text1, norm_text2).ratio()
@staticmethod
def _generate_cache_key(tool_name: str, function_args: Dict[str, Any]) -> str:
"""
生成缓存键
Args:
tool_name: 工具名称
function_args: 函数参数
Returns:
缓存键字符串
"""
# 将参数排序后序列化,确保相同参数产生相同的键
sorted_args = json.dumps(function_args, sort_keys=True, ensure_ascii=False)
# 纯 Python 实现
cache_string = f"{tool_name}:{sorted_args}"
return hashlib.md5(cache_string.encode("utf-8")).hexdigest()
def _get_cache_file_path(self, cache_key: str) -> Path:
"""获取缓存文件路径"""
return self.cache_dir / f"{cache_key}.json"
def _is_cache_expired(self, cached_time: datetime) -> bool:
"""检查缓存是否过期"""
return datetime.now() - cached_time > self.max_age
def _find_similar_cache(
self, tool_name: str, function_args: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
查找相似的缓存条目
Args:
tool_name: 工具名称
function_args: 函数参数
Returns:
相似的缓存结果如果不存在则返回None
"""
query = function_args.get("query", "")
if not query:
return None
candidates = []
cache_data_list = []
# 遍历所有缓存文件,收集候选项
for cache_file in self.cache_dir.glob("*.json"):
try:
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
# 检查是否是同一个工具
if cache_data.get("tool_name") != tool_name:
continue
# 检查缓存是否过期
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
continue
# 检查其他参数是否匹配除了query
cached_args = cache_data.get("function_args", {})
args_match = True
for key, value in function_args.items():
if key != "query" and cached_args.get(key) != value:
args_match = False
break
if not args_match:
continue
# 收集候选项
cached_query = cached_args.get("query", "")
candidates.append((cached_query, len(cache_data_list)))
cache_data_list.append(cache_data)
except Exception as e:
logger.warning(f"检查缓存文件时出错: {cache_file}, 错误: {e}")
continue
if not candidates:
logger.debug(
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
)
return None
# 纯 Python 实现
best_match = None
best_similarity = 0.0
for cached_query, index in candidates:
similarity = self._calculate_similarity(query, cached_query)
if similarity > best_similarity and similarity >= self.similarity_threshold:
best_similarity = similarity
best_match = cache_data_list[index]
if best_match is not None:
cached_query = best_match["function_args"].get("query", "")
logger.info(
f"相似缓存命中,相似度: {best_similarity:.2f}, 原查询: '{cached_query}', 当前查询: '{query}'"
)
return best_match["result"]
logger.debug(
f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}"
)
return None
def get(
self, tool_name: str, function_args: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
从缓存获取结果,支持精确匹配和近似匹配
Args:
tool_name: 工具名称
function_args: 函数参数
Returns:
缓存的结果如果不存在或已过期则返回None
"""
# 首先尝试精确匹配
cache_key = self._generate_cache_key(tool_name, function_args)
cache_file = self._get_cache_file_path(cache_key)
if cache_file.exists():
try:
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
# 检查缓存是否过期
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
logger.debug(f"缓存已过期: {cache_key}")
cache_file.unlink() # 删除过期缓存
else:
logger.debug(f"精确匹配缓存: {tool_name}")
return cache_data["result"]
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f"读取缓存文件失败: {cache_file}, 错误: {e}")
# 删除损坏的缓存文件
if cache_file.exists():
cache_file.unlink()
# 如果精确匹配失败,尝试近似匹配
return self._find_similar_cache(tool_name, function_args)
def set(
self, tool_name: str, function_args: Dict[str, Any], result: Dict[str, Any]
) -> None:
"""
将结果保存到缓存
Args:
tool_name: 工具名称
function_args: 函数参数
result: 缓存结果
"""
cache_key = self._generate_cache_key(tool_name, function_args)
cache_file = self._get_cache_file_path(cache_key)
cache_data = {
"tool_name": tool_name,
"function_args": function_args,
"result": result,
"timestamp": datetime.now().isoformat(),
}
try:
with open(cache_file, "w", encoding="utf-8") as f:
json.dump(cache_data, f, ensure_ascii=False, indent=2)
logger.debug(f"缓存已保存: {tool_name} -> {cache_key}")
except Exception as e:
logger.error(f"保存缓存失败: {cache_file}, 错误: {e}")
def clear_expired(self) -> int:
"""
清理过期缓存
Returns:
删除的文件数量
"""
removed_count = 0
for cache_file in self.cache_dir.glob("*.json"):
try:
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
cache_file.unlink()
removed_count += 1
logger.debug(f"删除过期缓存: {cache_file}")
except Exception as e:
logger.warning(f"清理缓存文件时出错: {cache_file}, 错误: {e}")
# 删除损坏的文件
try:
cache_file.unlink()
removed_count += 1
except (OSError, json.JSONDecodeError, KeyError, ValueError):
logger.warning(f"删除损坏的缓存文件失败: {cache_file}, 错误: {e}")
logger.info(f"清理完成,删除了 {removed_count} 个过期缓存文件")
return removed_count
def clear_all(self) -> int:
"""
清空所有缓存
Returns:
删除的文件数量
"""
removed_count = 0
for cache_file in self.cache_dir.glob("*.json"):
try:
cache_file.unlink()
removed_count += 1
except Exception as e:
logger.warning(f"删除缓存文件失败: {cache_file}, 错误: {e}")
logger.info(f"清空缓存完成,删除了 {removed_count} 个文件")
return removed_count
def get_stats(self) -> Dict[str, Any]:
"""
获取缓存统计信息
Returns:
缓存统计信息字典
"""
total_files = 0
expired_files = 0
total_size = 0
for cache_file in self.cache_dir.glob("*.json"):
try:
total_files += 1
total_size += cache_file.stat().st_size
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)
cached_time = datetime.fromisoformat(cache_data["timestamp"])
if self._is_cache_expired(cached_time):
expired_files += 1
except (OSError, json.JSONDecodeError, KeyError, ValueError):
expired_files += 1 # 损坏的文件也算作过期
return {
"total_files": total_files,
"expired_files": expired_files,
"total_size_bytes": total_size,
"cache_dir": str(self.cache_dir),
"max_age_hours": self.max_age.total_seconds() / 3600,
"similarity_threshold": self.similarity_threshold,
}
tool_cache = ToolCache()

View File

@@ -17,6 +17,7 @@ from src.common.server import get_global_server, Server
from src.mood.mood_manager import mood_manager from src.mood.mood_manager import mood_manager
from rich.traceback import install from rich.traceback import install
from src.manager.schedule_manager import schedule_manager from src.manager.schedule_manager import schedule_manager
from src.common.cache_manager import tool_cache
# from src.api.main import start_api_server # from src.api.main import start_api_server
# 导入新的插件管理器和热重载管理器 # 导入新的插件管理器和热重载管理器

View File

@@ -88,7 +88,8 @@ class WebSurfingTool(BaseTool):
return {"error": "搜索查询不能为空。"} return {"error": "搜索查询不能为空。"}
# 检查缓存 # 检查缓存
cached_result = tool_cache.get(self.name, function_args) query = function_args.get("query")
cached_result = await tool_cache.get(self.name, function_args, tool_class=self.__class__, semantic_query=query)
if cached_result: if cached_result:
logger.info(f"缓存命中: {self.name} -> {function_args}") logger.info(f"缓存命中: {self.name} -> {function_args}")
return cached_result return cached_result
@@ -109,7 +110,8 @@ class WebSurfingTool(BaseTool):
# 保存到缓存 # 保存到缓存
if "error" not in result: if "error" not in result:
tool_cache.set(self.name, function_args, result) query = function_args.get("query")
await tool_cache.set(self.name, function_args, self.__class__, result, semantic_query=query)
return result return result
@@ -463,7 +465,7 @@ class URLParserTool(BaseTool):
执行URL内容提取和总结。优先使用Exa失败后尝试本地解析。 执行URL内容提取和总结。优先使用Exa失败后尝试本地解析。
""" """
# 检查缓存 # 检查缓存
cached_result = tool_cache.get(self.name, function_args) cached_result = await tool_cache.get(self.name, function_args, tool_class=self.__class__)
if cached_result: if cached_result:
logger.info(f"缓存命中: {self.name} -> {function_args}") logger.info(f"缓存命中: {self.name} -> {function_args}")
return cached_result return cached_result
@@ -577,7 +579,7 @@ class URLParserTool(BaseTool):
# 保存到缓存 # 保存到缓存
if "error" not in result: if "error" not in result:
tool_cache.set(self.name, function_args, result) await tool_cache.set(self.name, function_args, self.__class__, result)
return result return result