feat(similarity): 重构相似度计算函数,优化性能并增加文档注释

This commit is contained in:
Windpicker-owo
2025-12-13 16:59:47 +08:00
parent 66df05c37f
commit 464002a863
2 changed files with 70 additions and 90 deletions

View File

@@ -21,6 +21,7 @@ import numpy as np
from src.common.logger import get_logger
from src.memory_graph.models import MemoryBlock, PerceptualMemory
from src.memory_graph.utils.embeddings import EmbeddingGenerator
from src.memory_graph.utils.similarity import _compute_similarities_sync
logger = get_logger(__name__)
@@ -406,47 +407,9 @@ class PerceptualMemoryManager:
) -> np.ndarray:
"""在后台线程中向量化计算相似度,避免阻塞事件循环。"""
return await asyncio.to_thread(
self._compute_similarities_sync, query_embedding, block_embeddings, block_norms
_compute_similarities_sync, query_embedding, block_embeddings, block_norms
)
@staticmethod
def _compute_similarities_sync(
query_embedding: np.ndarray,
block_embeddings: list[np.ndarray],
block_norms: list[float] | None = None,
) -> np.ndarray:
import numpy as np
if not block_embeddings:
return np.zeros(0, dtype=np.float32)
query = np.asarray(query_embedding, dtype=np.float32)
blocks = np.asarray(block_embeddings, dtype=np.float32)
if blocks.ndim == 1:
blocks = blocks.reshape(1, -1)
query_norm = np.linalg.norm(query)
if query_norm == 0.0:
return np.zeros(blocks.shape[0], dtype=np.float32)
if block_norms is None:
block_norms_array = np.linalg.norm(blocks, axis=1)
else:
block_norms_array = np.asarray(block_norms, dtype=np.float32)
if block_norms_array.shape[0] != blocks.shape[0]:
block_norms_array = np.linalg.norm(blocks, axis=1)
valid_mask = block_norms_array > 0
similarities = np.zeros(blocks.shape[0], dtype=np.float32)
if valid_mask.any():
normalized_blocks = blocks[valid_mask] / block_norms_array[valid_mask][:, None]
normalized_query = query / query_norm
similarities[valid_mask] = normalized_blocks @ normalized_query
return np.clip(similarities, 0.0, 1.0)
async def recall_blocks(
self,
query_text: str,

View File

@@ -5,12 +5,69 @@
"""
import asyncio
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
import numpy as np
def _compute_similarities_sync(
query_embedding: "np.ndarray",
block_embeddings: "np.ndarray | list[np.ndarray] | list[Any]",
block_norms: "np.ndarray | list[float] | None" = None,
) -> "np.ndarray":
"""
计算 query 向量与一组向量的余弦相似度(同步/向量化实现)。
- 返回 float32 ndarray
- 输出范围裁剪到 [0.0, 1.0]
- 支持可选的 block_norms 以减少重复 norm 计算
"""
import numpy as np
if block_embeddings is None:
return np.zeros(0, dtype=np.float32)
query = np.asarray(query_embedding, dtype=np.float32)
if isinstance(block_embeddings, (list, tuple)) and len(block_embeddings) == 0:
return np.zeros(0, dtype=np.float32)
blocks = np.asarray(block_embeddings, dtype=np.float32)
if blocks.dtype == object:
blocks = np.stack(
[np.asarray(vec, dtype=np.float32) for vec in block_embeddings],
axis=0,
)
if blocks.size == 0:
return np.zeros(0, dtype=np.float32)
if blocks.ndim == 1:
blocks = blocks.reshape(1, -1)
query_norm = float(np.linalg.norm(query))
if query_norm == 0.0:
return np.zeros(blocks.shape[0], dtype=np.float32)
if block_norms is None:
block_norms_array = np.linalg.norm(blocks, axis=1).astype(np.float32, copy=False)
else:
block_norms_array = np.asarray(block_norms, dtype=np.float32)
if block_norms_array.shape[0] != blocks.shape[0]:
block_norms_array = np.linalg.norm(blocks, axis=1).astype(np.float32, copy=False)
dot_products = blocks @ query
denom = block_norms_array * np.float32(query_norm)
similarities = np.zeros(blocks.shape[0], dtype=np.float32)
valid_mask = denom > 0
if valid_mask.any():
np.divide(dot_products, denom, out=similarities, where=valid_mask)
return np.clip(similarities, 0.0, 1.0)
def cosine_similarity(vec1: "np.ndarray", vec2: "np.ndarray") -> float:
"""
计算两个向量的余弦相似度
@@ -25,23 +82,16 @@ def cosine_similarity(vec1: "np.ndarray", vec2: "np.ndarray") -> float:
try:
import numpy as np
# 确保是numpy数组
if not isinstance(vec1, np.ndarray):
vec1 = np.array(vec1)
if not isinstance(vec2, np.ndarray):
vec2 = np.array(vec2)
vec1 = np.asarray(vec1, dtype=np.float32)
vec2 = np.asarray(vec2, dtype=np.float32)
# 归一化
vec1_norm = np.linalg.norm(vec1)
vec2_norm = np.linalg.norm(vec2)
vec1_norm = float(np.linalg.norm(vec1))
vec2_norm = float(np.linalg.norm(vec2))
if vec1_norm == 0 or vec2_norm == 0:
if vec1_norm == 0.0 or vec2_norm == 0.0:
return 0.0
# 余弦相似度
similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
# 确保在 [0, 1] 范围内(处理浮点误差)
similarity = float(np.dot(vec1, vec2) / (vec1_norm * vec2_norm))
return float(np.clip(similarity, 0.0, 1.0))
except Exception:
@@ -74,43 +124,10 @@ def batch_cosine_similarity(vec1: "np.ndarray", vec_list: list["np.ndarray"]) ->
相似度列表
"""
try:
import numpy as np
if not vec_list:
return []
# 确保是numpy数组
if not isinstance(vec1, np.ndarray):
vec1 = np.array(vec1)
# 批量转换为numpy数组
vec_list = [np.array(vec) for vec in vec_list]
# 计算归一化
vec1_norm = np.linalg.norm(vec1)
if vec1_norm == 0:
return [0.0] * len(vec_list)
# 计算所有向量的归一化
vec_norms = np.array([np.linalg.norm(vec) for vec in vec_list])
# 避免除以0
valid_mask = vec_norms != 0
similarities = np.zeros(len(vec_list))
if np.any(valid_mask):
# 批量计算点积
valid_vecs = np.array(vec_list)[valid_mask]
dot_products = np.dot(valid_vecs, vec1)
# 计算相似度
valid_norms = vec_norms[valid_mask]
valid_similarities = dot_products / (vec1_norm * valid_norms)
# 确保在 [0, 1] 范围内
valid_similarities = np.clip(valid_similarities, 0.0, 1.0)
# 填充结果
similarities[valid_mask] = valid_similarities
return similarities.tolist()
return _compute_similarities_sync(vec1, vec_list).tolist()
except Exception:
return [0.0] * len(vec_list)
@@ -134,5 +151,5 @@ __all__ = [
"batch_cosine_similarity",
"batch_cosine_similarity_async",
"cosine_similarity",
"cosine_similarity_async"
"cosine_similarity_async",
]