feat(similarity): 重构相似度计算函数,优化性能并增加文档注释
This commit is contained in:
@@ -21,6 +21,7 @@ import numpy as np
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.memory_graph.models import MemoryBlock, PerceptualMemory
|
from src.memory_graph.models import MemoryBlock, PerceptualMemory
|
||||||
from src.memory_graph.utils.embeddings import EmbeddingGenerator
|
from src.memory_graph.utils.embeddings import EmbeddingGenerator
|
||||||
|
from src.memory_graph.utils.similarity import _compute_similarities_sync
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -406,47 +407,9 @@ class PerceptualMemoryManager:
|
|||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""在后台线程中向量化计算相似度,避免阻塞事件循环。"""
|
"""在后台线程中向量化计算相似度,避免阻塞事件循环。"""
|
||||||
return await asyncio.to_thread(
|
return await asyncio.to_thread(
|
||||||
self._compute_similarities_sync, query_embedding, block_embeddings, block_norms
|
_compute_similarities_sync, query_embedding, block_embeddings, block_norms
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _compute_similarities_sync(
|
|
||||||
query_embedding: np.ndarray,
|
|
||||||
block_embeddings: list[np.ndarray],
|
|
||||||
block_norms: list[float] | None = None,
|
|
||||||
) -> np.ndarray:
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
if not block_embeddings:
|
|
||||||
return np.zeros(0, dtype=np.float32)
|
|
||||||
|
|
||||||
query = np.asarray(query_embedding, dtype=np.float32)
|
|
||||||
blocks = np.asarray(block_embeddings, dtype=np.float32)
|
|
||||||
|
|
||||||
if blocks.ndim == 1:
|
|
||||||
blocks = blocks.reshape(1, -1)
|
|
||||||
|
|
||||||
query_norm = np.linalg.norm(query)
|
|
||||||
if query_norm == 0.0:
|
|
||||||
return np.zeros(blocks.shape[0], dtype=np.float32)
|
|
||||||
|
|
||||||
if block_norms is None:
|
|
||||||
block_norms_array = np.linalg.norm(blocks, axis=1)
|
|
||||||
else:
|
|
||||||
block_norms_array = np.asarray(block_norms, dtype=np.float32)
|
|
||||||
if block_norms_array.shape[0] != blocks.shape[0]:
|
|
||||||
block_norms_array = np.linalg.norm(blocks, axis=1)
|
|
||||||
|
|
||||||
valid_mask = block_norms_array > 0
|
|
||||||
similarities = np.zeros(blocks.shape[0], dtype=np.float32)
|
|
||||||
|
|
||||||
if valid_mask.any():
|
|
||||||
normalized_blocks = blocks[valid_mask] / block_norms_array[valid_mask][:, None]
|
|
||||||
normalized_query = query / query_norm
|
|
||||||
similarities[valid_mask] = normalized_blocks @ normalized_query
|
|
||||||
|
|
||||||
return np.clip(similarities, 0.0, 1.0)
|
|
||||||
|
|
||||||
async def recall_blocks(
|
async def recall_blocks(
|
||||||
self,
|
self,
|
||||||
query_text: str,
|
query_text: str,
|
||||||
|
|||||||
@@ -5,12 +5,69 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_similarities_sync(
|
||||||
|
query_embedding: "np.ndarray",
|
||||||
|
block_embeddings: "np.ndarray | list[np.ndarray] | list[Any]",
|
||||||
|
block_norms: "np.ndarray | list[float] | None" = None,
|
||||||
|
) -> "np.ndarray":
|
||||||
|
"""
|
||||||
|
计算 query 向量与一组向量的余弦相似度(同步/向量化实现)。
|
||||||
|
|
||||||
|
- 返回 float32 ndarray
|
||||||
|
- 输出范围裁剪到 [0.0, 1.0]
|
||||||
|
- 支持可选的 block_norms 以减少重复 norm 计算
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
if block_embeddings is None:
|
||||||
|
return np.zeros(0, dtype=np.float32)
|
||||||
|
|
||||||
|
query = np.asarray(query_embedding, dtype=np.float32)
|
||||||
|
|
||||||
|
if isinstance(block_embeddings, (list, tuple)) and len(block_embeddings) == 0:
|
||||||
|
return np.zeros(0, dtype=np.float32)
|
||||||
|
|
||||||
|
blocks = np.asarray(block_embeddings, dtype=np.float32)
|
||||||
|
if blocks.dtype == object:
|
||||||
|
blocks = np.stack(
|
||||||
|
[np.asarray(vec, dtype=np.float32) for vec in block_embeddings],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if blocks.size == 0:
|
||||||
|
return np.zeros(0, dtype=np.float32)
|
||||||
|
|
||||||
|
if blocks.ndim == 1:
|
||||||
|
blocks = blocks.reshape(1, -1)
|
||||||
|
|
||||||
|
query_norm = float(np.linalg.norm(query))
|
||||||
|
if query_norm == 0.0:
|
||||||
|
return np.zeros(blocks.shape[0], dtype=np.float32)
|
||||||
|
|
||||||
|
if block_norms is None:
|
||||||
|
block_norms_array = np.linalg.norm(blocks, axis=1).astype(np.float32, copy=False)
|
||||||
|
else:
|
||||||
|
block_norms_array = np.asarray(block_norms, dtype=np.float32)
|
||||||
|
if block_norms_array.shape[0] != blocks.shape[0]:
|
||||||
|
block_norms_array = np.linalg.norm(blocks, axis=1).astype(np.float32, copy=False)
|
||||||
|
|
||||||
|
dot_products = blocks @ query
|
||||||
|
denom = block_norms_array * np.float32(query_norm)
|
||||||
|
|
||||||
|
similarities = np.zeros(blocks.shape[0], dtype=np.float32)
|
||||||
|
valid_mask = denom > 0
|
||||||
|
if valid_mask.any():
|
||||||
|
np.divide(dot_products, denom, out=similarities, where=valid_mask)
|
||||||
|
|
||||||
|
return np.clip(similarities, 0.0, 1.0)
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(vec1: "np.ndarray", vec2: "np.ndarray") -> float:
|
def cosine_similarity(vec1: "np.ndarray", vec2: "np.ndarray") -> float:
|
||||||
"""
|
"""
|
||||||
计算两个向量的余弦相似度
|
计算两个向量的余弦相似度
|
||||||
@@ -25,23 +82,16 @@ def cosine_similarity(vec1: "np.ndarray", vec2: "np.ndarray") -> float:
|
|||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# 确保是numpy数组
|
vec1 = np.asarray(vec1, dtype=np.float32)
|
||||||
if not isinstance(vec1, np.ndarray):
|
vec2 = np.asarray(vec2, dtype=np.float32)
|
||||||
vec1 = np.array(vec1)
|
|
||||||
if not isinstance(vec2, np.ndarray):
|
|
||||||
vec2 = np.array(vec2)
|
|
||||||
|
|
||||||
# 归一化
|
vec1_norm = float(np.linalg.norm(vec1))
|
||||||
vec1_norm = np.linalg.norm(vec1)
|
vec2_norm = float(np.linalg.norm(vec2))
|
||||||
vec2_norm = np.linalg.norm(vec2)
|
|
||||||
|
|
||||||
if vec1_norm == 0 or vec2_norm == 0:
|
if vec1_norm == 0.0 or vec2_norm == 0.0:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
# 余弦相似度
|
similarity = float(np.dot(vec1, vec2) / (vec1_norm * vec2_norm))
|
||||||
similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
|
|
||||||
|
|
||||||
# 确保在 [0, 1] 范围内(处理浮点误差)
|
|
||||||
return float(np.clip(similarity, 0.0, 1.0))
|
return float(np.clip(similarity, 0.0, 1.0))
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -74,43 +124,10 @@ def batch_cosine_similarity(vec1: "np.ndarray", vec_list: list["np.ndarray"]) ->
|
|||||||
相似度列表
|
相似度列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
if not vec_list:
|
||||||
|
return []
|
||||||
|
|
||||||
# 确保是numpy数组
|
return _compute_similarities_sync(vec1, vec_list).tolist()
|
||||||
if not isinstance(vec1, np.ndarray):
|
|
||||||
vec1 = np.array(vec1)
|
|
||||||
|
|
||||||
# 批量转换为numpy数组
|
|
||||||
vec_list = [np.array(vec) for vec in vec_list]
|
|
||||||
|
|
||||||
# 计算归一化
|
|
||||||
vec1_norm = np.linalg.norm(vec1)
|
|
||||||
if vec1_norm == 0:
|
|
||||||
return [0.0] * len(vec_list)
|
|
||||||
|
|
||||||
# 计算所有向量的归一化
|
|
||||||
vec_norms = np.array([np.linalg.norm(vec) for vec in vec_list])
|
|
||||||
|
|
||||||
# 避免除以0
|
|
||||||
valid_mask = vec_norms != 0
|
|
||||||
similarities = np.zeros(len(vec_list))
|
|
||||||
|
|
||||||
if np.any(valid_mask):
|
|
||||||
# 批量计算点积
|
|
||||||
valid_vecs = np.array(vec_list)[valid_mask]
|
|
||||||
dot_products = np.dot(valid_vecs, vec1)
|
|
||||||
|
|
||||||
# 计算相似度
|
|
||||||
valid_norms = vec_norms[valid_mask]
|
|
||||||
valid_similarities = dot_products / (vec1_norm * valid_norms)
|
|
||||||
|
|
||||||
# 确保在 [0, 1] 范围内
|
|
||||||
valid_similarities = np.clip(valid_similarities, 0.0, 1.0)
|
|
||||||
|
|
||||||
# 填充结果
|
|
||||||
similarities[valid_mask] = valid_similarities
|
|
||||||
|
|
||||||
return similarities.tolist()
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return [0.0] * len(vec_list)
|
return [0.0] * len(vec_list)
|
||||||
@@ -134,5 +151,5 @@ __all__ = [
|
|||||||
"batch_cosine_similarity",
|
"batch_cosine_similarity",
|
||||||
"batch_cosine_similarity_async",
|
"batch_cosine_similarity_async",
|
||||||
"cosine_similarity",
|
"cosine_similarity",
|
||||||
"cosine_similarity_async"
|
"cosine_similarity_async",
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user