feat:增强记忆节点的嵌入生成和日志记录- 在 MemoryBuilder 中为 SUBJECT 和 VALUE 节点类型添加了嵌入生成,确保仅为内容足够的节点创建嵌入。- 改进了 MemoryTools 的日志记录,在初始向量搜索期间提供详细见解,包括低召回情况的警告。- 调整了不同记忆类型的评分权重,以强调相似性和重要性,提高记忆检索的质量。- 将向量搜索限制从 2 倍提高到 5 倍,以改善初始召回率。- 引入了一个新脚本,用于为现有节点生成缺失的嵌入,支持批量处理并改进索引。

This commit is contained in:
Windpicker-owo
2025-11-11 19:25:03 +08:00
parent 28c0f764ea
commit e2236f5bc1
5 changed files with 1296 additions and 189 deletions

View File

@@ -0,0 +1,268 @@
"""
为现有节点生成嵌入向量
批量为图存储中缺少嵌入向量的节点生成并索引嵌入向量
使用场景:
1. 历史记忆节点没有嵌入向量
2. 嵌入生成器之前未配置,现在需要补充生成
3. 向量索引损坏需要重建
使用方法:
python scripts/generate_missing_embeddings.py [--node-types TOPIC,OBJECT] [--batch-size 50]
参数说明:
--node-types: 需要生成嵌入的节点类型,默认为 TOPIC,OBJECT
--batch-size: 批量处理大小,默认为 50
"""
import asyncio
import sys
from pathlib import Path
from typing import List
# 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent))
async def generate_missing_embeddings(
target_node_types: List[str] = None,
batch_size: int = 50,
):
"""
为缺失嵌入向量的节点生成嵌入
Args:
target_node_types: 需要处理的节点类型列表(如 ["主题", "客体"]
batch_size: 批处理大小
"""
from src.common.logger import get_logger
from src.memory_graph.manager_singleton import get_memory_manager, initialize_memory_manager
from src.memory_graph.models import NodeType
logger = get_logger("generate_missing_embeddings")
if target_node_types is None:
target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value]
print(f"\n{'='*80}")
print(f"🔧 为节点生成嵌入向量")
print(f"{'='*80}\n")
print(f"目标节点类型: {', '.join(target_node_types)}")
print(f"批处理大小: {batch_size}\n")
# 1. 初始化记忆管理器
print(f"🔧 正在初始化记忆管理器...")
await initialize_memory_manager()
manager = get_memory_manager()
if manager is None:
print("❌ 记忆管理器初始化失败")
return
print(f"✅ 记忆管理器已初始化\n")
# 2. 获取已索引的节点ID
print(f"🔍 检查现有向量索引...")
existing_node_ids = set()
try:
vector_count = manager.vector_store.collection.count()
if vector_count > 0:
# 分批获取所有已索引的ID
batch_size_check = 1000
for offset in range(0, vector_count, batch_size_check):
limit = min(batch_size_check, vector_count - offset)
result = manager.vector_store.collection.get(
limit=limit,
offset=offset,
)
if result and "ids" in result:
existing_node_ids.update(result["ids"])
print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n")
except Exception as e:
logger.warning(f"获取已索引节点ID失败: {e}")
print(f"⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
# 3. 收集需要生成嵌入的节点
print(f"🔍 扫描需要生成嵌入的节点...")
all_memories = manager.graph_store.get_all_memories()
nodes_to_process = []
total_target_nodes = 0
type_stats = {nt: {"total": 0, "need_emb": 0, "already_indexed": 0} for nt in target_node_types}
for memory in all_memories:
for node in memory.nodes:
if node.node_type.value in target_node_types:
total_target_nodes += 1
type_stats[node.node_type.value]["total"] += 1
# 检查是否已在向量索引中
if node.id in existing_node_ids:
type_stats[node.node_type.value]["already_indexed"] += 1
continue
if not node.has_embedding():
nodes_to_process.append({
"node": node,
"memory_id": memory.id,
})
type_stats[node.node_type.value]["need_emb"] += 1
print(f"\n📊 扫描结果:")
for node_type in target_node_types:
stats = type_stats[node_type]
already_ok = stats["already_indexed"]
coverage = (stats["total"] - stats["need_emb"]) / stats["total"] * 100 if stats["total"] > 0 else 0
print(f" - {node_type}: {stats['total']} 个节点, {stats['need_emb']} 个缺失嵌入, "
f"{already_ok} 个已索引 (覆盖率: {coverage:.1f}%)")
print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n")
if len(nodes_to_process) == 0:
print(f"✅ 所有节点已有嵌入向量,无需生成")
return
# 3. 批量生成嵌入
print(f"🚀 开始生成嵌入向量...\n")
total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size
success_count = 0
failed_count = 0
indexed_count = 0
for i in range(0, len(nodes_to_process), batch_size):
batch = nodes_to_process[i : i + batch_size]
batch_num = i // batch_size + 1
print(f"📦 批次 {batch_num}/{total_batches} ({len(batch)} 个节点)...")
try:
# 提取文本内容
texts = [item["node"].content for item in batch]
# 批量生成嵌入
embeddings = await manager.embedding_generator.generate_batch(texts)
# 为节点设置嵌入并索引
batch_nodes_for_index = []
for j, (item, embedding) in enumerate(zip(batch, embeddings)):
node = item["node"]
if embedding is not None:
# 设置嵌入向量
node.embedding = embedding
batch_nodes_for_index.append(node)
success_count += 1
else:
failed_count += 1
logger.warning(f" ⚠️ 节点 {node.id[:8]}... '{node.content[:30]}' 嵌入生成失败")
# 批量索引到向量数据库
if batch_nodes_for_index:
try:
await manager.vector_store.add_nodes_batch(batch_nodes_for_index)
indexed_count += len(batch_nodes_for_index)
print(f" ✅ 成功: {len(batch_nodes_for_index)}/{len(batch)} 个节点已生成并索引")
except Exception as e:
# 如果批量失败,尝试逐个添加(跳过重复)
logger.warning(f" 批量索引失败,尝试逐个添加: {e}")
individual_success = 0
for node in batch_nodes_for_index:
try:
await manager.vector_store.add_node(node)
individual_success += 1
indexed_count += 1
except Exception as e2:
if "Expected IDs to be unique" in str(e2):
logger.debug(f" 跳过已存在节点: {node.id}")
else:
logger.error(f" 节点 {node.id} 索引失败: {e2}")
print(f" ⚠️ 逐个索引: {individual_success}/{len(batch_nodes_for_index)} 个成功")
except Exception as e:
failed_count += len(batch)
logger.error(f"批次 {batch_num} 处理失败", exc_info=True)
print(f" ❌ 批次处理失败: {e}")
# 显示进度
total_processed = min(i + batch_size, len(nodes_to_process))
progress = total_processed / len(nodes_to_process) * 100
print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n")
# 4. 保存图数据(更新节点的 embedding 字段)
print(f"💾 保存图数据...")
try:
await manager.persistence.save_graph_store(manager.graph_store)
print(f"✅ 图数据已保存\n")
except Exception as e:
logger.error(f"保存图数据失败", exc_info=True)
print(f"❌ 保存失败: {e}\n")
# 5. 验证结果
print(f"🔍 验证向量索引...")
final_vector_count = manager.vector_store.collection.count()
stats = manager.graph_store.get_statistics()
total_nodes = stats["total_nodes"]
print(f"\n{'='*80}")
print(f"📊 生成完成")
print(f"{'='*80}")
print(f"处理节点数: {len(nodes_to_process)}")
print(f"成功生成: {success_count}")
print(f"失败数量: {failed_count}")
print(f"成功索引: {indexed_count}")
print(f"向量索引节点数: {final_vector_count}")
print(f"图存储节点数: {total_nodes}")
print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n")
# 6. 测试搜索
print(f"🧪 测试搜索功能...")
test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"]
for query in test_queries:
results = await manager.search_memories(query=query, top_k=3)
if results:
print(f"\n✅ 查询 '{query}' 找到 {len(results)} 条记忆:")
for i, memory in enumerate(results[:2], 1):
subject_node = memory.get_subject_node()
# 获取主题节点遍历所有节点找TOPIC类型
from src.memory_graph.models import NodeType
topic_nodes = [n for n in memory.nodes if n.node_type == NodeType.TOPIC]
subject = subject_node.content if subject_node else "?"
topic = topic_nodes[0].content if topic_nodes else "?"
print(f" {i}. {subject} - {topic} (重要性: {memory.importance:.2f})")
else:
print(f"\n⚠️ 查询 '{query}' 返回 0 条结果")
async def main():
import argparse
parser = argparse.ArgumentParser(description="为节点生成嵌入向量")
parser.add_argument(
"--node-types",
type=str,
default="主题,客体",
help="需要生成嵌入的节点类型,逗号分隔(默认:主题,客体)",
)
parser.add_argument(
"--batch-size",
type=int,
default=50,
help="批处理大小默认50",
)
args = parser.parse_args()
target_types = [t.strip() for t in args.node_types.split(",")]
await generate_missing_embeddings(
target_node_types=target_types,
batch_size=args.batch_size,
)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -7,9 +7,10 @@
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from collections import defaultdict
import orjson import orjson
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, HTTPException, Request, Query
from fastapi.responses import HTMLResponse, JSONResponse from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
@@ -227,6 +228,242 @@ async def get_full_graph():
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500) return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.get("/api/graph/summary")
async def get_graph_summary():
"""获取图的摘要信息(仅统计数据,不包含节点和边)"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
memory_manager = get_memory_manager()
if memory_manager and memory_manager._initialized:
stats = memory_manager.get_statistics()
return JSONResponse(content={"success": True, "data": {
"stats": {
"total_nodes": stats.get("total_nodes", 0),
"total_edges": stats.get("total_edges", 0),
"total_memories": stats.get("total_memories", 0),
},
"current_file": "memory_manager (实时数据)",
}})
else:
data = load_graph_data_from_file()
return JSONResponse(content={"success": True, "data": {
"stats": data.get("stats", {}),
"current_file": data.get("current_file", ""),
}})
except Exception as e:
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.get("/api/graph/paginated")
async def get_paginated_graph(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(500, ge=100, le=2000, description="每页节点数"),
min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"),
node_types: Optional[str] = Query(None, description="节点类型过滤,逗号分隔"),
):
"""分页获取图数据,支持重要性过滤"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
memory_manager = get_memory_manager()
# 获取完整数据
if memory_manager and memory_manager._initialized:
full_data = _format_graph_data_from_manager(memory_manager)
else:
full_data = load_graph_data_from_file()
nodes = full_data.get("nodes", [])
edges = full_data.get("edges", [])
# 过滤节点类型
if node_types:
allowed_types = set(node_types.split(","))
nodes = [n for n in nodes if n.get("group") in allowed_types]
# 按重要性排序如果有importance字段
nodes_with_importance = []
for node in nodes:
# 计算节点重要性(连接的边数)
edge_count = sum(1 for e in edges if e.get("from") == node["id"] or e.get("to") == node["id"])
importance = edge_count / max(len(edges), 1)
if importance >= min_importance:
node["importance"] = importance
nodes_with_importance.append(node)
# 按重要性降序排序
nodes_with_importance.sort(key=lambda x: x.get("importance", 0), reverse=True)
# 分页
total_nodes = len(nodes_with_importance)
total_pages = (total_nodes + page_size - 1) // page_size
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_nodes)
paginated_nodes = nodes_with_importance[start_idx:end_idx]
node_ids = set(n["id"] for n in paginated_nodes)
# 只保留连接分页节点的边
paginated_edges = [
e for e in edges
if e.get("from") in node_ids and e.get("to") in node_ids
]
return JSONResponse(content={"success": True, "data": {
"nodes": paginated_nodes,
"edges": paginated_edges,
"pagination": {
"page": page,
"page_size": page_size,
"total_nodes": total_nodes,
"total_pages": total_pages,
"has_next": page < total_pages,
"has_prev": page > 1,
},
"stats": {
"total_nodes": total_nodes,
"total_edges": len(paginated_edges),
"total_memories": full_data.get("stats", {}).get("total_memories", 0),
},
}})
except Exception as e:
import traceback
traceback.print_exc()
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.get("/api/graph/clustered")
async def get_clustered_graph(
max_nodes: int = Query(300, ge=50, le=1000, description="最大节点数"),
cluster_threshold: int = Query(10, ge=2, le=50, description="聚类阈值")
):
"""获取聚类简化后的图数据"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
memory_manager = get_memory_manager()
# 获取完整数据
if memory_manager and memory_manager._initialized:
full_data = _format_graph_data_from_manager(memory_manager)
else:
full_data = load_graph_data_from_file()
nodes = full_data.get("nodes", [])
edges = full_data.get("edges", [])
# 如果节点数小于阈值,直接返回
if len(nodes) <= max_nodes:
return JSONResponse(content={"success": True, "data": {
"nodes": nodes,
"edges": edges,
"stats": full_data.get("stats", {}),
"clustered": False,
}})
# 执行聚类
clustered_data = _cluster_graph_data(nodes, edges, max_nodes, cluster_threshold)
return JSONResponse(content={"success": True, "data": {
**clustered_data,
"stats": {
"original_nodes": len(nodes),
"original_edges": len(edges),
"clustered_nodes": len(clustered_data["nodes"]),
"clustered_edges": len(clustered_data["edges"]),
"total_memories": full_data.get("stats", {}).get("total_memories", 0),
},
"clustered": True,
}})
except Exception as e:
import traceback
traceback.print_exc()
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cluster_threshold: int) -> Dict:
"""简单的图聚类算法:按类型和连接度聚类"""
# 构建邻接表
adjacency = defaultdict(set)
for edge in edges:
adjacency[edge["from"]].add(edge["to"])
adjacency[edge["to"]].add(edge["from"])
# 按类型分组
type_groups = defaultdict(list)
for node in nodes:
type_groups[node.get("group", "UNKNOWN")].append(node)
clustered_nodes = []
clustered_edges = []
node_mapping = {} # 原始节点ID -> 聚类节点ID
for node_type, type_nodes in type_groups.items():
# 如果该类型节点少于阈值,直接保留
if len(type_nodes) <= cluster_threshold:
for node in type_nodes:
clustered_nodes.append(node)
node_mapping[node["id"]] = node["id"]
else:
# 按连接度排序,保留最重要的节点
node_importance = []
for node in type_nodes:
importance = len(adjacency[node["id"]])
node_importance.append((node, importance))
node_importance.sort(key=lambda x: x[1], reverse=True)
# 保留前N个重要节点
keep_count = min(len(type_nodes), max_nodes // len(type_groups))
for node, importance in node_importance[:keep_count]:
clustered_nodes.append(node)
node_mapping[node["id"]] = node["id"]
# 其余节点聚合为一个超级节点
if len(node_importance) > keep_count:
clustered_node_ids = [n["id"] for n, _ in node_importance[keep_count:]]
cluster_id = f"cluster_{node_type}_{len(clustered_nodes)}"
cluster_label = f"{node_type} 集群 ({len(clustered_node_ids)}个节点)"
clustered_nodes.append({
"id": cluster_id,
"label": cluster_label,
"group": node_type,
"title": f"包含 {len(clustered_node_ids)}{node_type}节点",
"is_cluster": True,
"cluster_size": len(clustered_node_ids),
"clustered_nodes": clustered_node_ids[:10], # 只保留前10个用于展示
})
for node_id in clustered_node_ids:
node_mapping[node_id] = cluster_id
# 重建边(去重)
edge_set = set()
for edge in edges:
from_id = node_mapping.get(edge["from"])
to_id = node_mapping.get(edge["to"])
if from_id and to_id and from_id != to_id:
edge_key = tuple(sorted([from_id, to_id]))
if edge_key not in edge_set:
edge_set.add(edge_key)
clustered_edges.append({
"id": f"{from_id}_{to_id}",
"from": from_id,
"to": to_id,
"label": edge.get("label", ""),
"arrows": "to",
})
return {
"nodes": clustered_nodes,
"edges": clustered_edges,
}
@router.get("/api/files") @router.get("/api/files")
async def list_files_api(): async def list_files_api():
"""列出所有可用的数据文件""" """列出所有可用的数据文件"""

File diff suppressed because it is too large Load Diff

View File

@@ -185,12 +185,19 @@ class MemoryBuilder:
logger.debug(f"复用已存在的主体节点: {existing.id}") logger.debug(f"复用已存在的主体节点: {existing.id}")
return existing return existing
# 为主体和值节点生成嵌入向量(用于人名/实体和重要描述检索)
embedding = None
if node_type in (NodeType.SUBJECT, NodeType.VALUE):
# 只为有足够内容的节点生成嵌入(避免浪费)
if len(content.strip()) >= 2:
embedding = await self._generate_embedding(content)
# 创建新节点 # 创建新节点
node = MemoryNode( node = MemoryNode(
id=self._generate_node_id(), id=self._generate_node_id(),
content=content, content=content,
node_type=node_type, node_type=node_type,
embedding=None, # 主体属性不需要嵌入 embedding=embedding, # 主体、值需要嵌入,属性不需要
metadata={"memory_ids": [memory_id]}, metadata={"memory_ids": [memory_id]},
) )

View File

@@ -516,6 +516,22 @@ class MemoryTools:
# 记录最高分数 # 记录最高分数
if mem_id not in memory_scores or similarity > memory_scores[mem_id]: if mem_id not in memory_scores or similarity > memory_scores[mem_id]:
memory_scores[mem_id] = similarity memory_scores[mem_id] = similarity
# 🔥 详细日志:检查初始召回情况
logger.info(
f"初始向量搜索: 返回{len(similar_nodes)}个节点 → "
f"提取{len(initial_memory_ids)}条记忆"
)
if len(initial_memory_ids) == 0:
logger.warning(
f"⚠️ 向量搜索未找到任何记忆!"
f"可能原因1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大"
)
# 输出相似节点的详细信息用于调试
if similar_nodes:
logger.debug(f"向量搜索返回的节点元数据样例: {similar_nodes[0][2] if len(similar_nodes) > 0 else 'None'}")
elif len(initial_memory_ids) < 3:
logger.warning(f"⚠️ 初始召回记忆数量较少({len(initial_memory_ids)}条),可能影响结果质量")
# 3. 图扩展如果启用且有expand_depth # 3. 图扩展如果启用且有expand_depth
expanded_memory_scores = {} expanded_memory_scores = {}
@@ -609,42 +625,37 @@ class MemoryTools:
if dominant_node_type in ["ATTRIBUTE", "REFERENCE"] or memory_type == "FACT": if dominant_node_type in ["ATTRIBUTE", "REFERENCE"] or memory_type == "FACT":
# 事实性记忆(如文档地址、配置信息):语义相似度最重要 # 事实性记忆(如文档地址、配置信息):语义相似度最重要
weights = { weights = {
"similarity": 0.65, # 语义相似度 65% ⬆️ "similarity": 0.70, # 语义相似度 70% ⬆️
"importance": 0.20, # 重要性 20% "importance": 0.25, # 重要性 25% ⬆️
"recency": 0.05, # 时效性 5% ⬇️(事实不随时间失效) "recency": 0.05, # 时效性 5%(事实不随时间失效)
"activation": 0.10 # 激活度 10% ⬇️(避免冷门信息被压制)
} }
elif memory_type in ["CONVERSATION", "EPISODIC"] or dominant_node_type == "EVENT": elif memory_type in ["CONVERSATION", "EPISODIC"] or dominant_node_type == "EVENT":
# 对话/事件记忆:时效性和激活度更重要 # 对话/事件记忆:时效性更重要
weights = { weights = {
"similarity": 0.45, # 语义相似度 45% "similarity": 0.55, # 语义相似度 55% ⬆️
"importance": 0.15, # 重要性 15% "importance": 0.20, # 重要性 20% ⬆️
"recency": 0.20, # 时效性 20% ⬆️ "recency": 0.25, # 时效性 25% ⬆️
"activation": 0.20 # 激活度 20%
} }
elif dominant_node_type == "ENTITY" or memory_type == "SEMANTIC": elif dominant_node_type == "ENTITY" or memory_type == "SEMANTIC":
# 实体/语义记忆:平衡各项 # 实体/语义记忆:平衡各项
weights = { weights = {
"similarity": 0.50, # 语义相似度 50% "similarity": 0.60, # 语义相似度 60% ⬆️
"importance": 0.25, # 重要性 25% "importance": 0.30, # 重要性 30% ⬆️
"recency": 0.10, # 时效性 10% "recency": 0.10, # 时效性 10%
"activation": 0.15 # 激活度 15%
} }
else: else:
# 默认权重(保守策略,偏向语义) # 默认权重(保守策略,偏向语义)
weights = { weights = {
"similarity": 0.55, # 语义相似度 55% "similarity": 0.65, # 语义相似度 65% ⬆️
"importance": 0.20, # 重要性 20% "importance": 0.25, # 重要性 25% ⬆️
"recency": 0.10, # 时效性 10% "recency": 0.10, # 时效性 10%
"activation": 0.15 # 激活度 15%
} }
# 综合分数计算 # 综合分数计算(🔥 移除激活度影响)
final_score = ( final_score = (
similarity_score * weights["similarity"] + similarity_score * weights["similarity"] +
importance_score * weights["importance"] + importance_score * weights["importance"] +
recency_score * weights["recency"] + recency_score * weights["recency"]
activation_score * weights["activation"]
) )
# 🆕 节点类型加权对REFERENCE/ATTRIBUTE节点额外加分促进事实性信息召回 # 🆕 节点类型加权对REFERENCE/ATTRIBUTE节点额外加分促进事实性信息召回
@@ -943,11 +954,16 @@ class MemoryTools:
logger.warning("嵌入生成失败,跳过节点搜索") logger.warning("嵌入生成失败,跳过节点搜索")
return [] return []
# 向量搜索 # 向量搜索(增加返回数量以提高召回率)
similar_nodes = await self.vector_store.search_similar_nodes( similar_nodes = await self.vector_store.search_similar_nodes(
query_embedding=query_embedding, query_embedding=query_embedding,
limit=top_k * 2, # 多取一些,后续过滤 limit=top_k * 5, # 🔥 从2倍提升到5倍提高初始召回率
min_similarity=0.0, # 不在这里过滤,交给后续评分
) )
logger.debug(f"单查询向量搜索: 查询='{query}', 返回节点数={len(similar_nodes)}")
if similar_nodes:
logger.debug(f"Top 3相似度: {[f'{sim:.3f}' for _, sim, _ in similar_nodes[:3]]}")
return similar_nodes return similar_nodes
@@ -1003,11 +1019,13 @@ class MemoryTools:
similar_nodes = await self.vector_store.search_with_multiple_queries( similar_nodes = await self.vector_store.search_with_multiple_queries(
query_embeddings=query_embeddings, query_embeddings=query_embeddings,
query_weights=query_weights, query_weights=query_weights,
limit=top_k * 2, # 多取一些,后续过滤 limit=top_k * 5, # 🔥 从2倍提升到5倍提高初始召回率
fusion_strategy="weighted_max", fusion_strategy="weighted_max",
) )
logger.info(f"多查询检索完成: {len(similar_nodes)} 个节点 (偏好类型: {prefer_node_types})") logger.info(f"多查询检索完成: {len(similar_nodes)} 个节点 (偏好类型: {prefer_node_types})")
if similar_nodes:
logger.debug(f"Top 5融合相似度: {[f'{sim:.3f}' for _, sim, _ in similar_nodes[:5]]}")
return similar_nodes, prefer_node_types return similar_nodes, prefer_node_types