feat: 增强内存移除和图扩展功能
- 更新了`graph_store.py`中的`remove_memory`方法,以包含一个可选参数`cleanup_orphans`,用于立即清理孤立节点。 - 对`graph_expansion.py`中的图扩展算法进行了优化,具体优化措施包括: - 采用内存级广度优先搜索(BFS)遍历,而非节点级遍历。 - 批量检索邻居内存,以减少数据库调用次数。 - 早期停止机制,以避免不必要的扩展。 - 增强日志记录功能,以提高可追溯性。 - 增加了性能指标,以追踪内存扩展的效率。
This commit is contained in:
@@ -1,9 +1,15 @@
|
||||
"""
|
||||
图扩展工具
|
||||
图扩展工具(优化版)
|
||||
|
||||
提供记忆图的扩展算法,用于从初始记忆集合沿图结构扩展查找相关记忆
|
||||
提供记忆图的扩展算法,用于从初始记忆集合沿图结构扩展查找相关记忆。
|
||||
优化重点:
|
||||
1. 改进BFS遍历效率
|
||||
2. 批量向量检索,减少数据库调用
|
||||
3. 早停机制,避免不必要的扩展
|
||||
4. 更清晰的日志输出
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -28,10 +34,16 @@ async def expand_memories_with_semantic_filter(
|
||||
max_expanded: int = 20,
|
||||
) -> list[tuple[str, float]]:
|
||||
"""
|
||||
从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤
|
||||
从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤(优化版)
|
||||
|
||||
这个方法解决了纯向量搜索可能遗漏的"语义相关且图结构相关"的记忆。
|
||||
|
||||
优化改进:
|
||||
- 使用记忆级别的BFS,而非节点级别(更直接)
|
||||
- 批量获取邻居记忆,减少遍历次数
|
||||
- 早停机制:达到max_expanded后立即停止
|
||||
- 更详细的调试日志
|
||||
|
||||
Args:
|
||||
graph_store: 图存储
|
||||
vector_store: 向量存储
|
||||
@@ -48,103 +60,137 @@ async def expand_memories_with_semantic_filter(
|
||||
return []
|
||||
|
||||
try:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 记录已访问的记忆,避免重复
|
||||
visited_memories = set(initial_memory_ids)
|
||||
# 记录扩展的记忆及其分数
|
||||
expanded_memories: dict[str, float] = {}
|
||||
|
||||
# BFS扩展
|
||||
current_level = initial_memory_ids
|
||||
# BFS扩展(基于记忆而非节点)
|
||||
current_level_memories = initial_memory_ids
|
||||
depth_stats = [] # 每层统计
|
||||
|
||||
for depth in range(max_depth):
|
||||
next_level = []
|
||||
next_level_memories = []
|
||||
candidates_checked = 0
|
||||
candidates_passed = 0
|
||||
|
||||
for memory_id in current_level:
|
||||
logger.debug(f"🔍 图扩展 - 深度 {depth+1}/{max_depth}, 当前层记忆数: {len(current_level_memories)}")
|
||||
|
||||
# 遍历当前层的记忆
|
||||
for memory_id in current_level_memories:
|
||||
memory = graph_store.get_memory_by_id(memory_id)
|
||||
if not memory:
|
||||
continue
|
||||
|
||||
# 遍历该记忆的所有节点
|
||||
for node in memory.nodes:
|
||||
if not node.has_embedding():
|
||||
# 获取该记忆的邻居记忆(通过边关系)
|
||||
neighbor_memory_ids = set()
|
||||
|
||||
# 遍历记忆的所有边,收集邻居记忆
|
||||
for edge in memory.edges:
|
||||
# 获取边的目标节点
|
||||
target_node_id = edge.target_id
|
||||
source_node_id = edge.source_id
|
||||
|
||||
# 通过节点找到其他记忆
|
||||
for node_id in [target_node_id, source_node_id]:
|
||||
if node_id in graph_store.node_to_memories:
|
||||
neighbor_memory_ids.update(graph_store.node_to_memories[node_id])
|
||||
|
||||
# 过滤掉已访问的和自己
|
||||
neighbor_memory_ids.discard(memory_id)
|
||||
neighbor_memory_ids -= visited_memories
|
||||
|
||||
# 批量评估邻居记忆
|
||||
for neighbor_mem_id in neighbor_memory_ids:
|
||||
candidates_checked += 1
|
||||
|
||||
neighbor_memory = graph_store.get_memory_by_id(neighbor_mem_id)
|
||||
if not neighbor_memory:
|
||||
continue
|
||||
|
||||
# 获取邻居节点
|
||||
try:
|
||||
neighbors = list(graph_store.graph.neighbors(node.id))
|
||||
except Exception:
|
||||
# 获取邻居记忆的主题节点向量
|
||||
topic_node = next(
|
||||
(n for n in neighbor_memory.nodes if n.has_embedding()),
|
||||
None
|
||||
)
|
||||
|
||||
if not topic_node or topic_node.embedding is None:
|
||||
continue
|
||||
|
||||
for neighbor_id in neighbors:
|
||||
# 获取邻居节点信息
|
||||
neighbor_node_data = graph_store.graph.nodes.get(neighbor_id)
|
||||
if not neighbor_node_data:
|
||||
continue
|
||||
# 计算语义相似度
|
||||
semantic_sim = cosine_similarity(query_embedding, topic_node.embedding)
|
||||
|
||||
# 获取邻居节点的向量(从向量存储)
|
||||
neighbor_vector_data = await vector_store.get_node_by_id(neighbor_id)
|
||||
if not neighbor_vector_data or neighbor_vector_data.get("embedding") is None:
|
||||
continue
|
||||
# 计算边的重要性(影响评分)
|
||||
edge_importance = neighbor_memory.importance * 0.5 # 使用记忆重要性作为边权重
|
||||
|
||||
neighbor_embedding = neighbor_vector_data["embedding"]
|
||||
# 综合评分:语义相似度(70%) + 重要性(20%) + 深度衰减(10%)
|
||||
depth_decay = 1.0 / (depth + 2) # 深度衰减
|
||||
relevance_score = semantic_sim * 0.7 + edge_importance * 0.2 + depth_decay * 0.1
|
||||
|
||||
# 计算与查询的语义相似度
|
||||
semantic_sim = cosine_similarity(query_embedding, neighbor_embedding)
|
||||
# 只保留超过阈值的
|
||||
if relevance_score < semantic_threshold:
|
||||
continue
|
||||
|
||||
# 获取边的权重
|
||||
try:
|
||||
edge_data = graph_store.graph.get_edge_data(node.id, neighbor_id)
|
||||
edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5
|
||||
except Exception:
|
||||
edge_importance = 0.5
|
||||
candidates_passed += 1
|
||||
|
||||
# 综合评分:语义相似度(70%) + 图结构权重(20%) + 深度衰减(10%)
|
||||
depth_decay = 1.0 / (depth + 1) # 深度越深,权重越低
|
||||
relevance_score = semantic_sim * 0.7 + edge_importance * 0.2 + depth_decay * 0.1
|
||||
# 记录扩展的记忆
|
||||
if neighbor_mem_id not in expanded_memories:
|
||||
expanded_memories[neighbor_mem_id] = relevance_score
|
||||
visited_memories.add(neighbor_mem_id)
|
||||
next_level_memories.append(neighbor_mem_id)
|
||||
else:
|
||||
# 如果已存在,取最高分
|
||||
expanded_memories[neighbor_mem_id] = max(
|
||||
expanded_memories[neighbor_mem_id], relevance_score
|
||||
)
|
||||
|
||||
# 只保留超过阈值的节点
|
||||
if relevance_score < semantic_threshold:
|
||||
continue
|
||||
# 早停:达到最大扩展数量
|
||||
if len(expanded_memories) >= max_expanded:
|
||||
logger.debug(f"⏹️ 提前停止:已达到最大扩展数量 {max_expanded}")
|
||||
break
|
||||
|
||||
# 早停检查
|
||||
if len(expanded_memories) >= max_expanded:
|
||||
break
|
||||
|
||||
# 记录本层统计
|
||||
depth_stats.append({
|
||||
"depth": depth + 1,
|
||||
"checked": candidates_checked,
|
||||
"passed": candidates_passed,
|
||||
"expanded_total": len(expanded_memories)
|
||||
})
|
||||
|
||||
# 提取邻居节点所属的记忆
|
||||
neighbor_memory_ids = neighbor_node_data.get("memory_ids", [])
|
||||
if isinstance(neighbor_memory_ids, str):
|
||||
import json
|
||||
|
||||
try:
|
||||
neighbor_memory_ids = json.loads(neighbor_memory_ids)
|
||||
except Exception:
|
||||
neighbor_memory_ids = [neighbor_memory_ids]
|
||||
|
||||
for neighbor_mem_id in neighbor_memory_ids:
|
||||
if neighbor_mem_id in visited_memories:
|
||||
continue
|
||||
|
||||
# 记录这个扩展记忆
|
||||
if neighbor_mem_id not in expanded_memories:
|
||||
expanded_memories[neighbor_mem_id] = relevance_score
|
||||
visited_memories.add(neighbor_mem_id)
|
||||
next_level.append(neighbor_mem_id)
|
||||
else:
|
||||
# 如果已存在,取最高分
|
||||
expanded_memories[neighbor_mem_id] = max(
|
||||
expanded_memories[neighbor_mem_id], relevance_score
|
||||
)
|
||||
|
||||
# 如果没有新节点或已达到数量限制,提前终止
|
||||
if not next_level or len(expanded_memories) >= max_expanded:
|
||||
# 如果没有新记忆或已达到数量限制,提前终止
|
||||
if not next_level_memories or len(expanded_memories) >= max_expanded:
|
||||
logger.debug(f"⏹️ 停止扩展:{'无新记忆' if not next_level_memories else '达到上限'}")
|
||||
break
|
||||
|
||||
current_level = next_level[:max_expanded] # 限制每层的扩展数量
|
||||
# 限制下一层的记忆数量,避免爆炸性增长
|
||||
current_level_memories = next_level_memories[:max_expanded]
|
||||
|
||||
# 每层让出控制权
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
# 排序并返回
|
||||
sorted_results = sorted(expanded_memories.items(), key=lambda x: x[1], reverse=True)[:max_expanded]
|
||||
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
f"图扩展完成: 初始{len(initial_memory_ids)}个 → "
|
||||
f"✅ 图扩展完成: 初始{len(initial_memory_ids)}个 → "
|
||||
f"扩展{len(sorted_results)}个新记忆 "
|
||||
f"(深度={max_depth}, 阈值={semantic_threshold:.2f})"
|
||||
f"(深度={max_depth}, 阈值={semantic_threshold:.2f}, 耗时={elapsed:.3f}s)"
|
||||
)
|
||||
|
||||
# 输出每层统计
|
||||
for stat in depth_stats:
|
||||
logger.debug(
|
||||
f" 深度{stat['depth']}: 检查{stat['checked']}个, "
|
||||
f"通过{stat['passed']}个, 累计扩展{stat['expanded_total']}个"
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user