refactor(api, chat): 改进异步处理并解决并发问题

内存可视化器 API 端点之前在异步路由中执行同步的阻塞操作(文件 I/O、数据处理)。在处理大型图文件时,这可能导致服务器冻结。现在,这些任务已被移至 ThreadPoolExecutor,从而使 API 非阻塞并显著提高响应速度。

在聊天消息管理器中,竞争条件可能导致消息处理重叠或中断后数据流停滞。此提交引入了:
- 并发锁(`is_chatter_processing`)以防止流循环同时运行多个 chatter 实例。
- 故障保护机制,在中断时重置处理状态,确保数据流能够恢复并正确继续。
This commit is contained in:
tt-P607
2025-12-02 14:40:58 +08:00
parent 1027c5abf7
commit 659a8e0d78
4 changed files with 860 additions and 115 deletions

View File

@@ -4,7 +4,9 @@
提供 Web API 用于可视化记忆图数据
"""
import asyncio
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import Any
@@ -23,6 +25,9 @@ data_dir = project_root / "data" / "memory_graph"
graph_data_cache = None
current_data_file = None
# 线程池用于异步文件读取
_executor = ThreadPoolExecutor(max_workers=2)
# FastAPI 路由
router = APIRouter()
@@ -64,7 +69,7 @@ def find_available_data_files() -> list[Path]:
async def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]:
"""从磁盘加载图数据"""
"""从磁盘加载图数据(异步,不阻塞主线程)"""
global graph_data_cache, current_data_file
if file_path and file_path != current_data_file:
@@ -86,65 +91,81 @@ async def load_graph_data_from_file(file_path: Path | None = None) -> dict[str,
if not graph_file.exists():
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
with open(graph_file, encoding="utf-8") as f:
data = orjson.loads(f.read())
# 在线程池中异步读取文件,避免阻塞主事件循环
loop = asyncio.get_event_loop()
data = await loop.run_in_executor(_executor, _sync_load_json_file, graph_file)
nodes = data.get("nodes", [])
edges = data.get("edges", [])
metadata = data.get("metadata", {})
nodes_dict = {
node["id"]: {
**node,
"label": node.get("content", ""),
"group": node.get("node_type", ""),
"title": f"{node.get('node_type', '')}: {node.get('content', '')}",
}
for node in nodes
if node.get("id")
}
edges_list = []
seen_edge_ids = set()
for edge in edges:
edge_id = edge.get("id")
if edge_id and edge_id not in seen_edge_ids:
edges_list.append(
{
**edge,
"from": edge.get("source", edge.get("source_id")),
"to": edge.get("target", edge.get("target_id")),
"label": edge.get("relation", ""),
"arrows": "to",
}
)
seen_edge_ids.add(edge_id)
stats = metadata.get("statistics", {})
total_memories = stats.get("total_memories", 0)
graph_data_cache = {
"nodes": list(nodes_dict.values()),
"edges": edges_list,
"memories": [],
"stats": {
"total_nodes": len(nodes_dict),
"total_edges": len(edges_list),
"total_memories": total_memories,
},
"current_file": str(graph_file),
"file_size": graph_file.stat().st_size,
"file_modified": datetime.fromtimestamp(graph_file.stat().st_mtime).isoformat(),
}
# 在线程池中处理数据转换
processed = await loop.run_in_executor(
_executor, _process_graph_data, nodes, edges, metadata, graph_file
)
graph_data_cache = processed
return graph_data_cache
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"加载图数据失败: {e}")
def _sync_load_json_file(file_path: Path) -> dict:
"""同步加载 JSON 文件(在线程池中执行)"""
with open(file_path, encoding="utf-8") as f:
return orjson.loads(f.read())
def _process_graph_data(nodes: list, edges: list, metadata: dict, graph_file: Path) -> dict:
"""处理图数据(在线程池中执行)"""
nodes_dict = {
node["id"]: {
**node,
"label": node.get("content", ""),
"group": node.get("node_type", ""),
"title": f"{node.get('node_type', '')}: {node.get('content', '')}",
}
for node in nodes
if node.get("id")
}
edges_list = []
seen_edge_ids = set()
for edge in edges:
edge_id = edge.get("id")
if edge_id and edge_id not in seen_edge_ids:
edges_list.append(
{
**edge,
"from": edge.get("source", edge.get("source_id")),
"to": edge.get("target", edge.get("target_id")),
"label": edge.get("relation", ""),
"arrows": "to",
}
)
seen_edge_ids.add(edge_id)
stats = metadata.get("statistics", {})
total_memories = stats.get("total_memories", 0)
return {
"nodes": list(nodes_dict.values()),
"edges": edges_list,
"memories": [],
"stats": {
"total_nodes": len(nodes_dict),
"total_edges": len(edges_list),
"total_memories": total_memories,
},
"current_file": str(graph_file),
"file_size": graph_file.stat().st_size,
"file_modified": datetime.fromtimestamp(graph_file.stat().st_mtime).isoformat(),
}
@router.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""主页面"""
@@ -152,7 +173,7 @@ async def index(request: Request):
def _format_graph_data_from_manager(memory_manager) -> dict[str, Any]:
"""从 MemoryManager 提取并格式化图数据"""
"""从 MemoryManager 提取并格式化图数据(同步版本,需在线程池中调用)"""
if not memory_manager.graph_store:
return {"nodes": [], "edges": [], "memories": [], "stats": {}}
@@ -216,7 +237,9 @@ async def get_full_graph():
data = {}
if memory_manager and memory_manager._initialized:
data = _format_graph_data_from_manager(memory_manager)
# 在线程池中执行,避免阻塞主事件循环
loop = asyncio.get_event_loop()
data = await loop.run_in_executor(_executor, _format_graph_data_from_manager, memory_manager)
else:
# 如果内存管理器不可用,则从文件加载
data = await load_graph_data_from_file()
@@ -270,71 +293,93 @@ async def get_paginated_graph(
memory_manager = get_memory_manager()
# 获取完整数据
# 获取完整数据(已经是异步的)
if memory_manager and memory_manager._initialized:
full_data = _format_graph_data_from_manager(memory_manager)
loop = asyncio.get_event_loop()
full_data = await loop.run_in_executor(_executor, _format_graph_data_from_manager, memory_manager)
else:
full_data = await load_graph_data_from_file()
nodes = full_data.get("nodes", [])
edges = full_data.get("edges", [])
# 在线程池中处理分页逻辑
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
_executor,
_process_pagination,
full_data, page, page_size, min_importance, node_types
)
# 过滤节点类型
if node_types:
allowed_types = set(node_types.split(","))
nodes = [n for n in nodes if n.get("group") in allowed_types]
# 按重要性排序如果有importance字段
nodes_with_importance = []
for node in nodes:
# 计算节点重要性(连接的边数)
edge_count = sum(1 for e in edges if e.get("from") == node["id"] or e.get("to") == node["id"])
importance = edge_count / max(len(edges), 1)
if importance >= min_importance:
node["importance"] = importance
nodes_with_importance.append(node)
# 按重要性降序排序
nodes_with_importance.sort(key=lambda x: x.get("importance", 0), reverse=True)
# 分页
total_nodes = len(nodes_with_importance)
total_pages = (total_nodes + page_size - 1) // page_size
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_nodes)
paginated_nodes = nodes_with_importance[start_idx:end_idx]
node_ids = set(n["id"] for n in paginated_nodes)
# 只保留连接分页节点的边
paginated_edges = [
e for e in edges
if e.get("from") in node_ids and e.get("to") in node_ids
]
return JSONResponse(content={"success": True, "data": {
"nodes": paginated_nodes,
"edges": paginated_edges,
"pagination": {
"page": page,
"page_size": page_size,
"total_nodes": total_nodes,
"total_pages": total_pages,
"has_next": page < total_pages,
"has_prev": page > 1,
},
"stats": {
"total_nodes": total_nodes,
"total_edges": len(paginated_edges),
"total_memories": full_data.get("stats", {}).get("total_memories", 0),
},
}})
return JSONResponse(content={"success": True, "data": result})
except Exception as e:
import traceback
traceback.print_exc()
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
def _process_pagination(full_data: dict, page: int, page_size: int, min_importance: float, node_types: str | None) -> dict:
"""处理分页逻辑(在线程池中执行)"""
nodes = full_data.get("nodes", [])
edges = full_data.get("edges", [])
# 过滤节点类型
if node_types:
allowed_types = set(node_types.split(","))
nodes = [n for n in nodes if n.get("group") in allowed_types]
# 构建边的索引以加速查找
edge_count_map = {}
for e in edges:
from_id = e.get("from")
to_id = e.get("to")
edge_count_map[from_id] = edge_count_map.get(from_id, 0) + 1
edge_count_map[to_id] = edge_count_map.get(to_id, 0) + 1
# 按重要性排序
nodes_with_importance = []
total_edges = max(len(edges), 1)
for node in nodes:
edge_count = edge_count_map.get(node["id"], 0)
importance = edge_count / total_edges
if importance >= min_importance:
node["importance"] = importance
nodes_with_importance.append(node)
# 按重要性降序排序
nodes_with_importance.sort(key=lambda x: x.get("importance", 0), reverse=True)
# 分页
total_nodes = len(nodes_with_importance)
total_pages = (total_nodes + page_size - 1) // page_size
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_nodes)
paginated_nodes = nodes_with_importance[start_idx:end_idx]
node_ids = set(n["id"] for n in paginated_nodes)
# 只保留连接分页节点的边
paginated_edges = [
e for e in edges
if e.get("from") in node_ids and e.get("to") in node_ids
]
return {
"nodes": paginated_nodes,
"edges": paginated_edges,
"pagination": {
"page": page,
"page_size": page_size,
"total_nodes": total_nodes,
"total_pages": total_pages,
"has_next": page < total_pages,
"has_prev": page > 1,
},
"stats": {
"total_nodes": total_nodes,
"total_edges": len(paginated_edges),
"total_memories": full_data.get("stats", {}).get("total_memories", 0),
},
}
@router.get("/api/graph/clustered")
async def get_clustered_graph(
max_nodes: int = Query(300, ge=50, le=1000, description="最大节点数"),
@@ -346,9 +391,10 @@ async def get_clustered_graph(
memory_manager = get_memory_manager()
# 获取完整数据
# 获取完整数据(异步)
if memory_manager and memory_manager._initialized:
full_data = _format_graph_data_from_manager(memory_manager)
loop = asyncio.get_event_loop()
full_data = await loop.run_in_executor(_executor, _format_graph_data_from_manager, memory_manager)
else:
full_data = await load_graph_data_from_file()
@@ -364,8 +410,11 @@ async def get_clustered_graph(
"clustered": False,
}})
# 执行聚类
clustered_data = _cluster_graph_data(nodes, edges, max_nodes, cluster_threshold)
# 在线程池中执行聚类
loop = asyncio.get_event_loop()
clustered_data = await loop.run_in_executor(
_executor, _cluster_graph_data, nodes, edges, max_nodes, cluster_threshold
)
return JSONResponse(content={"success": True, "data": {
**clustered_data,

View File

@@ -318,6 +318,15 @@ class StreamLoopManager:
has_messages = force_dispatch or await self._has_messages_to_process(context)
if has_messages:
# 🔒 并发保护:如果 Chatter 正在处理中,跳过本轮
# 这可能发生在1) 打断后重启循环 2) 处理时间超过轮询间隔
if context.is_chatter_processing:
logger.debug(f"🔒 [流工作器] stream={stream_id[:8]}, Chatter正在处理中跳过本轮")
# 不打印"开始处理"日志,直接进入下一轮等待
# 使用较短的等待时间,等待当前处理完成
await asyncio.sleep(1.0)
continue
if force_dispatch:
logger.info(f"⚡ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 未读消息 {unread_count} 条,触发强制分发")
else:
@@ -477,10 +486,11 @@ class StreamLoopManager:
logger.warning(f"Chatter管理器未设置: {stream_id}")
return False
# 🔒 防止并发处理:如果已经在处理中,直接返回
# 🔒 二次并发保护(防御性检查)
# 正常情况下不应该触发,如果触发说明有竞态条件
if context.is_chatter_processing:
logger.debug(f"🔒 [并发保护] stream={stream_id[:8]}, Chatter 正在处理中,跳过本次处理请求")
return True # 返回 True这是正常的保护机制不是失败
logger.warning(f"🔒 [并发保护] stream={stream_id[:8]}, Chatter正在处理中(二次检查触发,可能存在竞态)")
return False
# 设置处理状态为正在处理
self._set_stream_processing_status(stream_id, True)
@@ -720,8 +730,8 @@ class StreamLoopManager:
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and not chat_stream.group_info:
# 私聊:有消息时几乎立即响应,空转时稍微等待
min_interval = 0.1 if has_messages else 3.0
# 私聊:有消息时快速响应,空转时稍微等待
min_interval = 0.5 if has_messages else 5.0
logger.debug(f"{stream_id} 私聊模式,使用最小间隔: {min_interval:.2f}s")
return min_interval
except Exception as e:

View File

@@ -370,12 +370,18 @@ class MessageManager:
logger.info(f"🚀 打断后重新创建流循环任务: {stream_id}")
# 等待一小段时间确保当前消息已经添加到未读消息中
await asyncio.sleep(0.1)
# 获取当前的stream context
context = chat_stream.context
# 🔒 重要:确保 is_chatter_processing 被重置
# 被取消的任务的 finally 块可能还没执行完,这里强制重置
if context.is_chatter_processing:
logger.debug(f"打断后强制重置 is_chatter_processing: {stream_id}")
context.is_chatter_processing = False
# 等待一小段时间确保当前消息已经添加到未读消息中
await asyncio.sleep(0.1)
# 确保有未读消息需要处理
unread_messages = context.get_unread_messages()
if not unread_messages: