refactor(api, chat): 改进异步处理并解决并发问题
内存可视化器 API 端点之前在异步路由中执行同步的阻塞操作(文件 I/O、数据处理)。在处理大型图文件时,这可能导致服务器冻结。现在,这些任务已被移至 ThreadPoolExecutor,从而使 API 非阻塞并显著提高响应速度。 在聊天消息管理器中,竞争条件可能导致消息处理重叠或中断后数据流停滞。此提交引入了: - 并发锁(`is_chatter_processing`)以防止流循环同时运行多个 chatter 实例。 - 故障保护机制,在中断时重置处理状态,确保数据流能够恢复并正确继续。
This commit is contained in:
@@ -4,7 +4,9 @@
|
||||
提供 Web API 用于可视化记忆图数据
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -23,6 +25,9 @@ data_dir = project_root / "data" / "memory_graph"
|
||||
graph_data_cache = None
|
||||
current_data_file = None
|
||||
|
||||
# 线程池用于异步文件读取
|
||||
_executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
# FastAPI 路由
|
||||
router = APIRouter()
|
||||
|
||||
@@ -64,7 +69,7 @@ def find_available_data_files() -> list[Path]:
|
||||
|
||||
|
||||
async def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]:
|
||||
"""从磁盘加载图数据"""
|
||||
"""从磁盘加载图数据(异步,不阻塞主线程)"""
|
||||
global graph_data_cache, current_data_file
|
||||
|
||||
if file_path and file_path != current_data_file:
|
||||
@@ -86,65 +91,81 @@ async def load_graph_data_from_file(file_path: Path | None = None) -> dict[str,
|
||||
if not graph_file.exists():
|
||||
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
|
||||
|
||||
with open(graph_file, encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
# 在线程池中异步读取文件,避免阻塞主事件循环
|
||||
loop = asyncio.get_event_loop()
|
||||
data = await loop.run_in_executor(_executor, _sync_load_json_file, graph_file)
|
||||
|
||||
nodes = data.get("nodes", [])
|
||||
edges = data.get("edges", [])
|
||||
metadata = data.get("metadata", {})
|
||||
|
||||
nodes_dict = {
|
||||
node["id"]: {
|
||||
**node,
|
||||
"label": node.get("content", ""),
|
||||
"group": node.get("node_type", ""),
|
||||
"title": f"{node.get('node_type', '')}: {node.get('content', '')}",
|
||||
}
|
||||
for node in nodes
|
||||
if node.get("id")
|
||||
}
|
||||
|
||||
edges_list = []
|
||||
seen_edge_ids = set()
|
||||
for edge in edges:
|
||||
edge_id = edge.get("id")
|
||||
if edge_id and edge_id not in seen_edge_ids:
|
||||
edges_list.append(
|
||||
{
|
||||
**edge,
|
||||
"from": edge.get("source", edge.get("source_id")),
|
||||
"to": edge.get("target", edge.get("target_id")),
|
||||
"label": edge.get("relation", ""),
|
||||
"arrows": "to",
|
||||
}
|
||||
)
|
||||
seen_edge_ids.add(edge_id)
|
||||
|
||||
stats = metadata.get("statistics", {})
|
||||
total_memories = stats.get("total_memories", 0)
|
||||
|
||||
graph_data_cache = {
|
||||
"nodes": list(nodes_dict.values()),
|
||||
"edges": edges_list,
|
||||
"memories": [],
|
||||
"stats": {
|
||||
"total_nodes": len(nodes_dict),
|
||||
"total_edges": len(edges_list),
|
||||
"total_memories": total_memories,
|
||||
},
|
||||
"current_file": str(graph_file),
|
||||
"file_size": graph_file.stat().st_size,
|
||||
"file_modified": datetime.fromtimestamp(graph_file.stat().st_mtime).isoformat(),
|
||||
}
|
||||
# 在线程池中处理数据转换
|
||||
processed = await loop.run_in_executor(
|
||||
_executor, _process_graph_data, nodes, edges, metadata, graph_file
|
||||
)
|
||||
|
||||
graph_data_cache = processed
|
||||
return graph_data_cache
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"加载图数据失败: {e}")
|
||||
|
||||
|
||||
def _sync_load_json_file(file_path: Path) -> dict:
|
||||
"""同步加载 JSON 文件(在线程池中执行)"""
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
return orjson.loads(f.read())
|
||||
|
||||
|
||||
def _process_graph_data(nodes: list, edges: list, metadata: dict, graph_file: Path) -> dict:
|
||||
"""处理图数据(在线程池中执行)"""
|
||||
nodes_dict = {
|
||||
node["id"]: {
|
||||
**node,
|
||||
"label": node.get("content", ""),
|
||||
"group": node.get("node_type", ""),
|
||||
"title": f"{node.get('node_type', '')}: {node.get('content', '')}",
|
||||
}
|
||||
for node in nodes
|
||||
if node.get("id")
|
||||
}
|
||||
|
||||
edges_list = []
|
||||
seen_edge_ids = set()
|
||||
for edge in edges:
|
||||
edge_id = edge.get("id")
|
||||
if edge_id and edge_id not in seen_edge_ids:
|
||||
edges_list.append(
|
||||
{
|
||||
**edge,
|
||||
"from": edge.get("source", edge.get("source_id")),
|
||||
"to": edge.get("target", edge.get("target_id")),
|
||||
"label": edge.get("relation", ""),
|
||||
"arrows": "to",
|
||||
}
|
||||
)
|
||||
seen_edge_ids.add(edge_id)
|
||||
|
||||
stats = metadata.get("statistics", {})
|
||||
total_memories = stats.get("total_memories", 0)
|
||||
|
||||
return {
|
||||
"nodes": list(nodes_dict.values()),
|
||||
"edges": edges_list,
|
||||
"memories": [],
|
||||
"stats": {
|
||||
"total_nodes": len(nodes_dict),
|
||||
"total_edges": len(edges_list),
|
||||
"total_memories": total_memories,
|
||||
},
|
||||
"current_file": str(graph_file),
|
||||
"file_size": graph_file.stat().st_size,
|
||||
"file_modified": datetime.fromtimestamp(graph_file.stat().st_mtime).isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/", response_class=HTMLResponse)
|
||||
async def index(request: Request):
|
||||
"""主页面"""
|
||||
@@ -152,7 +173,7 @@ async def index(request: Request):
|
||||
|
||||
|
||||
def _format_graph_data_from_manager(memory_manager) -> dict[str, Any]:
|
||||
"""从 MemoryManager 提取并格式化图数据"""
|
||||
"""从 MemoryManager 提取并格式化图数据(同步版本,需在线程池中调用)"""
|
||||
if not memory_manager.graph_store:
|
||||
return {"nodes": [], "edges": [], "memories": [], "stats": {}}
|
||||
|
||||
@@ -216,7 +237,9 @@ async def get_full_graph():
|
||||
|
||||
data = {}
|
||||
if memory_manager and memory_manager._initialized:
|
||||
data = _format_graph_data_from_manager(memory_manager)
|
||||
# 在线程池中执行,避免阻塞主事件循环
|
||||
loop = asyncio.get_event_loop()
|
||||
data = await loop.run_in_executor(_executor, _format_graph_data_from_manager, memory_manager)
|
||||
else:
|
||||
# 如果内存管理器不可用,则从文件加载
|
||||
data = await load_graph_data_from_file()
|
||||
@@ -270,71 +293,93 @@ async def get_paginated_graph(
|
||||
|
||||
memory_manager = get_memory_manager()
|
||||
|
||||
# 获取完整数据
|
||||
# 获取完整数据(已经是异步的)
|
||||
if memory_manager and memory_manager._initialized:
|
||||
full_data = _format_graph_data_from_manager(memory_manager)
|
||||
loop = asyncio.get_event_loop()
|
||||
full_data = await loop.run_in_executor(_executor, _format_graph_data_from_manager, memory_manager)
|
||||
else:
|
||||
full_data = await load_graph_data_from_file()
|
||||
|
||||
nodes = full_data.get("nodes", [])
|
||||
edges = full_data.get("edges", [])
|
||||
# 在线程池中处理分页逻辑
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
_executor,
|
||||
_process_pagination,
|
||||
full_data, page, page_size, min_importance, node_types
|
||||
)
|
||||
|
||||
# 过滤节点类型
|
||||
if node_types:
|
||||
allowed_types = set(node_types.split(","))
|
||||
nodes = [n for n in nodes if n.get("group") in allowed_types]
|
||||
|
||||
# 按重要性排序(如果有importance字段)
|
||||
nodes_with_importance = []
|
||||
for node in nodes:
|
||||
# 计算节点重要性(连接的边数)
|
||||
edge_count = sum(1 for e in edges if e.get("from") == node["id"] or e.get("to") == node["id"])
|
||||
importance = edge_count / max(len(edges), 1)
|
||||
if importance >= min_importance:
|
||||
node["importance"] = importance
|
||||
nodes_with_importance.append(node)
|
||||
|
||||
# 按重要性降序排序
|
||||
nodes_with_importance.sort(key=lambda x: x.get("importance", 0), reverse=True)
|
||||
|
||||
# 分页
|
||||
total_nodes = len(nodes_with_importance)
|
||||
total_pages = (total_nodes + page_size - 1) // page_size
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_nodes)
|
||||
|
||||
paginated_nodes = nodes_with_importance[start_idx:end_idx]
|
||||
node_ids = set(n["id"] for n in paginated_nodes)
|
||||
|
||||
# 只保留连接分页节点的边
|
||||
paginated_edges = [
|
||||
e for e in edges
|
||||
if e.get("from") in node_ids and e.get("to") in node_ids
|
||||
]
|
||||
|
||||
return JSONResponse(content={"success": True, "data": {
|
||||
"nodes": paginated_nodes,
|
||||
"edges": paginated_edges,
|
||||
"pagination": {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total_nodes": total_nodes,
|
||||
"total_pages": total_pages,
|
||||
"has_next": page < total_pages,
|
||||
"has_prev": page > 1,
|
||||
},
|
||||
"stats": {
|
||||
"total_nodes": total_nodes,
|
||||
"total_edges": len(paginated_edges),
|
||||
"total_memories": full_data.get("stats", {}).get("total_memories", 0),
|
||||
},
|
||||
}})
|
||||
return JSONResponse(content={"success": True, "data": result})
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
def _process_pagination(full_data: dict, page: int, page_size: int, min_importance: float, node_types: str | None) -> dict:
|
||||
"""处理分页逻辑(在线程池中执行)"""
|
||||
nodes = full_data.get("nodes", [])
|
||||
edges = full_data.get("edges", [])
|
||||
|
||||
# 过滤节点类型
|
||||
if node_types:
|
||||
allowed_types = set(node_types.split(","))
|
||||
nodes = [n for n in nodes if n.get("group") in allowed_types]
|
||||
|
||||
# 构建边的索引以加速查找
|
||||
edge_count_map = {}
|
||||
for e in edges:
|
||||
from_id = e.get("from")
|
||||
to_id = e.get("to")
|
||||
edge_count_map[from_id] = edge_count_map.get(from_id, 0) + 1
|
||||
edge_count_map[to_id] = edge_count_map.get(to_id, 0) + 1
|
||||
|
||||
# 按重要性排序
|
||||
nodes_with_importance = []
|
||||
total_edges = max(len(edges), 1)
|
||||
for node in nodes:
|
||||
edge_count = edge_count_map.get(node["id"], 0)
|
||||
importance = edge_count / total_edges
|
||||
if importance >= min_importance:
|
||||
node["importance"] = importance
|
||||
nodes_with_importance.append(node)
|
||||
|
||||
# 按重要性降序排序
|
||||
nodes_with_importance.sort(key=lambda x: x.get("importance", 0), reverse=True)
|
||||
|
||||
# 分页
|
||||
total_nodes = len(nodes_with_importance)
|
||||
total_pages = (total_nodes + page_size - 1) // page_size
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_nodes)
|
||||
|
||||
paginated_nodes = nodes_with_importance[start_idx:end_idx]
|
||||
node_ids = set(n["id"] for n in paginated_nodes)
|
||||
|
||||
# 只保留连接分页节点的边
|
||||
paginated_edges = [
|
||||
e for e in edges
|
||||
if e.get("from") in node_ids and e.get("to") in node_ids
|
||||
]
|
||||
|
||||
return {
|
||||
"nodes": paginated_nodes,
|
||||
"edges": paginated_edges,
|
||||
"pagination": {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total_nodes": total_nodes,
|
||||
"total_pages": total_pages,
|
||||
"has_next": page < total_pages,
|
||||
"has_prev": page > 1,
|
||||
},
|
||||
"stats": {
|
||||
"total_nodes": total_nodes,
|
||||
"total_edges": len(paginated_edges),
|
||||
"total_memories": full_data.get("stats", {}).get("total_memories", 0),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/api/graph/clustered")
|
||||
async def get_clustered_graph(
|
||||
max_nodes: int = Query(300, ge=50, le=1000, description="最大节点数"),
|
||||
@@ -346,9 +391,10 @@ async def get_clustered_graph(
|
||||
|
||||
memory_manager = get_memory_manager()
|
||||
|
||||
# 获取完整数据
|
||||
# 获取完整数据(异步)
|
||||
if memory_manager and memory_manager._initialized:
|
||||
full_data = _format_graph_data_from_manager(memory_manager)
|
||||
loop = asyncio.get_event_loop()
|
||||
full_data = await loop.run_in_executor(_executor, _format_graph_data_from_manager, memory_manager)
|
||||
else:
|
||||
full_data = await load_graph_data_from_file()
|
||||
|
||||
@@ -364,8 +410,11 @@ async def get_clustered_graph(
|
||||
"clustered": False,
|
||||
}})
|
||||
|
||||
# 执行聚类
|
||||
clustered_data = _cluster_graph_data(nodes, edges, max_nodes, cluster_threshold)
|
||||
# 在线程池中执行聚类
|
||||
loop = asyncio.get_event_loop()
|
||||
clustered_data = await loop.run_in_executor(
|
||||
_executor, _cluster_graph_data, nodes, edges, max_nodes, cluster_threshold
|
||||
)
|
||||
|
||||
return JSONResponse(content={"success": True, "data": {
|
||||
**clustered_data,
|
||||
|
||||
@@ -318,6 +318,15 @@ class StreamLoopManager:
|
||||
has_messages = force_dispatch or await self._has_messages_to_process(context)
|
||||
|
||||
if has_messages:
|
||||
# 🔒 并发保护:如果 Chatter 正在处理中,跳过本轮
|
||||
# 这可能发生在:1) 打断后重启循环 2) 处理时间超过轮询间隔
|
||||
if context.is_chatter_processing:
|
||||
logger.debug(f"🔒 [流工作器] stream={stream_id[:8]}, Chatter正在处理中,跳过本轮")
|
||||
# 不打印"开始处理"日志,直接进入下一轮等待
|
||||
# 使用较短的等待时间,等待当前处理完成
|
||||
await asyncio.sleep(1.0)
|
||||
continue
|
||||
|
||||
if force_dispatch:
|
||||
logger.info(f"⚡ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 未读消息 {unread_count} 条,触发强制分发")
|
||||
else:
|
||||
@@ -477,10 +486,11 @@ class StreamLoopManager:
|
||||
logger.warning(f"Chatter管理器未设置: {stream_id}")
|
||||
return False
|
||||
|
||||
# 🔒 防止并发处理:如果已经在处理中,直接返回
|
||||
# 🔒 二次并发保护(防御性检查)
|
||||
# 正常情况下不应该触发,如果触发说明有竞态条件
|
||||
if context.is_chatter_processing:
|
||||
logger.debug(f"🔒 [并发保护] stream={stream_id[:8]}, Chatter 正在处理中,跳过本次处理请求")
|
||||
return True # 返回 True,这是正常的保护机制,不是失败
|
||||
logger.warning(f"🔒 [并发保护] stream={stream_id[:8]}, Chatter正在处理中(二次检查触发,可能存在竞态)")
|
||||
return False
|
||||
|
||||
# 设置处理状态为正在处理
|
||||
self._set_stream_processing_status(stream_id, True)
|
||||
@@ -720,8 +730,8 @@ class StreamLoopManager:
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = await chat_manager.get_stream(stream_id)
|
||||
if chat_stream and not chat_stream.group_info:
|
||||
# 私聊:有消息时几乎立即响应,空转时稍微等待
|
||||
min_interval = 0.1 if has_messages else 3.0
|
||||
# 私聊:有消息时快速响应,空转时稍微等待
|
||||
min_interval = 0.5 if has_messages else 5.0
|
||||
logger.debug(f"流 {stream_id} 私聊模式,使用最小间隔: {min_interval:.2f}s")
|
||||
return min_interval
|
||||
except Exception as e:
|
||||
|
||||
@@ -370,12 +370,18 @@ class MessageManager:
|
||||
|
||||
logger.info(f"🚀 打断后重新创建流循环任务: {stream_id}")
|
||||
|
||||
# 等待一小段时间确保当前消息已经添加到未读消息中
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 获取当前的stream context
|
||||
context = chat_stream.context
|
||||
|
||||
# 🔒 重要:确保 is_chatter_processing 被重置
|
||||
# 被取消的任务的 finally 块可能还没执行完,这里强制重置
|
||||
if context.is_chatter_processing:
|
||||
logger.debug(f"打断后强制重置 is_chatter_processing: {stream_id}")
|
||||
context.is_chatter_processing = False
|
||||
|
||||
# 等待一小段时间确保当前消息已经添加到未读消息中
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 确保有未读消息需要处理
|
||||
unread_messages = context.get_unread_messages()
|
||||
if not unread_messages:
|
||||
|
||||
Reference in New Issue
Block a user