feat(similarity): 添加异步和批量相似度计算功能,优化性能
feat(graph_store): 增强图存储管理,添加边的注册和注销功能 feat(memory_tools): 支持批量生成嵌入向量 feat(unified_manager): 优化感知记忆和短期记忆的检索逻辑
This commit is contained in:
@@ -21,7 +21,7 @@ import numpy as np
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.models import MemoryBlock, PerceptualMemory
|
||||
from src.memory_graph.utils.embeddings import EmbeddingGenerator
|
||||
from src.memory_graph.utils.similarity import cosine_similarity
|
||||
from src.memory_graph.utils.similarity import cosine_similarity_async, batch_cosine_similarity_async
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -430,14 +430,22 @@ class PerceptualMemoryManager:
|
||||
logger.warning("查询向量生成失败,返回空列表")
|
||||
return []
|
||||
|
||||
# 计算所有块的相似度
|
||||
# 批量计算所有块的相似度(使用异步版本)
|
||||
blocks_with_embeddings = [
|
||||
block for block in self.perceptual_memory.blocks
|
||||
if block.embedding is not None
|
||||
]
|
||||
|
||||
if not blocks_with_embeddings:
|
||||
return []
|
||||
|
||||
# 批量计算相似度
|
||||
block_embeddings = [block.embedding for block in blocks_with_embeddings]
|
||||
similarities = await batch_cosine_similarity_async(query_embedding, block_embeddings)
|
||||
|
||||
# 过滤和排序
|
||||
scored_blocks = []
|
||||
for block in self.perceptual_memory.blocks:
|
||||
if block.embedding is None:
|
||||
continue
|
||||
|
||||
similarity = cosine_similarity(query_embedding, block.embedding)
|
||||
|
||||
for block, similarity in zip(blocks_with_embeddings, similarities):
|
||||
# 过滤低于阈值的块
|
||||
if similarity >= similarity_threshold:
|
||||
scored_blocks.append((block, similarity))
|
||||
|
||||
@@ -25,7 +25,7 @@ from src.memory_graph.models import (
|
||||
ShortTermOperation,
|
||||
)
|
||||
from src.memory_graph.utils.embeddings import EmbeddingGenerator
|
||||
from src.memory_graph.utils.similarity import cosine_similarity
|
||||
from src.memory_graph.utils.similarity import cosine_similarity_async, batch_cosine_similarity_async
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -457,7 +457,7 @@ class ShortTermMemoryManager:
|
||||
if existing_mem.embedding is None:
|
||||
continue
|
||||
|
||||
similarity = cosine_similarity(memory.embedding, existing_mem.embedding)
|
||||
similarity = await cosine_similarity_async(memory.embedding, existing_mem.embedding)
|
||||
scored.append((existing_mem, similarity))
|
||||
|
||||
# 按相似度降序排序
|
||||
@@ -567,7 +567,7 @@ class ShortTermMemoryManager:
|
||||
if memory.embedding is None:
|
||||
continue
|
||||
|
||||
similarity = cosine_similarity(query_embedding, memory.embedding)
|
||||
similarity = await cosine_similarity_async(query_embedding, memory.embedding)
|
||||
if similarity >= similarity_threshold:
|
||||
scored.append((memory, similarity))
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -33,9 +35,51 @@ class GraphStore:
|
||||
|
||||
# 索引:节点ID -> 所属记忆ID集合
|
||||
self.node_to_memories: dict[str, set[str]] = {}
|
||||
|
||||
# 节点 -> {memory_id: [MemoryEdge]},用于快速获取邻接边
|
||||
self.node_edge_index: dict[str, dict[str, list[MemoryEdge]]] = {}
|
||||
|
||||
logger.info("初始化图存储")
|
||||
|
||||
|
||||
def _register_memory_edges(self, memory: Memory) -> None:
|
||||
"""在记忆中的边加入邻接索引"""
|
||||
for edge in memory.edges:
|
||||
self._register_edge_reference(memory.id, edge)
|
||||
|
||||
def _register_edge_reference(self, memory_id: str, edge: MemoryEdge) -> None:
|
||||
"""在节点邻接索引中登记一条边"""
|
||||
for node_id in (edge.source_id, edge.target_id):
|
||||
node_edges = self.node_edge_index.setdefault(node_id, {})
|
||||
edge_list = node_edges.setdefault(memory_id, [])
|
||||
if not any(existing.id == edge.id for existing in edge_list):
|
||||
edge_list.append(edge)
|
||||
|
||||
def _unregister_memory_edges(self, memory: Memory) -> None:
|
||||
"""从节点邻接索引中移除记忆相关的边"""
|
||||
for edge in memory.edges:
|
||||
self._unregister_edge_reference(memory.id, edge)
|
||||
|
||||
def _unregister_edge_reference(self, memory_id: str, edge: MemoryEdge) -> None:
|
||||
"""在节点邻接索引中删除一条边"""
|
||||
for node_id in (edge.source_id, edge.target_id):
|
||||
node_edges = self.node_edge_index.get(node_id)
|
||||
if not node_edges:
|
||||
continue
|
||||
if memory_id not in node_edges:
|
||||
continue
|
||||
node_edges[memory_id] = [e for e in node_edges[memory_id] if e.id != edge.id]
|
||||
if not node_edges[memory_id]:
|
||||
del node_edges[memory_id]
|
||||
if not node_edges:
|
||||
del self.node_edge_index[node_id]
|
||||
|
||||
def _rebuild_node_edge_index(self) -> None:
|
||||
"""重建节点邻接索引"""
|
||||
self.node_edge_index.clear()
|
||||
for memory in self.memory_index.values():
|
||||
self._register_memory_edges(memory)
|
||||
|
||||
def add_memory(self, memory: Memory) -> None:
|
||||
"""
|
||||
添加记忆到图
|
||||
@@ -77,6 +121,9 @@ class GraphStore:
|
||||
# 3. 保存记忆对象
|
||||
self.memory_index[memory.id] = memory
|
||||
|
||||
# 4. 注册记忆中的边到邻接索引
|
||||
self._register_memory_edges(memory)
|
||||
|
||||
logger.debug(f"添加记忆到图: {memory}")
|
||||
|
||||
except Exception as e:
|
||||
@@ -112,6 +159,12 @@ class GraphStore:
|
||||
|
||||
memory = self.memory_index[memory_id]
|
||||
|
||||
# 1.5. 注销记忆中的边的邻接索引记录
|
||||
self._unregister_memory_edges(memory)
|
||||
|
||||
# 1.5. 注销记忆中的边的邻接索引记录
|
||||
self._unregister_memory_edges(memory)
|
||||
|
||||
# 2. 添加节点到图
|
||||
if not self.graph.has_node(node_id):
|
||||
from datetime import datetime
|
||||
@@ -282,6 +335,7 @@ class GraphStore:
|
||||
memory = self.memory_index.get(mem_id)
|
||||
if memory:
|
||||
memory.edges.append(new_edge)
|
||||
self._register_edge_reference(mem_id, new_edge)
|
||||
|
||||
logger.debug(f"添加边成功: {source_id} -> {target_id} ({relation})")
|
||||
return edge_id
|
||||
@@ -393,6 +447,10 @@ class GraphStore:
|
||||
for mem_id in related_memory_ids:
|
||||
memory = self.memory_index.get(mem_id)
|
||||
if memory:
|
||||
removed_edges = [e for e in memory.edges if e.id == edge_id]
|
||||
if removed_edges:
|
||||
for edge_obj in removed_edges:
|
||||
self._unregister_edge_reference(mem_id, edge_obj)
|
||||
memory.edges = [e for e in memory.edges if e.id != edge_id]
|
||||
|
||||
return True
|
||||
@@ -440,8 +498,11 @@ class GraphStore:
|
||||
# 2. 转移边
|
||||
for edge in source_memory.edges:
|
||||
# 添加到目标记忆(如果不存在)
|
||||
if not any(e.id == edge.id for e in target_memory.edges):
|
||||
already_exists = any(e.id == edge.id for e in target_memory.edges)
|
||||
self._unregister_edge_reference(source_id, edge)
|
||||
if not already_exists:
|
||||
target_memory.edges.append(edge)
|
||||
self._register_edge_reference(target_memory_id, edge)
|
||||
|
||||
# 3. 删除源记忆(不清理孤立节点,因为节点已转移)
|
||||
del self.memory_index[source_id]
|
||||
@@ -465,6 +526,32 @@ class GraphStore:
|
||||
"""
|
||||
return self.memory_index.get(memory_id)
|
||||
|
||||
def get_memories_by_ids(self, memory_ids: Iterable[str]) -> dict[str, Memory]:
|
||||
"""
|
||||
根据一组ID批量获取记忆
|
||||
|
||||
Args:
|
||||
memory_ids: 记忆ID集合
|
||||
|
||||
Returns:
|
||||
{memory_id: Memory} 映射
|
||||
"""
|
||||
result: dict[str, Memory] = {}
|
||||
missing_ids: list[str] = []
|
||||
|
||||
# dict.fromkeys 可以保持参数序列的原始顺序同时帮忙去重
|
||||
for mem_id in dict.fromkeys(memory_ids):
|
||||
memory = self.memory_index.get(mem_id)
|
||||
if memory is not None:
|
||||
result[mem_id] = memory
|
||||
else:
|
||||
missing_ids.append(mem_id)
|
||||
|
||||
if missing_ids:
|
||||
logger.debug(f"批量获取记忆: 未找到 {missing_ids}")
|
||||
|
||||
return result
|
||||
|
||||
def get_all_memories(self) -> list[Memory]:
|
||||
"""
|
||||
获取所有记忆
|
||||
@@ -490,6 +577,32 @@ class GraphStore:
|
||||
memory_ids = self.node_to_memories[node_id]
|
||||
return [self.memory_index[mid] for mid in memory_ids if mid in self.memory_index]
|
||||
|
||||
def get_edges_for_node(self, node_id: str) -> list[MemoryEdge]:
|
||||
"""
|
||||
获取节点相关的全部边(包含入边和出边)
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
|
||||
Returns:
|
||||
MemoryEdge 列表
|
||||
"""
|
||||
node_edges = self.node_edge_index.get(node_id)
|
||||
if not node_edges:
|
||||
return []
|
||||
|
||||
unique_edges: dict[str | tuple[str, str, str, str], MemoryEdge] = {}
|
||||
for edges in node_edges.values():
|
||||
for edge in edges:
|
||||
key: str | tuple[str, str, str, str]
|
||||
if edge.id:
|
||||
key = edge.id
|
||||
else:
|
||||
key = (edge.source_id, edge.target_id, edge.relation, edge.edge_type.value)
|
||||
unique_edges[key] = edge
|
||||
|
||||
return list(unique_edges.values())
|
||||
|
||||
def get_edges_from_node(self, node_id: str, relation_types: list[str] | None = None) -> list[dict]:
|
||||
"""
|
||||
获取从指定节点出发的所有边
|
||||
@@ -762,6 +875,8 @@ class GraphStore:
|
||||
except Exception:
|
||||
logger.exception("同步图边到记忆.edges 失败")
|
||||
|
||||
store._rebuild_node_edge_index()
|
||||
|
||||
logger.info(f"从字典加载图: {store.get_statistics()}")
|
||||
return store
|
||||
|
||||
@@ -829,6 +944,7 @@ class GraphStore:
|
||||
existing_edges.setdefault(mid, set()).add(mem_edge.id)
|
||||
|
||||
logger.info("已将图中的边同步到 Memory.edges(保证 graph 与 memory 对象一致)")
|
||||
self._rebuild_node_edge_index()
|
||||
|
||||
def remove_memory(self, memory_id: str, cleanup_orphans: bool = True) -> bool:
|
||||
"""
|
||||
@@ -877,4 +993,5 @@ class GraphStore:
|
||||
self.graph.clear()
|
||||
self.memory_index.clear()
|
||||
self.node_to_memories.clear()
|
||||
logger.warning("图存储已清空")
|
||||
self.node_edge_index.clear()
|
||||
logger.warning("图存储已清空")
|
||||
@@ -1171,8 +1171,10 @@ class MemoryTools:
|
||||
query_embeddings = []
|
||||
query_weights = []
|
||||
|
||||
for sub_query, weight in multi_queries:
|
||||
embedding = await self.builder.embedding_generator.generate(sub_query)
|
||||
batch_texts = [sub_query for sub_query, _ in multi_queries]
|
||||
batch_embeddings = await self.builder.embedding_generator.generate_batch(batch_texts)
|
||||
|
||||
for (sub_query, weight), embedding in zip(multi_queries, batch_embeddings):
|
||||
if embedding is not None:
|
||||
query_embeddings.append(embedding)
|
||||
query_weights.append(weight)
|
||||
|
||||
@@ -228,9 +228,14 @@ class UnifiedMemoryManager:
|
||||
}
|
||||
|
||||
# 步骤1: 检索感知记忆和短期记忆
|
||||
perceptual_blocks = await self.perceptual_manager.recall_blocks(query_text)
|
||||
short_term_memories = await self.short_term_manager.search_memories(query_text)
|
||||
perceptual_blocks_task = asyncio.create_task(self.perceptual_manager.recall_blocks(query_text))
|
||||
short_term_memories_task = asyncio.create_task(self.short_term_manager.search_memories(query_text))
|
||||
|
||||
perceptual_blocks, short_term_memories = await asyncio.gather(
|
||||
perceptual_blocks_task,
|
||||
short_term_memories_task,
|
||||
)
|
||||
|
||||
# 步骤1.5: 检查需要转移的感知块,推迟到后台处理
|
||||
blocks_to_transfer = [
|
||||
block
|
||||
|
||||
@@ -4,7 +4,12 @@
|
||||
|
||||
from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator
|
||||
from src.memory_graph.utils.path_expansion import Path, PathExpansionConfig, PathScoreExpansion
|
||||
from src.memory_graph.utils.similarity import cosine_similarity
|
||||
from src.memory_graph.utils.similarity import (
|
||||
cosine_similarity,
|
||||
cosine_similarity_async,
|
||||
batch_cosine_similarity,
|
||||
batch_cosine_similarity_async
|
||||
)
|
||||
from src.memory_graph.utils.time_parser import TimeParser
|
||||
|
||||
__all__ = [
|
||||
@@ -14,5 +19,8 @@ __all__ = [
|
||||
"PathScoreExpansion",
|
||||
"TimeParser",
|
||||
"cosine_similarity",
|
||||
"cosine_similarity_async",
|
||||
"batch_cosine_similarity",
|
||||
"batch_cosine_similarity_async",
|
||||
"get_embedding_generator",
|
||||
]
|
||||
|
||||
@@ -137,56 +137,69 @@ class EmbeddingGenerator:
|
||||
|
||||
raise ValueError("无法确定嵌入向量维度,请确保已正确配置 embedding API")
|
||||
|
||||
|
||||
async def generate_batch(self, texts: list[str]) -> list[np.ndarray | None]:
|
||||
"""
|
||||
批量生成嵌入向量
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
|
||||
Returns:
|
||||
嵌入向量列表,失败的项目为 None
|
||||
"""
|
||||
"""保留输入顺序的批量嵌入生成"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 过滤空文本
|
||||
valid_texts = [t for t in texts if t and t.strip()]
|
||||
if not valid_texts:
|
||||
logger.debug("所有文本为空,返回 None 列表")
|
||||
return [None for _ in texts]
|
||||
results: list[np.ndarray | None] = [None] * len(texts)
|
||||
valid_entries = [
|
||||
(idx, text) for idx, text in enumerate(texts) if text and text.strip()
|
||||
]
|
||||
if not valid_entries:
|
||||
logger.debug('批量文本为空,返回空列表')
|
||||
return results
|
||||
|
||||
batch_texts = [text for _, text in valid_entries]
|
||||
batch_embeddings: list[np.ndarray | None] | None = None
|
||||
|
||||
# 使用 API 批量生成(如果可用)
|
||||
if self.use_api:
|
||||
results = await self._generate_batch_with_api(valid_texts)
|
||||
if results:
|
||||
return results
|
||||
batch_embeddings = await self._generate_batch_with_api(batch_texts)
|
||||
|
||||
# 回退到逐个生成
|
||||
results = []
|
||||
for text in valid_texts:
|
||||
embedding = await self.generate(text)
|
||||
results.append(embedding)
|
||||
if not batch_embeddings:
|
||||
batch_embeddings = []
|
||||
for _, text in valid_entries:
|
||||
batch_embeddings.append(await self.generate(text))
|
||||
|
||||
for (idx, _), embedding in zip(valid_entries, batch_embeddings):
|
||||
results[idx] = embedding
|
||||
|
||||
success_count = sum(1 for r in results if r is not None)
|
||||
logger.debug(f"✅ 批量生成嵌入: {success_count}/{len(texts)} 个成功")
|
||||
logger.debug(f"批量生成嵌入: {success_count}/{len(texts)}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 批量嵌入生成失败: {e}", exc_info=True)
|
||||
logger.error(f"批量生成嵌入失败: {e}", exc_info=True)
|
||||
return [None for _ in texts]
|
||||
|
||||
async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray | None] | None:
|
||||
"""使用 API 批量生成"""
|
||||
"""使用嵌入 API 在单次请求中生成向量"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 对于大多数 API,批量调用就是多次单独调用
|
||||
# 这里保持简单,逐个调用
|
||||
results = []
|
||||
for text in texts:
|
||||
embedding = await self._generate_with_api(text)
|
||||
results.append(embedding) # 失败的项目为 None,不中断整个批量处理
|
||||
if not self._api_available:
|
||||
await self._initialize_api()
|
||||
|
||||
if not self._api_available or not self._llm_request:
|
||||
return None
|
||||
|
||||
embeddings, model_name = await self._llm_request.get_embedding(texts)
|
||||
if not embeddings:
|
||||
return None
|
||||
|
||||
results: list[np.ndarray | None] = []
|
||||
for emb in embeddings:
|
||||
if emb:
|
||||
results.append(np.array(emb, dtype=np.float32))
|
||||
else:
|
||||
results.append(None)
|
||||
|
||||
logger.debug(f"API 批量生成 {len(texts)} 个嵌入向量,使用模型: {model_name}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"API 批量生成失败: {e}")
|
||||
return None
|
||||
|
||||
@@ -15,13 +15,14 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import heapq
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.utils.similarity import cosine_similarity
|
||||
from src.memory_graph.utils.similarity import cosine_similarity_async
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
@@ -273,9 +274,8 @@ class PathScoreExpansion:
|
||||
f"⚠️ 路径数量超限 ({len(next_paths)} > {self.config.max_active_paths}),"
|
||||
f"保留 top {self.config.top_paths_retain}"
|
||||
)
|
||||
next_paths = sorted(next_paths, key=lambda p: p.score, reverse=True)[
|
||||
: self.config.top_paths_retain
|
||||
]
|
||||
retain = min(self.config.top_paths_retain, len(next_paths))
|
||||
next_paths = heapq.nlargest(retain, next_paths, key=lambda p: p.score)
|
||||
|
||||
# 🚀 早停检测:如果路径增长很少,提前终止
|
||||
prev_path_count = len(active_paths)
|
||||
@@ -398,22 +398,14 @@ class PathScoreExpansion:
|
||||
if node_id in self._neighbor_cache:
|
||||
return self._neighbor_cache[node_id]
|
||||
|
||||
edges = []
|
||||
edges = self.graph_store.get_edges_for_node(node_id)
|
||||
|
||||
# 从图存储中获取与该节点相关的所有边
|
||||
# 需要遍历所有记忆找到包含该节点的边
|
||||
for memory_id in self.graph_store.node_to_memories.get(node_id, []):
|
||||
memory = self.graph_store.get_memory_by_id(memory_id)
|
||||
if memory:
|
||||
for edge in memory.edges:
|
||||
if edge.source_id == node_id or edge.target_id == node_id:
|
||||
edges.append(edge)
|
||||
|
||||
# 去重(同一条边可能出现多次)
|
||||
unique_edges = list({(e.source_id, e.target_id, e.edge_type): e for e in edges}.values())
|
||||
if not edges:
|
||||
self._neighbor_cache[node_id] = []
|
||||
return []
|
||||
|
||||
# 按边权重排序
|
||||
unique_edges.sort(key=lambda e: self._get_edge_weight(e), reverse=True)
|
||||
unique_edges = sorted(edges, key=lambda e: self._get_edge_weight(e), reverse=True)
|
||||
|
||||
# 🚀 存入缓存
|
||||
self._neighbor_cache[node_id] = unique_edges
|
||||
@@ -461,7 +453,7 @@ class PathScoreExpansion:
|
||||
base_score = 0.3 # 无向量的节点给低分
|
||||
else:
|
||||
node_embedding = node_data["embedding"]
|
||||
similarity = cosine_similarity(query_embedding, node_embedding)
|
||||
similarity = await cosine_similarity_async(query_embedding, node_embedding)
|
||||
base_score = max(0.0, min(1.0, similarity)) # 限制在[0, 1]
|
||||
|
||||
# 🆕 偏好类型加成
|
||||
@@ -522,14 +514,8 @@ class PathScoreExpansion:
|
||||
node_metadata_map[nid] = node_data.get("metadata", {})
|
||||
|
||||
if valid_embeddings:
|
||||
# 批量计算相似度(使用矩阵运算)
|
||||
embeddings_matrix = np.array(valid_embeddings)
|
||||
query_norm = np.linalg.norm(query_embedding)
|
||||
embeddings_norms = np.linalg.norm(embeddings_matrix, axis=1)
|
||||
|
||||
# 向量化计算余弦相似度
|
||||
similarities = np.dot(embeddings_matrix, query_embedding) / (embeddings_norms * query_norm + 1e-8)
|
||||
similarities = np.clip(similarities, 0.0, 1.0)
|
||||
# 批量计算相似度(使用矩阵运算)- 移至to_thread执行
|
||||
similarities = await asyncio.to_thread(self._batch_compute_similarities, valid_embeddings, query_embedding)
|
||||
|
||||
# 应用偏好类型加成
|
||||
for nid, sim in zip(valid_node_ids, similarities):
|
||||
@@ -706,11 +692,7 @@ class PathScoreExpansion:
|
||||
|
||||
# 🚀 批量获取记忆对象(如果graph_store支持批量获取)
|
||||
# 注意:这里假设逐个获取,如果有批量API可以进一步优化
|
||||
memory_cache: dict[str, Any] = {}
|
||||
for mem_id in all_memory_ids:
|
||||
memory = self.graph_store.get_memory_by_id(mem_id)
|
||||
if memory:
|
||||
memory_cache[mem_id] = memory
|
||||
memory_cache: dict[str, Any] = self.graph_store.get_memories_by_ids(all_memory_ids)
|
||||
|
||||
# 构建映射关系
|
||||
for path in paths:
|
||||
@@ -749,30 +731,31 @@ class PathScoreExpansion:
|
||||
node_type_cache: dict[str, str | None] = {}
|
||||
|
||||
if self.prefer_node_types:
|
||||
# 收集所有需要查询的节点ID
|
||||
all_node_ids = set()
|
||||
# 收集所有需要查询的节点ID,并记录记忆中的类型提示
|
||||
all_node_ids: set[str] = set()
|
||||
node_type_hints: dict[str, str | None] = {}
|
||||
for memory, _ in memory_paths.values():
|
||||
memory_nodes = getattr(memory, "nodes", [])
|
||||
for node in memory_nodes:
|
||||
node_id = node.id if hasattr(node, "id") else str(node)
|
||||
all_node_ids.add(node_id)
|
||||
if node_id not in node_type_hints:
|
||||
node_obj_type = getattr(node, "node_type", None)
|
||||
if node_obj_type is not None:
|
||||
node_type_hints[node_id] = getattr(node_obj_type, "value", str(node_obj_type))
|
||||
|
||||
# 批量获取节点数据
|
||||
if all_node_ids:
|
||||
logger.debug(f"🔍 批量预加载 {len(all_node_ids)} 个节点的类型信息")
|
||||
node_data_list = await asyncio.gather(
|
||||
*[self.vector_store.get_node_by_id(nid) for nid in all_node_ids],
|
||||
return_exceptions=True
|
||||
)
|
||||
logger.info(f"🧠 预处理 {len(all_node_ids)} 个节点的类型信息")
|
||||
for nid in all_node_ids:
|
||||
node_attrs = self.graph_store.graph.nodes.get(nid, {}) if hasattr(self.graph_store, "graph") else {}
|
||||
metadata = node_attrs.get("metadata", {}) if isinstance(node_attrs, dict) else {}
|
||||
node_type = metadata.get("node_type") or node_attrs.get("node_type")
|
||||
|
||||
# 构建类型缓存
|
||||
for nid, node_data in zip(all_node_ids, node_data_list):
|
||||
if isinstance(node_data, Exception) or not node_data or not isinstance(node_data, dict):
|
||||
node_type_cache[nid] = None
|
||||
else:
|
||||
metadata = node_data.get("metadata", {})
|
||||
node_type_cache[nid] = metadata.get("node_type")
|
||||
if not node_type:
|
||||
# 回退到记忆中的节点定义
|
||||
node_type = node_type_hints.get(nid)
|
||||
|
||||
node_type_cache[nid] = node_type
|
||||
# 遍历所有记忆进行评分
|
||||
for mem_id, (memory, paths) in memory_paths.items():
|
||||
# 1. 聚合路径分数
|
||||
@@ -868,5 +851,33 @@ class PathScoreExpansion:
|
||||
|
||||
return recency_score
|
||||
|
||||
def _batch_compute_similarities(
|
||||
self,
|
||||
valid_embeddings: list["np.ndarray"],
|
||||
query_embedding: "np.ndarray"
|
||||
) -> "np.ndarray":
|
||||
"""
|
||||
批量计算向量相似度(CPU密集型操作,移至to_thread中执行)
|
||||
|
||||
Args:
|
||||
valid_embeddings: 有效的嵌入向量列表
|
||||
query_embedding: 查询向量
|
||||
|
||||
Returns:
|
||||
相似度数组
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# 批量计算相似度(使用矩阵运算)
|
||||
embeddings_matrix = np.array(valid_embeddings)
|
||||
query_norm = np.linalg.norm(query_embedding)
|
||||
embeddings_norms = np.linalg.norm(embeddings_matrix, axis=1)
|
||||
|
||||
# 向量化计算余弦相似度
|
||||
similarities = np.dot(embeddings_matrix, query_embedding) / (embeddings_norms * query_norm + 1e-8)
|
||||
similarities = np.clip(similarities, 0.0, 1.0)
|
||||
|
||||
return similarities
|
||||
|
||||
|
||||
__all__ = ["Path", "PathExpansionConfig", "PathScoreExpansion"]
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
提供统一的向量相似度计算函数
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -47,4 +48,91 @@ def cosine_similarity(vec1: "np.ndarray", vec2: "np.ndarray") -> float:
|
||||
return 0.0
|
||||
|
||||
|
||||
__all__ = ["cosine_similarity"]
|
||||
async def cosine_similarity_async(vec1: "np.ndarray", vec2: "np.ndarray") -> float:
|
||||
"""
|
||||
异步计算两个向量的余弦相似度,使用to_thread避免阻塞
|
||||
|
||||
Args:
|
||||
vec1: 第一个向量
|
||||
vec2: 第二个向量
|
||||
|
||||
Returns:
|
||||
余弦相似度 (0.0-1.0)
|
||||
"""
|
||||
return await asyncio.to_thread(cosine_similarity, vec1, vec2)
|
||||
|
||||
|
||||
def batch_cosine_similarity(vec1: "np.ndarray", vec_list: list["np.ndarray"]) -> list[float]:
|
||||
"""
|
||||
批量计算向量相似度
|
||||
|
||||
Args:
|
||||
vec1: 基础向量
|
||||
vec_list: 待比较的向量列表
|
||||
|
||||
Returns:
|
||||
相似度列表
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
# 确保是numpy数组
|
||||
if not isinstance(vec1, np.ndarray):
|
||||
vec1 = np.array(vec1)
|
||||
|
||||
# 批量转换为numpy数组
|
||||
vec_list = [np.array(vec) for vec in vec_list]
|
||||
|
||||
# 计算归一化
|
||||
vec1_norm = np.linalg.norm(vec1)
|
||||
if vec1_norm == 0:
|
||||
return [0.0] * len(vec_list)
|
||||
|
||||
# 计算所有向量的归一化
|
||||
vec_norms = np.array([np.linalg.norm(vec) for vec in vec_list])
|
||||
|
||||
# 避免除以0
|
||||
valid_mask = vec_norms != 0
|
||||
similarities = np.zeros(len(vec_list))
|
||||
|
||||
if np.any(valid_mask):
|
||||
# 批量计算点积
|
||||
valid_vecs = np.array(vec_list)[valid_mask]
|
||||
dot_products = np.dot(valid_vecs, vec1)
|
||||
|
||||
# 计算相似度
|
||||
valid_norms = vec_norms[valid_mask]
|
||||
valid_similarities = dot_products / (vec1_norm * valid_norms)
|
||||
|
||||
# 确保在 [0, 1] 范围内
|
||||
valid_similarities = np.clip(valid_similarities, 0.0, 1.0)
|
||||
|
||||
# 填充结果
|
||||
similarities[valid_mask] = valid_similarities
|
||||
|
||||
return similarities.tolist()
|
||||
|
||||
except Exception:
|
||||
return [0.0] * len(vec_list)
|
||||
|
||||
|
||||
async def batch_cosine_similarity_async(vec1: "np.ndarray", vec_list: list["np.ndarray"]) -> list[float]:
|
||||
"""
|
||||
异步批量计算向量相似度,使用to_thread避免阻塞
|
||||
|
||||
Args:
|
||||
vec1: 基础向量
|
||||
vec_list: 待比较的向量列表
|
||||
|
||||
Returns:
|
||||
相似度列表
|
||||
"""
|
||||
return await asyncio.to_thread(batch_cosine_similarity, vec1, vec_list)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"cosine_similarity",
|
||||
"cosine_similarity_async",
|
||||
"batch_cosine_similarity",
|
||||
"batch_cosine_similarity_async"
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user