From 6568ea49da4a459a1af4f3933bb6b85a11332dfb Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Mon, 18 Aug 2025 18:08:14 +0800 Subject: [PATCH] =?UTF-8?q?feat(cache):=20=E5=A2=9E=E5=BC=BA=E5=B5=8C?= =?UTF-8?q?=E5=85=A5=E5=90=91=E9=87=8F=E5=A4=84=E7=90=86=E7=9A=84=E5=81=A5?= =?UTF-8?q?=E5=A3=AE=E6=80=A7=E5=92=8C=E9=AA=8C=E8=AF=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 `_validate_embedding` 方法,用于在存入缓存前对嵌入向量进行严格的格式检查、维度验证和数值有效性校验。 - 在缓存查询 (`get`) 和写入 (`set`) 流程中,集成此验证逻辑,确保只有合规的向量才能被处理和存储。 - 增加了在L1和L2向量索引操作中的异常捕获,防止因向量处理失败导致缓存功能中断,提升了系统的整体稳定性。 --- src/common/cache_manager.py | 119 ++++++++++++++++++++++-------------- 1 file changed, 74 insertions(+), 45 deletions(-) diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index efa28bb59..78bb92454 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -90,6 +90,43 @@ class CacheManager: logger.error(f"验证嵌入向量时发生错误: {e}") return None + def _validate_embedding(self, embedding_result: Any) -> Optional[np.ndarray]: + """ + 验证和标准化嵌入向量格式 + """ + try: + if embedding_result is None: + return None + + # 确保embedding_result是一维数组或列表 + if isinstance(embedding_result, (list, tuple, np.ndarray)): + # 转换为numpy数组进行处理 + embedding_array = np.array(embedding_result) + + # 如果是多维数组,展平它 + if embedding_array.ndim > 1: + embedding_array = embedding_array.flatten() + + # 检查维度是否符合预期 + expected_dim = global_config.lpmm_knowledge.embedding_dimension + if embedding_array.shape[0] != expected_dim: + logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}") + return None + + # 检查是否包含有效的数值 + if np.isnan(embedding_array).any() or np.isinf(embedding_array).any(): + logger.warning("嵌入向量包含无效的数值 (NaN 或 Inf)") + return None + + return embedding_array.astype('float32') + else: + logger.warning(f"嵌入结果格式不支持: {type(embedding_result)}") + return None + + except Exception as e: + logger.error(f"验证嵌入向量时发生错误: {e}") + return None + def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any) -> str: """生成确定性的缓存键,包含代码哈希以实现自动失效。""" try: @@ -179,43 +216,34 @@ class CacheManager: ) # 步骤 2c: L2 语义缓存 (ChromaDB) - if query_embedding is not None and self.chroma_collection: - try: - results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1) - if results and results['ids'] and results['ids'][0]: - distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A' - logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}") - if distance != 'N/A' and distance < 0.75: - l2_hit_key = results['ids'][0][0] if isinstance(results['ids'][0], list) else results['ids'][0] - logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}") - - # 从数据库获取缓存数据 - semantic_cache_results = await db_query( - model_class=CacheEntries, - query_type="get", - filters={"cache_key": l2_hit_key}, - single_result=True - ) - - if semantic_cache_results: - expires_at = semantic_cache_results["expires_at"] - if time.time() < expires_at: - data = json.loads(semantic_cache_results["cache_value"]) - logger.debug(f"L2语义缓存返回的数据: {data}") - - # 回填 L1 - self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} - if query_embedding is not None: - try: - new_id = self.l1_vector_index.ntotal - faiss.normalize_L2(query_embedding) - self.l1_vector_index.add(x=query_embedding) - self.l1_vector_id_to_key[new_id] = key - except Exception as e: - logger.error(f"回填L1向量索引时发生错误: {e}") - return data - except Exception as e: - logger.warning(f"ChromaDB查询失败: {e}") + if query_embedding is not None: + results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1) + if results and results['ids'] and results['ids'][0]: + distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A' + logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}") + if distance != 'N/A' and distance < 0.75: + l2_hit_key = results['ids'][0] + logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}") + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (l2_hit_key if isinstance(l2_hit_key, str) else l2_hit_key[0],)) + row = cursor.fetchone() + if row: + value, expires_at = row + if time.time() < expires_at: + data = json.loads(value) + logger.debug(f"L2语义缓存返回的数据: {data}") + # 回填 L1 + self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at} + if query_embedding is not None: + try: + new_id = self.l1_vector_index.ntotal + faiss.normalize_L2(query_embedding) + self.l1_vector_index.add(x=query_embedding) + self.l1_vector_id_to_key[new_id] = key + except Exception as e: + logger.error(f"回填L1向量索引时发生错误: {e}") + return data logger.debug(f"缓存未命中: {key}") return None @@ -252,12 +280,12 @@ class CacheManager: ) # 写入语义缓存 - if semantic_query and self.embedding_model and self.chroma_collection: - try: - embedding_result = await self.embedding_model.get_embedding(semantic_query) - if embedding_result: - validated_embedding = self._validate_embedding(embedding_result) - if validated_embedding is not None: + if semantic_query and self.embedding_model: + embedding_result = await self.embedding_model.get_embedding(semantic_query) + if embedding_result: + validated_embedding = self._validate_embedding(embedding_result) + if validated_embedding is not None: + try: embedding = np.array([validated_embedding], dtype='float32') # 写入 L1 Vector new_id = self.l1_vector_index.ntotal @@ -266,8 +294,9 @@ class CacheManager: self.l1_vector_id_to_key[new_id] = key # 写入 L2 Vector self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key]) - except Exception as e: - logger.warning(f"语义缓存写入失败: {e}") + except Exception as e: + logger.error(f"写入语义缓存时发生错误: {e}") + # 继续执行,不影响主要缓存功能 logger.info(f"已缓存条目: {key}, TTL: {ttl}s")