feat:增强记忆节点的嵌入生成和日志记录- 在 MemoryBuilder 中为 SUBJECT 和 VALUE 节点类型添加了嵌入生成,确保仅为内容足够的节点创建嵌入。- 改进了 MemoryTools 的日志记录,在初始向量搜索期间提供详细见解,包括低召回情况的警告。- 调整了不同记忆类型的评分权重,以强调相似性和重要性,提高记忆检索的质量。- 将向量搜索限制从 2 倍提高到 5 倍,以改善初始召回率。- 引入了一个新脚本,用于为现有节点生成缺失的嵌入,支持批量处理并改进索引。
This commit is contained in:
@@ -7,9 +7,10 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
import orjson
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi import APIRouter, HTTPException, Request, Query
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
@@ -227,6 +228,242 @@ async def get_full_graph():
|
||||
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
@router.get("/api/graph/summary")
|
||||
async def get_graph_summary():
|
||||
"""获取图的摘要信息(仅统计数据,不包含节点和边)"""
|
||||
try:
|
||||
from src.memory_graph.manager_singleton import get_memory_manager
|
||||
|
||||
memory_manager = get_memory_manager()
|
||||
|
||||
if memory_manager and memory_manager._initialized:
|
||||
stats = memory_manager.get_statistics()
|
||||
return JSONResponse(content={"success": True, "data": {
|
||||
"stats": {
|
||||
"total_nodes": stats.get("total_nodes", 0),
|
||||
"total_edges": stats.get("total_edges", 0),
|
||||
"total_memories": stats.get("total_memories", 0),
|
||||
},
|
||||
"current_file": "memory_manager (实时数据)",
|
||||
}})
|
||||
else:
|
||||
data = load_graph_data_from_file()
|
||||
return JSONResponse(content={"success": True, "data": {
|
||||
"stats": data.get("stats", {}),
|
||||
"current_file": data.get("current_file", ""),
|
||||
}})
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
@router.get("/api/graph/paginated")
|
||||
async def get_paginated_graph(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(500, ge=100, le=2000, description="每页节点数"),
|
||||
min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"),
|
||||
node_types: Optional[str] = Query(None, description="节点类型过滤,逗号分隔"),
|
||||
):
|
||||
"""分页获取图数据,支持重要性过滤"""
|
||||
try:
|
||||
from src.memory_graph.manager_singleton import get_memory_manager
|
||||
|
||||
memory_manager = get_memory_manager()
|
||||
|
||||
# 获取完整数据
|
||||
if memory_manager and memory_manager._initialized:
|
||||
full_data = _format_graph_data_from_manager(memory_manager)
|
||||
else:
|
||||
full_data = load_graph_data_from_file()
|
||||
|
||||
nodes = full_data.get("nodes", [])
|
||||
edges = full_data.get("edges", [])
|
||||
|
||||
# 过滤节点类型
|
||||
if node_types:
|
||||
allowed_types = set(node_types.split(","))
|
||||
nodes = [n for n in nodes if n.get("group") in allowed_types]
|
||||
|
||||
# 按重要性排序(如果有importance字段)
|
||||
nodes_with_importance = []
|
||||
for node in nodes:
|
||||
# 计算节点重要性(连接的边数)
|
||||
edge_count = sum(1 for e in edges if e.get("from") == node["id"] or e.get("to") == node["id"])
|
||||
importance = edge_count / max(len(edges), 1)
|
||||
if importance >= min_importance:
|
||||
node["importance"] = importance
|
||||
nodes_with_importance.append(node)
|
||||
|
||||
# 按重要性降序排序
|
||||
nodes_with_importance.sort(key=lambda x: x.get("importance", 0), reverse=True)
|
||||
|
||||
# 分页
|
||||
total_nodes = len(nodes_with_importance)
|
||||
total_pages = (total_nodes + page_size - 1) // page_size
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_nodes)
|
||||
|
||||
paginated_nodes = nodes_with_importance[start_idx:end_idx]
|
||||
node_ids = set(n["id"] for n in paginated_nodes)
|
||||
|
||||
# 只保留连接分页节点的边
|
||||
paginated_edges = [
|
||||
e for e in edges
|
||||
if e.get("from") in node_ids and e.get("to") in node_ids
|
||||
]
|
||||
|
||||
return JSONResponse(content={"success": True, "data": {
|
||||
"nodes": paginated_nodes,
|
||||
"edges": paginated_edges,
|
||||
"pagination": {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total_nodes": total_nodes,
|
||||
"total_pages": total_pages,
|
||||
"has_next": page < total_pages,
|
||||
"has_prev": page > 1,
|
||||
},
|
||||
"stats": {
|
||||
"total_nodes": total_nodes,
|
||||
"total_edges": len(paginated_edges),
|
||||
"total_memories": full_data.get("stats", {}).get("total_memories", 0),
|
||||
},
|
||||
}})
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
@router.get("/api/graph/clustered")
|
||||
async def get_clustered_graph(
|
||||
max_nodes: int = Query(300, ge=50, le=1000, description="最大节点数"),
|
||||
cluster_threshold: int = Query(10, ge=2, le=50, description="聚类阈值")
|
||||
):
|
||||
"""获取聚类简化后的图数据"""
|
||||
try:
|
||||
from src.memory_graph.manager_singleton import get_memory_manager
|
||||
|
||||
memory_manager = get_memory_manager()
|
||||
|
||||
# 获取完整数据
|
||||
if memory_manager and memory_manager._initialized:
|
||||
full_data = _format_graph_data_from_manager(memory_manager)
|
||||
else:
|
||||
full_data = load_graph_data_from_file()
|
||||
|
||||
nodes = full_data.get("nodes", [])
|
||||
edges = full_data.get("edges", [])
|
||||
|
||||
# 如果节点数小于阈值,直接返回
|
||||
if len(nodes) <= max_nodes:
|
||||
return JSONResponse(content={"success": True, "data": {
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"stats": full_data.get("stats", {}),
|
||||
"clustered": False,
|
||||
}})
|
||||
|
||||
# 执行聚类
|
||||
clustered_data = _cluster_graph_data(nodes, edges, max_nodes, cluster_threshold)
|
||||
|
||||
return JSONResponse(content={"success": True, "data": {
|
||||
**clustered_data,
|
||||
"stats": {
|
||||
"original_nodes": len(nodes),
|
||||
"original_edges": len(edges),
|
||||
"clustered_nodes": len(clustered_data["nodes"]),
|
||||
"clustered_edges": len(clustered_data["edges"]),
|
||||
"total_memories": full_data.get("stats", {}).get("total_memories", 0),
|
||||
},
|
||||
"clustered": True,
|
||||
}})
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cluster_threshold: int) -> Dict:
|
||||
"""简单的图聚类算法:按类型和连接度聚类"""
|
||||
# 构建邻接表
|
||||
adjacency = defaultdict(set)
|
||||
for edge in edges:
|
||||
adjacency[edge["from"]].add(edge["to"])
|
||||
adjacency[edge["to"]].add(edge["from"])
|
||||
|
||||
# 按类型分组
|
||||
type_groups = defaultdict(list)
|
||||
for node in nodes:
|
||||
type_groups[node.get("group", "UNKNOWN")].append(node)
|
||||
|
||||
clustered_nodes = []
|
||||
clustered_edges = []
|
||||
node_mapping = {} # 原始节点ID -> 聚类节点ID
|
||||
|
||||
for node_type, type_nodes in type_groups.items():
|
||||
# 如果该类型节点少于阈值,直接保留
|
||||
if len(type_nodes) <= cluster_threshold:
|
||||
for node in type_nodes:
|
||||
clustered_nodes.append(node)
|
||||
node_mapping[node["id"]] = node["id"]
|
||||
else:
|
||||
# 按连接度排序,保留最重要的节点
|
||||
node_importance = []
|
||||
for node in type_nodes:
|
||||
importance = len(adjacency[node["id"]])
|
||||
node_importance.append((node, importance))
|
||||
|
||||
node_importance.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 保留前N个重要节点
|
||||
keep_count = min(len(type_nodes), max_nodes // len(type_groups))
|
||||
for node, importance in node_importance[:keep_count]:
|
||||
clustered_nodes.append(node)
|
||||
node_mapping[node["id"]] = node["id"]
|
||||
|
||||
# 其余节点聚合为一个超级节点
|
||||
if len(node_importance) > keep_count:
|
||||
clustered_node_ids = [n["id"] for n, _ in node_importance[keep_count:]]
|
||||
cluster_id = f"cluster_{node_type}_{len(clustered_nodes)}"
|
||||
cluster_label = f"{node_type} 集群 ({len(clustered_node_ids)}个节点)"
|
||||
|
||||
clustered_nodes.append({
|
||||
"id": cluster_id,
|
||||
"label": cluster_label,
|
||||
"group": node_type,
|
||||
"title": f"包含 {len(clustered_node_ids)} 个{node_type}节点",
|
||||
"is_cluster": True,
|
||||
"cluster_size": len(clustered_node_ids),
|
||||
"clustered_nodes": clustered_node_ids[:10], # 只保留前10个用于展示
|
||||
})
|
||||
|
||||
for node_id in clustered_node_ids:
|
||||
node_mapping[node_id] = cluster_id
|
||||
|
||||
# 重建边(去重)
|
||||
edge_set = set()
|
||||
for edge in edges:
|
||||
from_id = node_mapping.get(edge["from"])
|
||||
to_id = node_mapping.get(edge["to"])
|
||||
|
||||
if from_id and to_id and from_id != to_id:
|
||||
edge_key = tuple(sorted([from_id, to_id]))
|
||||
if edge_key not in edge_set:
|
||||
edge_set.add(edge_key)
|
||||
clustered_edges.append({
|
||||
"id": f"{from_id}_{to_id}",
|
||||
"from": from_id,
|
||||
"to": to_id,
|
||||
"label": edge.get("label", ""),
|
||||
"arrows": "to",
|
||||
})
|
||||
|
||||
return {
|
||||
"nodes": clustered_nodes,
|
||||
"edges": clustered_edges,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/api/files")
|
||||
async def list_files_api():
|
||||
"""列出所有可用的数据文件"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user