diff --git a/src/memory_graph/perceptual_manager.py b/src/memory_graph/perceptual_manager.py index 6a41a09d2..b4861b3d8 100644 --- a/src/memory_graph/perceptual_manager.py +++ b/src/memory_graph/perceptual_manager.py @@ -21,7 +21,7 @@ import numpy as np from src.common.logger import get_logger from src.memory_graph.models import MemoryBlock, PerceptualMemory from src.memory_graph.utils.embeddings import EmbeddingGenerator -from src.memory_graph.utils.similarity import cosine_similarity +from src.memory_graph.utils.similarity import cosine_similarity_async, batch_cosine_similarity_async logger = get_logger(__name__) @@ -430,14 +430,22 @@ class PerceptualMemoryManager: logger.warning("查询向量生成失败,返回空列表") return [] - # 计算所有块的相似度 + # 批量计算所有块的相似度(使用异步版本) + blocks_with_embeddings = [ + block for block in self.perceptual_memory.blocks + if block.embedding is not None + ] + + if not blocks_with_embeddings: + return [] + + # 批量计算相似度 + block_embeddings = [block.embedding for block in blocks_with_embeddings] + similarities = await batch_cosine_similarity_async(query_embedding, block_embeddings) + + # 过滤和排序 scored_blocks = [] - for block in self.perceptual_memory.blocks: - if block.embedding is None: - continue - - similarity = cosine_similarity(query_embedding, block.embedding) - + for block, similarity in zip(blocks_with_embeddings, similarities): # 过滤低于阈值的块 if similarity >= similarity_threshold: scored_blocks.append((block, similarity)) diff --git a/src/memory_graph/short_term_manager.py b/src/memory_graph/short_term_manager.py index dbe666e27..979529c4a 100644 --- a/src/memory_graph/short_term_manager.py +++ b/src/memory_graph/short_term_manager.py @@ -25,7 +25,7 @@ from src.memory_graph.models import ( ShortTermOperation, ) from src.memory_graph.utils.embeddings import EmbeddingGenerator -from src.memory_graph.utils.similarity import cosine_similarity +from src.memory_graph.utils.similarity import cosine_similarity_async, batch_cosine_similarity_async logger = get_logger(__name__) @@ -457,7 +457,7 @@ class ShortTermMemoryManager: if existing_mem.embedding is None: continue - similarity = cosine_similarity(memory.embedding, existing_mem.embedding) + similarity = await cosine_similarity_async(memory.embedding, existing_mem.embedding) scored.append((existing_mem, similarity)) # 按相似度降序排序 @@ -567,7 +567,7 @@ class ShortTermMemoryManager: if memory.embedding is None: continue - similarity = cosine_similarity(query_embedding, memory.embedding) + similarity = await cosine_similarity_async(query_embedding, memory.embedding) if similarity >= similarity_threshold: scored.append((memory, similarity)) diff --git a/src/memory_graph/storage/graph_store.py b/src/memory_graph/storage/graph_store.py index 516714f49..2314ac529 100644 --- a/src/memory_graph/storage/graph_store.py +++ b/src/memory_graph/storage/graph_store.py @@ -4,6 +4,8 @@ from __future__ import annotations +from collections.abc import Iterable + import networkx as nx from src.common.logger import get_logger @@ -33,9 +35,51 @@ class GraphStore: # 索引:节点ID -> 所属记忆ID集合 self.node_to_memories: dict[str, set[str]] = {} + + # 节点 -> {memory_id: [MemoryEdge]},用于快速获取邻接边 + self.node_edge_index: dict[str, dict[str, list[MemoryEdge]]] = {} logger.info("初始化图存储") + + def _register_memory_edges(self, memory: Memory) -> None: + """在记忆中的边加入邻接索引""" + for edge in memory.edges: + self._register_edge_reference(memory.id, edge) + + def _register_edge_reference(self, memory_id: str, edge: MemoryEdge) -> None: + """在节点邻接索引中登记一条边""" + for node_id in (edge.source_id, edge.target_id): + node_edges = self.node_edge_index.setdefault(node_id, {}) + edge_list = node_edges.setdefault(memory_id, []) + if not any(existing.id == edge.id for existing in edge_list): + edge_list.append(edge) + + def _unregister_memory_edges(self, memory: Memory) -> None: + """从节点邻接索引中移除记忆相关的边""" + for edge in memory.edges: + self._unregister_edge_reference(memory.id, edge) + + def _unregister_edge_reference(self, memory_id: str, edge: MemoryEdge) -> None: + """在节点邻接索引中删除一条边""" + for node_id in (edge.source_id, edge.target_id): + node_edges = self.node_edge_index.get(node_id) + if not node_edges: + continue + if memory_id not in node_edges: + continue + node_edges[memory_id] = [e for e in node_edges[memory_id] if e.id != edge.id] + if not node_edges[memory_id]: + del node_edges[memory_id] + if not node_edges: + del self.node_edge_index[node_id] + + def _rebuild_node_edge_index(self) -> None: + """重建节点邻接索引""" + self.node_edge_index.clear() + for memory in self.memory_index.values(): + self._register_memory_edges(memory) + def add_memory(self, memory: Memory) -> None: """ 添加记忆到图 @@ -77,6 +121,9 @@ class GraphStore: # 3. 保存记忆对象 self.memory_index[memory.id] = memory + # 4. 注册记忆中的边到邻接索引 + self._register_memory_edges(memory) + logger.debug(f"添加记忆到图: {memory}") except Exception as e: @@ -112,6 +159,12 @@ class GraphStore: memory = self.memory_index[memory_id] + # 1.5. 注销记忆中的边的邻接索引记录 + self._unregister_memory_edges(memory) + + # 1.5. 注销记忆中的边的邻接索引记录 + self._unregister_memory_edges(memory) + # 2. 添加节点到图 if not self.graph.has_node(node_id): from datetime import datetime @@ -282,6 +335,7 @@ class GraphStore: memory = self.memory_index.get(mem_id) if memory: memory.edges.append(new_edge) + self._register_edge_reference(mem_id, new_edge) logger.debug(f"添加边成功: {source_id} -> {target_id} ({relation})") return edge_id @@ -393,6 +447,10 @@ class GraphStore: for mem_id in related_memory_ids: memory = self.memory_index.get(mem_id) if memory: + removed_edges = [e for e in memory.edges if e.id == edge_id] + if removed_edges: + for edge_obj in removed_edges: + self._unregister_edge_reference(mem_id, edge_obj) memory.edges = [e for e in memory.edges if e.id != edge_id] return True @@ -440,8 +498,11 @@ class GraphStore: # 2. 转移边 for edge in source_memory.edges: # 添加到目标记忆(如果不存在) - if not any(e.id == edge.id for e in target_memory.edges): + already_exists = any(e.id == edge.id for e in target_memory.edges) + self._unregister_edge_reference(source_id, edge) + if not already_exists: target_memory.edges.append(edge) + self._register_edge_reference(target_memory_id, edge) # 3. 删除源记忆(不清理孤立节点,因为节点已转移) del self.memory_index[source_id] @@ -465,6 +526,32 @@ class GraphStore: """ return self.memory_index.get(memory_id) + def get_memories_by_ids(self, memory_ids: Iterable[str]) -> dict[str, Memory]: + """ + 根据一组ID批量获取记忆 + + Args: + memory_ids: 记忆ID集合 + + Returns: + {memory_id: Memory} 映射 + """ + result: dict[str, Memory] = {} + missing_ids: list[str] = [] + + # dict.fromkeys 可以保持参数序列的原始顺序同时帮忙去重 + for mem_id in dict.fromkeys(memory_ids): + memory = self.memory_index.get(mem_id) + if memory is not None: + result[mem_id] = memory + else: + missing_ids.append(mem_id) + + if missing_ids: + logger.debug(f"批量获取记忆: 未找到 {missing_ids}") + + return result + def get_all_memories(self) -> list[Memory]: """ 获取所有记忆 @@ -490,6 +577,32 @@ class GraphStore: memory_ids = self.node_to_memories[node_id] return [self.memory_index[mid] for mid in memory_ids if mid in self.memory_index] + def get_edges_for_node(self, node_id: str) -> list[MemoryEdge]: + """ + 获取节点相关的全部边(包含入边和出边) + + Args: + node_id: 节点ID + + Returns: + MemoryEdge 列表 + """ + node_edges = self.node_edge_index.get(node_id) + if not node_edges: + return [] + + unique_edges: dict[str | tuple[str, str, str, str], MemoryEdge] = {} + for edges in node_edges.values(): + for edge in edges: + key: str | tuple[str, str, str, str] + if edge.id: + key = edge.id + else: + key = (edge.source_id, edge.target_id, edge.relation, edge.edge_type.value) + unique_edges[key] = edge + + return list(unique_edges.values()) + def get_edges_from_node(self, node_id: str, relation_types: list[str] | None = None) -> list[dict]: """ 获取从指定节点出发的所有边 @@ -762,6 +875,8 @@ class GraphStore: except Exception: logger.exception("同步图边到记忆.edges 失败") + store._rebuild_node_edge_index() + logger.info(f"从字典加载图: {store.get_statistics()}") return store @@ -829,6 +944,7 @@ class GraphStore: existing_edges.setdefault(mid, set()).add(mem_edge.id) logger.info("已将图中的边同步到 Memory.edges(保证 graph 与 memory 对象一致)") + self._rebuild_node_edge_index() def remove_memory(self, memory_id: str, cleanup_orphans: bool = True) -> bool: """ @@ -877,4 +993,5 @@ class GraphStore: self.graph.clear() self.memory_index.clear() self.node_to_memories.clear() - logger.warning("图存储已清空") + self.node_edge_index.clear() + logger.warning("图存储已清空") \ No newline at end of file diff --git a/src/memory_graph/tools/memory_tools.py b/src/memory_graph/tools/memory_tools.py index a970b2448..69fc6f450 100644 --- a/src/memory_graph/tools/memory_tools.py +++ b/src/memory_graph/tools/memory_tools.py @@ -1171,8 +1171,10 @@ class MemoryTools: query_embeddings = [] query_weights = [] - for sub_query, weight in multi_queries: - embedding = await self.builder.embedding_generator.generate(sub_query) + batch_texts = [sub_query for sub_query, _ in multi_queries] + batch_embeddings = await self.builder.embedding_generator.generate_batch(batch_texts) + + for (sub_query, weight), embedding in zip(multi_queries, batch_embeddings): if embedding is not None: query_embeddings.append(embedding) query_weights.append(weight) diff --git a/src/memory_graph/unified_manager.py b/src/memory_graph/unified_manager.py index 9661fe614..7f033d682 100644 --- a/src/memory_graph/unified_manager.py +++ b/src/memory_graph/unified_manager.py @@ -228,9 +228,14 @@ class UnifiedMemoryManager: } # 步骤1: 检索感知记忆和短期记忆 - perceptual_blocks = await self.perceptual_manager.recall_blocks(query_text) - short_term_memories = await self.short_term_manager.search_memories(query_text) + perceptual_blocks_task = asyncio.create_task(self.perceptual_manager.recall_blocks(query_text)) + short_term_memories_task = asyncio.create_task(self.short_term_manager.search_memories(query_text)) + perceptual_blocks, short_term_memories = await asyncio.gather( + perceptual_blocks_task, + short_term_memories_task, + ) + # 步骤1.5: 检查需要转移的感知块,推迟到后台处理 blocks_to_transfer = [ block diff --git a/src/memory_graph/utils/__init__.py b/src/memory_graph/utils/__init__.py index 72b64e611..dab583400 100644 --- a/src/memory_graph/utils/__init__.py +++ b/src/memory_graph/utils/__init__.py @@ -4,7 +4,12 @@ from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator from src.memory_graph.utils.path_expansion import Path, PathExpansionConfig, PathScoreExpansion -from src.memory_graph.utils.similarity import cosine_similarity +from src.memory_graph.utils.similarity import ( + cosine_similarity, + cosine_similarity_async, + batch_cosine_similarity, + batch_cosine_similarity_async +) from src.memory_graph.utils.time_parser import TimeParser __all__ = [ @@ -14,5 +19,8 @@ __all__ = [ "PathScoreExpansion", "TimeParser", "cosine_similarity", + "cosine_similarity_async", + "batch_cosine_similarity", + "batch_cosine_similarity_async", "get_embedding_generator", ] diff --git a/src/memory_graph/utils/embeddings.py b/src/memory_graph/utils/embeddings.py index 1432d1c8b..5f7836914 100644 --- a/src/memory_graph/utils/embeddings.py +++ b/src/memory_graph/utils/embeddings.py @@ -137,56 +137,69 @@ class EmbeddingGenerator: raise ValueError("无法确定嵌入向量维度,请确保已正确配置 embedding API") + async def generate_batch(self, texts: list[str]) -> list[np.ndarray | None]: - """ - 批量生成嵌入向量 - - Args: - texts: 文本列表 - - Returns: - 嵌入向量列表,失败的项目为 None - """ + """保留输入顺序的批量嵌入生成""" if not texts: return [] try: - # 过滤空文本 - valid_texts = [t for t in texts if t and t.strip()] - if not valid_texts: - logger.debug("所有文本为空,返回 None 列表") - return [None for _ in texts] + results: list[np.ndarray | None] = [None] * len(texts) + valid_entries = [ + (idx, text) for idx, text in enumerate(texts) if text and text.strip() + ] + if not valid_entries: + logger.debug('批量文本为空,返回空列表') + return results + + batch_texts = [text for _, text in valid_entries] + batch_embeddings: list[np.ndarray | None] | None = None - # 使用 API 批量生成(如果可用) if self.use_api: - results = await self._generate_batch_with_api(valid_texts) - if results: - return results + batch_embeddings = await self._generate_batch_with_api(batch_texts) - # 回退到逐个生成 - results = [] - for text in valid_texts: - embedding = await self.generate(text) - results.append(embedding) + if not batch_embeddings: + batch_embeddings = [] + for _, text in valid_entries: + batch_embeddings.append(await self.generate(text)) + + for (idx, _), embedding in zip(valid_entries, batch_embeddings): + results[idx] = embedding success_count = sum(1 for r in results if r is not None) - logger.debug(f"✅ 批量生成嵌入: {success_count}/{len(texts)} 个成功") + logger.debug(f"批量生成嵌入: {success_count}/{len(texts)}") return results except Exception as e: - logger.error(f"❌ 批量嵌入生成失败: {e}", exc_info=True) + logger.error(f"批量生成嵌入失败: {e}", exc_info=True) return [None for _ in texts] async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray | None] | None: - """使用 API 批量生成""" + """使用嵌入 API 在单次请求中生成向量""" + if not texts: + return [] + try: - # 对于大多数 API,批量调用就是多次单独调用 - # 这里保持简单,逐个调用 - results = [] - for text in texts: - embedding = await self._generate_with_api(text) - results.append(embedding) # 失败的项目为 None,不中断整个批量处理 + if not self._api_available: + await self._initialize_api() + + if not self._api_available or not self._llm_request: + return None + + embeddings, model_name = await self._llm_request.get_embedding(texts) + if not embeddings: + return None + + results: list[np.ndarray | None] = [] + for emb in embeddings: + if emb: + results.append(np.array(emb, dtype=np.float32)) + else: + results.append(None) + + logger.debug(f"API 批量生成 {len(texts)} 个嵌入向量,使用模型: {model_name}") return results + except Exception as e: logger.debug(f"API 批量生成失败: {e}") return None diff --git a/src/memory_graph/utils/path_expansion.py b/src/memory_graph/utils/path_expansion.py index 4c80e7553..d6b05b862 100644 --- a/src/memory_graph/utils/path_expansion.py +++ b/src/memory_graph/utils/path_expansion.py @@ -15,13 +15,14 @@ """ import asyncio +import heapq import time from dataclasses import dataclass, field from datetime import datetime, timezone from typing import TYPE_CHECKING, Any from src.common.logger import get_logger -from src.memory_graph.utils.similarity import cosine_similarity +from src.memory_graph.utils.similarity import cosine_similarity_async if TYPE_CHECKING: import numpy as np @@ -273,9 +274,8 @@ class PathScoreExpansion: f"⚠️ 路径数量超限 ({len(next_paths)} > {self.config.max_active_paths})," f"保留 top {self.config.top_paths_retain}" ) - next_paths = sorted(next_paths, key=lambda p: p.score, reverse=True)[ - : self.config.top_paths_retain - ] + retain = min(self.config.top_paths_retain, len(next_paths)) + next_paths = heapq.nlargest(retain, next_paths, key=lambda p: p.score) # 🚀 早停检测:如果路径增长很少,提前终止 prev_path_count = len(active_paths) @@ -398,22 +398,14 @@ class PathScoreExpansion: if node_id in self._neighbor_cache: return self._neighbor_cache[node_id] - edges = [] + edges = self.graph_store.get_edges_for_node(node_id) - # 从图存储中获取与该节点相关的所有边 - # 需要遍历所有记忆找到包含该节点的边 - for memory_id in self.graph_store.node_to_memories.get(node_id, []): - memory = self.graph_store.get_memory_by_id(memory_id) - if memory: - for edge in memory.edges: - if edge.source_id == node_id or edge.target_id == node_id: - edges.append(edge) - - # 去重(同一条边可能出现多次) - unique_edges = list({(e.source_id, e.target_id, e.edge_type): e for e in edges}.values()) + if not edges: + self._neighbor_cache[node_id] = [] + return [] # 按边权重排序 - unique_edges.sort(key=lambda e: self._get_edge_weight(e), reverse=True) + unique_edges = sorted(edges, key=lambda e: self._get_edge_weight(e), reverse=True) # 🚀 存入缓存 self._neighbor_cache[node_id] = unique_edges @@ -461,7 +453,7 @@ class PathScoreExpansion: base_score = 0.3 # 无向量的节点给低分 else: node_embedding = node_data["embedding"] - similarity = cosine_similarity(query_embedding, node_embedding) + similarity = await cosine_similarity_async(query_embedding, node_embedding) base_score = max(0.0, min(1.0, similarity)) # 限制在[0, 1] # 🆕 偏好类型加成 @@ -522,14 +514,8 @@ class PathScoreExpansion: node_metadata_map[nid] = node_data.get("metadata", {}) if valid_embeddings: - # 批量计算相似度(使用矩阵运算) - embeddings_matrix = np.array(valid_embeddings) - query_norm = np.linalg.norm(query_embedding) - embeddings_norms = np.linalg.norm(embeddings_matrix, axis=1) - - # 向量化计算余弦相似度 - similarities = np.dot(embeddings_matrix, query_embedding) / (embeddings_norms * query_norm + 1e-8) - similarities = np.clip(similarities, 0.0, 1.0) + # 批量计算相似度(使用矩阵运算)- 移至to_thread执行 + similarities = await asyncio.to_thread(self._batch_compute_similarities, valid_embeddings, query_embedding) # 应用偏好类型加成 for nid, sim in zip(valid_node_ids, similarities): @@ -706,11 +692,7 @@ class PathScoreExpansion: # 🚀 批量获取记忆对象(如果graph_store支持批量获取) # 注意:这里假设逐个获取,如果有批量API可以进一步优化 - memory_cache: dict[str, Any] = {} - for mem_id in all_memory_ids: - memory = self.graph_store.get_memory_by_id(mem_id) - if memory: - memory_cache[mem_id] = memory + memory_cache: dict[str, Any] = self.graph_store.get_memories_by_ids(all_memory_ids) # 构建映射关系 for path in paths: @@ -749,30 +731,31 @@ class PathScoreExpansion: node_type_cache: dict[str, str | None] = {} if self.prefer_node_types: - # 收集所有需要查询的节点ID - all_node_ids = set() + # 收集所有需要查询的节点ID,并记录记忆中的类型提示 + all_node_ids: set[str] = set() + node_type_hints: dict[str, str | None] = {} for memory, _ in memory_paths.values(): memory_nodes = getattr(memory, "nodes", []) for node in memory_nodes: node_id = node.id if hasattr(node, "id") else str(node) all_node_ids.add(node_id) + if node_id not in node_type_hints: + node_obj_type = getattr(node, "node_type", None) + if node_obj_type is not None: + node_type_hints[node_id] = getattr(node_obj_type, "value", str(node_obj_type)) - # 批量获取节点数据 if all_node_ids: - logger.debug(f"🔍 批量预加载 {len(all_node_ids)} 个节点的类型信息") - node_data_list = await asyncio.gather( - *[self.vector_store.get_node_by_id(nid) for nid in all_node_ids], - return_exceptions=True - ) + logger.info(f"🧠 预处理 {len(all_node_ids)} 个节点的类型信息") + for nid in all_node_ids: + node_attrs = self.graph_store.graph.nodes.get(nid, {}) if hasattr(self.graph_store, "graph") else {} + metadata = node_attrs.get("metadata", {}) if isinstance(node_attrs, dict) else {} + node_type = metadata.get("node_type") or node_attrs.get("node_type") - # 构建类型缓存 - for nid, node_data in zip(all_node_ids, node_data_list): - if isinstance(node_data, Exception) or not node_data or not isinstance(node_data, dict): - node_type_cache[nid] = None - else: - metadata = node_data.get("metadata", {}) - node_type_cache[nid] = metadata.get("node_type") + if not node_type: + # 回退到记忆中的节点定义 + node_type = node_type_hints.get(nid) + node_type_cache[nid] = node_type # 遍历所有记忆进行评分 for mem_id, (memory, paths) in memory_paths.items(): # 1. 聚合路径分数 @@ -868,5 +851,33 @@ class PathScoreExpansion: return recency_score + def _batch_compute_similarities( + self, + valid_embeddings: list["np.ndarray"], + query_embedding: "np.ndarray" + ) -> "np.ndarray": + """ + 批量计算向量相似度(CPU密集型操作,移至to_thread中执行) + + Args: + valid_embeddings: 有效的嵌入向量列表 + query_embedding: 查询向量 + + Returns: + 相似度数组 + """ + import numpy as np + + # 批量计算相似度(使用矩阵运算) + embeddings_matrix = np.array(valid_embeddings) + query_norm = np.linalg.norm(query_embedding) + embeddings_norms = np.linalg.norm(embeddings_matrix, axis=1) + + # 向量化计算余弦相似度 + similarities = np.dot(embeddings_matrix, query_embedding) / (embeddings_norms * query_norm + 1e-8) + similarities = np.clip(similarities, 0.0, 1.0) + + return similarities + __all__ = ["Path", "PathExpansionConfig", "PathScoreExpansion"] diff --git a/src/memory_graph/utils/similarity.py b/src/memory_graph/utils/similarity.py index d610cfda4..0c0c3c13c 100644 --- a/src/memory_graph/utils/similarity.py +++ b/src/memory_graph/utils/similarity.py @@ -4,6 +4,7 @@ 提供统一的向量相似度计算函数 """ +import asyncio from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -47,4 +48,91 @@ def cosine_similarity(vec1: "np.ndarray", vec2: "np.ndarray") -> float: return 0.0 -__all__ = ["cosine_similarity"] +async def cosine_similarity_async(vec1: "np.ndarray", vec2: "np.ndarray") -> float: + """ + 异步计算两个向量的余弦相似度,使用to_thread避免阻塞 + + Args: + vec1: 第一个向量 + vec2: 第二个向量 + + Returns: + 余弦相似度 (0.0-1.0) + """ + return await asyncio.to_thread(cosine_similarity, vec1, vec2) + + +def batch_cosine_similarity(vec1: "np.ndarray", vec_list: list["np.ndarray"]) -> list[float]: + """ + 批量计算向量相似度 + + Args: + vec1: 基础向量 + vec_list: 待比较的向量列表 + + Returns: + 相似度列表 + """ + try: + import numpy as np + + # 确保是numpy数组 + if not isinstance(vec1, np.ndarray): + vec1 = np.array(vec1) + + # 批量转换为numpy数组 + vec_list = [np.array(vec) for vec in vec_list] + + # 计算归一化 + vec1_norm = np.linalg.norm(vec1) + if vec1_norm == 0: + return [0.0] * len(vec_list) + + # 计算所有向量的归一化 + vec_norms = np.array([np.linalg.norm(vec) for vec in vec_list]) + + # 避免除以0 + valid_mask = vec_norms != 0 + similarities = np.zeros(len(vec_list)) + + if np.any(valid_mask): + # 批量计算点积 + valid_vecs = np.array(vec_list)[valid_mask] + dot_products = np.dot(valid_vecs, vec1) + + # 计算相似度 + valid_norms = vec_norms[valid_mask] + valid_similarities = dot_products / (vec1_norm * valid_norms) + + # 确保在 [0, 1] 范围内 + valid_similarities = np.clip(valid_similarities, 0.0, 1.0) + + # 填充结果 + similarities[valid_mask] = valid_similarities + + return similarities.tolist() + + except Exception: + return [0.0] * len(vec_list) + + +async def batch_cosine_similarity_async(vec1: "np.ndarray", vec_list: list["np.ndarray"]) -> list[float]: + """ + 异步批量计算向量相似度,使用to_thread避免阻塞 + + Args: + vec1: 基础向量 + vec_list: 待比较的向量列表 + + Returns: + 相似度列表 + """ + return await asyncio.to_thread(batch_cosine_similarity, vec1, vec_list) + + +__all__ = [ + "cosine_similarity", + "cosine_similarity_async", + "batch_cosine_similarity", + "batch_cosine_similarity_async" +]