feat(similarity): 重构相似度计算函数,优化性能并增加文档注释
This commit is contained in:
@@ -21,6 +21,7 @@ import numpy as np
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.models import MemoryBlock, PerceptualMemory
|
||||
from src.memory_graph.utils.embeddings import EmbeddingGenerator
|
||||
from src.memory_graph.utils.similarity import _compute_similarities_sync
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -406,47 +407,9 @@ class PerceptualMemoryManager:
|
||||
) -> np.ndarray:
|
||||
"""在后台线程中向量化计算相似度,避免阻塞事件循环。"""
|
||||
return await asyncio.to_thread(
|
||||
self._compute_similarities_sync, query_embedding, block_embeddings, block_norms
|
||||
_compute_similarities_sync, query_embedding, block_embeddings, block_norms
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _compute_similarities_sync(
|
||||
query_embedding: np.ndarray,
|
||||
block_embeddings: list[np.ndarray],
|
||||
block_norms: list[float] | None = None,
|
||||
) -> np.ndarray:
|
||||
import numpy as np
|
||||
|
||||
if not block_embeddings:
|
||||
return np.zeros(0, dtype=np.float32)
|
||||
|
||||
query = np.asarray(query_embedding, dtype=np.float32)
|
||||
blocks = np.asarray(block_embeddings, dtype=np.float32)
|
||||
|
||||
if blocks.ndim == 1:
|
||||
blocks = blocks.reshape(1, -1)
|
||||
|
||||
query_norm = np.linalg.norm(query)
|
||||
if query_norm == 0.0:
|
||||
return np.zeros(blocks.shape[0], dtype=np.float32)
|
||||
|
||||
if block_norms is None:
|
||||
block_norms_array = np.linalg.norm(blocks, axis=1)
|
||||
else:
|
||||
block_norms_array = np.asarray(block_norms, dtype=np.float32)
|
||||
if block_norms_array.shape[0] != blocks.shape[0]:
|
||||
block_norms_array = np.linalg.norm(blocks, axis=1)
|
||||
|
||||
valid_mask = block_norms_array > 0
|
||||
similarities = np.zeros(blocks.shape[0], dtype=np.float32)
|
||||
|
||||
if valid_mask.any():
|
||||
normalized_blocks = blocks[valid_mask] / block_norms_array[valid_mask][:, None]
|
||||
normalized_query = query / query_norm
|
||||
similarities[valid_mask] = normalized_blocks @ normalized_query
|
||||
|
||||
return np.clip(similarities, 0.0, 1.0)
|
||||
|
||||
async def recall_blocks(
|
||||
self,
|
||||
query_text: str,
|
||||
|
||||
@@ -5,12 +5,69 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _compute_similarities_sync(
|
||||
query_embedding: "np.ndarray",
|
||||
block_embeddings: "np.ndarray | list[np.ndarray] | list[Any]",
|
||||
block_norms: "np.ndarray | list[float] | None" = None,
|
||||
) -> "np.ndarray":
|
||||
"""
|
||||
计算 query 向量与一组向量的余弦相似度(同步/向量化实现)。
|
||||
|
||||
- 返回 float32 ndarray
|
||||
- 输出范围裁剪到 [0.0, 1.0]
|
||||
- 支持可选的 block_norms 以减少重复 norm 计算
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
if block_embeddings is None:
|
||||
return np.zeros(0, dtype=np.float32)
|
||||
|
||||
query = np.asarray(query_embedding, dtype=np.float32)
|
||||
|
||||
if isinstance(block_embeddings, (list, tuple)) and len(block_embeddings) == 0:
|
||||
return np.zeros(0, dtype=np.float32)
|
||||
|
||||
blocks = np.asarray(block_embeddings, dtype=np.float32)
|
||||
if blocks.dtype == object:
|
||||
blocks = np.stack(
|
||||
[np.asarray(vec, dtype=np.float32) for vec in block_embeddings],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
if blocks.size == 0:
|
||||
return np.zeros(0, dtype=np.float32)
|
||||
|
||||
if blocks.ndim == 1:
|
||||
blocks = blocks.reshape(1, -1)
|
||||
|
||||
query_norm = float(np.linalg.norm(query))
|
||||
if query_norm == 0.0:
|
||||
return np.zeros(blocks.shape[0], dtype=np.float32)
|
||||
|
||||
if block_norms is None:
|
||||
block_norms_array = np.linalg.norm(blocks, axis=1).astype(np.float32, copy=False)
|
||||
else:
|
||||
block_norms_array = np.asarray(block_norms, dtype=np.float32)
|
||||
if block_norms_array.shape[0] != blocks.shape[0]:
|
||||
block_norms_array = np.linalg.norm(blocks, axis=1).astype(np.float32, copy=False)
|
||||
|
||||
dot_products = blocks @ query
|
||||
denom = block_norms_array * np.float32(query_norm)
|
||||
|
||||
similarities = np.zeros(blocks.shape[0], dtype=np.float32)
|
||||
valid_mask = denom > 0
|
||||
if valid_mask.any():
|
||||
np.divide(dot_products, denom, out=similarities, where=valid_mask)
|
||||
|
||||
return np.clip(similarities, 0.0, 1.0)
|
||||
|
||||
|
||||
def cosine_similarity(vec1: "np.ndarray", vec2: "np.ndarray") -> float:
|
||||
"""
|
||||
计算两个向量的余弦相似度
|
||||
@@ -25,23 +82,16 @@ def cosine_similarity(vec1: "np.ndarray", vec2: "np.ndarray") -> float:
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
# 确保是numpy数组
|
||||
if not isinstance(vec1, np.ndarray):
|
||||
vec1 = np.array(vec1)
|
||||
if not isinstance(vec2, np.ndarray):
|
||||
vec2 = np.array(vec2)
|
||||
vec1 = np.asarray(vec1, dtype=np.float32)
|
||||
vec2 = np.asarray(vec2, dtype=np.float32)
|
||||
|
||||
# 归一化
|
||||
vec1_norm = np.linalg.norm(vec1)
|
||||
vec2_norm = np.linalg.norm(vec2)
|
||||
vec1_norm = float(np.linalg.norm(vec1))
|
||||
vec2_norm = float(np.linalg.norm(vec2))
|
||||
|
||||
if vec1_norm == 0 or vec2_norm == 0:
|
||||
if vec1_norm == 0.0 or vec2_norm == 0.0:
|
||||
return 0.0
|
||||
|
||||
# 余弦相似度
|
||||
similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
|
||||
|
||||
# 确保在 [0, 1] 范围内(处理浮点误差)
|
||||
similarity = float(np.dot(vec1, vec2) / (vec1_norm * vec2_norm))
|
||||
return float(np.clip(similarity, 0.0, 1.0))
|
||||
|
||||
except Exception:
|
||||
@@ -74,43 +124,10 @@ def batch_cosine_similarity(vec1: "np.ndarray", vec_list: list["np.ndarray"]) ->
|
||||
相似度列表
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
if not vec_list:
|
||||
return []
|
||||
|
||||
# 确保是numpy数组
|
||||
if not isinstance(vec1, np.ndarray):
|
||||
vec1 = np.array(vec1)
|
||||
|
||||
# 批量转换为numpy数组
|
||||
vec_list = [np.array(vec) for vec in vec_list]
|
||||
|
||||
# 计算归一化
|
||||
vec1_norm = np.linalg.norm(vec1)
|
||||
if vec1_norm == 0:
|
||||
return [0.0] * len(vec_list)
|
||||
|
||||
# 计算所有向量的归一化
|
||||
vec_norms = np.array([np.linalg.norm(vec) for vec in vec_list])
|
||||
|
||||
# 避免除以0
|
||||
valid_mask = vec_norms != 0
|
||||
similarities = np.zeros(len(vec_list))
|
||||
|
||||
if np.any(valid_mask):
|
||||
# 批量计算点积
|
||||
valid_vecs = np.array(vec_list)[valid_mask]
|
||||
dot_products = np.dot(valid_vecs, vec1)
|
||||
|
||||
# 计算相似度
|
||||
valid_norms = vec_norms[valid_mask]
|
||||
valid_similarities = dot_products / (vec1_norm * valid_norms)
|
||||
|
||||
# 确保在 [0, 1] 范围内
|
||||
valid_similarities = np.clip(valid_similarities, 0.0, 1.0)
|
||||
|
||||
# 填充结果
|
||||
similarities[valid_mask] = valid_similarities
|
||||
|
||||
return similarities.tolist()
|
||||
return _compute_similarities_sync(vec1, vec_list).tolist()
|
||||
|
||||
except Exception:
|
||||
return [0.0] * len(vec_list)
|
||||
@@ -134,5 +151,5 @@ __all__ = [
|
||||
"batch_cosine_similarity",
|
||||
"batch_cosine_similarity_async",
|
||||
"cosine_similarity",
|
||||
"cosine_similarity_async"
|
||||
"cosine_similarity_async",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user