refactor(server): 将记忆可视化工具和统计功能整合进主服务
将原先独立的记忆可视化工具(Memory Visualizer)和LLM使用统计逻辑深度整合到项目主服务中。 主要变更包括: - **移除独立的可视化工具**: 删除了 `tools/memory_visualizer` 目录下的所有独立服务器、脚本和文档,清理了项目结构。 - **API路由整合**: 在主 FastAPI 应用中注册了记忆可视化工具的路由,使其成为核心功能的一部分,可通过 `/visualizer` 访问。 - **统计逻辑重构**: 将LLM使用统计的计算逻辑从API路由层 `statistic_router.py` 中剥离,迁移到 `src/chat/utils/statistic.py` 中,实现了逻辑的解耦和复用。API路由现在直接调用重构后的统计任务。 - **依赖清理与添加**: 添加了 `jinja2` 作为模板渲染的依赖,并清除了与独立可视化工具相关的旧依赖。 此次重构简化了项目的维护和部署,将原本分散的功能统一管理,提升了代码的内聚性和可维护性。
This commit is contained in:
361
src/api/memory_visualizer_router.py
Normal file
361
src/api/memory_visualizer_router.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
记忆图可视化 - API 路由模块
|
||||
|
||||
提供 Web API 用于可视化记忆图数据
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import orjson
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
# 调整项目根目录的计算方式
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
data_dir = project_root / "data" / "memory_graph"
|
||||
|
||||
# 缓存
|
||||
graph_data_cache = None
|
||||
current_data_file = None
|
||||
|
||||
# FastAPI 路由
|
||||
router = APIRouter()
|
||||
|
||||
# Jinja2 模板引擎
|
||||
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
|
||||
|
||||
|
||||
def find_available_data_files() -> List[Path]:
|
||||
"""查找所有可用的记忆图数据文件"""
|
||||
files = []
|
||||
if not data_dir.exists():
|
||||
return files
|
||||
|
||||
possible_files = ["graph_store.json", "memory_graph.json", "graph_data.json"]
|
||||
for filename in possible_files:
|
||||
file_path = data_dir / filename
|
||||
if file_path.exists():
|
||||
files.append(file_path)
|
||||
|
||||
for pattern in ["graph_store_*.json", "memory_graph_*.json", "graph_data_*.json"]:
|
||||
for backup_file in data_dir.glob(pattern):
|
||||
if backup_file not in files:
|
||||
files.append(backup_file)
|
||||
|
||||
backups_dir = data_dir / "backups"
|
||||
if backups_dir.exists():
|
||||
for backup_file in backups_dir.glob("**/*.json"):
|
||||
if backup_file not in files:
|
||||
files.append(backup_file)
|
||||
|
||||
backup_dir = data_dir.parent / "backup"
|
||||
if backup_dir.exists():
|
||||
for pattern in ["**/graph_*.json", "**/memory_*.json"]:
|
||||
for backup_file in backup_dir.glob(pattern):
|
||||
if backup_file not in files:
|
||||
files.append(backup_file)
|
||||
|
||||
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
|
||||
|
||||
|
||||
def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any]:
|
||||
"""从磁盘加载图数据"""
|
||||
global graph_data_cache, current_data_file
|
||||
|
||||
if file_path and file_path != current_data_file:
|
||||
graph_data_cache = None
|
||||
current_data_file = file_path
|
||||
|
||||
if graph_data_cache:
|
||||
return graph_data_cache
|
||||
|
||||
try:
|
||||
graph_file = current_data_file
|
||||
if not graph_file:
|
||||
available_files = find_available_data_files()
|
||||
if not available_files:
|
||||
return {"error": "未找到数据文件", "nodes": [], "edges": [], "stats": {}}
|
||||
graph_file = available_files[0]
|
||||
current_data_file = graph_file
|
||||
|
||||
if not graph_file.exists():
|
||||
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
|
||||
|
||||
with open(graph_file, "r", encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
|
||||
nodes = data.get("nodes", [])
|
||||
edges = data.get("edges", [])
|
||||
metadata = data.get("metadata", {})
|
||||
|
||||
nodes_dict = {
|
||||
node["id"]: {
|
||||
**node,
|
||||
"label": node.get("content", ""),
|
||||
"group": node.get("node_type", ""),
|
||||
"title": f"{node.get('node_type', '')}: {node.get('content', '')}",
|
||||
}
|
||||
for node in nodes
|
||||
if node.get("id")
|
||||
}
|
||||
|
||||
edges_list = [
|
||||
{
|
||||
**edge,
|
||||
"from": edge.get("source", edge.get("source_id")),
|
||||
"to": edge.get("target", edge.get("target_id")),
|
||||
"label": edge.get("relation", ""),
|
||||
"arrows": "to",
|
||||
}
|
||||
for edge in edges
|
||||
]
|
||||
|
||||
stats = metadata.get("statistics", {})
|
||||
total_memories = stats.get("total_memories", 0)
|
||||
|
||||
graph_data_cache = {
|
||||
"nodes": list(nodes_dict.values()),
|
||||
"edges": edges_list,
|
||||
"memories": [],
|
||||
"stats": {
|
||||
"total_nodes": len(nodes_dict),
|
||||
"total_edges": len(edges_list),
|
||||
"total_memories": total_memories,
|
||||
},
|
||||
"current_file": str(graph_file),
|
||||
"file_size": graph_file.stat().st_size,
|
||||
"file_modified": datetime.fromtimestamp(graph_file.stat().st_mtime).isoformat(),
|
||||
}
|
||||
return graph_data_cache
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"加载图数据失败: {e}")
|
||||
|
||||
|
||||
@router.get("/", response_class=HTMLResponse)
|
||||
async def index(request: Request):
|
||||
"""主页面"""
|
||||
return templates.TemplateResponse("visualizer.html", {"request": request})
|
||||
|
||||
|
||||
def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]:
|
||||
"""从 MemoryManager 提取并格式化图数据"""
|
||||
if not memory_manager.graph_store:
|
||||
return {"nodes": [], "edges": [], "memories": [], "stats": {}}
|
||||
|
||||
all_memories = memory_manager.graph_store.get_all_memories()
|
||||
nodes_dict = {}
|
||||
edges_list = []
|
||||
memory_info = []
|
||||
|
||||
for memory in all_memories:
|
||||
memory_info.append(
|
||||
{
|
||||
"id": memory.id,
|
||||
"type": memory.memory_type.value,
|
||||
"importance": memory.importance,
|
||||
"text": memory.to_text(),
|
||||
}
|
||||
)
|
||||
for node in memory.nodes:
|
||||
if node.id not in nodes_dict:
|
||||
nodes_dict[node.id] = {
|
||||
"id": node.id,
|
||||
"label": node.content,
|
||||
"type": node.node_type.value,
|
||||
"group": node.node_type.name,
|
||||
"title": f"{node.node_type.value}: {node.content}",
|
||||
}
|
||||
for edge in memory.edges:
|
||||
edges_list.append( # noqa: PERF401
|
||||
{
|
||||
"id": edge.id,
|
||||
"from": edge.source_id,
|
||||
"to": edge.target_id,
|
||||
"label": edge.relation,
|
||||
"arrows": "to",
|
||||
"memory_id": memory.id,
|
||||
}
|
||||
)
|
||||
|
||||
stats = memory_manager.get_statistics()
|
||||
return {
|
||||
"nodes": list(nodes_dict.values()),
|
||||
"edges": edges_list,
|
||||
"memories": memory_info,
|
||||
"stats": {
|
||||
"total_nodes": stats.get("total_nodes", 0),
|
||||
"total_edges": stats.get("total_edges", 0),
|
||||
"total_memories": stats.get("total_memories", 0),
|
||||
},
|
||||
"current_file": "memory_manager (实时数据)",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/api/graph/full")
|
||||
async def get_full_graph():
|
||||
"""获取完整记忆图数据"""
|
||||
try:
|
||||
from src.memory_graph.manager_singleton import get_memory_manager
|
||||
|
||||
memory_manager = get_memory_manager()
|
||||
|
||||
data = {}
|
||||
if memory_manager and memory_manager._initialized:
|
||||
data = _format_graph_data_from_manager(memory_manager)
|
||||
else:
|
||||
# 如果内存管理器不可用,则从文件加载
|
||||
data = load_graph_data_from_file()
|
||||
|
||||
return JSONResponse(content={"success": True, "data": data})
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
@router.get("/api/files")
|
||||
async def list_files_api():
|
||||
"""列出所有可用的数据文件"""
|
||||
try:
|
||||
files = find_available_data_files()
|
||||
file_list = []
|
||||
for f in files:
|
||||
stat = f.stat()
|
||||
file_list.append(
|
||||
{
|
||||
"path": str(f),
|
||||
"name": f.name,
|
||||
"size": stat.st_size,
|
||||
"size_kb": round(stat.st_size / 1024, 2),
|
||||
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
|
||||
"modified_readable": datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"is_current": str(f) == str(current_data_file) if current_data_file else False,
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"files": file_list,
|
||||
"count": len(file_list),
|
||||
"current_file": str(current_data_file) if current_data_file else None,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
# 增加日志记录
|
||||
# logger.error(f"列出数据文件失败: {e}", exc_info=True)
|
||||
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
@router.post("/select_file")
|
||||
async def select_file(request: Request):
|
||||
"""选择要加载的数据文件"""
|
||||
global graph_data_cache, current_data_file
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get("file_path")
|
||||
if not file_path:
|
||||
raise HTTPException(status_code=400, detail="未提供文件路径")
|
||||
|
||||
file_to_load = Path(file_path)
|
||||
if not file_to_load.exists():
|
||||
raise HTTPException(status_code=404, detail=f"文件不存在: {file_path}")
|
||||
|
||||
graph_data_cache = None
|
||||
current_data_file = file_to_load
|
||||
graph_data = load_graph_data_from_file(file_to_load)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": f"已切换到文件: {file_to_load.name}",
|
||||
"stats": graph_data.get("stats", {}),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
@router.get("/reload")
|
||||
async def reload_data():
|
||||
"""重新加载数据"""
|
||||
global graph_data_cache
|
||||
graph_data_cache = None
|
||||
data = load_graph_data_from_file()
|
||||
return JSONResponse(content={"success": True, "message": "数据已重新加载", "stats": data.get("stats", {})})
|
||||
|
||||
|
||||
@router.get("/api/search")
|
||||
async def search_memories(q: str, limit: int = 50):
|
||||
"""搜索记忆"""
|
||||
try:
|
||||
from src.memory_graph.manager_singleton import get_memory_manager
|
||||
|
||||
memory_manager = get_memory_manager()
|
||||
|
||||
results = []
|
||||
if memory_manager and memory_manager._initialized and memory_manager.graph_store:
|
||||
# 从 memory_manager 搜索
|
||||
all_memories = memory_manager.graph_store.get_all_memories()
|
||||
for memory in all_memories:
|
||||
if q.lower() in memory.to_text().lower():
|
||||
results.append(
|
||||
{
|
||||
"id": memory.id,
|
||||
"type": memory.memory_type.value,
|
||||
"importance": memory.importance,
|
||||
"text": memory.to_text(),
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 从文件加载的数据中搜索 (降级方案)
|
||||
data = load_graph_data_from_file()
|
||||
for memory in data.get("memories", []):
|
||||
if q.lower() in memory.get("text", "").lower():
|
||||
results.append(memory)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"data": {
|
||||
"results": results[:limit],
|
||||
"count": len(results),
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
@router.get("/api/stats")
|
||||
async def get_statistics():
|
||||
"""获取统计信息"""
|
||||
try:
|
||||
data = load_graph_data_from_file()
|
||||
|
||||
node_types = {}
|
||||
memory_types = {}
|
||||
|
||||
for node in data["nodes"]:
|
||||
node_type = node.get("type", "Unknown")
|
||||
node_types[node_type] = node_types.get(node_type, 0) + 1
|
||||
|
||||
for memory in data.get("memories", []):
|
||||
mem_type = memory.get("type", "Unknown")
|
||||
memory_types[mem_type] = memory_types.get(mem_type, 0) + 1
|
||||
|
||||
stats = data.get("stats", {})
|
||||
stats["node_types"] = node_types
|
||||
stats["memory_types"] = memory_types
|
||||
|
||||
return JSONResponse(content={"success": True, "data": stats})
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
||||
@@ -4,10 +4,10 @@ from typing import Any, Literal
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from src.common.database.compatibility import db_get
|
||||
from src.common.database.core.models import LLMUsage
|
||||
from src.chat.utils.statistic import (
|
||||
StatisticOutputTask,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
|
||||
logger = get_logger("LLM统计API")
|
||||
|
||||
@@ -37,108 +37,6 @@ COST_BY_USER = "costs_by_user"
|
||||
COST_BY_MODEL = "costs_by_model"
|
||||
COST_BY_MODULE = "costs_by_module"
|
||||
|
||||
|
||||
async def _collect_stats_in_period(start_time: datetime, end_time: datetime) -> dict[str, Any]:
|
||||
"""在指定时间段内收集LLM使用统计信息"""
|
||||
records = await db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": start_time, "$lt": end_time}},
|
||||
)
|
||||
if not records:
|
||||
return {}
|
||||
|
||||
# 创建一个从 model_identifier 到 name 的映射
|
||||
model_identifier_to_name_map = {model.model_identifier: model.name for model in model_config.models}
|
||||
|
||||
stats: dict[str, Any] = {
|
||||
TOTAL_REQ_CNT: 0,
|
||||
TOTAL_COST: 0.0,
|
||||
REQ_CNT_BY_TYPE: defaultdict(int),
|
||||
REQ_CNT_BY_USER: defaultdict(int),
|
||||
REQ_CNT_BY_MODEL: defaultdict(int),
|
||||
REQ_CNT_BY_MODULE: defaultdict(int),
|
||||
IN_TOK_BY_TYPE: defaultdict(int),
|
||||
IN_TOK_BY_USER: defaultdict(int),
|
||||
IN_TOK_BY_MODEL: defaultdict(int),
|
||||
IN_TOK_BY_MODULE: defaultdict(int),
|
||||
OUT_TOK_BY_TYPE: defaultdict(int),
|
||||
OUT_TOK_BY_USER: defaultdict(int),
|
||||
OUT_TOK_BY_MODEL: defaultdict(int),
|
||||
OUT_TOK_BY_MODULE: defaultdict(int),
|
||||
TOTAL_TOK_BY_TYPE: defaultdict(int),
|
||||
TOTAL_TOK_BY_USER: defaultdict(int),
|
||||
TOTAL_TOK_BY_MODEL: defaultdict(int),
|
||||
TOTAL_TOK_BY_MODULE: defaultdict(int),
|
||||
COST_BY_TYPE: defaultdict(float),
|
||||
COST_BY_USER: defaultdict(float),
|
||||
COST_BY_MODEL: defaultdict(float),
|
||||
COST_BY_MODULE: defaultdict(float),
|
||||
}
|
||||
|
||||
for record in records:
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
|
||||
stats[TOTAL_REQ_CNT] += 1
|
||||
|
||||
request_type = record.get("request_type") or "unknown"
|
||||
user_id = record.get("user_id") or "unknown"
|
||||
# 从数据库获取的是真实模型名 (model_identifier)
|
||||
real_model_name = record.get("model_name") or "unknown"
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
|
||||
# 尝试通过真实模型名找到配置文件中的模型名
|
||||
config_model_name = model_identifier_to_name_map.get(real_model_name, real_model_name)
|
||||
|
||||
prompt_tokens = record.get("prompt_tokens") or 0
|
||||
completion_tokens = record.get("completion_tokens") or 0
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
cost = 0.0
|
||||
try:
|
||||
# 使用配置文件中的模型名来获取模型信息
|
||||
model_info = model_config.get_model_info(config_model_name)
|
||||
if model_info:
|
||||
input_cost = (prompt_tokens / 1000000) * model_info.price_in
|
||||
output_cost = (completion_tokens / 1000000) * model_info.price_out
|
||||
cost = round(input_cost + output_cost, 6)
|
||||
except KeyError as e:
|
||||
logger.info(str(e))
|
||||
logger.warning(f"模型 '{config_model_name}' (真实名称: '{real_model_name}') 在配置中未找到,成本计算将使用默认值 0.0")
|
||||
|
||||
stats[TOTAL_COST] += cost
|
||||
|
||||
# 按类型统计
|
||||
stats[REQ_CNT_BY_TYPE][request_type] += 1
|
||||
stats[IN_TOK_BY_TYPE][request_type] += prompt_tokens
|
||||
stats[OUT_TOK_BY_TYPE][request_type] += completion_tokens
|
||||
stats[TOTAL_TOK_BY_TYPE][request_type] += total_tokens
|
||||
stats[COST_BY_TYPE][request_type] += cost
|
||||
|
||||
# 按用户统计
|
||||
stats[REQ_CNT_BY_USER][user_id] += 1
|
||||
stats[IN_TOK_BY_USER][user_id] += prompt_tokens
|
||||
stats[OUT_TOK_BY_USER][user_id] += completion_tokens
|
||||
stats[TOTAL_TOK_BY_USER][user_id] += total_tokens
|
||||
stats[COST_BY_USER][user_id] += cost
|
||||
|
||||
# 按模型统计 (使用配置文件中的名称)
|
||||
stats[REQ_CNT_BY_MODEL][config_model_name] += 1
|
||||
stats[IN_TOK_BY_MODEL][config_model_name] += prompt_tokens
|
||||
stats[OUT_TOK_BY_MODEL][config_model_name] += completion_tokens
|
||||
stats[TOTAL_TOK_BY_MODEL][config_model_name] += total_tokens
|
||||
stats[COST_BY_MODEL][config_model_name] += cost
|
||||
|
||||
# 按模块统计
|
||||
stats[REQ_CNT_BY_MODULE][module_name] += 1
|
||||
stats[IN_TOK_BY_MODULE][module_name] += prompt_tokens
|
||||
stats[OUT_TOK_BY_MODULE][module_name] += completion_tokens
|
||||
stats[TOTAL_TOK_BY_MODULE][module_name] += total_tokens
|
||||
stats[COST_BY_MODULE][module_name] += cost
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
@router.get("/llm/stats")
|
||||
async def get_llm_stats(
|
||||
period_type: Literal[
|
||||
@@ -179,7 +77,8 @@ async def get_llm_stats(
|
||||
if start_time is None:
|
||||
raise HTTPException(status_code=400, detail="无法确定查询的起始时间")
|
||||
|
||||
period_stats = await _collect_stats_in_period(start_time, end_time)
|
||||
stats_data = await StatisticOutputTask._collect_model_request_for_period([("custom", start_time)])
|
||||
period_stats = stats_data.get("custom", {})
|
||||
|
||||
if not period_stats:
|
||||
return {"period": {"start": start_time.isoformat(), "end": end_time.isoformat()}, "data": {}}
|
||||
|
||||
1175
src/api/templates/visualizer.html
Normal file
1175
src/api/templates/visualizer.html
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user