Refactor cache L2 storage to use SQLAlchemy DB

Replaces the L2 cache layer's SQLite implementation with an async SQLAlchemy-based database model (CacheEntries). Updates cache_manager.py to use db_query and db_save for cache operations, adds semantic cache handling with ChromaDB, and introduces async cache clearing and expiration cleaning methods. Adds the CacheEntries model and integrates it into the database API.
This commit is contained in:
雅诺狐
2025-08-18 18:30:17 +08:00
parent 7856c6a8e9
commit bcbcabb0d8
3 changed files with 233 additions and 90 deletions

View File

@@ -1,15 +1,16 @@
import time import time
import json import json
import sqlite3
import chromadb
import hashlib import hashlib
import inspect import inspect
import numpy as np import numpy as np
import faiss import faiss
import chromadb
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.common.database.sqlalchemy_models import CacheEntries
from src.common.database.sqlalchemy_database_api import db_query, db_save
logger = get_logger("cache_manager") logger = get_logger("cache_manager")
@@ -18,7 +19,7 @@ class CacheManager:
一个支持分层和语义缓存的通用工具缓存管理器。 一个支持分层和语义缓存的通用工具缓存管理器。
采用单例模式,确保在整个应用中只有一个缓存实例。 采用单例模式,确保在整个应用中只有一个缓存实例。
L1缓存: 内存字典 (KV) + FAISS (Vector)。 L1缓存: 内存字典 (KV) + FAISS (Vector)。
L2缓存: SQLite (KV) + ChromaDB (Vector)。 L2缓存: 数据库 (KV) + ChromaDB (Vector)。
""" """
_instance = None _instance = None
@@ -27,7 +28,7 @@ class CacheManager:
cls._instance = super(CacheManager, cls).__new__(cls) cls._instance = super(CacheManager, cls).__new__(cls)
return cls._instance return cls._instance
def __init__(self, default_ttl: int = 3600, db_path: str = "data/cache.db", chroma_path: str = "data/chroma_db"): def __init__(self, default_ttl: int = 3600, chroma_path: str = "data/chroma_db"):
""" """
初始化缓存管理器。 初始化缓存管理器。
""" """
@@ -40,30 +41,54 @@ class CacheManager:
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim) self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
self.l1_vector_id_to_key: Dict[int, str] = {} self.l1_vector_id_to_key: Dict[int, str] = {}
# L2 缓存 (持久化) # 语义缓存 (ChromaDB)
self.db_path = db_path
self._init_sqlite()
self.chroma_client = chromadb.PersistentClient(path=chroma_path) self.chroma_client = chromadb.PersistentClient(path=chroma_path)
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache") self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
# 嵌入模型 # 嵌入模型
self.embedding_model = LLMRequest(model_config.model_task_config.embedding) self.embedding_model = LLMRequest(model_config.model_task_config.embedding)
self._initialized = True self._initialized = True
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (SQLite+ChromaDB)") logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB)")
def _init_sqlite(self): def _validate_embedding(self, embedding_result: Any) -> Optional[np.ndarray]:
"""初始化SQLite数据库和表结构。""" """
with sqlite3.connect(self.db_path) as conn: 验证和标准化嵌入向量格式
cursor = conn.cursor() """
cursor.execute(""" try:
CREATE TABLE IF NOT EXISTS cache ( if embedding_result is None:
key TEXT PRIMARY KEY, return None
value TEXT,
expires_at REAL # 确保embedding_result是一维数组或列表
) if isinstance(embedding_result, (list, tuple, np.ndarray)):
""") # 转换为numpy数组进行处理
conn.commit() embedding_array = np.array(embedding_result)
# 如果是多维数组,展平它
if embedding_array.ndim > 1:
embedding_array = embedding_array.flatten()
# 检查维度是否符合预期
expected_dim = global_config.lpmm_knowledge.embedding_dimension
if embedding_array.shape[0] != expected_dim:
logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}")
return None
# 检查是否包含有效的数值
if np.isnan(embedding_array).any() or np.isinf(embedding_array).any():
logger.warning("嵌入向量包含无效的数值 (NaN 或 Inf)")
return None
return embedding_array.astype('float32')
else:
logger.warning(f"嵌入结果格式不支持: {type(embedding_result)}")
return None
except Exception as e:
logger.error(f"验证嵌入向量时发生错误: {e}")
return None
def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any) -> str: def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any) -> str:
"""生成确定性的缓存键,包含代码哈希以实现自动失效。""" """生成确定性的缓存键,包含代码哈希以实现自动失效。"""
@@ -102,7 +127,9 @@ class CacheManager:
if semantic_query and self.embedding_model: if semantic_query and self.embedding_model:
embedding_result = await self.embedding_model.get_embedding(semantic_query) embedding_result = await self.embedding_model.get_embedding(semantic_query)
if embedding_result: if embedding_result:
query_embedding = np.array([embedding_result], dtype='float32') validated_embedding = self._validate_embedding(embedding_result)
if validated_embedding is not None:
query_embedding = np.array([validated_embedding], dtype='float32')
# 步骤 2a: L1 语义缓存 (FAISS) # 步骤 2a: L1 语义缓存 (FAISS)
if query_embedding is not None and self.l1_vector_index.ntotal > 0: if query_embedding is not None and self.l1_vector_index.ntotal > 0:
@@ -115,49 +142,80 @@ class CacheManager:
logger.info(f"命中L1语义缓存: {l1_hit_key}") logger.info(f"命中L1语义缓存: {l1_hit_key}")
return self.l1_kv_cache[l1_hit_key]["data"] return self.l1_kv_cache[l1_hit_key]["data"]
# 步骤 2b: L2 精确缓存 (SQLite) # 步骤 2b: L2 精确缓存 (数据库)
with sqlite3.connect(self.db_path) as conn: cache_results = await db_query(
cursor = conn.cursor() model_class=CacheEntries,
cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (key,)) query_type="get",
row = cursor.fetchone() filters={"cache_key": key},
if row: single_result=True
value, expires_at = row )
if cache_results:
expires_at = cache_results["expires_at"]
if time.time() < expires_at: if time.time() < expires_at:
logger.info(f"命中L2键值缓存: {key}") logger.info(f"命中L2键值缓存: {key}")
data = json.loads(value) data = json.loads(cache_results["cache_value"])
# 更新访问统计
await db_query(
model_class=CacheEntries,
query_type="update",
filters={"cache_key": key},
data={
"last_accessed": time.time(),
"access_count": cache_results["access_count"] + 1
}
)
# 回填 L1 # 回填 L1
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
return data return data
else: else:
cursor.execute("DELETE FROM cache WHERE key = ?", (key,)) # 删除过期的缓存条目
conn.commit() await db_query(
model_class=CacheEntries,
query_type="delete",
filters={"cache_key": key}
)
# 步骤 2c: L2 语义缓存 (ChromaDB) # 步骤 2c: L2 语义缓存 (ChromaDB)
if query_embedding is not None: if query_embedding is not None and self.chroma_collection:
try:
results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1) results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1)
if results and results['ids'] and results['ids'][0]: if results and results['ids'] and results['ids'][0]:
distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A' distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A'
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}") logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
if distance != 'N/A' and distance < 0.75: if distance != 'N/A' and distance < 0.75:
l2_hit_key = results['ids'][0] l2_hit_key = results['ids'][0][0] if isinstance(results['ids'][0], list) else results['ids'][0]
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}") logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor() # 从数据库获取缓存数据
cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (l2_hit_key if isinstance(l2_hit_key, str) else l2_hit_key[0],)) semantic_cache_results = await db_query(
row = cursor.fetchone() model_class=CacheEntries,
if row: query_type="get",
value, expires_at = row filters={"cache_key": l2_hit_key},
single_result=True
)
if semantic_cache_results:
expires_at = semantic_cache_results["expires_at"]
if time.time() < expires_at: if time.time() < expires_at:
data = json.loads(value) data = json.loads(semantic_cache_results["cache_value"])
logger.debug(f"L2语义缓存返回的数据: {data}") logger.debug(f"L2语义缓存返回的数据: {data}")
# 回填 L1 # 回填 L1
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
if query_embedding is not None: if query_embedding is not None:
try:
new_id = self.l1_vector_index.ntotal new_id = self.l1_vector_index.ntotal
faiss.normalize_L2(query_embedding) faiss.normalize_L2(query_embedding)
self.l1_vector_index.add(x=query_embedding) self.l1_vector_index.add(x=query_embedding)
self.l1_vector_id_to_key[new_id] = key self.l1_vector_id_to_key[new_id] = key
except Exception as e:
logger.error(f"回填L1向量索引时发生错误: {e}")
return data return data
except Exception as e:
logger.warning(f"ChromaDB查询失败: {e}")
logger.debug(f"缓存未命中: {key}") logger.debug(f"缓存未命中: {key}")
return None return None
@@ -175,18 +233,32 @@ class CacheManager:
# 写入 L1 # 写入 L1
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
# 写入 L2 # 写入 L2 (数据库)
value = json.dumps(data) cache_data = {
with sqlite3.connect(self.db_path) as conn: "cache_key": key,
cursor = conn.cursor() "cache_value": json.dumps(data, ensure_ascii=False),
cursor.execute("REPLACE INTO cache (key, value, expires_at) VALUES (?, ?, ?)", (key, value, expires_at)) "expires_at": expires_at,
conn.commit() "tool_name": tool_name,
"created_at": time.time(),
"last_accessed": time.time(),
"access_count": 1
}
await db_save(
model_class=CacheEntries,
data=cache_data,
key_field="cache_key",
key_value=key
)
# 写入语义缓存 # 写入语义缓存
if semantic_query and self.embedding_model: if semantic_query and self.embedding_model and self.chroma_collection:
try:
embedding_result = await self.embedding_model.get_embedding(semantic_query) embedding_result = await self.embedding_model.get_embedding(semantic_query)
if embedding_result: if embedding_result:
embedding = np.array([embedding_result], dtype='float32') validated_embedding = self._validate_embedding(embedding_result)
if validated_embedding is not None:
embedding = np.array([validated_embedding], dtype='float32')
# 写入 L1 Vector # 写入 L1 Vector
new_id = self.l1_vector_index.ntotal new_id = self.l1_vector_index.ntotal
faiss.normalize_L2(embedding) faiss.normalize_L2(embedding)
@@ -194,6 +266,8 @@ class CacheManager:
self.l1_vector_id_to_key[new_id] = key self.l1_vector_id_to_key[new_id] = key
# 写入 L2 Vector # 写入 L2 Vector
self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key]) self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key])
except Exception as e:
logger.warning(f"语义缓存写入失败: {e}")
logger.info(f"已缓存条目: {key}, TTL: {ttl}s") logger.info(f"已缓存条目: {key}, TTL: {ttl}s")
@@ -204,21 +278,53 @@ class CacheManager:
self.l1_vector_id_to_key.clear() self.l1_vector_id_to_key.clear()
logger.info("L1 (内存+FAISS) 缓存已清空。") logger.info("L1 (内存+FAISS) 缓存已清空。")
def clear_l2(self): async def clear_l2(self):
"""清空L2缓存。""" """清空L2缓存。"""
with sqlite3.connect(self.db_path) as conn: # 清空数据库缓存
cursor = conn.cursor() await db_query(
cursor.execute("DELETE FROM cache") model_class=CacheEntries,
conn.commit() query_type="delete",
filters={} # 删除所有记录
)
# 清空ChromaDB
if self.chroma_collection:
try:
self.chroma_client.delete_collection(name="semantic_cache") self.chroma_client.delete_collection(name="semantic_cache")
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache") self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
logger.info("L2 (SQLite & ChromaDB) 缓存已清空。") except Exception as e:
logger.warning(f"清空ChromaDB失败: {e}")
def clear_all(self): logger.info("L2 (数据库 & ChromaDB) 缓存已清空。")
async def clear_all(self):
"""清空所有缓存。""" """清空所有缓存。"""
self.clear_l1() self.clear_l1()
self.clear_l2() await self.clear_l2()
logger.info("所有缓存层级已清空。") logger.info("所有缓存层级已清空。")
async def clean_expired(self):
"""清理过期的缓存条目"""
current_time = time.time()
# 清理L1过期条目
expired_keys = []
for key, entry in self.l1_kv_cache.items():
if current_time >= entry["expires_at"]:
expired_keys.append(key)
for key in expired_keys:
del self.l1_kv_cache[key]
# 清理L2过期条目
await db_query(
model_class=CacheEntries,
query_type="delete",
filters={"expires_at": {"$lt": current_time}}
)
if expired_keys:
logger.info(f"清理了 {len(expired_keys)} 个过期的L1缓存条目")
# 全局实例 # 全局实例
tool_cache = CacheManager() tool_cache = CacheManager()

View File

@@ -15,7 +15,8 @@ from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import ( from src.common.database.sqlalchemy_models import (
Base, get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams, Base, get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams,
LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory, LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory,
Expression, ThinkingLog, GraphNodes, GraphEdges, Schedule, MaiZoneScheduleStatus Expression, ThinkingLog, GraphNodes, GraphEdges, Schedule, MaiZoneScheduleStatus,
CacheEntries
) )
logger = get_logger("sqlalchemy_database_api") logger = get_logger("sqlalchemy_database_api")
@@ -38,6 +39,7 @@ MODEL_MAPPING = {
'GraphEdges': GraphEdges, 'GraphEdges': GraphEdges,
'Schedule': Schedule, 'Schedule': Schedule,
'MaiZoneScheduleStatus': MaiZoneScheduleStatus, 'MaiZoneScheduleStatus': MaiZoneScheduleStatus,
'CacheEntries': CacheEntries,
} }

View File

@@ -9,6 +9,7 @@ from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool from sqlalchemy.pool import QueuePool
import os import os
import datetime import datetime
import time
from src.common.logger import get_logger from src.common.logger import get_logger
import threading import threading
from contextlib import contextmanager from contextlib import contextmanager
@@ -476,6 +477,40 @@ class AntiInjectionStats(Base):
) )
class CacheEntries(Base):
"""工具缓存条目模型"""
__tablename__ = 'cache_entries'
id = Column(Integer, primary_key=True, autoincrement=True)
cache_key = Column(get_string_field(500), nullable=False, unique=True, index=True)
"""缓存键,包含工具名、参数和代码哈希"""
cache_value = Column(Text, nullable=False)
"""缓存的数据JSON格式"""
expires_at = Column(Float, nullable=False, index=True)
"""过期时间戳"""
tool_name = Column(get_string_field(100), nullable=False, index=True)
"""工具名称"""
created_at = Column(Float, nullable=False, default=lambda: time.time())
"""创建时间戳"""
last_accessed = Column(Float, nullable=False, default=lambda: time.time())
"""最后访问时间戳"""
access_count = Column(Integer, nullable=False, default=0)
"""访问次数"""
__table_args__ = (
Index('idx_cache_entries_key', 'cache_key'),
Index('idx_cache_entries_expires_at', 'expires_at'),
Index('idx_cache_entries_tool_name', 'tool_name'),
Index('idx_cache_entries_created_at', 'created_at'),
)
# 数据库引擎和会话管理 # 数据库引擎和会话管理
_engine = None _engine = None
_SessionLocal = None _SessionLocal = None