feat:增强记忆节点的嵌入生成和日志记录- 在 MemoryBuilder 中为 SUBJECT 和 VALUE 节点类型添加了嵌入生成,确保仅为内容足够的节点创建嵌入。- 改进了 MemoryTools 的日志记录,在初始向量搜索期间提供详细见解,包括低召回情况的警告。- 调整了不同记忆类型的评分权重,以强调相似性和重要性,提高记忆检索的质量。- 将向量搜索限制从 2 倍提高到 5 倍,以改善初始召回率。- 引入了一个新脚本,用于为现有节点生成缺失的嵌入,支持批量处理并改进索引。

This commit is contained in:
Windpicker-owo
2025-11-11 19:25:03 +08:00
parent 28c0f764ea
commit e2236f5bc1
5 changed files with 1296 additions and 189 deletions

View File

@@ -7,9 +7,10 @@
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from collections import defaultdict
import orjson
from fastapi import APIRouter, HTTPException, Request
from fastapi import APIRouter, HTTPException, Request, Query
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
@@ -227,6 +228,242 @@ async def get_full_graph():
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.get("/api/graph/summary")
async def get_graph_summary():
"""获取图的摘要信息(仅统计数据,不包含节点和边)"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
memory_manager = get_memory_manager()
if memory_manager and memory_manager._initialized:
stats = memory_manager.get_statistics()
return JSONResponse(content={"success": True, "data": {
"stats": {
"total_nodes": stats.get("total_nodes", 0),
"total_edges": stats.get("total_edges", 0),
"total_memories": stats.get("total_memories", 0),
},
"current_file": "memory_manager (实时数据)",
}})
else:
data = load_graph_data_from_file()
return JSONResponse(content={"success": True, "data": {
"stats": data.get("stats", {}),
"current_file": data.get("current_file", ""),
}})
except Exception as e:
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.get("/api/graph/paginated")
async def get_paginated_graph(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(500, ge=100, le=2000, description="每页节点数"),
min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"),
node_types: Optional[str] = Query(None, description="节点类型过滤,逗号分隔"),
):
"""分页获取图数据,支持重要性过滤"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
memory_manager = get_memory_manager()
# 获取完整数据
if memory_manager and memory_manager._initialized:
full_data = _format_graph_data_from_manager(memory_manager)
else:
full_data = load_graph_data_from_file()
nodes = full_data.get("nodes", [])
edges = full_data.get("edges", [])
# 过滤节点类型
if node_types:
allowed_types = set(node_types.split(","))
nodes = [n for n in nodes if n.get("group") in allowed_types]
# 按重要性排序如果有importance字段
nodes_with_importance = []
for node in nodes:
# 计算节点重要性(连接的边数)
edge_count = sum(1 for e in edges if e.get("from") == node["id"] or e.get("to") == node["id"])
importance = edge_count / max(len(edges), 1)
if importance >= min_importance:
node["importance"] = importance
nodes_with_importance.append(node)
# 按重要性降序排序
nodes_with_importance.sort(key=lambda x: x.get("importance", 0), reverse=True)
# 分页
total_nodes = len(nodes_with_importance)
total_pages = (total_nodes + page_size - 1) // page_size
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_nodes)
paginated_nodes = nodes_with_importance[start_idx:end_idx]
node_ids = set(n["id"] for n in paginated_nodes)
# 只保留连接分页节点的边
paginated_edges = [
e for e in edges
if e.get("from") in node_ids and e.get("to") in node_ids
]
return JSONResponse(content={"success": True, "data": {
"nodes": paginated_nodes,
"edges": paginated_edges,
"pagination": {
"page": page,
"page_size": page_size,
"total_nodes": total_nodes,
"total_pages": total_pages,
"has_next": page < total_pages,
"has_prev": page > 1,
},
"stats": {
"total_nodes": total_nodes,
"total_edges": len(paginated_edges),
"total_memories": full_data.get("stats", {}).get("total_memories", 0),
},
}})
except Exception as e:
import traceback
traceback.print_exc()
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.get("/api/graph/clustered")
async def get_clustered_graph(
max_nodes: int = Query(300, ge=50, le=1000, description="最大节点数"),
cluster_threshold: int = Query(10, ge=2, le=50, description="聚类阈值")
):
"""获取聚类简化后的图数据"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
memory_manager = get_memory_manager()
# 获取完整数据
if memory_manager and memory_manager._initialized:
full_data = _format_graph_data_from_manager(memory_manager)
else:
full_data = load_graph_data_from_file()
nodes = full_data.get("nodes", [])
edges = full_data.get("edges", [])
# 如果节点数小于阈值,直接返回
if len(nodes) <= max_nodes:
return JSONResponse(content={"success": True, "data": {
"nodes": nodes,
"edges": edges,
"stats": full_data.get("stats", {}),
"clustered": False,
}})
# 执行聚类
clustered_data = _cluster_graph_data(nodes, edges, max_nodes, cluster_threshold)
return JSONResponse(content={"success": True, "data": {
**clustered_data,
"stats": {
"original_nodes": len(nodes),
"original_edges": len(edges),
"clustered_nodes": len(clustered_data["nodes"]),
"clustered_edges": len(clustered_data["edges"]),
"total_memories": full_data.get("stats", {}).get("total_memories", 0),
},
"clustered": True,
}})
except Exception as e:
import traceback
traceback.print_exc()
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cluster_threshold: int) -> Dict:
"""简单的图聚类算法:按类型和连接度聚类"""
# 构建邻接表
adjacency = defaultdict(set)
for edge in edges:
adjacency[edge["from"]].add(edge["to"])
adjacency[edge["to"]].add(edge["from"])
# 按类型分组
type_groups = defaultdict(list)
for node in nodes:
type_groups[node.get("group", "UNKNOWN")].append(node)
clustered_nodes = []
clustered_edges = []
node_mapping = {} # 原始节点ID -> 聚类节点ID
for node_type, type_nodes in type_groups.items():
# 如果该类型节点少于阈值,直接保留
if len(type_nodes) <= cluster_threshold:
for node in type_nodes:
clustered_nodes.append(node)
node_mapping[node["id"]] = node["id"]
else:
# 按连接度排序,保留最重要的节点
node_importance = []
for node in type_nodes:
importance = len(adjacency[node["id"]])
node_importance.append((node, importance))
node_importance.sort(key=lambda x: x[1], reverse=True)
# 保留前N个重要节点
keep_count = min(len(type_nodes), max_nodes // len(type_groups))
for node, importance in node_importance[:keep_count]:
clustered_nodes.append(node)
node_mapping[node["id"]] = node["id"]
# 其余节点聚合为一个超级节点
if len(node_importance) > keep_count:
clustered_node_ids = [n["id"] for n, _ in node_importance[keep_count:]]
cluster_id = f"cluster_{node_type}_{len(clustered_nodes)}"
cluster_label = f"{node_type} 集群 ({len(clustered_node_ids)}个节点)"
clustered_nodes.append({
"id": cluster_id,
"label": cluster_label,
"group": node_type,
"title": f"包含 {len(clustered_node_ids)}{node_type}节点",
"is_cluster": True,
"cluster_size": len(clustered_node_ids),
"clustered_nodes": clustered_node_ids[:10], # 只保留前10个用于展示
})
for node_id in clustered_node_ids:
node_mapping[node_id] = cluster_id
# 重建边(去重)
edge_set = set()
for edge in edges:
from_id = node_mapping.get(edge["from"])
to_id = node_mapping.get(edge["to"])
if from_id and to_id and from_id != to_id:
edge_key = tuple(sorted([from_id, to_id]))
if edge_key not in edge_set:
edge_set.add(edge_key)
clustered_edges.append({
"id": f"{from_id}_{to_id}",
"from": from_id,
"to": to_id,
"label": edge.get("label", ""),
"arrows": "to",
})
return {
"nodes": clustered_nodes,
"edges": clustered_edges,
}
@router.get("/api/files")
async def list_files_api():
"""列出所有可用的数据文件"""

File diff suppressed because it is too large Load Diff