feat(memory_tools): 添加优先节点类型支持,优化多查询生成与记忆扩展逻辑
This commit is contained in:
@@ -88,23 +88,46 @@ async def expand_memories_with_semantic_filter(
|
||||
# 获取该记忆的邻居记忆(通过边关系)
|
||||
neighbor_memory_ids = set()
|
||||
|
||||
# 遍历记忆的所有边,收集邻居记忆
|
||||
# 🆕 遍历记忆的所有边,收集邻居记忆(带边类型权重)
|
||||
edge_weights = {} # 记录通过不同边类型到达的记忆的权重
|
||||
|
||||
for edge in memory.edges:
|
||||
# 获取边的目标节点
|
||||
target_node_id = edge.target_id
|
||||
source_node_id = edge.source_id
|
||||
|
||||
# 🆕 根据边类型设置权重(优先扩展REFERENCE、ATTRIBUTE相关的边)
|
||||
edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
|
||||
if edge_type_str == "REFERENCE":
|
||||
edge_weight = 1.3 # REFERENCE边权重最高(引用关系)
|
||||
elif edge_type_str in ["ATTRIBUTE", "HAS_PROPERTY"]:
|
||||
edge_weight = 1.2 # 属性边次之
|
||||
elif edge_type_str == "TEMPORAL":
|
||||
edge_weight = 0.7 # 时间关系降权(避免扩展到无关时间点)
|
||||
elif edge_type_str == "RELATION":
|
||||
edge_weight = 0.9 # 一般关系适中降权
|
||||
else:
|
||||
edge_weight = 1.0 # 默认权重
|
||||
|
||||
# 通过节点找到其他记忆
|
||||
for node_id in [target_node_id, source_node_id]:
|
||||
if node_id in graph_store.node_to_memories:
|
||||
neighbor_memory_ids.update(graph_store.node_to_memories[node_id])
|
||||
for neighbor_id in graph_store.node_to_memories[node_id]:
|
||||
if neighbor_id not in edge_weights or edge_weights[neighbor_id] < edge_weight:
|
||||
edge_weights[neighbor_id] = edge_weight
|
||||
|
||||
# 将权重高的邻居记忆加入候选
|
||||
for neighbor_id, edge_weight in edge_weights.items():
|
||||
neighbor_memory_ids.add((neighbor_id, edge_weight))
|
||||
|
||||
# 过滤掉已访问的和自己
|
||||
neighbor_memory_ids.discard(memory_id)
|
||||
neighbor_memory_ids -= visited_memories
|
||||
filtered_neighbors = []
|
||||
for neighbor_id, edge_weight in neighbor_memory_ids:
|
||||
if neighbor_id != memory_id and neighbor_id not in visited_memories:
|
||||
filtered_neighbors.append((neighbor_id, edge_weight))
|
||||
|
||||
# 批量评估邻居记忆
|
||||
for neighbor_mem_id in neighbor_memory_ids:
|
||||
for neighbor_mem_id, edge_weight in filtered_neighbors:
|
||||
candidates_checked += 1
|
||||
|
||||
neighbor_memory = graph_store.get_memory_by_id(neighbor_mem_id)
|
||||
@@ -123,12 +146,17 @@ async def expand_memories_with_semantic_filter(
|
||||
# 计算语义相似度
|
||||
semantic_sim = cosine_similarity(query_embedding, topic_node.embedding)
|
||||
|
||||
# 计算边的重要性(影响评分)
|
||||
edge_importance = neighbor_memory.importance * 0.5 # 使用记忆重要性作为边权重
|
||||
# 🆕 计算边的重要性(结合边类型权重和记忆重要性)
|
||||
edge_importance = neighbor_memory.importance * edge_weight * 0.5
|
||||
|
||||
# 综合评分:语义相似度(70%) + 重要性(20%) + 深度衰减(10%)
|
||||
# 🆕 综合评分:语义相似度(60%) + 边权重(20%) + 重要性(10%) + 深度衰减(10%)
|
||||
depth_decay = 1.0 / (depth + 2) # 深度衰减
|
||||
relevance_score = semantic_sim * 0.7 + edge_importance * 0.2 + depth_decay * 0.1
|
||||
relevance_score = (
|
||||
semantic_sim * 0.60 + # 语义相似度主导 ⬆️
|
||||
edge_weight * 0.20 + # 边类型权重 🆕
|
||||
edge_importance * 0.10 + # 重要性降权 ⬇️
|
||||
depth_decay * 0.10 # 深度衰减
|
||||
)
|
||||
|
||||
# 只保留超过阈值的
|
||||
if relevance_score < semantic_threshold:
|
||||
|
||||
Reference in New Issue
Block a user