feat:增强记忆节点的嵌入生成和日志记录- 在 MemoryBuilder 中为 SUBJECT 和 VALUE 节点类型添加了嵌入生成,确保仅为内容足够的节点创建嵌入。- 改进了 MemoryTools 的日志记录,在初始向量搜索期间提供详细见解,包括低召回情况的警告。- 调整了不同记忆类型的评分权重,以强调相似性和重要性,提高记忆检索的质量。- 将向量搜索限制从 2 倍提高到 5 倍,以改善初始召回率。- 引入了一个新脚本,用于为现有节点生成缺失的嵌入,支持批量处理并改进索引。
This commit is contained in:
268
scripts/generate_missing_embeddings.py
Normal file
268
scripts/generate_missing_embeddings.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
为现有节点生成嵌入向量
|
||||
|
||||
批量为图存储中缺少嵌入向量的节点生成并索引嵌入向量
|
||||
|
||||
使用场景:
|
||||
1. 历史记忆节点没有嵌入向量
|
||||
2. 嵌入生成器之前未配置,现在需要补充生成
|
||||
3. 向量索引损坏需要重建
|
||||
|
||||
使用方法:
|
||||
python scripts/generate_missing_embeddings.py [--node-types TOPIC,OBJECT] [--batch-size 50]
|
||||
|
||||
参数说明:
|
||||
--node-types: 需要生成嵌入的节点类型,默认为 TOPIC,OBJECT
|
||||
--batch-size: 批量处理大小,默认为 50
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
async def generate_missing_embeddings(
|
||||
target_node_types: List[str] = None,
|
||||
batch_size: int = 50,
|
||||
):
|
||||
"""
|
||||
为缺失嵌入向量的节点生成嵌入
|
||||
|
||||
Args:
|
||||
target_node_types: 需要处理的节点类型列表(如 ["主题", "客体"])
|
||||
batch_size: 批处理大小
|
||||
"""
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.manager_singleton import get_memory_manager, initialize_memory_manager
|
||||
from src.memory_graph.models import NodeType
|
||||
|
||||
logger = get_logger("generate_missing_embeddings")
|
||||
|
||||
if target_node_types is None:
|
||||
target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value]
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🔧 为节点生成嵌入向量")
|
||||
print(f"{'='*80}\n")
|
||||
print(f"目标节点类型: {', '.join(target_node_types)}")
|
||||
print(f"批处理大小: {batch_size}\n")
|
||||
|
||||
# 1. 初始化记忆管理器
|
||||
print(f"🔧 正在初始化记忆管理器...")
|
||||
await initialize_memory_manager()
|
||||
manager = get_memory_manager()
|
||||
|
||||
if manager is None:
|
||||
print("❌ 记忆管理器初始化失败")
|
||||
return
|
||||
|
||||
print(f"✅ 记忆管理器已初始化\n")
|
||||
|
||||
# 2. 获取已索引的节点ID
|
||||
print(f"🔍 检查现有向量索引...")
|
||||
existing_node_ids = set()
|
||||
try:
|
||||
vector_count = manager.vector_store.collection.count()
|
||||
if vector_count > 0:
|
||||
# 分批获取所有已索引的ID
|
||||
batch_size_check = 1000
|
||||
for offset in range(0, vector_count, batch_size_check):
|
||||
limit = min(batch_size_check, vector_count - offset)
|
||||
result = manager.vector_store.collection.get(
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
if result and "ids" in result:
|
||||
existing_node_ids.update(result["ids"])
|
||||
|
||||
print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取已索引节点ID失败: {e}")
|
||||
print(f"⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
|
||||
|
||||
# 3. 收集需要生成嵌入的节点
|
||||
print(f"🔍 扫描需要生成嵌入的节点...")
|
||||
all_memories = manager.graph_store.get_all_memories()
|
||||
|
||||
nodes_to_process = []
|
||||
total_target_nodes = 0
|
||||
type_stats = {nt: {"total": 0, "need_emb": 0, "already_indexed": 0} for nt in target_node_types}
|
||||
|
||||
for memory in all_memories:
|
||||
for node in memory.nodes:
|
||||
if node.node_type.value in target_node_types:
|
||||
total_target_nodes += 1
|
||||
type_stats[node.node_type.value]["total"] += 1
|
||||
|
||||
# 检查是否已在向量索引中
|
||||
if node.id in existing_node_ids:
|
||||
type_stats[node.node_type.value]["already_indexed"] += 1
|
||||
continue
|
||||
|
||||
if not node.has_embedding():
|
||||
nodes_to_process.append({
|
||||
"node": node,
|
||||
"memory_id": memory.id,
|
||||
})
|
||||
type_stats[node.node_type.value]["need_emb"] += 1
|
||||
|
||||
print(f"\n📊 扫描结果:")
|
||||
for node_type in target_node_types:
|
||||
stats = type_stats[node_type]
|
||||
already_ok = stats["already_indexed"]
|
||||
coverage = (stats["total"] - stats["need_emb"]) / stats["total"] * 100 if stats["total"] > 0 else 0
|
||||
print(f" - {node_type}: {stats['total']} 个节点, {stats['need_emb']} 个缺失嵌入, "
|
||||
f"{already_ok} 个已索引 (覆盖率: {coverage:.1f}%)")
|
||||
|
||||
print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n")
|
||||
|
||||
if len(nodes_to_process) == 0:
|
||||
print(f"✅ 所有节点已有嵌入向量,无需生成")
|
||||
return
|
||||
|
||||
# 3. 批量生成嵌入
|
||||
print(f"🚀 开始生成嵌入向量...\n")
|
||||
|
||||
total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
indexed_count = 0
|
||||
|
||||
for i in range(0, len(nodes_to_process), batch_size):
|
||||
batch = nodes_to_process[i : i + batch_size]
|
||||
batch_num = i // batch_size + 1
|
||||
|
||||
print(f"📦 批次 {batch_num}/{total_batches} ({len(batch)} 个节点)...")
|
||||
|
||||
try:
|
||||
# 提取文本内容
|
||||
texts = [item["node"].content for item in batch]
|
||||
|
||||
# 批量生成嵌入
|
||||
embeddings = await manager.embedding_generator.generate_batch(texts)
|
||||
|
||||
# 为节点设置嵌入并索引
|
||||
batch_nodes_for_index = []
|
||||
|
||||
for j, (item, embedding) in enumerate(zip(batch, embeddings)):
|
||||
node = item["node"]
|
||||
|
||||
if embedding is not None:
|
||||
# 设置嵌入向量
|
||||
node.embedding = embedding
|
||||
batch_nodes_for_index.append(node)
|
||||
success_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
logger.warning(f" ⚠️ 节点 {node.id[:8]}... '{node.content[:30]}' 嵌入生成失败")
|
||||
|
||||
# 批量索引到向量数据库
|
||||
if batch_nodes_for_index:
|
||||
try:
|
||||
await manager.vector_store.add_nodes_batch(batch_nodes_for_index)
|
||||
indexed_count += len(batch_nodes_for_index)
|
||||
print(f" ✅ 成功: {len(batch_nodes_for_index)}/{len(batch)} 个节点已生成并索引")
|
||||
except Exception as e:
|
||||
# 如果批量失败,尝试逐个添加(跳过重复)
|
||||
logger.warning(f" 批量索引失败,尝试逐个添加: {e}")
|
||||
individual_success = 0
|
||||
for node in batch_nodes_for_index:
|
||||
try:
|
||||
await manager.vector_store.add_node(node)
|
||||
individual_success += 1
|
||||
indexed_count += 1
|
||||
except Exception as e2:
|
||||
if "Expected IDs to be unique" in str(e2):
|
||||
logger.debug(f" 跳过已存在节点: {node.id}")
|
||||
else:
|
||||
logger.error(f" 节点 {node.id} 索引失败: {e2}")
|
||||
print(f" ⚠️ 逐个索引: {individual_success}/{len(batch_nodes_for_index)} 个成功")
|
||||
|
||||
except Exception as e:
|
||||
failed_count += len(batch)
|
||||
logger.error(f"批次 {batch_num} 处理失败", exc_info=True)
|
||||
print(f" ❌ 批次处理失败: {e}")
|
||||
|
||||
# 显示进度
|
||||
total_processed = min(i + batch_size, len(nodes_to_process))
|
||||
progress = total_processed / len(nodes_to_process) * 100
|
||||
print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n")
|
||||
|
||||
# 4. 保存图数据(更新节点的 embedding 字段)
|
||||
print(f"💾 保存图数据...")
|
||||
try:
|
||||
await manager.persistence.save_graph_store(manager.graph_store)
|
||||
print(f"✅ 图数据已保存\n")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图数据失败", exc_info=True)
|
||||
print(f"❌ 保存失败: {e}\n")
|
||||
|
||||
# 5. 验证结果
|
||||
print(f"🔍 验证向量索引...")
|
||||
final_vector_count = manager.vector_store.collection.count()
|
||||
stats = manager.graph_store.get_statistics()
|
||||
total_nodes = stats["total_nodes"]
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"📊 生成完成")
|
||||
print(f"{'='*80}")
|
||||
print(f"处理节点数: {len(nodes_to_process)}")
|
||||
print(f"成功生成: {success_count}")
|
||||
print(f"失败数量: {failed_count}")
|
||||
print(f"成功索引: {indexed_count}")
|
||||
print(f"向量索引节点数: {final_vector_count}")
|
||||
print(f"图存储节点数: {total_nodes}")
|
||||
print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n")
|
||||
|
||||
# 6. 测试搜索
|
||||
print(f"🧪 测试搜索功能...")
|
||||
test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"]
|
||||
|
||||
for query in test_queries:
|
||||
results = await manager.search_memories(query=query, top_k=3)
|
||||
if results:
|
||||
print(f"\n✅ 查询 '{query}' 找到 {len(results)} 条记忆:")
|
||||
for i, memory in enumerate(results[:2], 1):
|
||||
subject_node = memory.get_subject_node()
|
||||
# 获取主题节点(遍历所有节点找TOPIC类型)
|
||||
from src.memory_graph.models import NodeType
|
||||
topic_nodes = [n for n in memory.nodes if n.node_type == NodeType.TOPIC]
|
||||
subject = subject_node.content if subject_node else "?"
|
||||
topic = topic_nodes[0].content if topic_nodes else "?"
|
||||
print(f" {i}. {subject} - {topic} (重要性: {memory.importance:.2f})")
|
||||
else:
|
||||
print(f"\n⚠️ 查询 '{query}' 返回 0 条结果")
|
||||
|
||||
|
||||
async def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="为节点生成嵌入向量")
|
||||
parser.add_argument(
|
||||
"--node-types",
|
||||
type=str,
|
||||
default="主题,客体",
|
||||
help="需要生成嵌入的节点类型,逗号分隔(默认:主题,客体)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=50,
|
||||
help="批处理大小(默认:50)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
target_types = [t.strip() for t in args.node_types.split(",")]
|
||||
await generate_missing_embeddings(
|
||||
target_node_types=target_types,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -7,9 +7,10 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
import orjson
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi import APIRouter, HTTPException, Request, Query
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
@@ -227,6 +228,242 @@ async def get_full_graph():
|
||||
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
@router.get("/api/graph/summary")
|
||||
async def get_graph_summary():
|
||||
"""获取图的摘要信息(仅统计数据,不包含节点和边)"""
|
||||
try:
|
||||
from src.memory_graph.manager_singleton import get_memory_manager
|
||||
|
||||
memory_manager = get_memory_manager()
|
||||
|
||||
if memory_manager and memory_manager._initialized:
|
||||
stats = memory_manager.get_statistics()
|
||||
return JSONResponse(content={"success": True, "data": {
|
||||
"stats": {
|
||||
"total_nodes": stats.get("total_nodes", 0),
|
||||
"total_edges": stats.get("total_edges", 0),
|
||||
"total_memories": stats.get("total_memories", 0),
|
||||
},
|
||||
"current_file": "memory_manager (实时数据)",
|
||||
}})
|
||||
else:
|
||||
data = load_graph_data_from_file()
|
||||
return JSONResponse(content={"success": True, "data": {
|
||||
"stats": data.get("stats", {}),
|
||||
"current_file": data.get("current_file", ""),
|
||||
}})
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
@router.get("/api/graph/paginated")
|
||||
async def get_paginated_graph(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(500, ge=100, le=2000, description="每页节点数"),
|
||||
min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"),
|
||||
node_types: Optional[str] = Query(None, description="节点类型过滤,逗号分隔"),
|
||||
):
|
||||
"""分页获取图数据,支持重要性过滤"""
|
||||
try:
|
||||
from src.memory_graph.manager_singleton import get_memory_manager
|
||||
|
||||
memory_manager = get_memory_manager()
|
||||
|
||||
# 获取完整数据
|
||||
if memory_manager and memory_manager._initialized:
|
||||
full_data = _format_graph_data_from_manager(memory_manager)
|
||||
else:
|
||||
full_data = load_graph_data_from_file()
|
||||
|
||||
nodes = full_data.get("nodes", [])
|
||||
edges = full_data.get("edges", [])
|
||||
|
||||
# 过滤节点类型
|
||||
if node_types:
|
||||
allowed_types = set(node_types.split(","))
|
||||
nodes = [n for n in nodes if n.get("group") in allowed_types]
|
||||
|
||||
# 按重要性排序(如果有importance字段)
|
||||
nodes_with_importance = []
|
||||
for node in nodes:
|
||||
# 计算节点重要性(连接的边数)
|
||||
edge_count = sum(1 for e in edges if e.get("from") == node["id"] or e.get("to") == node["id"])
|
||||
importance = edge_count / max(len(edges), 1)
|
||||
if importance >= min_importance:
|
||||
node["importance"] = importance
|
||||
nodes_with_importance.append(node)
|
||||
|
||||
# 按重要性降序排序
|
||||
nodes_with_importance.sort(key=lambda x: x.get("importance", 0), reverse=True)
|
||||
|
||||
# 分页
|
||||
total_nodes = len(nodes_with_importance)
|
||||
total_pages = (total_nodes + page_size - 1) // page_size
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_nodes)
|
||||
|
||||
paginated_nodes = nodes_with_importance[start_idx:end_idx]
|
||||
node_ids = set(n["id"] for n in paginated_nodes)
|
||||
|
||||
# 只保留连接分页节点的边
|
||||
paginated_edges = [
|
||||
e for e in edges
|
||||
if e.get("from") in node_ids and e.get("to") in node_ids
|
||||
]
|
||||
|
||||
return JSONResponse(content={"success": True, "data": {
|
||||
"nodes": paginated_nodes,
|
||||
"edges": paginated_edges,
|
||||
"pagination": {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total_nodes": total_nodes,
|
||||
"total_pages": total_pages,
|
||||
"has_next": page < total_pages,
|
||||
"has_prev": page > 1,
|
||||
},
|
||||
"stats": {
|
||||
"total_nodes": total_nodes,
|
||||
"total_edges": len(paginated_edges),
|
||||
"total_memories": full_data.get("stats", {}).get("total_memories", 0),
|
||||
},
|
||||
}})
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
@router.get("/api/graph/clustered")
|
||||
async def get_clustered_graph(
|
||||
max_nodes: int = Query(300, ge=50, le=1000, description="最大节点数"),
|
||||
cluster_threshold: int = Query(10, ge=2, le=50, description="聚类阈值")
|
||||
):
|
||||
"""获取聚类简化后的图数据"""
|
||||
try:
|
||||
from src.memory_graph.manager_singleton import get_memory_manager
|
||||
|
||||
memory_manager = get_memory_manager()
|
||||
|
||||
# 获取完整数据
|
||||
if memory_manager and memory_manager._initialized:
|
||||
full_data = _format_graph_data_from_manager(memory_manager)
|
||||
else:
|
||||
full_data = load_graph_data_from_file()
|
||||
|
||||
nodes = full_data.get("nodes", [])
|
||||
edges = full_data.get("edges", [])
|
||||
|
||||
# 如果节点数小于阈值,直接返回
|
||||
if len(nodes) <= max_nodes:
|
||||
return JSONResponse(content={"success": True, "data": {
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"stats": full_data.get("stats", {}),
|
||||
"clustered": False,
|
||||
}})
|
||||
|
||||
# 执行聚类
|
||||
clustered_data = _cluster_graph_data(nodes, edges, max_nodes, cluster_threshold)
|
||||
|
||||
return JSONResponse(content={"success": True, "data": {
|
||||
**clustered_data,
|
||||
"stats": {
|
||||
"original_nodes": len(nodes),
|
||||
"original_edges": len(edges),
|
||||
"clustered_nodes": len(clustered_data["nodes"]),
|
||||
"clustered_edges": len(clustered_data["edges"]),
|
||||
"total_memories": full_data.get("stats", {}).get("total_memories", 0),
|
||||
},
|
||||
"clustered": True,
|
||||
}})
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cluster_threshold: int) -> Dict:
|
||||
"""简单的图聚类算法:按类型和连接度聚类"""
|
||||
# 构建邻接表
|
||||
adjacency = defaultdict(set)
|
||||
for edge in edges:
|
||||
adjacency[edge["from"]].add(edge["to"])
|
||||
adjacency[edge["to"]].add(edge["from"])
|
||||
|
||||
# 按类型分组
|
||||
type_groups = defaultdict(list)
|
||||
for node in nodes:
|
||||
type_groups[node.get("group", "UNKNOWN")].append(node)
|
||||
|
||||
clustered_nodes = []
|
||||
clustered_edges = []
|
||||
node_mapping = {} # 原始节点ID -> 聚类节点ID
|
||||
|
||||
for node_type, type_nodes in type_groups.items():
|
||||
# 如果该类型节点少于阈值,直接保留
|
||||
if len(type_nodes) <= cluster_threshold:
|
||||
for node in type_nodes:
|
||||
clustered_nodes.append(node)
|
||||
node_mapping[node["id"]] = node["id"]
|
||||
else:
|
||||
# 按连接度排序,保留最重要的节点
|
||||
node_importance = []
|
||||
for node in type_nodes:
|
||||
importance = len(adjacency[node["id"]])
|
||||
node_importance.append((node, importance))
|
||||
|
||||
node_importance.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 保留前N个重要节点
|
||||
keep_count = min(len(type_nodes), max_nodes // len(type_groups))
|
||||
for node, importance in node_importance[:keep_count]:
|
||||
clustered_nodes.append(node)
|
||||
node_mapping[node["id"]] = node["id"]
|
||||
|
||||
# 其余节点聚合为一个超级节点
|
||||
if len(node_importance) > keep_count:
|
||||
clustered_node_ids = [n["id"] for n, _ in node_importance[keep_count:]]
|
||||
cluster_id = f"cluster_{node_type}_{len(clustered_nodes)}"
|
||||
cluster_label = f"{node_type} 集群 ({len(clustered_node_ids)}个节点)"
|
||||
|
||||
clustered_nodes.append({
|
||||
"id": cluster_id,
|
||||
"label": cluster_label,
|
||||
"group": node_type,
|
||||
"title": f"包含 {len(clustered_node_ids)} 个{node_type}节点",
|
||||
"is_cluster": True,
|
||||
"cluster_size": len(clustered_node_ids),
|
||||
"clustered_nodes": clustered_node_ids[:10], # 只保留前10个用于展示
|
||||
})
|
||||
|
||||
for node_id in clustered_node_ids:
|
||||
node_mapping[node_id] = cluster_id
|
||||
|
||||
# 重建边(去重)
|
||||
edge_set = set()
|
||||
for edge in edges:
|
||||
from_id = node_mapping.get(edge["from"])
|
||||
to_id = node_mapping.get(edge["to"])
|
||||
|
||||
if from_id and to_id and from_id != to_id:
|
||||
edge_key = tuple(sorted([from_id, to_id]))
|
||||
if edge_key not in edge_set:
|
||||
edge_set.add(edge_key)
|
||||
clustered_edges.append({
|
||||
"id": f"{from_id}_{to_id}",
|
||||
"from": from_id,
|
||||
"to": to_id,
|
||||
"label": edge.get("label", ""),
|
||||
"arrows": "to",
|
||||
})
|
||||
|
||||
return {
|
||||
"nodes": clustered_nodes,
|
||||
"edges": clustered_edges,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/api/files")
|
||||
async def list_files_api():
|
||||
"""列出所有可用的数据文件"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -185,12 +185,19 @@ class MemoryBuilder:
|
||||
logger.debug(f"复用已存在的主体节点: {existing.id}")
|
||||
return existing
|
||||
|
||||
# 为主体和值节点生成嵌入向量(用于人名/实体和重要描述检索)
|
||||
embedding = None
|
||||
if node_type in (NodeType.SUBJECT, NodeType.VALUE):
|
||||
# 只为有足够内容的节点生成嵌入(避免浪费)
|
||||
if len(content.strip()) >= 2:
|
||||
embedding = await self._generate_embedding(content)
|
||||
|
||||
# 创建新节点
|
||||
node = MemoryNode(
|
||||
id=self._generate_node_id(),
|
||||
content=content,
|
||||
node_type=node_type,
|
||||
embedding=None, # 主体和属性不需要嵌入
|
||||
embedding=embedding, # 主体、值需要嵌入,属性不需要
|
||||
metadata={"memory_ids": [memory_id]},
|
||||
)
|
||||
|
||||
|
||||
@@ -516,6 +516,22 @@ class MemoryTools:
|
||||
# 记录最高分数
|
||||
if mem_id not in memory_scores or similarity > memory_scores[mem_id]:
|
||||
memory_scores[mem_id] = similarity
|
||||
|
||||
# 🔥 详细日志:检查初始召回情况
|
||||
logger.info(
|
||||
f"初始向量搜索: 返回{len(similar_nodes)}个节点 → "
|
||||
f"提取{len(initial_memory_ids)}条记忆"
|
||||
)
|
||||
if len(initial_memory_ids) == 0:
|
||||
logger.warning(
|
||||
f"⚠️ 向量搜索未找到任何记忆!"
|
||||
f"可能原因:1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大"
|
||||
)
|
||||
# 输出相似节点的详细信息用于调试
|
||||
if similar_nodes:
|
||||
logger.debug(f"向量搜索返回的节点元数据样例: {similar_nodes[0][2] if len(similar_nodes) > 0 else 'None'}")
|
||||
elif len(initial_memory_ids) < 3:
|
||||
logger.warning(f"⚠️ 初始召回记忆数量较少({len(initial_memory_ids)}条),可能影响结果质量")
|
||||
|
||||
# 3. 图扩展(如果启用且有expand_depth)
|
||||
expanded_memory_scores = {}
|
||||
@@ -609,42 +625,37 @@ class MemoryTools:
|
||||
if dominant_node_type in ["ATTRIBUTE", "REFERENCE"] or memory_type == "FACT":
|
||||
# 事实性记忆(如文档地址、配置信息):语义相似度最重要
|
||||
weights = {
|
||||
"similarity": 0.65, # 语义相似度 65% ⬆️
|
||||
"importance": 0.20, # 重要性 20%
|
||||
"recency": 0.05, # 时效性 5% ⬇️(事实不随时间失效)
|
||||
"activation": 0.10 # 激活度 10% ⬇️(避免冷门信息被压制)
|
||||
"similarity": 0.70, # 语义相似度 70% ⬆️
|
||||
"importance": 0.25, # 重要性 25% ⬆️
|
||||
"recency": 0.05, # 时效性 5%(事实不随时间失效)
|
||||
}
|
||||
elif memory_type in ["CONVERSATION", "EPISODIC"] or dominant_node_type == "EVENT":
|
||||
# 对话/事件记忆:时效性和激活度更重要
|
||||
# 对话/事件记忆:时效性更重要
|
||||
weights = {
|
||||
"similarity": 0.45, # 语义相似度 45%
|
||||
"importance": 0.15, # 重要性 15%
|
||||
"recency": 0.20, # 时效性 20% ⬆️
|
||||
"activation": 0.20 # 激活度 20%
|
||||
"similarity": 0.55, # 语义相似度 55% ⬆️
|
||||
"importance": 0.20, # 重要性 20% ⬆️
|
||||
"recency": 0.25, # 时效性 25% ⬆️
|
||||
}
|
||||
elif dominant_node_type == "ENTITY" or memory_type == "SEMANTIC":
|
||||
# 实体/语义记忆:平衡各项
|
||||
weights = {
|
||||
"similarity": 0.50, # 语义相似度 50%
|
||||
"importance": 0.25, # 重要性 25%
|
||||
"similarity": 0.60, # 语义相似度 60% ⬆️
|
||||
"importance": 0.30, # 重要性 30% ⬆️
|
||||
"recency": 0.10, # 时效性 10%
|
||||
"activation": 0.15 # 激活度 15%
|
||||
}
|
||||
else:
|
||||
# 默认权重(保守策略,偏向语义)
|
||||
weights = {
|
||||
"similarity": 0.55, # 语义相似度 55%
|
||||
"importance": 0.20, # 重要性 20%
|
||||
"similarity": 0.65, # 语义相似度 65% ⬆️
|
||||
"importance": 0.25, # 重要性 25% ⬆️
|
||||
"recency": 0.10, # 时效性 10%
|
||||
"activation": 0.15 # 激活度 15%
|
||||
}
|
||||
|
||||
# 综合分数计算
|
||||
# 综合分数计算(🔥 移除激活度影响)
|
||||
final_score = (
|
||||
similarity_score * weights["similarity"] +
|
||||
importance_score * weights["importance"] +
|
||||
recency_score * weights["recency"] +
|
||||
activation_score * weights["activation"]
|
||||
recency_score * weights["recency"]
|
||||
)
|
||||
|
||||
# 🆕 节点类型加权:对REFERENCE/ATTRIBUTE节点额外加分(促进事实性信息召回)
|
||||
@@ -943,11 +954,16 @@ class MemoryTools:
|
||||
logger.warning("嵌入生成失败,跳过节点搜索")
|
||||
return []
|
||||
|
||||
# 向量搜索
|
||||
# 向量搜索(增加返回数量以提高召回率)
|
||||
similar_nodes = await self.vector_store.search_similar_nodes(
|
||||
query_embedding=query_embedding,
|
||||
limit=top_k * 2, # 多取一些,后续过滤
|
||||
limit=top_k * 5, # 🔥 从2倍提升到5倍,提高初始召回率
|
||||
min_similarity=0.0, # 不在这里过滤,交给后续评分
|
||||
)
|
||||
|
||||
logger.debug(f"单查询向量搜索: 查询='{query}', 返回节点数={len(similar_nodes)}")
|
||||
if similar_nodes:
|
||||
logger.debug(f"Top 3相似度: {[f'{sim:.3f}' for _, sim, _ in similar_nodes[:3]]}")
|
||||
|
||||
return similar_nodes
|
||||
|
||||
@@ -1003,11 +1019,13 @@ class MemoryTools:
|
||||
similar_nodes = await self.vector_store.search_with_multiple_queries(
|
||||
query_embeddings=query_embeddings,
|
||||
query_weights=query_weights,
|
||||
limit=top_k * 2, # 多取一些,后续过滤
|
||||
limit=top_k * 5, # 🔥 从2倍提升到5倍,提高初始召回率
|
||||
fusion_strategy="weighted_max",
|
||||
)
|
||||
|
||||
logger.info(f"多查询检索完成: {len(similar_nodes)} 个节点 (偏好类型: {prefer_node_types})")
|
||||
if similar_nodes:
|
||||
logger.debug(f"Top 5融合相似度: {[f'{sim:.3f}' for _, sim, _ in similar_nodes[:5]]}")
|
||||
|
||||
return similar_nodes, prefer_node_types
|
||||
|
||||
|
||||
Reference in New Issue
Block a user