Files
Mofox-Core/src/api/memory_visualizer_router.py
tt-P607 659a8e0d78 refactor(api, chat): 改进异步处理并解决并发问题
内存可视化器 API 端点之前在异步路由中执行同步的阻塞操作(文件 I/O、数据处理)。在处理大型图文件时,这可能导致服务器冻结。现在,这些任务已被移至 ThreadPoolExecutor,从而使 API 非阻塞并显著提高响应速度。

在聊天消息管理器中,竞争条件可能导致消息处理重叠或中断后数据流停滞。此提交引入了:
- 并发锁(`is_chatter_processing`)以防止流循环同时运行多个 chatter 实例。
- 故障保护机制,在中断时重置处理状态,确保数据流能够恢复并正确继续。
2025-12-02 14:40:58 +08:00

662 lines
23 KiB
Python

"""
记忆图可视化 - API 路由模块
提供 Web API 用于可视化记忆图数据
"""
import asyncio
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import Any
import orjson
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
# 调整项目根目录的计算方式
project_root = Path(__file__).parent.parent.parent
data_dir = project_root / "data" / "memory_graph"
# 缓存
graph_data_cache = None
current_data_file = None
# 线程池用于异步文件读取
_executor = ThreadPoolExecutor(max_workers=2)
# FastAPI 路由
router = APIRouter()
# Jinja2 模板引擎
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
def find_available_data_files() -> list[Path]:
"""查找所有可用的记忆图数据文件"""
files = []
if not data_dir.exists():
return files
possible_files = ["graph_store.json", "memory_graph.json", "graph_data.json"]
for filename in possible_files:
file_path = data_dir / filename
if file_path.exists():
files.append(file_path)
for pattern in ["graph_store_*.json", "memory_graph_*.json", "graph_data_*.json"]:
for backup_file in data_dir.glob(pattern):
if backup_file not in files:
files.append(backup_file)
backups_dir = data_dir / "backups"
if backups_dir.exists():
for backup_file in backups_dir.glob("**/*.json"):
if backup_file not in files:
files.append(backup_file)
backup_dir = data_dir.parent / "backup"
if backup_dir.exists():
for pattern in ["**/graph_*.json", "**/memory_*.json"]:
for backup_file in backup_dir.glob(pattern):
if backup_file not in files:
files.append(backup_file)
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
async def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]:
"""从磁盘加载图数据(异步,不阻塞主线程)"""
global graph_data_cache, current_data_file
if file_path and file_path != current_data_file:
graph_data_cache = None
current_data_file = file_path
if graph_data_cache:
return graph_data_cache
try:
graph_file = current_data_file
if not graph_file:
available_files = find_available_data_files()
if not available_files:
return {"error": "未找到数据文件", "nodes": [], "edges": [], "stats": {}}
graph_file = available_files[0]
current_data_file = graph_file
if not graph_file.exists():
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
# 在线程池中异步读取文件,避免阻塞主事件循环
loop = asyncio.get_event_loop()
data = await loop.run_in_executor(_executor, _sync_load_json_file, graph_file)
nodes = data.get("nodes", [])
edges = data.get("edges", [])
metadata = data.get("metadata", {})
# 在线程池中处理数据转换
processed = await loop.run_in_executor(
_executor, _process_graph_data, nodes, edges, metadata, graph_file
)
graph_data_cache = processed
return graph_data_cache
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"加载图数据失败: {e}")
def _sync_load_json_file(file_path: Path) -> dict:
"""同步加载 JSON 文件(在线程池中执行)"""
with open(file_path, encoding="utf-8") as f:
return orjson.loads(f.read())
def _process_graph_data(nodes: list, edges: list, metadata: dict, graph_file: Path) -> dict:
"""处理图数据(在线程池中执行)"""
nodes_dict = {
node["id"]: {
**node,
"label": node.get("content", ""),
"group": node.get("node_type", ""),
"title": f"{node.get('node_type', '')}: {node.get('content', '')}",
}
for node in nodes
if node.get("id")
}
edges_list = []
seen_edge_ids = set()
for edge in edges:
edge_id = edge.get("id")
if edge_id and edge_id not in seen_edge_ids:
edges_list.append(
{
**edge,
"from": edge.get("source", edge.get("source_id")),
"to": edge.get("target", edge.get("target_id")),
"label": edge.get("relation", ""),
"arrows": "to",
}
)
seen_edge_ids.add(edge_id)
stats = metadata.get("statistics", {})
total_memories = stats.get("total_memories", 0)
return {
"nodes": list(nodes_dict.values()),
"edges": edges_list,
"memories": [],
"stats": {
"total_nodes": len(nodes_dict),
"total_edges": len(edges_list),
"total_memories": total_memories,
},
"current_file": str(graph_file),
"file_size": graph_file.stat().st_size,
"file_modified": datetime.fromtimestamp(graph_file.stat().st_mtime).isoformat(),
}
@router.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""主页面"""
return templates.TemplateResponse("visualizer.html", {"request": request})
def _format_graph_data_from_manager(memory_manager) -> dict[str, Any]:
"""从 MemoryManager 提取并格式化图数据(同步版本,需在线程池中调用)"""
if not memory_manager.graph_store:
return {"nodes": [], "edges": [], "memories": [], "stats": {}}
all_memories = memory_manager.graph_store.get_all_memories()
nodes_dict = {}
edges_dict = {}
memory_info = []
for memory in all_memories:
memory_info.append(
{
"id": memory.id,
"type": memory.memory_type.value,
"importance": memory.importance,
"text": memory.to_text(),
}
)
for node in memory.nodes:
if node.id not in nodes_dict:
nodes_dict[node.id] = {
"id": node.id,
"label": node.content,
"type": node.node_type.value,
"group": node.node_type.name,
"title": f"{node.node_type.value}: {node.content}",
}
for edge in memory.edges:
if edge.id not in edges_dict:
edges_dict[edge.id] = {
"id": edge.id,
"from": edge.source_id,
"to": edge.target_id,
"label": edge.relation,
"arrows": "to",
"memory_id": memory.id,
}
edges_list = list(edges_dict.values())
stats = memory_manager.get_statistics()
return {
"nodes": list(nodes_dict.values()),
"edges": edges_list,
"memories": memory_info,
"stats": {
"total_nodes": stats.get("total_nodes", 0),
"total_edges": stats.get("total_edges", 0),
"total_memories": stats.get("total_memories", 0),
},
"current_file": "memory_manager (实时数据)",
}
@router.get("/api/graph/full")
async def get_full_graph():
"""获取完整记忆图数据"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
memory_manager = get_memory_manager()
data = {}
if memory_manager and memory_manager._initialized:
# 在线程池中执行,避免阻塞主事件循环
loop = asyncio.get_event_loop()
data = await loop.run_in_executor(_executor, _format_graph_data_from_manager, memory_manager)
else:
# 如果内存管理器不可用,则从文件加载
data = await load_graph_data_from_file()
return JSONResponse(content={"success": True, "data": data})
except Exception as e:
import traceback
traceback.print_exc()
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.get("/api/graph/summary")
async def get_graph_summary():
"""获取图的摘要信息(仅统计数据,不包含节点和边)"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
memory_manager = get_memory_manager()
if memory_manager and memory_manager._initialized:
stats = memory_manager.get_statistics()
return JSONResponse(content={"success": True, "data": {
"stats": {
"total_nodes": stats.get("total_nodes", 0),
"total_edges": stats.get("total_edges", 0),
"total_memories": stats.get("total_memories", 0),
},
"current_file": "memory_manager (实时数据)",
}})
else:
data = await load_graph_data_from_file()
return JSONResponse(content={"success": True, "data": {
"stats": data.get("stats", {}),
"current_file": data.get("current_file", ""),
}})
except Exception as e:
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.get("/api/graph/paginated")
async def get_paginated_graph(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(500, ge=100, le=2000, description="每页节点数"),
min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"),
node_types: str | None = Query(None, description="节点类型过滤,逗号分隔"),
):
"""分页获取图数据,支持重要性过滤"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
memory_manager = get_memory_manager()
# 获取完整数据(已经是异步的)
if memory_manager and memory_manager._initialized:
loop = asyncio.get_event_loop()
full_data = await loop.run_in_executor(_executor, _format_graph_data_from_manager, memory_manager)
else:
full_data = await load_graph_data_from_file()
# 在线程池中处理分页逻辑
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
_executor,
_process_pagination,
full_data, page, page_size, min_importance, node_types
)
return JSONResponse(content={"success": True, "data": result})
except Exception as e:
import traceback
traceback.print_exc()
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
def _process_pagination(full_data: dict, page: int, page_size: int, min_importance: float, node_types: str | None) -> dict:
"""处理分页逻辑(在线程池中执行)"""
nodes = full_data.get("nodes", [])
edges = full_data.get("edges", [])
# 过滤节点类型
if node_types:
allowed_types = set(node_types.split(","))
nodes = [n for n in nodes if n.get("group") in allowed_types]
# 构建边的索引以加速查找
edge_count_map = {}
for e in edges:
from_id = e.get("from")
to_id = e.get("to")
edge_count_map[from_id] = edge_count_map.get(from_id, 0) + 1
edge_count_map[to_id] = edge_count_map.get(to_id, 0) + 1
# 按重要性排序
nodes_with_importance = []
total_edges = max(len(edges), 1)
for node in nodes:
edge_count = edge_count_map.get(node["id"], 0)
importance = edge_count / total_edges
if importance >= min_importance:
node["importance"] = importance
nodes_with_importance.append(node)
# 按重要性降序排序
nodes_with_importance.sort(key=lambda x: x.get("importance", 0), reverse=True)
# 分页
total_nodes = len(nodes_with_importance)
total_pages = (total_nodes + page_size - 1) // page_size
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_nodes)
paginated_nodes = nodes_with_importance[start_idx:end_idx]
node_ids = set(n["id"] for n in paginated_nodes)
# 只保留连接分页节点的边
paginated_edges = [
e for e in edges
if e.get("from") in node_ids and e.get("to") in node_ids
]
return {
"nodes": paginated_nodes,
"edges": paginated_edges,
"pagination": {
"page": page,
"page_size": page_size,
"total_nodes": total_nodes,
"total_pages": total_pages,
"has_next": page < total_pages,
"has_prev": page > 1,
},
"stats": {
"total_nodes": total_nodes,
"total_edges": len(paginated_edges),
"total_memories": full_data.get("stats", {}).get("total_memories", 0),
},
}
@router.get("/api/graph/clustered")
async def get_clustered_graph(
max_nodes: int = Query(300, ge=50, le=1000, description="最大节点数"),
cluster_threshold: int = Query(10, ge=2, le=50, description="聚类阈值")
):
"""获取聚类简化后的图数据"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
memory_manager = get_memory_manager()
# 获取完整数据(异步)
if memory_manager and memory_manager._initialized:
loop = asyncio.get_event_loop()
full_data = await loop.run_in_executor(_executor, _format_graph_data_from_manager, memory_manager)
else:
full_data = await load_graph_data_from_file()
nodes = full_data.get("nodes", [])
edges = full_data.get("edges", [])
# 如果节点数小于阈值,直接返回
if len(nodes) <= max_nodes:
return JSONResponse(content={"success": True, "data": {
"nodes": nodes,
"edges": edges,
"stats": full_data.get("stats", {}),
"clustered": False,
}})
# 在线程池中执行聚类
loop = asyncio.get_event_loop()
clustered_data = await loop.run_in_executor(
_executor, _cluster_graph_data, nodes, edges, max_nodes, cluster_threshold
)
return JSONResponse(content={"success": True, "data": {
**clustered_data,
"stats": {
"original_nodes": len(nodes),
"original_edges": len(edges),
"clustered_nodes": len(clustered_data["nodes"]),
"clustered_edges": len(clustered_data["edges"]),
"total_memories": full_data.get("stats", {}).get("total_memories", 0),
},
"clustered": True,
}})
except Exception as e:
import traceback
traceback.print_exc()
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
def _cluster_graph_data(nodes: list[dict], edges: list[dict], max_nodes: int, cluster_threshold: int) -> dict:
"""简单的图聚类算法:按类型和连接度聚类"""
# 构建邻接表
adjacency = defaultdict(set)
for edge in edges:
adjacency[edge["from"]].add(edge["to"])
adjacency[edge["to"]].add(edge["from"])
# 按类型分组
type_groups = defaultdict(list)
for node in nodes:
type_groups[node.get("group", "UNKNOWN")].append(node)
clustered_nodes = []
clustered_edges = []
node_mapping = {} # 原始节点ID -> 聚类节点ID
for node_type, type_nodes in type_groups.items():
# 如果该类型节点少于阈值,直接保留
if len(type_nodes) <= cluster_threshold:
for node in type_nodes:
clustered_nodes.append(node)
node_mapping[node["id"]] = node["id"]
else:
# 按连接度排序,保留最重要的节点
node_importance = []
for node in type_nodes:
importance = len(adjacency[node["id"]])
node_importance.append((node, importance))
node_importance.sort(key=lambda x: x[1], reverse=True)
# 保留前N个重要节点
keep_count = min(len(type_nodes), max_nodes // len(type_groups))
for node, importance in node_importance[:keep_count]:
clustered_nodes.append(node)
node_mapping[node["id"]] = node["id"]
# 其余节点聚合为一个超级节点
if len(node_importance) > keep_count:
clustered_node_ids = [n["id"] for n, _ in node_importance[keep_count:]]
cluster_id = f"cluster_{node_type}_{len(clustered_nodes)}"
cluster_label = f"{node_type} 集群 ({len(clustered_node_ids)}个节点)"
clustered_nodes.append({
"id": cluster_id,
"label": cluster_label,
"group": node_type,
"title": f"包含 {len(clustered_node_ids)}{node_type}节点",
"is_cluster": True,
"cluster_size": len(clustered_node_ids),
"clustered_nodes": clustered_node_ids[:10], # 只保留前10个用于展示
})
for node_id in clustered_node_ids:
node_mapping[node_id] = cluster_id
# 重建边(去重)
edge_set = set()
for edge in edges:
from_id = node_mapping.get(edge["from"])
to_id = node_mapping.get(edge["to"])
if from_id and to_id and from_id != to_id:
edge_key = tuple(sorted([from_id, to_id]))
if edge_key not in edge_set:
edge_set.add(edge_key)
clustered_edges.append({
"id": f"{from_id}_{to_id}",
"from": from_id,
"to": to_id,
"label": edge.get("label", ""),
"arrows": "to",
})
return {
"nodes": clustered_nodes,
"edges": clustered_edges,
}
@router.get("/api/files")
async def list_files_api():
"""列出所有可用的数据文件"""
try:
files = find_available_data_files()
file_list = []
for f in files:
stat = f.stat()
file_list.append(
{
"path": str(f),
"name": f.name,
"size": stat.st_size,
"size_kb": round(stat.st_size / 1024, 2),
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
"modified_readable": datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S"),
"is_current": str(f) == str(current_data_file) if current_data_file else False,
}
)
return JSONResponse(
content={
"success": True,
"files": file_list,
"count": len(file_list),
"current_file": str(current_data_file) if current_data_file else None,
}
)
except Exception as e:
# 增加日志记录
# logger.error(f"列出数据文件失败: {e}")
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.post("/select_file")
async def select_file(request: Request):
"""选择要加载的数据文件"""
global graph_data_cache, current_data_file
try:
data = await request.json()
file_path = data.get("file_path")
if not file_path:
raise HTTPException(status_code=400, detail="未提供文件路径")
file_to_load = Path(file_path)
if not file_to_load.exists():
raise HTTPException(status_code=404, detail=f"文件不存在: {file_path}")
graph_data_cache = None
current_data_file = file_to_load
graph_data = await load_graph_data_from_file(file_to_load)
return JSONResponse(
content={
"success": True,
"message": f"已切换到文件: {file_to_load.name}",
"stats": graph_data.get("stats", {}),
}
)
except Exception as e:
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.get("/reload")
async def reload_data():
"""重新加载数据"""
global graph_data_cache
graph_data_cache = None
data = await load_graph_data_from_file()
return JSONResponse(content={"success": True, "message": "数据已重新加载", "stats": data.get("stats", {})})
@router.get("/api/search")
async def search_memories(q: str, limit: int = 50):
"""搜索记忆"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
memory_manager = get_memory_manager()
results = []
if memory_manager and memory_manager._initialized and memory_manager.graph_store:
# 从 memory_manager 搜索
all_memories = memory_manager.graph_store.get_all_memories()
for memory in all_memories:
if q.lower() in memory.to_text().lower():
node_ids = [node.id for node in memory.nodes]
results.append(
{
"id": memory.id,
"type": memory.memory_type.value,
"importance": memory.importance,
"text": memory.to_text(),
"node_ids": node_ids, # 返回关联的节点ID
}
)
else:
# 从文件加载的数据中搜索 (降级方案)
# 注意:此模式下无法直接获取关联节点,前端需要做兼容处理
data = await load_graph_data_from_file()
for memory in data.get("memories", []):
if q.lower() in memory.get("text", "").lower():
results.append(memory) # node_ids 可能不存在
return JSONResponse(
content={
"success": True,
"data": {
"results": results[:limit],
"count": len(results),
},
}
)
except Exception as e:
# 确保即使在异常情况下也返回 data 字段
return JSONResponse(
content={"success": False, "error": str(e), "data": {"results": [], "count": 0}},
status_code=500,
)
@router.get("/api/stats")
async def get_statistics():
"""获取统计信息"""
try:
data = await load_graph_data_from_file()
node_types = {}
memory_types = {}
for node in data["nodes"]:
node_type = node.get("type", "Unknown")
node_types[node_type] = node_types.get(node_type, 0) + 1
for memory in data.get("memories", []):
mem_type = memory.get("type", "Unknown")
memory_types[mem_type] = memory_types.get(mem_type, 0) + 1
stats = data.get("stats", {})
stats["node_types"] = node_types
stats["memory_types"] = memory_types
return JSONResponse(content={"success": True, "data": stats})
except Exception as e:
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)