feat(similarity): 添加异步和批量相似度计算功能,优化性能

feat(graph_store): 增强图存储管理,添加边的注册和注销功能
feat(memory_tools): 支持批量生成嵌入向量
feat(unified_manager): 优化感知记忆和短期记忆的检索逻辑
This commit is contained in:
Windpicker-owo
2025-11-20 22:40:53 +08:00
parent 8dc754e562
commit ddc68b9257
9 changed files with 349 additions and 97 deletions

View File

@@ -21,7 +21,7 @@ import numpy as np
from src.common.logger import get_logger from src.common.logger import get_logger
from src.memory_graph.models import MemoryBlock, PerceptualMemory from src.memory_graph.models import MemoryBlock, PerceptualMemory
from src.memory_graph.utils.embeddings import EmbeddingGenerator from src.memory_graph.utils.embeddings import EmbeddingGenerator
from src.memory_graph.utils.similarity import cosine_similarity from src.memory_graph.utils.similarity import cosine_similarity_async, batch_cosine_similarity_async
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -430,14 +430,22 @@ class PerceptualMemoryManager:
logger.warning("查询向量生成失败,返回空列表") logger.warning("查询向量生成失败,返回空列表")
return [] return []
# 计算所有块的相似度 # 批量计算所有块的相似度(使用异步版本)
blocks_with_embeddings = [
block for block in self.perceptual_memory.blocks
if block.embedding is not None
]
if not blocks_with_embeddings:
return []
# 批量计算相似度
block_embeddings = [block.embedding for block in blocks_with_embeddings]
similarities = await batch_cosine_similarity_async(query_embedding, block_embeddings)
# 过滤和排序
scored_blocks = [] scored_blocks = []
for block in self.perceptual_memory.blocks: for block, similarity in zip(blocks_with_embeddings, similarities):
if block.embedding is None:
continue
similarity = cosine_similarity(query_embedding, block.embedding)
# 过滤低于阈值的块 # 过滤低于阈值的块
if similarity >= similarity_threshold: if similarity >= similarity_threshold:
scored_blocks.append((block, similarity)) scored_blocks.append((block, similarity))

View File

@@ -25,7 +25,7 @@ from src.memory_graph.models import (
ShortTermOperation, ShortTermOperation,
) )
from src.memory_graph.utils.embeddings import EmbeddingGenerator from src.memory_graph.utils.embeddings import EmbeddingGenerator
from src.memory_graph.utils.similarity import cosine_similarity from src.memory_graph.utils.similarity import cosine_similarity_async, batch_cosine_similarity_async
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -457,7 +457,7 @@ class ShortTermMemoryManager:
if existing_mem.embedding is None: if existing_mem.embedding is None:
continue continue
similarity = cosine_similarity(memory.embedding, existing_mem.embedding) similarity = await cosine_similarity_async(memory.embedding, existing_mem.embedding)
scored.append((existing_mem, similarity)) scored.append((existing_mem, similarity))
# 按相似度降序排序 # 按相似度降序排序
@@ -567,7 +567,7 @@ class ShortTermMemoryManager:
if memory.embedding is None: if memory.embedding is None:
continue continue
similarity = cosine_similarity(query_embedding, memory.embedding) similarity = await cosine_similarity_async(query_embedding, memory.embedding)
if similarity >= similarity_threshold: if similarity >= similarity_threshold:
scored.append((memory, similarity)) scored.append((memory, similarity))

View File

@@ -4,6 +4,8 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable
import networkx as nx import networkx as nx
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -34,8 +36,50 @@ class GraphStore:
# 索引节点ID -> 所属记忆ID集合 # 索引节点ID -> 所属记忆ID集合
self.node_to_memories: dict[str, set[str]] = {} self.node_to_memories: dict[str, set[str]] = {}
# 节点 -> {memory_id: [MemoryEdge]},用于快速获取邻接边
self.node_edge_index: dict[str, dict[str, list[MemoryEdge]]] = {}
logger.info("初始化图存储") logger.info("初始化图存储")
def _register_memory_edges(self, memory: Memory) -> None:
"""在记忆中的边加入邻接索引"""
for edge in memory.edges:
self._register_edge_reference(memory.id, edge)
def _register_edge_reference(self, memory_id: str, edge: MemoryEdge) -> None:
"""在节点邻接索引中登记一条边"""
for node_id in (edge.source_id, edge.target_id):
node_edges = self.node_edge_index.setdefault(node_id, {})
edge_list = node_edges.setdefault(memory_id, [])
if not any(existing.id == edge.id for existing in edge_list):
edge_list.append(edge)
def _unregister_memory_edges(self, memory: Memory) -> None:
"""从节点邻接索引中移除记忆相关的边"""
for edge in memory.edges:
self._unregister_edge_reference(memory.id, edge)
def _unregister_edge_reference(self, memory_id: str, edge: MemoryEdge) -> None:
"""在节点邻接索引中删除一条边"""
for node_id in (edge.source_id, edge.target_id):
node_edges = self.node_edge_index.get(node_id)
if not node_edges:
continue
if memory_id not in node_edges:
continue
node_edges[memory_id] = [e for e in node_edges[memory_id] if e.id != edge.id]
if not node_edges[memory_id]:
del node_edges[memory_id]
if not node_edges:
del self.node_edge_index[node_id]
def _rebuild_node_edge_index(self) -> None:
"""重建节点邻接索引"""
self.node_edge_index.clear()
for memory in self.memory_index.values():
self._register_memory_edges(memory)
def add_memory(self, memory: Memory) -> None: def add_memory(self, memory: Memory) -> None:
""" """
添加记忆到图 添加记忆到图
@@ -77,6 +121,9 @@ class GraphStore:
# 3. 保存记忆对象 # 3. 保存记忆对象
self.memory_index[memory.id] = memory self.memory_index[memory.id] = memory
# 4. 注册记忆中的边到邻接索引
self._register_memory_edges(memory)
logger.debug(f"添加记忆到图: {memory}") logger.debug(f"添加记忆到图: {memory}")
except Exception as e: except Exception as e:
@@ -112,6 +159,12 @@ class GraphStore:
memory = self.memory_index[memory_id] memory = self.memory_index[memory_id]
# 1.5. 注销记忆中的边的邻接索引记录
self._unregister_memory_edges(memory)
# 1.5. 注销记忆中的边的邻接索引记录
self._unregister_memory_edges(memory)
# 2. 添加节点到图 # 2. 添加节点到图
if not self.graph.has_node(node_id): if not self.graph.has_node(node_id):
from datetime import datetime from datetime import datetime
@@ -282,6 +335,7 @@ class GraphStore:
memory = self.memory_index.get(mem_id) memory = self.memory_index.get(mem_id)
if memory: if memory:
memory.edges.append(new_edge) memory.edges.append(new_edge)
self._register_edge_reference(mem_id, new_edge)
logger.debug(f"添加边成功: {source_id} -> {target_id} ({relation})") logger.debug(f"添加边成功: {source_id} -> {target_id} ({relation})")
return edge_id return edge_id
@@ -393,6 +447,10 @@ class GraphStore:
for mem_id in related_memory_ids: for mem_id in related_memory_ids:
memory = self.memory_index.get(mem_id) memory = self.memory_index.get(mem_id)
if memory: if memory:
removed_edges = [e for e in memory.edges if e.id == edge_id]
if removed_edges:
for edge_obj in removed_edges:
self._unregister_edge_reference(mem_id, edge_obj)
memory.edges = [e for e in memory.edges if e.id != edge_id] memory.edges = [e for e in memory.edges if e.id != edge_id]
return True return True
@@ -440,8 +498,11 @@ class GraphStore:
# 2. 转移边 # 2. 转移边
for edge in source_memory.edges: for edge in source_memory.edges:
# 添加到目标记忆(如果不存在) # 添加到目标记忆(如果不存在)
if not any(e.id == edge.id for e in target_memory.edges): already_exists = any(e.id == edge.id for e in target_memory.edges)
self._unregister_edge_reference(source_id, edge)
if not already_exists:
target_memory.edges.append(edge) target_memory.edges.append(edge)
self._register_edge_reference(target_memory_id, edge)
# 3. 删除源记忆(不清理孤立节点,因为节点已转移) # 3. 删除源记忆(不清理孤立节点,因为节点已转移)
del self.memory_index[source_id] del self.memory_index[source_id]
@@ -465,6 +526,32 @@ class GraphStore:
""" """
return self.memory_index.get(memory_id) return self.memory_index.get(memory_id)
def get_memories_by_ids(self, memory_ids: Iterable[str]) -> dict[str, Memory]:
"""
根据一组ID批量获取记忆
Args:
memory_ids: 记忆ID集合
Returns:
{memory_id: Memory} 映射
"""
result: dict[str, Memory] = {}
missing_ids: list[str] = []
# dict.fromkeys 可以保持参数序列的原始顺序同时帮忙去重
for mem_id in dict.fromkeys(memory_ids):
memory = self.memory_index.get(mem_id)
if memory is not None:
result[mem_id] = memory
else:
missing_ids.append(mem_id)
if missing_ids:
logger.debug(f"批量获取记忆: 未找到 {missing_ids}")
return result
def get_all_memories(self) -> list[Memory]: def get_all_memories(self) -> list[Memory]:
""" """
获取所有记忆 获取所有记忆
@@ -490,6 +577,32 @@ class GraphStore:
memory_ids = self.node_to_memories[node_id] memory_ids = self.node_to_memories[node_id]
return [self.memory_index[mid] for mid in memory_ids if mid in self.memory_index] return [self.memory_index[mid] for mid in memory_ids if mid in self.memory_index]
def get_edges_for_node(self, node_id: str) -> list[MemoryEdge]:
"""
获取节点相关的全部边(包含入边和出边)
Args:
node_id: 节点ID
Returns:
MemoryEdge 列表
"""
node_edges = self.node_edge_index.get(node_id)
if not node_edges:
return []
unique_edges: dict[str | tuple[str, str, str, str], MemoryEdge] = {}
for edges in node_edges.values():
for edge in edges:
key: str | tuple[str, str, str, str]
if edge.id:
key = edge.id
else:
key = (edge.source_id, edge.target_id, edge.relation, edge.edge_type.value)
unique_edges[key] = edge
return list(unique_edges.values())
def get_edges_from_node(self, node_id: str, relation_types: list[str] | None = None) -> list[dict]: def get_edges_from_node(self, node_id: str, relation_types: list[str] | None = None) -> list[dict]:
""" """
获取从指定节点出发的所有边 获取从指定节点出发的所有边
@@ -762,6 +875,8 @@ class GraphStore:
except Exception: except Exception:
logger.exception("同步图边到记忆.edges 失败") logger.exception("同步图边到记忆.edges 失败")
store._rebuild_node_edge_index()
logger.info(f"从字典加载图: {store.get_statistics()}") logger.info(f"从字典加载图: {store.get_statistics()}")
return store return store
@@ -829,6 +944,7 @@ class GraphStore:
existing_edges.setdefault(mid, set()).add(mem_edge.id) existing_edges.setdefault(mid, set()).add(mem_edge.id)
logger.info("已将图中的边同步到 Memory.edges保证 graph 与 memory 对象一致)") logger.info("已将图中的边同步到 Memory.edges保证 graph 与 memory 对象一致)")
self._rebuild_node_edge_index()
def remove_memory(self, memory_id: str, cleanup_orphans: bool = True) -> bool: def remove_memory(self, memory_id: str, cleanup_orphans: bool = True) -> bool:
""" """
@@ -877,4 +993,5 @@ class GraphStore:
self.graph.clear() self.graph.clear()
self.memory_index.clear() self.memory_index.clear()
self.node_to_memories.clear() self.node_to_memories.clear()
self.node_edge_index.clear()
logger.warning("图存储已清空") logger.warning("图存储已清空")

View File

@@ -1171,8 +1171,10 @@ class MemoryTools:
query_embeddings = [] query_embeddings = []
query_weights = [] query_weights = []
for sub_query, weight in multi_queries: batch_texts = [sub_query for sub_query, _ in multi_queries]
embedding = await self.builder.embedding_generator.generate(sub_query) batch_embeddings = await self.builder.embedding_generator.generate_batch(batch_texts)
for (sub_query, weight), embedding in zip(multi_queries, batch_embeddings):
if embedding is not None: if embedding is not None:
query_embeddings.append(embedding) query_embeddings.append(embedding)
query_weights.append(weight) query_weights.append(weight)

View File

@@ -228,8 +228,13 @@ class UnifiedMemoryManager:
} }
# 步骤1: 检索感知记忆和短期记忆 # 步骤1: 检索感知记忆和短期记忆
perceptual_blocks = await self.perceptual_manager.recall_blocks(query_text) perceptual_blocks_task = asyncio.create_task(self.perceptual_manager.recall_blocks(query_text))
short_term_memories = await self.short_term_manager.search_memories(query_text) short_term_memories_task = asyncio.create_task(self.short_term_manager.search_memories(query_text))
perceptual_blocks, short_term_memories = await asyncio.gather(
perceptual_blocks_task,
short_term_memories_task,
)
# 步骤1.5: 检查需要转移的感知块,推迟到后台处理 # 步骤1.5: 检查需要转移的感知块,推迟到后台处理
blocks_to_transfer = [ blocks_to_transfer = [

View File

@@ -4,7 +4,12 @@
from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator
from src.memory_graph.utils.path_expansion import Path, PathExpansionConfig, PathScoreExpansion from src.memory_graph.utils.path_expansion import Path, PathExpansionConfig, PathScoreExpansion
from src.memory_graph.utils.similarity import cosine_similarity from src.memory_graph.utils.similarity import (
cosine_similarity,
cosine_similarity_async,
batch_cosine_similarity,
batch_cosine_similarity_async
)
from src.memory_graph.utils.time_parser import TimeParser from src.memory_graph.utils.time_parser import TimeParser
__all__ = [ __all__ = [
@@ -14,5 +19,8 @@ __all__ = [
"PathScoreExpansion", "PathScoreExpansion",
"TimeParser", "TimeParser",
"cosine_similarity", "cosine_similarity",
"cosine_similarity_async",
"batch_cosine_similarity",
"batch_cosine_similarity_async",
"get_embedding_generator", "get_embedding_generator",
] ]

View File

@@ -137,56 +137,69 @@ class EmbeddingGenerator:
raise ValueError("无法确定嵌入向量维度,请确保已正确配置 embedding API") raise ValueError("无法确定嵌入向量维度,请确保已正确配置 embedding API")
async def generate_batch(self, texts: list[str]) -> list[np.ndarray | None]: async def generate_batch(self, texts: list[str]) -> list[np.ndarray | None]:
""" """保留输入顺序的批量嵌入生成"""
批量生成嵌入向量
Args:
texts: 文本列表
Returns:
嵌入向量列表,失败的项目为 None
"""
if not texts: if not texts:
return [] return []
try: try:
# 过滤空文本 results: list[np.ndarray | None] = [None] * len(texts)
valid_texts = [t for t in texts if t and t.strip()] valid_entries = [
if not valid_texts: (idx, text) for idx, text in enumerate(texts) if text and text.strip()
logger.debug("所有文本为空,返回 None 列表") ]
return [None for _ in texts] if not valid_entries:
logger.debug('批量文本为空,返回空列表')
# 使用 API 批量生成(如果可用)
if self.use_api:
results = await self._generate_batch_with_api(valid_texts)
if results:
return results return results
# 回退到逐个生成 batch_texts = [text for _, text in valid_entries]
results = [] batch_embeddings: list[np.ndarray | None] | None = None
for text in valid_texts:
embedding = await self.generate(text) if self.use_api:
results.append(embedding) batch_embeddings = await self._generate_batch_with_api(batch_texts)
if not batch_embeddings:
batch_embeddings = []
for _, text in valid_entries:
batch_embeddings.append(await self.generate(text))
for (idx, _), embedding in zip(valid_entries, batch_embeddings):
results[idx] = embedding
success_count = sum(1 for r in results if r is not None) success_count = sum(1 for r in results if r is not None)
logger.debug(f"批量生成嵌入: {success_count}/{len(texts)} 个成功") logger.debug(f"批量生成嵌入: {success_count}/{len(texts)}")
return results return results
except Exception as e: except Exception as e:
logger.error(f"❌ 批量嵌入生成失败: {e}", exc_info=True) logger.error(f"批量生成嵌入失败: {e}", exc_info=True)
return [None for _ in texts] return [None for _ in texts]
async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray | None] | None: async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray | None] | None:
"""使用 API 批量生成""" """使用嵌入 API 在单次请求中生成向量"""
if not texts:
return []
try: try:
# 对于大多数 API批量调用就是多次单独调用 if not self._api_available:
# 这里保持简单,逐个调用 await self._initialize_api()
results = []
for text in texts: if not self._api_available or not self._llm_request:
embedding = await self._generate_with_api(text) return None
results.append(embedding) # 失败的项目为 None不中断整个批量处理
embeddings, model_name = await self._llm_request.get_embedding(texts)
if not embeddings:
return None
results: list[np.ndarray | None] = []
for emb in embeddings:
if emb:
results.append(np.array(emb, dtype=np.float32))
else:
results.append(None)
logger.debug(f"API 批量生成 {len(texts)} 个嵌入向量,使用模型: {model_name}")
return results return results
except Exception as e: except Exception as e:
logger.debug(f"API 批量生成失败: {e}") logger.debug(f"API 批量生成失败: {e}")
return None return None

View File

@@ -15,13 +15,14 @@
""" """
import asyncio import asyncio
import heapq
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.memory_graph.utils.similarity import cosine_similarity from src.memory_graph.utils.similarity import cosine_similarity_async
if TYPE_CHECKING: if TYPE_CHECKING:
import numpy as np import numpy as np
@@ -273,9 +274,8 @@ class PathScoreExpansion:
f"⚠️ 路径数量超限 ({len(next_paths)} > {self.config.max_active_paths})" f"⚠️ 路径数量超限 ({len(next_paths)} > {self.config.max_active_paths})"
f"保留 top {self.config.top_paths_retain}" f"保留 top {self.config.top_paths_retain}"
) )
next_paths = sorted(next_paths, key=lambda p: p.score, reverse=True)[ retain = min(self.config.top_paths_retain, len(next_paths))
: self.config.top_paths_retain next_paths = heapq.nlargest(retain, next_paths, key=lambda p: p.score)
]
# 🚀 早停检测:如果路径增长很少,提前终止 # 🚀 早停检测:如果路径增长很少,提前终止
prev_path_count = len(active_paths) prev_path_count = len(active_paths)
@@ -398,22 +398,14 @@ class PathScoreExpansion:
if node_id in self._neighbor_cache: if node_id in self._neighbor_cache:
return self._neighbor_cache[node_id] return self._neighbor_cache[node_id]
edges = [] edges = self.graph_store.get_edges_for_node(node_id)
# 从图存储中获取与该节点相关的所有边 if not edges:
# 需要遍历所有记忆找到包含该节点的边 self._neighbor_cache[node_id] = []
for memory_id in self.graph_store.node_to_memories.get(node_id, []): return []
memory = self.graph_store.get_memory_by_id(memory_id)
if memory:
for edge in memory.edges:
if edge.source_id == node_id or edge.target_id == node_id:
edges.append(edge)
# 去重(同一条边可能出现多次)
unique_edges = list({(e.source_id, e.target_id, e.edge_type): e for e in edges}.values())
# 按边权重排序 # 按边权重排序
unique_edges.sort(key=lambda e: self._get_edge_weight(e), reverse=True) unique_edges = sorted(edges, key=lambda e: self._get_edge_weight(e), reverse=True)
# 🚀 存入缓存 # 🚀 存入缓存
self._neighbor_cache[node_id] = unique_edges self._neighbor_cache[node_id] = unique_edges
@@ -461,7 +453,7 @@ class PathScoreExpansion:
base_score = 0.3 # 无向量的节点给低分 base_score = 0.3 # 无向量的节点给低分
else: else:
node_embedding = node_data["embedding"] node_embedding = node_data["embedding"]
similarity = cosine_similarity(query_embedding, node_embedding) similarity = await cosine_similarity_async(query_embedding, node_embedding)
base_score = max(0.0, min(1.0, similarity)) # 限制在[0, 1] base_score = max(0.0, min(1.0, similarity)) # 限制在[0, 1]
# 🆕 偏好类型加成 # 🆕 偏好类型加成
@@ -522,14 +514,8 @@ class PathScoreExpansion:
node_metadata_map[nid] = node_data.get("metadata", {}) node_metadata_map[nid] = node_data.get("metadata", {})
if valid_embeddings: if valid_embeddings:
# 批量计算相似度(使用矩阵运算) # 批量计算相似度(使用矩阵运算)- 移至to_thread执行
embeddings_matrix = np.array(valid_embeddings) similarities = await asyncio.to_thread(self._batch_compute_similarities, valid_embeddings, query_embedding)
query_norm = np.linalg.norm(query_embedding)
embeddings_norms = np.linalg.norm(embeddings_matrix, axis=1)
# 向量化计算余弦相似度
similarities = np.dot(embeddings_matrix, query_embedding) / (embeddings_norms * query_norm + 1e-8)
similarities = np.clip(similarities, 0.0, 1.0)
# 应用偏好类型加成 # 应用偏好类型加成
for nid, sim in zip(valid_node_ids, similarities): for nid, sim in zip(valid_node_ids, similarities):
@@ -706,11 +692,7 @@ class PathScoreExpansion:
# 🚀 批量获取记忆对象如果graph_store支持批量获取 # 🚀 批量获取记忆对象如果graph_store支持批量获取
# 注意这里假设逐个获取如果有批量API可以进一步优化 # 注意这里假设逐个获取如果有批量API可以进一步优化
memory_cache: dict[str, Any] = {} memory_cache: dict[str, Any] = self.graph_store.get_memories_by_ids(all_memory_ids)
for mem_id in all_memory_ids:
memory = self.graph_store.get_memory_by_id(mem_id)
if memory:
memory_cache[mem_id] = memory
# 构建映射关系 # 构建映射关系
for path in paths: for path in paths:
@@ -749,30 +731,31 @@ class PathScoreExpansion:
node_type_cache: dict[str, str | None] = {} node_type_cache: dict[str, str | None] = {}
if self.prefer_node_types: if self.prefer_node_types:
# 收集所有需要查询的节点ID # 收集所有需要查询的节点ID,并记录记忆中的类型提示
all_node_ids = set() all_node_ids: set[str] = set()
node_type_hints: dict[str, str | None] = {}
for memory, _ in memory_paths.values(): for memory, _ in memory_paths.values():
memory_nodes = getattr(memory, "nodes", []) memory_nodes = getattr(memory, "nodes", [])
for node in memory_nodes: for node in memory_nodes:
node_id = node.id if hasattr(node, "id") else str(node) node_id = node.id if hasattr(node, "id") else str(node)
all_node_ids.add(node_id) all_node_ids.add(node_id)
if node_id not in node_type_hints:
node_obj_type = getattr(node, "node_type", None)
if node_obj_type is not None:
node_type_hints[node_id] = getattr(node_obj_type, "value", str(node_obj_type))
# 批量获取节点数据
if all_node_ids: if all_node_ids:
logger.debug(f"🔍 批量预加载 {len(all_node_ids)} 个节点的类型信息") logger.info(f"🧠 预处理 {len(all_node_ids)} 个节点的类型信息")
node_data_list = await asyncio.gather( for nid in all_node_ids:
*[self.vector_store.get_node_by_id(nid) for nid in all_node_ids], node_attrs = self.graph_store.graph.nodes.get(nid, {}) if hasattr(self.graph_store, "graph") else {}
return_exceptions=True metadata = node_attrs.get("metadata", {}) if isinstance(node_attrs, dict) else {}
) node_type = metadata.get("node_type") or node_attrs.get("node_type")
# 构建类型缓存 if not node_type:
for nid, node_data in zip(all_node_ids, node_data_list): # 回退到记忆中的节点定义
if isinstance(node_data, Exception) or not node_data or not isinstance(node_data, dict): node_type = node_type_hints.get(nid)
node_type_cache[nid] = None
else:
metadata = node_data.get("metadata", {})
node_type_cache[nid] = metadata.get("node_type")
node_type_cache[nid] = node_type
# 遍历所有记忆进行评分 # 遍历所有记忆进行评分
for mem_id, (memory, paths) in memory_paths.items(): for mem_id, (memory, paths) in memory_paths.items():
# 1. 聚合路径分数 # 1. 聚合路径分数
@@ -868,5 +851,33 @@ class PathScoreExpansion:
return recency_score return recency_score
def _batch_compute_similarities(
self,
valid_embeddings: list["np.ndarray"],
query_embedding: "np.ndarray"
) -> "np.ndarray":
"""
批量计算向量相似度CPU密集型操作移至to_thread中执行
Args:
valid_embeddings: 有效的嵌入向量列表
query_embedding: 查询向量
Returns:
相似度数组
"""
import numpy as np
# 批量计算相似度(使用矩阵运算)
embeddings_matrix = np.array(valid_embeddings)
query_norm = np.linalg.norm(query_embedding)
embeddings_norms = np.linalg.norm(embeddings_matrix, axis=1)
# 向量化计算余弦相似度
similarities = np.dot(embeddings_matrix, query_embedding) / (embeddings_norms * query_norm + 1e-8)
similarities = np.clip(similarities, 0.0, 1.0)
return similarities
__all__ = ["Path", "PathExpansionConfig", "PathScoreExpansion"] __all__ = ["Path", "PathExpansionConfig", "PathScoreExpansion"]

View File

@@ -4,6 +4,7 @@
提供统一的向量相似度计算函数 提供统一的向量相似度计算函数
""" """
import asyncio
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -47,4 +48,91 @@ def cosine_similarity(vec1: "np.ndarray", vec2: "np.ndarray") -> float:
return 0.0 return 0.0
__all__ = ["cosine_similarity"] async def cosine_similarity_async(vec1: "np.ndarray", vec2: "np.ndarray") -> float:
"""
异步计算两个向量的余弦相似度使用to_thread避免阻塞
Args:
vec1: 第一个向量
vec2: 第二个向量
Returns:
余弦相似度 (0.0-1.0)
"""
return await asyncio.to_thread(cosine_similarity, vec1, vec2)
def batch_cosine_similarity(vec1: "np.ndarray", vec_list: list["np.ndarray"]) -> list[float]:
"""
批量计算向量相似度
Args:
vec1: 基础向量
vec_list: 待比较的向量列表
Returns:
相似度列表
"""
try:
import numpy as np
# 确保是numpy数组
if not isinstance(vec1, np.ndarray):
vec1 = np.array(vec1)
# 批量转换为numpy数组
vec_list = [np.array(vec) for vec in vec_list]
# 计算归一化
vec1_norm = np.linalg.norm(vec1)
if vec1_norm == 0:
return [0.0] * len(vec_list)
# 计算所有向量的归一化
vec_norms = np.array([np.linalg.norm(vec) for vec in vec_list])
# 避免除以0
valid_mask = vec_norms != 0
similarities = np.zeros(len(vec_list))
if np.any(valid_mask):
# 批量计算点积
valid_vecs = np.array(vec_list)[valid_mask]
dot_products = np.dot(valid_vecs, vec1)
# 计算相似度
valid_norms = vec_norms[valid_mask]
valid_similarities = dot_products / (vec1_norm * valid_norms)
# 确保在 [0, 1] 范围内
valid_similarities = np.clip(valid_similarities, 0.0, 1.0)
# 填充结果
similarities[valid_mask] = valid_similarities
return similarities.tolist()
except Exception:
return [0.0] * len(vec_list)
async def batch_cosine_similarity_async(vec1: "np.ndarray", vec_list: list["np.ndarray"]) -> list[float]:
"""
异步批量计算向量相似度使用to_thread避免阻塞
Args:
vec1: 基础向量
vec_list: 待比较的向量列表
Returns:
相似度列表
"""
return await asyncio.to_thread(batch_cosine_similarity, vec1, vec_list)
__all__ = [
"cosine_similarity",
"cosine_similarity_async",
"batch_cosine_similarity",
"batch_cosine_similarity_async"
]