Refactor cache L2 storage to use SQLAlchemy DB
Replaces the L2 cache layer's SQLite implementation with an async SQLAlchemy-based database model (CacheEntries). Updates cache_manager.py to use db_query and db_save for cache operations, adds semantic cache handling with ChromaDB, and introduces async cache clearing and expiration cleaning methods. Adds the CacheEntries model and integrates it into the database API.
This commit is contained in:
@@ -1,15 +1,16 @@
|
||||
import time
|
||||
import json
|
||||
import sqlite3
|
||||
import chromadb
|
||||
import hashlib
|
||||
import inspect
|
||||
import numpy as np
|
||||
import faiss
|
||||
import chromadb
|
||||
from typing import Any, Dict, Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.database.sqlalchemy_models import CacheEntries
|
||||
from src.common.database.sqlalchemy_database_api import db_query, db_save
|
||||
|
||||
logger = get_logger("cache_manager")
|
||||
|
||||
@@ -18,7 +19,7 @@ class CacheManager:
|
||||
一个支持分层和语义缓存的通用工具缓存管理器。
|
||||
采用单例模式,确保在整个应用中只有一个缓存实例。
|
||||
L1缓存: 内存字典 (KV) + FAISS (Vector)。
|
||||
L2缓存: SQLite (KV) + ChromaDB (Vector)。
|
||||
L2缓存: 数据库 (KV) + ChromaDB (Vector)。
|
||||
"""
|
||||
_instance = None
|
||||
|
||||
@@ -27,7 +28,7 @@ class CacheManager:
|
||||
cls._instance = super(CacheManager, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, default_ttl: int = 3600, db_path: str = "data/cache.db", chroma_path: str = "data/chroma_db"):
|
||||
def __init__(self, default_ttl: int = 3600, chroma_path: str = "data/chroma_db"):
|
||||
"""
|
||||
初始化缓存管理器。
|
||||
"""
|
||||
@@ -40,30 +41,54 @@ class CacheManager:
|
||||
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
|
||||
self.l1_vector_id_to_key: Dict[int, str] = {}
|
||||
|
||||
# L2 缓存 (持久化)
|
||||
self.db_path = db_path
|
||||
self._init_sqlite()
|
||||
# 语义缓存 (ChromaDB)
|
||||
|
||||
self.chroma_client = chromadb.PersistentClient(path=chroma_path)
|
||||
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
|
||||
|
||||
|
||||
# 嵌入模型
|
||||
self.embedding_model = LLMRequest(model_config.model_task_config.embedding)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (SQLite+ChromaDB)")
|
||||
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB)")
|
||||
|
||||
def _init_sqlite(self):
|
||||
"""初始化SQLite数据库和表结构。"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS cache (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT,
|
||||
expires_at REAL
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
def _validate_embedding(self, embedding_result: Any) -> Optional[np.ndarray]:
|
||||
"""
|
||||
验证和标准化嵌入向量格式
|
||||
"""
|
||||
try:
|
||||
if embedding_result is None:
|
||||
return None
|
||||
|
||||
# 确保embedding_result是一维数组或列表
|
||||
if isinstance(embedding_result, (list, tuple, np.ndarray)):
|
||||
# 转换为numpy数组进行处理
|
||||
embedding_array = np.array(embedding_result)
|
||||
|
||||
# 如果是多维数组,展平它
|
||||
if embedding_array.ndim > 1:
|
||||
embedding_array = embedding_array.flatten()
|
||||
|
||||
# 检查维度是否符合预期
|
||||
expected_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||
if embedding_array.shape[0] != expected_dim:
|
||||
logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}")
|
||||
return None
|
||||
|
||||
# 检查是否包含有效的数值
|
||||
if np.isnan(embedding_array).any() or np.isinf(embedding_array).any():
|
||||
logger.warning("嵌入向量包含无效的数值 (NaN 或 Inf)")
|
||||
return None
|
||||
|
||||
return embedding_array.astype('float32')
|
||||
else:
|
||||
logger.warning(f"嵌入结果格式不支持: {type(embedding_result)}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"验证嵌入向量时发生错误: {e}")
|
||||
return None
|
||||
|
||||
def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any) -> str:
|
||||
"""生成确定性的缓存键,包含代码哈希以实现自动失效。"""
|
||||
@@ -102,7 +127,9 @@ class CacheManager:
|
||||
if semantic_query and self.embedding_model:
|
||||
embedding_result = await self.embedding_model.get_embedding(semantic_query)
|
||||
if embedding_result:
|
||||
query_embedding = np.array([embedding_result], dtype='float32')
|
||||
validated_embedding = self._validate_embedding(embedding_result)
|
||||
if validated_embedding is not None:
|
||||
query_embedding = np.array([validated_embedding], dtype='float32')
|
||||
|
||||
# 步骤 2a: L1 语义缓存 (FAISS)
|
||||
if query_embedding is not None and self.l1_vector_index.ntotal > 0:
|
||||
@@ -115,49 +142,80 @@ class CacheManager:
|
||||
logger.info(f"命中L1语义缓存: {l1_hit_key}")
|
||||
return self.l1_kv_cache[l1_hit_key]["data"]
|
||||
|
||||
# 步骤 2b: L2 精确缓存 (SQLite)
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (key,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
value, expires_at = row
|
||||
if time.time() < expires_at:
|
||||
logger.info(f"命中L2键值缓存: {key}")
|
||||
data = json.loads(value)
|
||||
# 回填 L1
|
||||
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
||||
return data
|
||||
else:
|
||||
cursor.execute("DELETE FROM cache WHERE key = ?", (key,))
|
||||
conn.commit()
|
||||
# 步骤 2b: L2 精确缓存 (数据库)
|
||||
cache_results = await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="get",
|
||||
filters={"cache_key": key},
|
||||
single_result=True
|
||||
)
|
||||
|
||||
if cache_results:
|
||||
expires_at = cache_results["expires_at"]
|
||||
if time.time() < expires_at:
|
||||
logger.info(f"命中L2键值缓存: {key}")
|
||||
data = json.loads(cache_results["cache_value"])
|
||||
|
||||
# 更新访问统计
|
||||
await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="update",
|
||||
filters={"cache_key": key},
|
||||
data={
|
||||
"last_accessed": time.time(),
|
||||
"access_count": cache_results["access_count"] + 1
|
||||
}
|
||||
)
|
||||
|
||||
# 回填 L1
|
||||
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
||||
return data
|
||||
else:
|
||||
# 删除过期的缓存条目
|
||||
await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="delete",
|
||||
filters={"cache_key": key}
|
||||
)
|
||||
|
||||
# 步骤 2c: L2 语义缓存 (ChromaDB)
|
||||
if query_embedding is not None:
|
||||
results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1)
|
||||
if results and results['ids'] and results['ids'][0]:
|
||||
distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A'
|
||||
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
|
||||
if distance != 'N/A' and distance < 0.75:
|
||||
l2_hit_key = results['ids'][0]
|
||||
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (l2_hit_key if isinstance(l2_hit_key, str) else l2_hit_key[0],))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
value, expires_at = row
|
||||
if time.time() < expires_at:
|
||||
data = json.loads(value)
|
||||
logger.debug(f"L2语义缓存返回的数据: {data}")
|
||||
# 回填 L1
|
||||
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
||||
if query_embedding is not None:
|
||||
new_id = self.l1_vector_index.ntotal
|
||||
faiss.normalize_L2(query_embedding)
|
||||
self.l1_vector_index.add(x=query_embedding)
|
||||
self.l1_vector_id_to_key[new_id] = key
|
||||
return data
|
||||
if query_embedding is not None and self.chroma_collection:
|
||||
try:
|
||||
results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1)
|
||||
if results and results['ids'] and results['ids'][0]:
|
||||
distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A'
|
||||
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
|
||||
if distance != 'N/A' and distance < 0.75:
|
||||
l2_hit_key = results['ids'][0][0] if isinstance(results['ids'][0], list) else results['ids'][0]
|
||||
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
|
||||
|
||||
# 从数据库获取缓存数据
|
||||
semantic_cache_results = await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="get",
|
||||
filters={"cache_key": l2_hit_key},
|
||||
single_result=True
|
||||
)
|
||||
|
||||
if semantic_cache_results:
|
||||
expires_at = semantic_cache_results["expires_at"]
|
||||
if time.time() < expires_at:
|
||||
data = json.loads(semantic_cache_results["cache_value"])
|
||||
logger.debug(f"L2语义缓存返回的数据: {data}")
|
||||
|
||||
# 回填 L1
|
||||
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
||||
if query_embedding is not None:
|
||||
try:
|
||||
new_id = self.l1_vector_index.ntotal
|
||||
faiss.normalize_L2(query_embedding)
|
||||
self.l1_vector_index.add(x=query_embedding)
|
||||
self.l1_vector_id_to_key[new_id] = key
|
||||
except Exception as e:
|
||||
logger.error(f"回填L1向量索引时发生错误: {e}")
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.warning(f"ChromaDB查询失败: {e}")
|
||||
|
||||
logger.debug(f"缓存未命中: {key}")
|
||||
return None
|
||||
@@ -175,25 +233,41 @@ class CacheManager:
|
||||
# 写入 L1
|
||||
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
||||
|
||||
# 写入 L2
|
||||
value = json.dumps(data)
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("REPLACE INTO cache (key, value, expires_at) VALUES (?, ?, ?)", (key, value, expires_at))
|
||||
conn.commit()
|
||||
# 写入 L2 (数据库)
|
||||
cache_data = {
|
||||
"cache_key": key,
|
||||
"cache_value": json.dumps(data, ensure_ascii=False),
|
||||
"expires_at": expires_at,
|
||||
"tool_name": tool_name,
|
||||
"created_at": time.time(),
|
||||
"last_accessed": time.time(),
|
||||
"access_count": 1
|
||||
}
|
||||
|
||||
await db_save(
|
||||
model_class=CacheEntries,
|
||||
data=cache_data,
|
||||
key_field="cache_key",
|
||||
key_value=key
|
||||
)
|
||||
|
||||
# 写入语义缓存
|
||||
if semantic_query and self.embedding_model:
|
||||
embedding_result = await self.embedding_model.get_embedding(semantic_query)
|
||||
if embedding_result:
|
||||
embedding = np.array([embedding_result], dtype='float32')
|
||||
# 写入 L1 Vector
|
||||
new_id = self.l1_vector_index.ntotal
|
||||
faiss.normalize_L2(embedding)
|
||||
self.l1_vector_index.add(x=embedding)
|
||||
self.l1_vector_id_to_key[new_id] = key
|
||||
# 写入 L2 Vector
|
||||
self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key])
|
||||
if semantic_query and self.embedding_model and self.chroma_collection:
|
||||
try:
|
||||
embedding_result = await self.embedding_model.get_embedding(semantic_query)
|
||||
if embedding_result:
|
||||
validated_embedding = self._validate_embedding(embedding_result)
|
||||
if validated_embedding is not None:
|
||||
embedding = np.array([validated_embedding], dtype='float32')
|
||||
# 写入 L1 Vector
|
||||
new_id = self.l1_vector_index.ntotal
|
||||
faiss.normalize_L2(embedding)
|
||||
self.l1_vector_index.add(x=embedding)
|
||||
self.l1_vector_id_to_key[new_id] = key
|
||||
# 写入 L2 Vector
|
||||
self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key])
|
||||
except Exception as e:
|
||||
logger.warning(f"语义缓存写入失败: {e}")
|
||||
|
||||
logger.info(f"已缓存条目: {key}, TTL: {ttl}s")
|
||||
|
||||
@@ -204,21 +278,53 @@ class CacheManager:
|
||||
self.l1_vector_id_to_key.clear()
|
||||
logger.info("L1 (内存+FAISS) 缓存已清空。")
|
||||
|
||||
def clear_l2(self):
|
||||
async def clear_l2(self):
|
||||
"""清空L2缓存。"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM cache")
|
||||
conn.commit()
|
||||
self.chroma_client.delete_collection(name="semantic_cache")
|
||||
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
|
||||
logger.info("L2 (SQLite & ChromaDB) 缓存已清空。")
|
||||
# 清空数据库缓存
|
||||
await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="delete",
|
||||
filters={} # 删除所有记录
|
||||
)
|
||||
|
||||
# 清空ChromaDB
|
||||
if self.chroma_collection:
|
||||
try:
|
||||
self.chroma_client.delete_collection(name="semantic_cache")
|
||||
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"清空ChromaDB失败: {e}")
|
||||
|
||||
logger.info("L2 (数据库 & ChromaDB) 缓存已清空。")
|
||||
|
||||
def clear_all(self):
|
||||
async def clear_all(self):
|
||||
"""清空所有缓存。"""
|
||||
self.clear_l1()
|
||||
self.clear_l2()
|
||||
await self.clear_l2()
|
||||
logger.info("所有缓存层级已清空。")
|
||||
|
||||
async def clean_expired(self):
|
||||
"""清理过期的缓存条目"""
|
||||
current_time = time.time()
|
||||
|
||||
# 清理L1过期条目
|
||||
expired_keys = []
|
||||
for key, entry in self.l1_kv_cache.items():
|
||||
if current_time >= entry["expires_at"]:
|
||||
expired_keys.append(key)
|
||||
|
||||
for key in expired_keys:
|
||||
del self.l1_kv_cache[key]
|
||||
|
||||
# 清理L2过期条目
|
||||
await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="delete",
|
||||
filters={"expires_at": {"$lt": current_time}}
|
||||
)
|
||||
|
||||
if expired_keys:
|
||||
logger.info(f"清理了 {len(expired_keys)} 个过期的L1缓存条目")
|
||||
|
||||
# 全局实例
|
||||
tool_cache = CacheManager()
|
||||
@@ -15,7 +15,8 @@ from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import (
|
||||
Base, get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams,
|
||||
LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory,
|
||||
Expression, ThinkingLog, GraphNodes, GraphEdges, Schedule, MaiZoneScheduleStatus
|
||||
Expression, ThinkingLog, GraphNodes, GraphEdges, Schedule, MaiZoneScheduleStatus,
|
||||
CacheEntries
|
||||
)
|
||||
|
||||
logger = get_logger("sqlalchemy_database_api")
|
||||
@@ -38,6 +39,7 @@ MODEL_MAPPING = {
|
||||
'GraphEdges': GraphEdges,
|
||||
'Schedule': Schedule,
|
||||
'MaiZoneScheduleStatus': MaiZoneScheduleStatus,
|
||||
'CacheEntries': CacheEntries,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import QueuePool
|
||||
import os
|
||||
import datetime
|
||||
import time
|
||||
from src.common.logger import get_logger
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
@@ -476,6 +477,40 @@ class AntiInjectionStats(Base):
|
||||
)
|
||||
|
||||
|
||||
class CacheEntries(Base):
|
||||
"""工具缓存条目模型"""
|
||||
__tablename__ = 'cache_entries'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
cache_key = Column(get_string_field(500), nullable=False, unique=True, index=True)
|
||||
"""缓存键,包含工具名、参数和代码哈希"""
|
||||
|
||||
cache_value = Column(Text, nullable=False)
|
||||
"""缓存的数据,JSON格式"""
|
||||
|
||||
expires_at = Column(Float, nullable=False, index=True)
|
||||
"""过期时间戳"""
|
||||
|
||||
tool_name = Column(get_string_field(100), nullable=False, index=True)
|
||||
"""工具名称"""
|
||||
|
||||
created_at = Column(Float, nullable=False, default=lambda: time.time())
|
||||
"""创建时间戳"""
|
||||
|
||||
last_accessed = Column(Float, nullable=False, default=lambda: time.time())
|
||||
"""最后访问时间戳"""
|
||||
|
||||
access_count = Column(Integer, nullable=False, default=0)
|
||||
"""访问次数"""
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_cache_entries_key', 'cache_key'),
|
||||
Index('idx_cache_entries_expires_at', 'expires_at'),
|
||||
Index('idx_cache_entries_tool_name', 'tool_name'),
|
||||
Index('idx_cache_entries_created_at', 'created_at'),
|
||||
)
|
||||
|
||||
|
||||
# 数据库引擎和会话管理
|
||||
_engine = None
|
||||
_SessionLocal = None
|
||||
|
||||
Reference in New Issue
Block a user