feat(memory): 添加图扩展功能,支持语义相似度过滤与深度探索
This commit is contained in:
@@ -438,6 +438,7 @@ class MemoryManager:
|
||||
include_forgotten: bool = False,
|
||||
optimize_query: bool = True,
|
||||
use_multi_query: bool = True,
|
||||
expand_depth: int = 1,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Memory]:
|
||||
"""
|
||||
@@ -447,6 +448,8 @@ class MemoryManager:
|
||||
例如:"杰瑞喵如何评价新的记忆系统" 会被分解为多个子查询,
|
||||
确保同时匹配"杰瑞喵"和"新的记忆系统"两个关键概念。
|
||||
|
||||
同时支持图扩展:从初始检索结果出发,沿图结构查找语义相关的邻居记忆。
|
||||
|
||||
Args:
|
||||
query: 搜索查询
|
||||
top_k: 返回结果数
|
||||
@@ -456,6 +459,7 @@ class MemoryManager:
|
||||
include_forgotten: 是否包含已遗忘的记忆
|
||||
optimize_query: 是否使用小模型优化查询(已弃用,被 use_multi_query 替代)
|
||||
use_multi_query: 是否使用多查询策略(推荐,默认True)
|
||||
expand_depth: 图扩展深度(0=禁用, 1=推荐, 2-3=深度探索)
|
||||
context: 查询上下文(用于优化)
|
||||
|
||||
Returns:
|
||||
@@ -470,6 +474,7 @@ class MemoryManager:
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"use_multi_query": use_multi_query,
|
||||
"expand_depth": expand_depth, # 传递图扩展深度
|
||||
"context": context,
|
||||
}
|
||||
|
||||
@@ -644,7 +649,7 @@ class MemoryManager:
|
||||
|
||||
def _get_related_memories(self, memory_id: str, max_depth: int = 1) -> List[str]:
|
||||
"""
|
||||
获取相关记忆 ID 列表
|
||||
获取相关记忆 ID 列表(旧版本,保留用于激活传播)
|
||||
|
||||
Args:
|
||||
memory_id: 记忆 ID
|
||||
@@ -675,6 +680,176 @@ class MemoryManager:
|
||||
|
||||
return list(related_ids)
|
||||
|
||||
async def expand_memories_with_semantic_filter(
|
||||
self,
|
||||
initial_memory_ids: List[str],
|
||||
query_embedding: "np.ndarray",
|
||||
max_depth: int = 2,
|
||||
semantic_threshold: float = 0.5,
|
||||
max_expanded: int = 20
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤
|
||||
|
||||
这个方法解决了纯向量搜索可能遗漏的"语义相关且图结构相关"的记忆。
|
||||
|
||||
Args:
|
||||
initial_memory_ids: 初始记忆ID集合(由向量搜索得到)
|
||||
query_embedding: 查询向量
|
||||
max_depth: 最大扩展深度(1-3推荐)
|
||||
semantic_threshold: 语义相似度阈值(0.5推荐)
|
||||
max_expanded: 最多扩展多少个记忆
|
||||
|
||||
Returns:
|
||||
List[(memory_id, relevance_score)] 按相关度排序
|
||||
"""
|
||||
if not initial_memory_ids or query_embedding is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
# 记录已访问的记忆,避免重复
|
||||
visited_memories = set(initial_memory_ids)
|
||||
# 记录扩展的记忆及其分数
|
||||
expanded_memories: Dict[str, float] = {}
|
||||
|
||||
# BFS扩展
|
||||
current_level = initial_memory_ids
|
||||
|
||||
for depth in range(max_depth):
|
||||
next_level = []
|
||||
|
||||
for memory_id in current_level:
|
||||
memory = self.graph_store.get_memory_by_id(memory_id)
|
||||
if not memory:
|
||||
continue
|
||||
|
||||
# 遍历该记忆的所有节点
|
||||
for node in memory.nodes:
|
||||
if not node.has_embedding():
|
||||
continue
|
||||
|
||||
# 获取邻居节点
|
||||
try:
|
||||
neighbors = list(self.graph_store.graph.neighbors(node.id))
|
||||
except:
|
||||
continue
|
||||
|
||||
for neighbor_id in neighbors:
|
||||
# 获取邻居节点信息
|
||||
neighbor_node_data = self.graph_store.graph.nodes.get(neighbor_id)
|
||||
if not neighbor_node_data:
|
||||
continue
|
||||
|
||||
# 获取邻居节点的向量(从向量存储)
|
||||
neighbor_vector_data = await self.vector_store.get_node_by_id(neighbor_id)
|
||||
if not neighbor_vector_data or neighbor_vector_data.get("embedding") is None:
|
||||
continue
|
||||
|
||||
neighbor_embedding = neighbor_vector_data["embedding"]
|
||||
|
||||
# 计算与查询的语义相似度
|
||||
semantic_sim = self._cosine_similarity(
|
||||
query_embedding,
|
||||
neighbor_embedding
|
||||
)
|
||||
|
||||
# 获取边的权重
|
||||
try:
|
||||
edge_data = self.graph_store.graph.get_edge_data(node.id, neighbor_id)
|
||||
edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5
|
||||
except:
|
||||
edge_importance = 0.5
|
||||
|
||||
# 综合评分:语义相似度(70%) + 图结构权重(20%) + 深度衰减(10%)
|
||||
depth_decay = 1.0 / (depth + 1) # 深度越深,权重越低
|
||||
relevance_score = (
|
||||
semantic_sim * 0.7 +
|
||||
edge_importance * 0.2 +
|
||||
depth_decay * 0.1
|
||||
)
|
||||
|
||||
# 只保留超过阈值的节点
|
||||
if relevance_score < semantic_threshold:
|
||||
continue
|
||||
|
||||
# 提取邻居节点所属的记忆
|
||||
neighbor_memory_ids = neighbor_node_data.get("memory_ids", [])
|
||||
if isinstance(neighbor_memory_ids, str):
|
||||
import json
|
||||
try:
|
||||
neighbor_memory_ids = json.loads(neighbor_memory_ids)
|
||||
except:
|
||||
neighbor_memory_ids = [neighbor_memory_ids]
|
||||
|
||||
for neighbor_mem_id in neighbor_memory_ids:
|
||||
if neighbor_mem_id in visited_memories:
|
||||
continue
|
||||
|
||||
# 记录这个扩展记忆
|
||||
if neighbor_mem_id not in expanded_memories:
|
||||
expanded_memories[neighbor_mem_id] = relevance_score
|
||||
visited_memories.add(neighbor_mem_id)
|
||||
next_level.append(neighbor_mem_id)
|
||||
else:
|
||||
# 如果已存在,取最高分
|
||||
expanded_memories[neighbor_mem_id] = max(
|
||||
expanded_memories[neighbor_mem_id],
|
||||
relevance_score
|
||||
)
|
||||
|
||||
# 如果没有新节点或已达到数量限制,提前终止
|
||||
if not next_level or len(expanded_memories) >= max_expanded:
|
||||
break
|
||||
|
||||
current_level = next_level[:max_expanded] # 限制每层的扩展数量
|
||||
|
||||
# 排序并返回
|
||||
sorted_results = sorted(
|
||||
expanded_memories.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)[:max_expanded]
|
||||
|
||||
logger.info(
|
||||
f"图扩展完成: 初始{len(initial_memory_ids)}个 → "
|
||||
f"扩展{len(sorted_results)}个新记忆 "
|
||||
f"(深度={max_depth}, 阈值={semantic_threshold:.2f})"
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"语义图扩展失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def _cosine_similarity(self, vec1: "np.ndarray", vec2: "np.ndarray") -> float:
|
||||
"""计算余弦相似度"""
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
# 确保是numpy数组
|
||||
if not isinstance(vec1, np.ndarray):
|
||||
vec1 = np.array(vec1)
|
||||
if not isinstance(vec2, np.ndarray):
|
||||
vec2 = np.array(vec2)
|
||||
|
||||
# 归一化
|
||||
vec1_norm = np.linalg.norm(vec1)
|
||||
vec2_norm = np.linalg.norm(vec2)
|
||||
|
||||
if vec1_norm == 0 or vec2_norm == 0:
|
||||
return 0.0
|
||||
|
||||
# 余弦相似度
|
||||
similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
|
||||
return float(similarity)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算余弦相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
async def forget_memory(self, memory_id: str) -> bool:
|
||||
"""
|
||||
遗忘记忆(标记为已遗忘,不删除)
|
||||
|
||||
@@ -465,8 +465,10 @@ class MemoryTools:
|
||||
# 传统单查询策略
|
||||
similar_nodes = await self._single_query_search(query, top_k)
|
||||
|
||||
# 2. 提取记忆ID
|
||||
memory_ids = set()
|
||||
# 2. 提取初始记忆ID(来自向量搜索)
|
||||
initial_memory_ids = set()
|
||||
memory_scores = {} # 记录每个记忆的初始分数
|
||||
|
||||
for node_id, similarity, metadata in similar_nodes:
|
||||
if "memory_ids" in metadata:
|
||||
ids = metadata["memory_ids"]
|
||||
@@ -478,16 +480,105 @@ class MemoryTools:
|
||||
except:
|
||||
ids = [ids]
|
||||
if isinstance(ids, list):
|
||||
memory_ids.update(ids)
|
||||
for mem_id in ids:
|
||||
initial_memory_ids.add(mem_id)
|
||||
# 记录最高分数
|
||||
if mem_id not in memory_scores or similarity > memory_scores[mem_id]:
|
||||
memory_scores[mem_id] = similarity
|
||||
|
||||
# 3. 获取完整记忆
|
||||
memories = []
|
||||
for memory_id in list(memory_ids)[:top_k]:
|
||||
# 3. 图扩展(如果启用且有expand_depth)
|
||||
expanded_memory_scores = {}
|
||||
if expand_depth > 0 and initial_memory_ids:
|
||||
logger.info(f"开始图扩展: 初始记忆{len(initial_memory_ids)}个, 深度={expand_depth}")
|
||||
|
||||
# 获取查询的embedding用于语义过滤
|
||||
if self.builder.embedding_generator:
|
||||
try:
|
||||
query_embedding = await self.builder.embedding_generator.generate(query)
|
||||
|
||||
# 直接使用图扩展逻辑(避免循环依赖)
|
||||
expanded_results = await self._expand_with_semantic_filter(
|
||||
initial_memory_ids=list(initial_memory_ids),
|
||||
query_embedding=query_embedding,
|
||||
max_depth=expand_depth,
|
||||
semantic_threshold=0.5,
|
||||
max_expanded=top_k * 2
|
||||
)
|
||||
|
||||
# 旧代码(如果需要使用Manager):
|
||||
# from src.memory_graph.manager import MemoryManager
|
||||
# manager = MemoryManager.get_instance()
|
||||
# expanded_results = await manager.expand_memories_with_semantic_filter(
|
||||
# initial_memory_ids=list(initial_memory_ids),
|
||||
# query_embedding=query_embedding,
|
||||
# max_depth=expand_depth,
|
||||
# semantic_threshold=0.5,
|
||||
# max_expanded=top_k * 2
|
||||
# )
|
||||
|
||||
# 合并扩展结果
|
||||
for mem_id, score in expanded_results:
|
||||
expanded_memory_scores[mem_id] = score
|
||||
|
||||
logger.info(f"图扩展完成: 新增{len(expanded_memory_scores)}个相关记忆")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"图扩展失败: {e}")
|
||||
|
||||
# 4. 合并初始记忆和扩展记忆
|
||||
all_memory_ids = set(initial_memory_ids) | set(expanded_memory_scores.keys())
|
||||
|
||||
# 计算最终分数:初始记忆保持原分数,扩展记忆使用扩展分数
|
||||
final_scores = {}
|
||||
for mem_id in all_memory_ids:
|
||||
if mem_id in memory_scores:
|
||||
# 初始记忆:使用向量相似度分数
|
||||
final_scores[mem_id] = memory_scores[mem_id]
|
||||
elif mem_id in expanded_memory_scores:
|
||||
# 扩展记忆:使用图扩展分数(稍微降权)
|
||||
final_scores[mem_id] = expanded_memory_scores[mem_id] * 0.8
|
||||
|
||||
# 按分数排序
|
||||
sorted_memory_ids = sorted(
|
||||
final_scores.keys(),
|
||||
key=lambda x: final_scores[x],
|
||||
reverse=True
|
||||
)[:top_k * 2] # 取2倍数量用于后续过滤
|
||||
|
||||
# 5. 获取完整记忆并进行最终排序
|
||||
memories_with_scores = []
|
||||
for memory_id in sorted_memory_ids:
|
||||
memory = self.graph_store.get_memory_by_id(memory_id)
|
||||
if memory:
|
||||
memories.append(memory)
|
||||
# 综合评分:相似度(60%) + 重要性(30%) + 时效性(10%)
|
||||
similarity_score = final_scores[memory_id]
|
||||
importance_score = memory.importance
|
||||
|
||||
# 计算时效性分数(最近的记忆得分更高)
|
||||
from datetime import datetime, timezone
|
||||
now = datetime.now(timezone.utc)
|
||||
# 确保 memory.created_at 有时区信息
|
||||
if memory.created_at.tzinfo is None:
|
||||
memory_time = memory.created_at.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
memory_time = memory.created_at
|
||||
age_days = (now - memory_time).total_seconds() / 86400
|
||||
recency_score = 1.0 / (1.0 + age_days / 30) # 30天半衰期
|
||||
|
||||
# 综合分数
|
||||
final_score = (
|
||||
similarity_score * 0.6 +
|
||||
importance_score * 0.3 +
|
||||
recency_score * 0.1
|
||||
)
|
||||
|
||||
memories_with_scores.append((memory, final_score))
|
||||
|
||||
# 按综合分数排序
|
||||
memories_with_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
memories = [mem for mem, _ in memories_with_scores[:top_k]]
|
||||
|
||||
# 4. 格式化结果
|
||||
# 6. 格式化结果
|
||||
results = []
|
||||
for memory in memories:
|
||||
result = {
|
||||
@@ -498,7 +589,11 @@ class MemoryTools:
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
logger.info(f"搜索完成: 找到 {len(results)} 条记忆")
|
||||
logger.info(
|
||||
f"搜索完成: 初始{len(initial_memory_ids)}个 → "
|
||||
f"扩展{len(expanded_memory_scores)}个 → "
|
||||
f"最终返回{len(results)}条记忆"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -506,6 +601,8 @@ class MemoryTools:
|
||||
"total": len(results),
|
||||
"query": query,
|
||||
"strategy": "multi_query" if use_multi_query else "single_query",
|
||||
"expanded_count": len(expanded_memory_scores),
|
||||
"expand_depth": expand_depth,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -726,6 +823,153 @@ class MemoryTools:
|
||||
|
||||
return f"{subject} - {memory_type}: {topic}"
|
||||
|
||||
async def _expand_with_semantic_filter(
|
||||
self,
|
||||
initial_memory_ids: List[str],
|
||||
query_embedding,
|
||||
max_depth: int = 2,
|
||||
semantic_threshold: float = 0.5,
|
||||
max_expanded: int = 20
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤
|
||||
|
||||
Args:
|
||||
initial_memory_ids: 初始记忆ID集合
|
||||
query_embedding: 查询向量
|
||||
max_depth: 最大扩展深度
|
||||
semantic_threshold: 语义相似度阈值
|
||||
max_expanded: 最多扩展多少个记忆
|
||||
|
||||
Returns:
|
||||
List[(memory_id, relevance_score)]
|
||||
"""
|
||||
if not initial_memory_ids or query_embedding is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
visited_memories = set(initial_memory_ids)
|
||||
expanded_memories: Dict[str, float] = {}
|
||||
|
||||
current_level = initial_memory_ids
|
||||
|
||||
for depth in range(max_depth):
|
||||
next_level = []
|
||||
|
||||
for memory_id in current_level:
|
||||
memory = self.graph_store.get_memory_by_id(memory_id)
|
||||
if not memory:
|
||||
continue
|
||||
|
||||
for node in memory.nodes:
|
||||
if not node.has_embedding():
|
||||
continue
|
||||
|
||||
try:
|
||||
neighbors = list(self.graph_store.graph.neighbors(node.id))
|
||||
except:
|
||||
continue
|
||||
|
||||
for neighbor_id in neighbors:
|
||||
neighbor_node_data = self.graph_store.graph.nodes.get(neighbor_id)
|
||||
if not neighbor_node_data:
|
||||
continue
|
||||
|
||||
neighbor_vector_data = await self.vector_store.get_node_by_id(neighbor_id)
|
||||
if not neighbor_vector_data or neighbor_vector_data.get("embedding") is None:
|
||||
continue
|
||||
|
||||
neighbor_embedding = neighbor_vector_data["embedding"]
|
||||
|
||||
# 计算语义相似度
|
||||
semantic_sim = self._cosine_similarity(
|
||||
query_embedding,
|
||||
neighbor_embedding
|
||||
)
|
||||
|
||||
# 获取边权重
|
||||
try:
|
||||
edge_data = self.graph_store.graph.get_edge_data(node.id, neighbor_id)
|
||||
edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5
|
||||
except:
|
||||
edge_importance = 0.5
|
||||
|
||||
# 综合评分
|
||||
depth_decay = 1.0 / (depth + 1)
|
||||
relevance_score = (
|
||||
semantic_sim * 0.7 +
|
||||
edge_importance * 0.2 +
|
||||
depth_decay * 0.1
|
||||
)
|
||||
|
||||
if relevance_score < semantic_threshold:
|
||||
continue
|
||||
|
||||
# 提取记忆ID
|
||||
neighbor_memory_ids = neighbor_node_data.get("memory_ids", [])
|
||||
if isinstance(neighbor_memory_ids, str):
|
||||
import json
|
||||
try:
|
||||
neighbor_memory_ids = json.loads(neighbor_memory_ids)
|
||||
except:
|
||||
neighbor_memory_ids = [neighbor_memory_ids]
|
||||
|
||||
for neighbor_mem_id in neighbor_memory_ids:
|
||||
if neighbor_mem_id in visited_memories:
|
||||
continue
|
||||
|
||||
if neighbor_mem_id not in expanded_memories:
|
||||
expanded_memories[neighbor_mem_id] = relevance_score
|
||||
visited_memories.add(neighbor_mem_id)
|
||||
next_level.append(neighbor_mem_id)
|
||||
else:
|
||||
expanded_memories[neighbor_mem_id] = max(
|
||||
expanded_memories[neighbor_mem_id],
|
||||
relevance_score
|
||||
)
|
||||
|
||||
if not next_level or len(expanded_memories) >= max_expanded:
|
||||
break
|
||||
|
||||
current_level = next_level[:max_expanded]
|
||||
|
||||
sorted_results = sorted(
|
||||
expanded_memories.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)[:max_expanded]
|
||||
|
||||
return sorted_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"图扩展失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def _cosine_similarity(self, vec1, vec2) -> float:
|
||||
"""计算余弦相似度"""
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
if not isinstance(vec1, np.ndarray):
|
||||
vec1 = np.array(vec1)
|
||||
if not isinstance(vec2, np.ndarray):
|
||||
vec2 = np.array(vec2)
|
||||
|
||||
vec1_norm = np.linalg.norm(vec1)
|
||||
vec2_norm = np.linalg.norm(vec2)
|
||||
|
||||
if vec1_norm == 0 or vec2_norm == 0:
|
||||
return 0.0
|
||||
|
||||
similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
|
||||
return float(similarity)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算余弦相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def get_all_tool_schemas() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user