refactor: 消除重复代码并优化记忆系统结构
- 提取共享工具函数到 utils 模块 - 创建 utils/similarity.py: 统一余弦相似度计算 - 创建 utils/graph_expansion.py: 统一图扩展算法 - 删除重复实现 - manager.py: 删除 _cosine_similarity 和 _fast_cosine_similarity (60行) - tools/memory_tools.py: 删除 _expand_with_semantic_filter 和 _cosine_similarity (150行) - 清理废弃代码 - 删除 tools/memory_tools.py 中10行注释掉的旧代码 - 删除空的 retrieval/ 模块 - 净减少 ~150行重复代码 Co-authored-by: Windpicker-owo <221029311+Windpicker-owo@users.noreply.github.com>
This commit is contained in:
@@ -25,6 +25,8 @@ from src.memory_graph.storage.persistence import PersistenceManager
|
||||
from src.memory_graph.storage.vector_store import VectorStore
|
||||
from src.memory_graph.tools.memory_tools import MemoryTools
|
||||
from src.memory_graph.utils.embeddings import EmbeddingGenerator
|
||||
from src.memory_graph.utils.graph_expansion import expand_memories_with_semantic_filter as _expand_graph
|
||||
from src.memory_graph.utils.similarity import cosine_similarity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
@@ -708,151 +710,15 @@ class MemoryManager:
|
||||
Returns:
|
||||
List[(memory_id, relevance_score)] 按相关度排序
|
||||
"""
|
||||
if not initial_memory_ids or query_embedding is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
|
||||
# 记录已访问的记忆,避免重复
|
||||
visited_memories = set(initial_memory_ids)
|
||||
# 记录扩展的记忆及其分数
|
||||
expanded_memories: dict[str, float] = {}
|
||||
|
||||
# BFS扩展
|
||||
current_level = initial_memory_ids
|
||||
|
||||
for depth in range(max_depth):
|
||||
next_level = []
|
||||
|
||||
for memory_id in current_level:
|
||||
memory = self.graph_store.get_memory_by_id(memory_id)
|
||||
if not memory:
|
||||
continue
|
||||
|
||||
# 遍历该记忆的所有节点
|
||||
for node in memory.nodes:
|
||||
if not node.has_embedding():
|
||||
continue
|
||||
|
||||
# 获取邻居节点
|
||||
try:
|
||||
neighbors = list(self.graph_store.graph.neighbors(node.id))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
for neighbor_id in neighbors:
|
||||
# 获取邻居节点信息
|
||||
neighbor_node_data = self.graph_store.graph.nodes.get(neighbor_id)
|
||||
if not neighbor_node_data:
|
||||
continue
|
||||
|
||||
# 获取邻居节点的向量(从向量存储)
|
||||
neighbor_vector_data = await self.vector_store.get_node_by_id(neighbor_id)
|
||||
if not neighbor_vector_data or neighbor_vector_data.get("embedding") is None:
|
||||
continue
|
||||
|
||||
neighbor_embedding = neighbor_vector_data["embedding"]
|
||||
|
||||
# 计算与查询的语义相似度
|
||||
semantic_sim = self._cosine_similarity(
|
||||
query_embedding,
|
||||
neighbor_embedding
|
||||
)
|
||||
|
||||
# 获取边的权重
|
||||
try:
|
||||
edge_data = self.graph_store.graph.get_edge_data(node.id, neighbor_id)
|
||||
edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5
|
||||
except Exception:
|
||||
edge_importance = 0.5
|
||||
|
||||
# 综合评分:语义相似度(70%) + 图结构权重(20%) + 深度衰减(10%)
|
||||
depth_decay = 1.0 / (depth + 1) # 深度越深,权重越低
|
||||
relevance_score = (
|
||||
semantic_sim * 0.7 +
|
||||
edge_importance * 0.2 +
|
||||
depth_decay * 0.1
|
||||
)
|
||||
|
||||
# 只保留超过阈值的节点
|
||||
if relevance_score < semantic_threshold:
|
||||
continue
|
||||
|
||||
# 提取邻居节点所属的记忆
|
||||
neighbor_memory_ids = neighbor_node_data.get("memory_ids", [])
|
||||
if isinstance(neighbor_memory_ids, str):
|
||||
import json
|
||||
try:
|
||||
neighbor_memory_ids = json.loads(neighbor_memory_ids)
|
||||
except Exception:
|
||||
neighbor_memory_ids = [neighbor_memory_ids]
|
||||
|
||||
for neighbor_mem_id in neighbor_memory_ids:
|
||||
if neighbor_mem_id in visited_memories:
|
||||
continue
|
||||
|
||||
# 记录这个扩展记忆
|
||||
if neighbor_mem_id not in expanded_memories:
|
||||
expanded_memories[neighbor_mem_id] = relevance_score
|
||||
visited_memories.add(neighbor_mem_id)
|
||||
next_level.append(neighbor_mem_id)
|
||||
else:
|
||||
# 如果已存在,取最高分
|
||||
expanded_memories[neighbor_mem_id] = max(
|
||||
expanded_memories[neighbor_mem_id],
|
||||
relevance_score
|
||||
)
|
||||
|
||||
# 如果没有新节点或已达到数量限制,提前终止
|
||||
if not next_level or len(expanded_memories) >= max_expanded:
|
||||
break
|
||||
|
||||
current_level = next_level[:max_expanded] # 限制每层的扩展数量
|
||||
|
||||
# 排序并返回
|
||||
sorted_results = sorted(
|
||||
expanded_memories.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)[:max_expanded]
|
||||
|
||||
logger.info(
|
||||
f"图扩展完成: 初始{len(initial_memory_ids)}个 → "
|
||||
f"扩展{len(sorted_results)}个新记忆 "
|
||||
f"(深度={max_depth}, 阈值={semantic_threshold:.2f})"
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"语义图扩展失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def _cosine_similarity(self, vec1: "np.ndarray", vec2: "np.ndarray") -> float:
|
||||
"""计算余弦相似度"""
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
# 确保是numpy数组
|
||||
if not isinstance(vec1, np.ndarray):
|
||||
vec1 = np.array(vec1)
|
||||
if not isinstance(vec2, np.ndarray):
|
||||
vec2 = np.array(vec2)
|
||||
|
||||
# 归一化
|
||||
vec1_norm = np.linalg.norm(vec1)
|
||||
vec2_norm = np.linalg.norm(vec2)
|
||||
|
||||
if vec1_norm == 0 or vec2_norm == 0:
|
||||
return 0.0
|
||||
|
||||
# 余弦相似度
|
||||
similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
|
||||
return float(similarity)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算余弦相似度失败: {e}")
|
||||
return 0.0
|
||||
return await _expand_graph(
|
||||
graph_store=self.graph_store,
|
||||
vector_store=self.vector_store,
|
||||
initial_memory_ids=initial_memory_ids,
|
||||
query_embedding=query_embedding,
|
||||
max_depth=max_depth,
|
||||
semantic_threshold=semantic_threshold,
|
||||
max_expanded=max_expanded,
|
||||
)
|
||||
|
||||
async def forget_memory(self, memory_id: str) -> bool:
|
||||
"""
|
||||
@@ -1114,7 +980,7 @@ class MemoryManager:
|
||||
embedding_j = embeddings_map[mem_j.id]
|
||||
|
||||
# 优化的余弦相似度计算
|
||||
similarity = self._fast_cosine_similarity(embedding_i, embedding_j)
|
||||
similarity = cosine_similarity(embedding_i, embedding_j)
|
||||
|
||||
if similarity >= similarity_threshold:
|
||||
# 决定保留哪个记忆
|
||||
@@ -1169,40 +1035,6 @@ class MemoryManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 记忆整理失败: {e}", exc_info=True)
|
||||
|
||||
def _fast_cosine_similarity(self, vec1: "np.ndarray", vec2: "np.ndarray") -> float:
|
||||
"""
|
||||
快速余弦相似度计算(优化版本)
|
||||
|
||||
Args:
|
||||
vec1, vec2: 向量
|
||||
|
||||
Returns:
|
||||
余弦相似度
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
# 避免重复的类型检查和转换
|
||||
# 向量应该是numpy数组,如果不是,转换一次
|
||||
if not isinstance(vec1, np.ndarray):
|
||||
vec1 = np.asarray(vec1, dtype=np.float32)
|
||||
if not isinstance(vec2, np.ndarray):
|
||||
vec2 = np.asarray(vec2, dtype=np.float32)
|
||||
|
||||
# 使用更高效的范数计算
|
||||
norm1 = np.linalg.norm(vec1)
|
||||
norm2 = np.linalg.norm(vec2)
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
# 直接计算点积和除法
|
||||
return float(np.dot(vec1, vec2) / (norm1 * norm2))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算余弦相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
async def auto_link_memories(
|
||||
self,
|
||||
time_window_hours: float | None = None,
|
||||
@@ -1724,7 +1556,7 @@ class MemoryManager:
|
||||
continue
|
||||
|
||||
# 快速相似度计算
|
||||
similarity = self._fast_cosine_similarity(
|
||||
similarity = cosine_similarity(
|
||||
topic_node.embedding,
|
||||
other_topic.embedding
|
||||
)
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
"""
|
||||
记忆检索模块
|
||||
|
||||
提供简化的多查询检索功能:
|
||||
- 直接使用小模型生成多个查询语句
|
||||
- 多查询融合检索
|
||||
- 避免复杂的查询分解逻辑
|
||||
"""
|
||||
|
||||
__all__ = []
|
||||
@@ -14,6 +14,7 @@ from src.memory_graph.storage.graph_store import GraphStore
|
||||
from src.memory_graph.storage.persistence import PersistenceManager
|
||||
from src.memory_graph.storage.vector_store import VectorStore
|
||||
from src.memory_graph.utils.embeddings import EmbeddingGenerator
|
||||
from src.memory_graph.utils.graph_expansion import expand_memories_with_semantic_filter
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -505,8 +506,10 @@ class MemoryTools:
|
||||
try:
|
||||
query_embedding = await self.builder.embedding_generator.generate(query)
|
||||
|
||||
# 直接使用图扩展逻辑(避免循环依赖)
|
||||
expanded_results = await self._expand_with_semantic_filter(
|
||||
# 使用共享的图扩展工具函数
|
||||
expanded_results = await expand_memories_with_semantic_filter(
|
||||
graph_store=self.graph_store,
|
||||
vector_store=self.vector_store,
|
||||
initial_memory_ids=list(initial_memory_ids),
|
||||
query_embedding=query_embedding,
|
||||
max_depth=expand_depth,
|
||||
@@ -514,17 +517,6 @@ class MemoryTools:
|
||||
max_expanded=top_k * 2
|
||||
)
|
||||
|
||||
# 旧代码(如果需要使用Manager):
|
||||
# from src.memory_graph.manager import MemoryManager
|
||||
# manager = MemoryManager.get_instance()
|
||||
# expanded_results = await manager.expand_memories_with_semantic_filter(
|
||||
# initial_memory_ids=list(initial_memory_ids),
|
||||
# query_embedding=query_embedding,
|
||||
# max_depth=expand_depth,
|
||||
# semantic_threshold=0.5,
|
||||
# max_expanded=top_k * 2
|
||||
# )
|
||||
|
||||
# 合并扩展结果
|
||||
expanded_memory_scores.update(dict(expanded_results))
|
||||
|
||||
@@ -861,154 +853,6 @@ class MemoryTools:
|
||||
|
||||
return f"{subject} - {memory_type}: {topic}"
|
||||
|
||||
async def _expand_with_semantic_filter(
|
||||
self,
|
||||
initial_memory_ids: list[str],
|
||||
query_embedding,
|
||||
max_depth: int = 2,
|
||||
semantic_threshold: float = 0.5,
|
||||
max_expanded: int = 20
|
||||
) -> list[tuple[str, float]]:
|
||||
"""
|
||||
从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤
|
||||
|
||||
Args:
|
||||
initial_memory_ids: 初始记忆ID集合
|
||||
query_embedding: 查询向量
|
||||
max_depth: 最大扩展深度
|
||||
semantic_threshold: 语义相似度阈值
|
||||
max_expanded: 最多扩展多少个记忆
|
||||
|
||||
Returns:
|
||||
List[(memory_id, relevance_score)]
|
||||
"""
|
||||
if not initial_memory_ids or query_embedding is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
|
||||
visited_memories = set(initial_memory_ids)
|
||||
expanded_memories: dict[str, float] = {}
|
||||
|
||||
current_level = initial_memory_ids
|
||||
|
||||
for depth in range(max_depth):
|
||||
next_level = []
|
||||
|
||||
for memory_id in current_level:
|
||||
memory = self.graph_store.get_memory_by_id(memory_id)
|
||||
if not memory:
|
||||
continue
|
||||
|
||||
for node in memory.nodes:
|
||||
if not node.has_embedding():
|
||||
continue
|
||||
|
||||
try:
|
||||
neighbors = list(self.graph_store.graph.neighbors(node.id))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
for neighbor_id in neighbors:
|
||||
neighbor_node_data = self.graph_store.graph.nodes.get(neighbor_id)
|
||||
if not neighbor_node_data:
|
||||
continue
|
||||
|
||||
neighbor_vector_data = await self.vector_store.get_node_by_id(neighbor_id)
|
||||
if neighbor_vector_data is None:
|
||||
continue
|
||||
|
||||
neighbor_embedding = neighbor_vector_data.get("embedding")
|
||||
if neighbor_embedding is None:
|
||||
continue
|
||||
|
||||
# 计算语义相似度
|
||||
semantic_sim = self._cosine_similarity(
|
||||
query_embedding,
|
||||
neighbor_embedding
|
||||
)
|
||||
|
||||
# 获取边权重
|
||||
try:
|
||||
edge_data = self.graph_store.graph.get_edge_data(node.id, neighbor_id)
|
||||
edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5
|
||||
except Exception:
|
||||
edge_importance = 0.5
|
||||
|
||||
# 综合评分
|
||||
depth_decay = 1.0 / (depth + 1)
|
||||
relevance_score = (
|
||||
semantic_sim * 0.7 +
|
||||
edge_importance * 0.2 +
|
||||
depth_decay * 0.1
|
||||
)
|
||||
|
||||
if relevance_score < semantic_threshold:
|
||||
continue
|
||||
|
||||
# 提取记忆ID
|
||||
neighbor_memory_ids = neighbor_node_data.get("memory_ids", [])
|
||||
if isinstance(neighbor_memory_ids, str):
|
||||
import orjson
|
||||
try:
|
||||
neighbor_memory_ids = orjson.loads(neighbor_memory_ids)
|
||||
except Exception:
|
||||
neighbor_memory_ids = [neighbor_memory_ids]
|
||||
|
||||
for neighbor_mem_id in neighbor_memory_ids:
|
||||
if neighbor_mem_id in visited_memories:
|
||||
continue
|
||||
|
||||
if neighbor_mem_id not in expanded_memories:
|
||||
expanded_memories[neighbor_mem_id] = relevance_score
|
||||
visited_memories.add(neighbor_mem_id)
|
||||
next_level.append(neighbor_mem_id)
|
||||
else:
|
||||
expanded_memories[neighbor_mem_id] = max(
|
||||
expanded_memories[neighbor_mem_id],
|
||||
relevance_score
|
||||
)
|
||||
|
||||
if not next_level or len(expanded_memories) >= max_expanded:
|
||||
break
|
||||
|
||||
current_level = next_level[:max_expanded]
|
||||
|
||||
sorted_results = sorted(
|
||||
expanded_memories.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)[:max_expanded]
|
||||
|
||||
return sorted_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"图扩展失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def _cosine_similarity(self, vec1, vec2) -> float:
|
||||
"""计算余弦相似度"""
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
if not isinstance(vec1, np.ndarray):
|
||||
vec1 = np.array(vec1)
|
||||
if not isinstance(vec2, np.ndarray):
|
||||
vec2 = np.array(vec2)
|
||||
|
||||
vec1_norm = np.linalg.norm(vec1)
|
||||
vec2_norm = np.linalg.norm(vec2)
|
||||
|
||||
if vec1_norm == 0 or vec2_norm == 0:
|
||||
return 0.0
|
||||
|
||||
similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
|
||||
return float(similarity)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算余弦相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def get_all_tool_schemas() -> list[dict[str, Any]]:
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"""
|
||||
|
||||
from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator
|
||||
from src.memory_graph.utils.similarity import cosine_similarity
|
||||
from src.memory_graph.utils.time_parser import TimeParser
|
||||
|
||||
__all__ = ["EmbeddingGenerator", "TimeParser", "get_embedding_generator"]
|
||||
__all__ = ["EmbeddingGenerator", "TimeParser", "cosine_similarity", "get_embedding_generator"]
|
||||
|
||||
156
src/memory_graph/utils/graph_expansion.py
Normal file
156
src/memory_graph/utils/graph_expansion.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
图扩展工具
|
||||
|
||||
提供记忆图的扩展算法,用于从初始记忆集合沿图结构扩展查找相关记忆
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.utils.similarity import cosine_similarity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
||||
from src.memory_graph.storage.graph_store import GraphStore
|
||||
from src.memory_graph.storage.vector_store import VectorStore
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def expand_memories_with_semantic_filter(
|
||||
graph_store: "GraphStore",
|
||||
vector_store: "VectorStore",
|
||||
initial_memory_ids: list[str],
|
||||
query_embedding: "np.ndarray",
|
||||
max_depth: int = 2,
|
||||
semantic_threshold: float = 0.5,
|
||||
max_expanded: int = 20,
|
||||
) -> list[tuple[str, float]]:
|
||||
"""
|
||||
从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤
|
||||
|
||||
这个方法解决了纯向量搜索可能遗漏的"语义相关且图结构相关"的记忆。
|
||||
|
||||
Args:
|
||||
graph_store: 图存储
|
||||
vector_store: 向量存储
|
||||
initial_memory_ids: 初始记忆ID集合(由向量搜索得到)
|
||||
query_embedding: 查询向量
|
||||
max_depth: 最大扩展深度(1-3推荐)
|
||||
semantic_threshold: 语义相似度阈值(0.5推荐)
|
||||
max_expanded: 最多扩展多少个记忆
|
||||
|
||||
Returns:
|
||||
List[(memory_id, relevance_score)] 按相关度排序
|
||||
"""
|
||||
if not initial_memory_ids or query_embedding is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 记录已访问的记忆,避免重复
|
||||
visited_memories = set(initial_memory_ids)
|
||||
# 记录扩展的记忆及其分数
|
||||
expanded_memories: dict[str, float] = {}
|
||||
|
||||
# BFS扩展
|
||||
current_level = initial_memory_ids
|
||||
|
||||
for depth in range(max_depth):
|
||||
next_level = []
|
||||
|
||||
for memory_id in current_level:
|
||||
memory = graph_store.get_memory_by_id(memory_id)
|
||||
if not memory:
|
||||
continue
|
||||
|
||||
# 遍历该记忆的所有节点
|
||||
for node in memory.nodes:
|
||||
if not node.has_embedding():
|
||||
continue
|
||||
|
||||
# 获取邻居节点
|
||||
try:
|
||||
neighbors = list(graph_store.graph.neighbors(node.id))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
for neighbor_id in neighbors:
|
||||
# 获取邻居节点信息
|
||||
neighbor_node_data = graph_store.graph.nodes.get(neighbor_id)
|
||||
if not neighbor_node_data:
|
||||
continue
|
||||
|
||||
# 获取邻居节点的向量(从向量存储)
|
||||
neighbor_vector_data = await vector_store.get_node_by_id(neighbor_id)
|
||||
if not neighbor_vector_data or neighbor_vector_data.get("embedding") is None:
|
||||
continue
|
||||
|
||||
neighbor_embedding = neighbor_vector_data["embedding"]
|
||||
|
||||
# 计算与查询的语义相似度
|
||||
semantic_sim = cosine_similarity(query_embedding, neighbor_embedding)
|
||||
|
||||
# 获取边的权重
|
||||
try:
|
||||
edge_data = graph_store.graph.get_edge_data(node.id, neighbor_id)
|
||||
edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5
|
||||
except Exception:
|
||||
edge_importance = 0.5
|
||||
|
||||
# 综合评分:语义相似度(70%) + 图结构权重(20%) + 深度衰减(10%)
|
||||
depth_decay = 1.0 / (depth + 1) # 深度越深,权重越低
|
||||
relevance_score = semantic_sim * 0.7 + edge_importance * 0.2 + depth_decay * 0.1
|
||||
|
||||
# 只保留超过阈值的节点
|
||||
if relevance_score < semantic_threshold:
|
||||
continue
|
||||
|
||||
# 提取邻居节点所属的记忆
|
||||
neighbor_memory_ids = neighbor_node_data.get("memory_ids", [])
|
||||
if isinstance(neighbor_memory_ids, str):
|
||||
import json
|
||||
|
||||
try:
|
||||
neighbor_memory_ids = json.loads(neighbor_memory_ids)
|
||||
except Exception:
|
||||
neighbor_memory_ids = [neighbor_memory_ids]
|
||||
|
||||
for neighbor_mem_id in neighbor_memory_ids:
|
||||
if neighbor_mem_id in visited_memories:
|
||||
continue
|
||||
|
||||
# 记录这个扩展记忆
|
||||
if neighbor_mem_id not in expanded_memories:
|
||||
expanded_memories[neighbor_mem_id] = relevance_score
|
||||
visited_memories.add(neighbor_mem_id)
|
||||
next_level.append(neighbor_mem_id)
|
||||
else:
|
||||
# 如果已存在,取最高分
|
||||
expanded_memories[neighbor_mem_id] = max(
|
||||
expanded_memories[neighbor_mem_id], relevance_score
|
||||
)
|
||||
|
||||
# 如果没有新节点或已达到数量限制,提前终止
|
||||
if not next_level or len(expanded_memories) >= max_expanded:
|
||||
break
|
||||
|
||||
current_level = next_level[:max_expanded] # 限制每层的扩展数量
|
||||
|
||||
# 排序并返回
|
||||
sorted_results = sorted(expanded_memories.items(), key=lambda x: x[1], reverse=True)[:max_expanded]
|
||||
|
||||
logger.info(
|
||||
f"图扩展完成: 初始{len(initial_memory_ids)}个 → "
|
||||
f"扩展{len(sorted_results)}个新记忆 "
|
||||
f"(深度={max_depth}, 阈值={semantic_threshold:.2f})"
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"语义图扩展失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
__all__ = ["expand_memories_with_semantic_filter"]
|
||||
50
src/memory_graph/utils/similarity.py
Normal file
50
src/memory_graph/utils/similarity.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
相似度计算工具
|
||||
|
||||
提供统一的向量相似度计算函数
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
||||
|
||||
def cosine_similarity(vec1: "np.ndarray", vec2: "np.ndarray") -> float:
|
||||
"""
|
||||
计算两个向量的余弦相似度
|
||||
|
||||
Args:
|
||||
vec1: 第一个向量
|
||||
vec2: 第二个向量
|
||||
|
||||
Returns:
|
||||
余弦相似度 (0.0-1.0)
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
# 确保是numpy数组
|
||||
if not isinstance(vec1, np.ndarray):
|
||||
vec1 = np.array(vec1)
|
||||
if not isinstance(vec2, np.ndarray):
|
||||
vec2 = np.array(vec2)
|
||||
|
||||
# 归一化
|
||||
vec1_norm = np.linalg.norm(vec1)
|
||||
vec2_norm = np.linalg.norm(vec2)
|
||||
|
||||
if vec1_norm == 0 or vec2_norm == 0:
|
||||
return 0.0
|
||||
|
||||
# 余弦相似度
|
||||
similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
|
||||
|
||||
# 确保在 [0, 1] 范围内(处理浮点误差)
|
||||
return float(np.clip(similarity, 0.0, 1.0))
|
||||
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
|
||||
__all__ = ["cosine_similarity"]
|
||||
Reference in New Issue
Block a user