feat(cache): 增强嵌入向量处理的健壮性和验证
- 新增 `_validate_embedding` 方法,用于在存入缓存前对嵌入向量进行严格的格式检查、维度验证和数值有效性校验。 - 在缓存查询 (`get`) 和写入 (`set`) 流程中,集成此验证逻辑,确保只有合规的向量才能被处理和存储。 - 增加了在L1和L2向量索引操作中的异常捕获,防止因向量处理失败导致缓存功能中断,提升了系统的整体稳定性。
This commit is contained in:
committed by
Windpicker-owo
parent
48ed62deae
commit
6568ea49da
@@ -90,6 +90,43 @@ class CacheManager:
|
|||||||
logger.error(f"验证嵌入向量时发生错误: {e}")
|
logger.error(f"验证嵌入向量时发生错误: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _validate_embedding(self, embedding_result: Any) -> Optional[np.ndarray]:
|
||||||
|
"""
|
||||||
|
验证和标准化嵌入向量格式
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if embedding_result is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 确保embedding_result是一维数组或列表
|
||||||
|
if isinstance(embedding_result, (list, tuple, np.ndarray)):
|
||||||
|
# 转换为numpy数组进行处理
|
||||||
|
embedding_array = np.array(embedding_result)
|
||||||
|
|
||||||
|
# 如果是多维数组,展平它
|
||||||
|
if embedding_array.ndim > 1:
|
||||||
|
embedding_array = embedding_array.flatten()
|
||||||
|
|
||||||
|
# 检查维度是否符合预期
|
||||||
|
expected_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||||
|
if embedding_array.shape[0] != expected_dim:
|
||||||
|
logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 检查是否包含有效的数值
|
||||||
|
if np.isnan(embedding_array).any() or np.isinf(embedding_array).any():
|
||||||
|
logger.warning("嵌入向量包含无效的数值 (NaN 或 Inf)")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return embedding_array.astype('float32')
|
||||||
|
else:
|
||||||
|
logger.warning(f"嵌入结果格式不支持: {type(embedding_result)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"验证嵌入向量时发生错误: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any) -> str:
|
def _generate_key(self, tool_name: str, function_args: Dict[str, Any], tool_class: Any) -> str:
|
||||||
"""生成确定性的缓存键,包含代码哈希以实现自动失效。"""
|
"""生成确定性的缓存键,包含代码哈希以实现自动失效。"""
|
||||||
try:
|
try:
|
||||||
@@ -179,30 +216,23 @@ class CacheManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 步骤 2c: L2 语义缓存 (ChromaDB)
|
# 步骤 2c: L2 语义缓存 (ChromaDB)
|
||||||
if query_embedding is not None and self.chroma_collection:
|
if query_embedding is not None:
|
||||||
try:
|
|
||||||
results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1)
|
results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1)
|
||||||
if results and results['ids'] and results['ids'][0]:
|
if results and results['ids'] and results['ids'][0]:
|
||||||
distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A'
|
distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A'
|
||||||
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
|
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
|
||||||
if distance != 'N/A' and distance < 0.75:
|
if distance != 'N/A' and distance < 0.75:
|
||||||
l2_hit_key = results['ids'][0][0] if isinstance(results['ids'][0], list) else results['ids'][0]
|
l2_hit_key = results['ids'][0]
|
||||||
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
|
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
# 从数据库获取缓存数据
|
cursor = conn.cursor()
|
||||||
semantic_cache_results = await db_query(
|
cursor.execute("SELECT value, expires_at FROM cache WHERE key = ?", (l2_hit_key if isinstance(l2_hit_key, str) else l2_hit_key[0],))
|
||||||
model_class=CacheEntries,
|
row = cursor.fetchone()
|
||||||
query_type="get",
|
if row:
|
||||||
filters={"cache_key": l2_hit_key},
|
value, expires_at = row
|
||||||
single_result=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if semantic_cache_results:
|
|
||||||
expires_at = semantic_cache_results["expires_at"]
|
|
||||||
if time.time() < expires_at:
|
if time.time() < expires_at:
|
||||||
data = json.loads(semantic_cache_results["cache_value"])
|
data = json.loads(value)
|
||||||
logger.debug(f"L2语义缓存返回的数据: {data}")
|
logger.debug(f"L2语义缓存返回的数据: {data}")
|
||||||
|
|
||||||
# 回填 L1
|
# 回填 L1
|
||||||
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
self.l1_kv_cache[key] = {"data": data, "expires_at": expires_at}
|
||||||
if query_embedding is not None:
|
if query_embedding is not None:
|
||||||
@@ -214,8 +244,6 @@ class CacheManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"回填L1向量索引时发生错误: {e}")
|
logger.error(f"回填L1向量索引时发生错误: {e}")
|
||||||
return data
|
return data
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"ChromaDB查询失败: {e}")
|
|
||||||
|
|
||||||
logger.debug(f"缓存未命中: {key}")
|
logger.debug(f"缓存未命中: {key}")
|
||||||
return None
|
return None
|
||||||
@@ -252,12 +280,12 @@ class CacheManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 写入语义缓存
|
# 写入语义缓存
|
||||||
if semantic_query and self.embedding_model and self.chroma_collection:
|
if semantic_query and self.embedding_model:
|
||||||
try:
|
|
||||||
embedding_result = await self.embedding_model.get_embedding(semantic_query)
|
embedding_result = await self.embedding_model.get_embedding(semantic_query)
|
||||||
if embedding_result:
|
if embedding_result:
|
||||||
validated_embedding = self._validate_embedding(embedding_result)
|
validated_embedding = self._validate_embedding(embedding_result)
|
||||||
if validated_embedding is not None:
|
if validated_embedding is not None:
|
||||||
|
try:
|
||||||
embedding = np.array([validated_embedding], dtype='float32')
|
embedding = np.array([validated_embedding], dtype='float32')
|
||||||
# 写入 L1 Vector
|
# 写入 L1 Vector
|
||||||
new_id = self.l1_vector_index.ntotal
|
new_id = self.l1_vector_index.ntotal
|
||||||
@@ -267,7 +295,8 @@ class CacheManager:
|
|||||||
# 写入 L2 Vector
|
# 写入 L2 Vector
|
||||||
self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key])
|
self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"语义缓存写入失败: {e}")
|
logger.error(f"写入语义缓存时发生错误: {e}")
|
||||||
|
# 继续执行,不影响主要缓存功能
|
||||||
|
|
||||||
logger.info(f"已缓存条目: {key}, TTL: {ttl}s")
|
logger.info(f"已缓存条目: {key}, TTL: {ttl}s")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user