From 48ed62deae679c7f428e98a1fa768f6b7bf4c383 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Mon, 18 Aug 2025 18:30:17 +0800 Subject: [PATCH] Refactor cache L2 storage to use SQLAlchemy DB Replaces the L2 cache layer's SQLite implementation with an async SQLAlchemy-based database model (CacheEntries). Updates cache_manager.py to use db_query and db_save for cache operations, adds semantic cache handling with ChromaDB, and introduces async cache clearing and expiration cleaning methods. Adds the CacheEntries model and integrates it into the database API. --- src/common/cache_manager.py | 284 ++++++++++++------ .../database/sqlalchemy_database_api.py | 4 +- src/common/database/sqlalchemy_models.py | 35 +++ 3 files changed, 233 insertions(+), 90 deletions(-) diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index 28fcd0d87..efa28bb59 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -1,15 +1,16 @@ import time import json -import sqlite3 -import chromadb import hashlib import inspect import numpy as np import faiss +import chromadb from typing import Any, Dict, Optional from src.common.logger import get_logger from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config +from src.common.database.sqlalchemy_models import CacheEntries +from src.common.database.sqlalchemy_database_api import db_query, db_save logger = get_logger("cache_manager") @@ -18,7 +19,7 @@ class CacheManager: 一个支持分层和语义缓存的通用工具缓存管理器。 采用单例模式,确保在整个应用中只有一个缓存实例。 L1缓存: 内存字典 (KV) + FAISS (Vector)。 - L2缓存: SQLite (KV) + ChromaDB (Vector)。 + L2缓存: 数据库 (KV) + ChromaDB (Vector)。 """ _instance = None @@ -27,7 +28,7 @@ class CacheManager: cls._instance = super(CacheManager, cls).__new__(cls) return cls._instance - def __init__(self, default_ttl: int = 3600, db_path: str = "data/cache.db", chroma_path: str = "data/chroma_db"): + def __init__(self, default_ttl: int = 3600, chroma_path: str = "data/chroma_db"): """ 初始化缓存管理器。 """ @@ -40,30 +41,54 @@ class CacheManager: self.l1_vector_index = faiss.IndexFlatIP(embedding_dim) self.l1_vector_id_to_key: Dict[int, str] = {} - # L2 缓存 (持久化) - self.db_path = db_path - self._init_sqlite() + # 语义缓存 (ChromaDB) + self.chroma_client = chromadb.PersistentClient(path=chroma_path) self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache") + # 嵌入模型 self.embedding_model = LLMRequest(model_config.model_task_config.embedding) self._initialized = True - logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (SQLite+ChromaDB)") + logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB)") - def _init_sqlite(self): - """初始化SQLite数据库和表结构。""" - with sqlite3.connect(self.db_path) as conn: - cursor = conn.cursor() - cursor.execute(""" - CREATE TABLE IF NOT EXISTS cache ( - key TEXT PRIMARY KEY, - value TEXT, - expires_at REAL - ) - """) - conn.commit() + def _validate_embedding(self, embedding_result: Any) -> Optional[np.ndarray]: + """ + 验证和标准化嵌入向量格式 + """ + try: + if embedding_result is None: + return None + + # 确保embedding_result是一维数组或列表 + if isinstance(embedding_result, (list, tuple, np.ndarray)): + # 转换为numpy数组进行处理 + embedding_array = np.array(embedding_result) + + # 如果是多维数组,展平它 + if embedding_array.ndim > 1: + embedding_array = embedding_array.flatten() + + # 检查维度是否符合预期 + expected_dim = global_config.lpmm_knowledge.embedding_dimension + if embedding_array.shape[0] != expected_dim: + logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}") + return None + + # 检查是否包含有效的数值 + if np.isnan(embedding_array).any() or np.isinf(embedding_array).any(): + logger.warning("嵌入向量包含无效的数值 (NaN 或 Inf)") + return None + + return embedding_array.astype('float32') + else: + logger.warning(f"嵌入结果格式不支持: {type(embedding_result)}") + return None + + except Exception as e: + logger.error(f"验证嵌入向量时发生错误: {e}") + return None def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any) -> str: """生成确定性的缓存键,包含代码哈希以实现自动失效。""" @@ -102,7 +127,9 @@ class CacheManager: if semantic_query and self.embedding_model: embedding_result = await self.embedding_model.get_embedding(semantic_query) if embedding_result: - query_embedding = np.array([embedding_result], dtype='float32') + validated_embedding = self._validate_embedding(embedding_result) + if validated_embedding is not None: + query_embedding = np.array([validated_embedding], dtype='float32') # 步骤 2a: L1 语义缓存 (FAISS) if query_embedding is not None and self.l1_vector_index.ntotal > 0: @@ -115,49 +142,80 @@ class CacheManager: logger.info(f"命中L1语义缓存: {l1_hit_key}") return self.l1_kv_cache[l1_hit_key]["data"] - # 步骤 2b: L2 精确缓存 (SQLite) - with sqlite3.connect(self.db_path) as conn: - cursor = conn.cursor() - cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (key,)) - row = cursor.fetchone() - if row: - value, expires_at = row - if time.time() < expires_at: - logger.info(f"命中L2键值缓存: {key}") - data = json.loads(value) - # 回填 L1 - self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} - return data - else: - cursor.execute("DELETE FROM cache WHERE key = ?", (key,)) - conn.commit() + # 步骤 2b: L2 精确缓存 (数据库) + cache_results = await db_query( + model_class=CacheEntries, + query_type="get", + filters={"cache_key": key}, + single_result=True + ) + + if cache_results: + expires_at = cache_results["expires_at"] + if time.time() < expires_at: + logger.info(f"命中L2键值缓存: {key}") + data = json.loads(cache_results["cache_value"]) + + # 更新访问统计 + await db_query( + model_class=CacheEntries, + query_type="update", + filters={"cache_key": key}, + data={ + "last_accessed": time.time(), + "access_count": cache_results["access_count"] + 1 + } + ) + + # 回填 L1 + self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} + return data + else: + # 删除过期的缓存条目 + await db_query( + model_class=CacheEntries, + query_type="delete", + filters={"cache_key": key} + ) # 步骤 2c: L2 语义缓存 (ChromaDB) - if query_embedding is not None: - results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1) - if results and results['ids'] and results['ids'][0]: - distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A' - logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}") - if distance != 'N/A' and distance < 0.75: - l2_hit_key = results['ids'][0] - logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}") - with sqlite3.connect(self.db_path) as conn: - cursor = conn.cursor() - cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (l2_hit_key if isinstance(l2_hit_key, str) else l2_hit_key[0],)) - row = cursor.fetchone() - if row: - value, expires_at = row - if time.time() < expires_at: - data = json.loads(value) - logger.debug(f"L2语义缓存返回的数据: {data}") - # 回填 L1 - self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} - if query_embedding is not None: - new_id = self.l1_vector_index.ntotal - faiss.normalize_L2(query_embedding) - self.l1_vector_index.add(x=query_embedding) - self.l1_vector_id_to_key[new_id] = key - return data + if query_embedding is not None and self.chroma_collection: + try: + results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1) + if results and results['ids'] and results['ids'][0]: + distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A' + logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}") + if distance != 'N/A' and distance < 0.75: + l2_hit_key = results['ids'][0][0] if isinstance(results['ids'][0], list) else results['ids'][0] + logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}") + + # 从数据库获取缓存数据 + semantic_cache_results = await db_query( + model_class=CacheEntries, + query_type="get", + filters={"cache_key": l2_hit_key}, + single_result=True + ) + + if semantic_cache_results: + expires_at = semantic_cache_results["expires_at"] + if time.time() < expires_at: + data = json.loads(semantic_cache_results["cache_value"]) + logger.debug(f"L2语义缓存返回的数据: {data}") + + # 回填 L1 + self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} + if query_embedding is not None: + try: + new_id = self.l1_vector_index.ntotal + faiss.normalize_L2(query_embedding) + self.l1_vector_index.add(x=query_embedding) + self.l1_vector_id_to_key[new_id] = key + except Exception as e: + logger.error(f"回填L1向量索引时发生错误: {e}") + return data + except Exception as e: + logger.warning(f"ChromaDB查询失败: {e}") logger.debug(f"缓存未命中: {key}") return None @@ -175,25 +233,41 @@ class CacheManager: # 写入 L1 self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} - # 写入 L2 - value = json.dumps(data) - with sqlite3.connect(self.db_path) as conn: - cursor = conn.cursor() - cursor.execute("REPLACE INTO cache (key, value, expires_at) VALUES (?, ?, ?)", (key, value, expires_at)) - conn.commit() + # 写入 L2 (数据库) + cache_data = { + "cache_key": key, + "cache_value": json.dumps(data, ensure_ascii=False), + "expires_at": expires_at, + "tool_name": tool_name, + "created_at": time.time(), + "last_accessed": time.time(), + "access_count": 1 + } + + await db_save( + model_class=CacheEntries, + data=cache_data, + key_field="cache_key", + key_value=key + ) # 写入语义缓存 - if semantic_query and self.embedding_model: - embedding_result = await self.embedding_model.get_embedding(semantic_query) - if embedding_result: - embedding = np.array([embedding_result], dtype='float32') - # 写入 L1 Vector - new_id = self.l1_vector_index.ntotal - faiss.normalize_L2(embedding) - self.l1_vector_index.add(x=embedding) - self.l1_vector_id_to_key[new_id] = key - # 写入 L2 Vector - self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key]) + if semantic_query and self.embedding_model and self.chroma_collection: + try: + embedding_result = await self.embedding_model.get_embedding(semantic_query) + if embedding_result: + validated_embedding = self._validate_embedding(embedding_result) + if validated_embedding is not None: + embedding = np.array([validated_embedding], dtype='float32') + # 写入 L1 Vector + new_id = self.l1_vector_index.ntotal + faiss.normalize_L2(embedding) + self.l1_vector_index.add(x=embedding) + self.l1_vector_id_to_key[new_id] = key + # 写入 L2 Vector + self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key]) + except Exception as e: + logger.warning(f"语义缓存写入失败: {e}") logger.info(f"已缓存条目: {key}, TTL: {ttl}s") @@ -204,21 +278,53 @@ class CacheManager: self.l1_vector_id_to_key.clear() logger.info("L1 (内存+FAISS) 缓存已清空。") - def clear_l2(self): + async def clear_l2(self): """清空L2缓存。""" - with sqlite3.connect(self.db_path) as conn: - cursor = conn.cursor() - cursor.execute("DELETE FROM cache") - conn.commit() - self.chroma_client.delete_collection(name="semantic_cache") - self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache") - logger.info("L2 (SQLite & ChromaDB) 缓存已清空。") + # 清空数据库缓存 + await db_query( + model_class=CacheEntries, + query_type="delete", + filters={} # 删除所有记录 + ) + + # 清空ChromaDB + if self.chroma_collection: + try: + self.chroma_client.delete_collection(name="semantic_cache") + self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache") + except Exception as e: + logger.warning(f"清空ChromaDB失败: {e}") + + logger.info("L2 (数据库 & ChromaDB) 缓存已清空。") - def clear_all(self): + async def clear_all(self): """清空所有缓存。""" self.clear_l1() - self.clear_l2() + await self.clear_l2() logger.info("所有缓存层级已清空。") + async def clean_expired(self): + """清理过期的缓存条目""" + current_time = time.time() + + # 清理L1过期条目 + expired_keys = [] + for key, entry in self.l1_kv_cache.items(): + if current_time >= entry["expires_at"]: + expired_keys.append(key) + + for key in expired_keys: + del self.l1_kv_cache[key] + + # 清理L2过期条目 + await db_query( + model_class=CacheEntries, + query_type="delete", + filters={"expires_at": {"$lt": current_time}} + ) + + if expired_keys: + logger.info(f"清理了 {len(expired_keys)} 个过期的L1缓存条目") + # 全局实例 tool_cache = CacheManager() \ No newline at end of file diff --git a/src/common/database/sqlalchemy_database_api.py b/src/common/database/sqlalchemy_database_api.py index 4c773f74e..e3c10ece6 100644 --- a/src/common/database/sqlalchemy_database_api.py +++ b/src/common/database/sqlalchemy_database_api.py @@ -15,7 +15,8 @@ from src.common.logger import get_logger from src.common.database.sqlalchemy_models import ( Base, get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams, LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory, - Expression, ThinkingLog, GraphNodes, GraphEdges, Schedule, MaiZoneScheduleStatus + Expression, ThinkingLog, GraphNodes, GraphEdges, Schedule, MaiZoneScheduleStatus, + CacheEntries ) logger = get_logger("sqlalchemy_database_api") @@ -38,6 +39,7 @@ MODEL_MAPPING = { 'GraphEdges': GraphEdges, 'Schedule': Schedule, 'MaiZoneScheduleStatus': MaiZoneScheduleStatus, + 'CacheEntries': CacheEntries, } diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 3f1f5e080..11ae50133 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -9,6 +9,7 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import QueuePool import os import datetime +import time from src.common.logger import get_logger import threading from contextlib import contextmanager @@ -476,6 +477,40 @@ class AntiInjectionStats(Base): ) +class CacheEntries(Base): + """工具缓存条目模型""" + __tablename__ = 'cache_entries' + + id = Column(Integer, primary_key=True, autoincrement=True) + cache_key = Column(get_string_field(500), nullable=False, unique=True, index=True) + """缓存键,包含工具名、参数和代码哈希""" + + cache_value = Column(Text, nullable=False) + """缓存的数据,JSON格式""" + + expires_at = Column(Float, nullable=False, index=True) + """过期时间戳""" + + tool_name = Column(get_string_field(100), nullable=False, index=True) + """工具名称""" + + created_at = Column(Float, nullable=False, default=lambda: time.time()) + """创建时间戳""" + + last_accessed = Column(Float, nullable=False, default=lambda: time.time()) + """最后访问时间戳""" + + access_count = Column(Integer, nullable=False, default=0) + """访问次数""" + + __table_args__ = ( + Index('idx_cache_entries_key', 'cache_key'), + Index('idx_cache_entries_expires_at', 'expires_at'), + Index('idx_cache_entries_tool_name', 'tool_name'), + Index('idx_cache_entries_created_at', 'created_at'), + ) + + # 数据库引擎和会话管理 _engine = None _SessionLocal = None