refactor(server): 将记忆可视化工具和统计功能整合进主服务

将原先独立的记忆可视化工具(Memory Visualizer)和LLM使用统计逻辑深度整合到项目主服务中。

主要变更包括:
- **移除独立的可视化工具**: 删除了 `tools/memory_visualizer` 目录下的所有独立服务器、脚本和文档,清理了项目结构。
- **API路由整合**: 在主 FastAPI 应用中注册了记忆可视化工具的路由,使其成为核心功能的一部分,可通过 `/visualizer` 访问。
- **统计逻辑重构**: 将LLM使用统计的计算逻辑从API路由层 `statistic_router.py` 中剥离,迁移到 `src/chat/utils/statistic.py` 中,实现了逻辑的解耦和复用。API路由现在直接调用重构后的统计任务。
- **依赖清理与添加**: 添加了 `jinja2` 作为模板渲染的依赖,并清除了与独立可视化工具相关的旧依赖。

此次重构简化了项目的维护和部署,将原本分散的功能统一管理,提升了代码的内聚性和可维护性。
This commit is contained in:
minecraft1024a
2025-11-07 21:12:11 +08:00
parent 33897bec53
commit 077628930b
23 changed files with 376 additions and 2543 deletions

View File

@@ -0,0 +1,361 @@
"""
记忆图可视化 - API 路由模块
提供 Web API 用于可视化记忆图数据
"""
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
import orjson
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
# 调整项目根目录的计算方式
project_root = Path(__file__).parent.parent.parent
data_dir = project_root / "data" / "memory_graph"
# 缓存
graph_data_cache = None
current_data_file = None
# FastAPI 路由
router = APIRouter()
# Jinja2 模板引擎
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
def find_available_data_files() -> List[Path]:
"""查找所有可用的记忆图数据文件"""
files = []
if not data_dir.exists():
return files
possible_files = ["graph_store.json", "memory_graph.json", "graph_data.json"]
for filename in possible_files:
file_path = data_dir / filename
if file_path.exists():
files.append(file_path)
for pattern in ["graph_store_*.json", "memory_graph_*.json", "graph_data_*.json"]:
for backup_file in data_dir.glob(pattern):
if backup_file not in files:
files.append(backup_file)
backups_dir = data_dir / "backups"
if backups_dir.exists():
for backup_file in backups_dir.glob("**/*.json"):
if backup_file not in files:
files.append(backup_file)
backup_dir = data_dir.parent / "backup"
if backup_dir.exists():
for pattern in ["**/graph_*.json", "**/memory_*.json"]:
for backup_file in backup_dir.glob(pattern):
if backup_file not in files:
files.append(backup_file)
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any]:
"""从磁盘加载图数据"""
global graph_data_cache, current_data_file
if file_path and file_path != current_data_file:
graph_data_cache = None
current_data_file = file_path
if graph_data_cache:
return graph_data_cache
try:
graph_file = current_data_file
if not graph_file:
available_files = find_available_data_files()
if not available_files:
return {"error": "未找到数据文件", "nodes": [], "edges": [], "stats": {}}
graph_file = available_files[0]
current_data_file = graph_file
if not graph_file.exists():
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
with open(graph_file, "r", encoding="utf-8") as f:
data = orjson.loads(f.read())
nodes = data.get("nodes", [])
edges = data.get("edges", [])
metadata = data.get("metadata", {})
nodes_dict = {
node["id"]: {
**node,
"label": node.get("content", ""),
"group": node.get("node_type", ""),
"title": f"{node.get('node_type', '')}: {node.get('content', '')}",
}
for node in nodes
if node.get("id")
}
edges_list = [
{
**edge,
"from": edge.get("source", edge.get("source_id")),
"to": edge.get("target", edge.get("target_id")),
"label": edge.get("relation", ""),
"arrows": "to",
}
for edge in edges
]
stats = metadata.get("statistics", {})
total_memories = stats.get("total_memories", 0)
graph_data_cache = {
"nodes": list(nodes_dict.values()),
"edges": edges_list,
"memories": [],
"stats": {
"total_nodes": len(nodes_dict),
"total_edges": len(edges_list),
"total_memories": total_memories,
},
"current_file": str(graph_file),
"file_size": graph_file.stat().st_size,
"file_modified": datetime.fromtimestamp(graph_file.stat().st_mtime).isoformat(),
}
return graph_data_cache
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"加载图数据失败: {e}")
@router.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""主页面"""
return templates.TemplateResponse("visualizer.html", {"request": request})
def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]:
"""从 MemoryManager 提取并格式化图数据"""
if not memory_manager.graph_store:
return {"nodes": [], "edges": [], "memories": [], "stats": {}}
all_memories = memory_manager.graph_store.get_all_memories()
nodes_dict = {}
edges_list = []
memory_info = []
for memory in all_memories:
memory_info.append(
{
"id": memory.id,
"type": memory.memory_type.value,
"importance": memory.importance,
"text": memory.to_text(),
}
)
for node in memory.nodes:
if node.id not in nodes_dict:
nodes_dict[node.id] = {
"id": node.id,
"label": node.content,
"type": node.node_type.value,
"group": node.node_type.name,
"title": f"{node.node_type.value}: {node.content}",
}
for edge in memory.edges:
edges_list.append( # noqa: PERF401
{
"id": edge.id,
"from": edge.source_id,
"to": edge.target_id,
"label": edge.relation,
"arrows": "to",
"memory_id": memory.id,
}
)
stats = memory_manager.get_statistics()
return {
"nodes": list(nodes_dict.values()),
"edges": edges_list,
"memories": memory_info,
"stats": {
"total_nodes": stats.get("total_nodes", 0),
"total_edges": stats.get("total_edges", 0),
"total_memories": stats.get("total_memories", 0),
},
"current_file": "memory_manager (实时数据)",
}
@router.get("/api/graph/full")
async def get_full_graph():
"""获取完整记忆图数据"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
memory_manager = get_memory_manager()
data = {}
if memory_manager and memory_manager._initialized:
data = _format_graph_data_from_manager(memory_manager)
else:
# 如果内存管理器不可用,则从文件加载
data = load_graph_data_from_file()
return JSONResponse(content={"success": True, "data": data})
except Exception as e:
import traceback
traceback.print_exc()
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.get("/api/files")
async def list_files_api():
"""列出所有可用的数据文件"""
try:
files = find_available_data_files()
file_list = []
for f in files:
stat = f.stat()
file_list.append(
{
"path": str(f),
"name": f.name,
"size": stat.st_size,
"size_kb": round(stat.st_size / 1024, 2),
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
"modified_readable": datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S"),
"is_current": str(f) == str(current_data_file) if current_data_file else False,
}
)
return JSONResponse(
content={
"success": True,
"files": file_list,
"count": len(file_list),
"current_file": str(current_data_file) if current_data_file else None,
}
)
except Exception as e:
# 增加日志记录
# logger.error(f"列出数据文件失败: {e}", exc_info=True)
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.post("/select_file")
async def select_file(request: Request):
"""选择要加载的数据文件"""
global graph_data_cache, current_data_file
try:
data = await request.json()
file_path = data.get("file_path")
if not file_path:
raise HTTPException(status_code=400, detail="未提供文件路径")
file_to_load = Path(file_path)
if not file_to_load.exists():
raise HTTPException(status_code=404, detail=f"文件不存在: {file_path}")
graph_data_cache = None
current_data_file = file_to_load
graph_data = load_graph_data_from_file(file_to_load)
return JSONResponse(
content={
"success": True,
"message": f"已切换到文件: {file_to_load.name}",
"stats": graph_data.get("stats", {}),
}
)
except Exception as e:
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.get("/reload")
async def reload_data():
"""重新加载数据"""
global graph_data_cache
graph_data_cache = None
data = load_graph_data_from_file()
return JSONResponse(content={"success": True, "message": "数据已重新加载", "stats": data.get("stats", {})})
@router.get("/api/search")
async def search_memories(q: str, limit: int = 50):
"""搜索记忆"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
memory_manager = get_memory_manager()
results = []
if memory_manager and memory_manager._initialized and memory_manager.graph_store:
# 从 memory_manager 搜索
all_memories = memory_manager.graph_store.get_all_memories()
for memory in all_memories:
if q.lower() in memory.to_text().lower():
results.append(
{
"id": memory.id,
"type": memory.memory_type.value,
"importance": memory.importance,
"text": memory.to_text(),
}
)
else:
# 从文件加载的数据中搜索 (降级方案)
data = load_graph_data_from_file()
for memory in data.get("memories", []):
if q.lower() in memory.get("text", "").lower():
results.append(memory)
return JSONResponse(
content={
"success": True,
"data": {
"results": results[:limit],
"count": len(results),
},
}
)
except Exception as e:
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.get("/api/stats")
async def get_statistics():
"""获取统计信息"""
try:
data = load_graph_data_from_file()
node_types = {}
memory_types = {}
for node in data["nodes"]:
node_type = node.get("type", "Unknown")
node_types[node_type] = node_types.get(node_type, 0) + 1
for memory in data.get("memories", []):
mem_type = memory.get("type", "Unknown")
memory_types[mem_type] = memory_types.get(mem_type, 0) + 1
stats = data.get("stats", {})
stats["node_types"] = node_types
stats["memory_types"] = memory_types
return JSONResponse(content={"success": True, "data": stats})
except Exception as e:
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)

View File

@@ -4,10 +4,10 @@ from typing import Any, Literal
from fastapi import APIRouter, HTTPException, Query
from src.common.database.compatibility import db_get
from src.common.database.core.models import LLMUsage
from src.chat.utils.statistic import (
StatisticOutputTask,
)
from src.common.logger import get_logger
from src.config.config import model_config
logger = get_logger("LLM统计API")
@@ -37,108 +37,6 @@ COST_BY_USER = "costs_by_user"
COST_BY_MODEL = "costs_by_model"
COST_BY_MODULE = "costs_by_module"
async def _collect_stats_in_period(start_time: datetime, end_time: datetime) -> dict[str, Any]:
"""在指定时间段内收集LLM使用统计信息"""
records = await db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": start_time, "$lt": end_time}},
)
if not records:
return {}
# 创建一个从 model_identifier 到 name 的映射
model_identifier_to_name_map = {model.model_identifier: model.name for model in model_config.models}
stats: dict[str, Any] = {
TOTAL_REQ_CNT: 0,
TOTAL_COST: 0.0,
REQ_CNT_BY_TYPE: defaultdict(int),
REQ_CNT_BY_USER: defaultdict(int),
REQ_CNT_BY_MODEL: defaultdict(int),
REQ_CNT_BY_MODULE: defaultdict(int),
IN_TOK_BY_TYPE: defaultdict(int),
IN_TOK_BY_USER: defaultdict(int),
IN_TOK_BY_MODEL: defaultdict(int),
IN_TOK_BY_MODULE: defaultdict(int),
OUT_TOK_BY_TYPE: defaultdict(int),
OUT_TOK_BY_USER: defaultdict(int),
OUT_TOK_BY_MODEL: defaultdict(int),
OUT_TOK_BY_MODULE: defaultdict(int),
TOTAL_TOK_BY_TYPE: defaultdict(int),
TOTAL_TOK_BY_USER: defaultdict(int),
TOTAL_TOK_BY_MODEL: defaultdict(int),
TOTAL_TOK_BY_MODULE: defaultdict(int),
COST_BY_TYPE: defaultdict(float),
COST_BY_USER: defaultdict(float),
COST_BY_MODEL: defaultdict(float),
COST_BY_MODULE: defaultdict(float),
}
for record in records:
if not isinstance(record, dict):
continue
stats[TOTAL_REQ_CNT] += 1
request_type = record.get("request_type") or "unknown"
user_id = record.get("user_id") or "unknown"
# 从数据库获取的是真实模型名 (model_identifier)
real_model_name = record.get("model_name") or "unknown"
module_name = request_type.split(".")[0] if "." in request_type else request_type
# 尝试通过真实模型名找到配置文件中的模型名
config_model_name = model_identifier_to_name_map.get(real_model_name, real_model_name)
prompt_tokens = record.get("prompt_tokens") or 0
completion_tokens = record.get("completion_tokens") or 0
total_tokens = prompt_tokens + completion_tokens
cost = 0.0
try:
# 使用配置文件中的模型名来获取模型信息
model_info = model_config.get_model_info(config_model_name)
if model_info:
input_cost = (prompt_tokens / 1000000) * model_info.price_in
output_cost = (completion_tokens / 1000000) * model_info.price_out
cost = round(input_cost + output_cost, 6)
except KeyError as e:
logger.info(str(e))
logger.warning(f"模型 '{config_model_name}' (真实名称: '{real_model_name}') 在配置中未找到,成本计算将使用默认值 0.0")
stats[TOTAL_COST] += cost
# 按类型统计
stats[REQ_CNT_BY_TYPE][request_type] += 1
stats[IN_TOK_BY_TYPE][request_type] += prompt_tokens
stats[OUT_TOK_BY_TYPE][request_type] += completion_tokens
stats[TOTAL_TOK_BY_TYPE][request_type] += total_tokens
stats[COST_BY_TYPE][request_type] += cost
# 按用户统计
stats[REQ_CNT_BY_USER][user_id] += 1
stats[IN_TOK_BY_USER][user_id] += prompt_tokens
stats[OUT_TOK_BY_USER][user_id] += completion_tokens
stats[TOTAL_TOK_BY_USER][user_id] += total_tokens
stats[COST_BY_USER][user_id] += cost
# 按模型统计 (使用配置文件中的名称)
stats[REQ_CNT_BY_MODEL][config_model_name] += 1
stats[IN_TOK_BY_MODEL][config_model_name] += prompt_tokens
stats[OUT_TOK_BY_MODEL][config_model_name] += completion_tokens
stats[TOTAL_TOK_BY_MODEL][config_model_name] += total_tokens
stats[COST_BY_MODEL][config_model_name] += cost
# 按模块统计
stats[REQ_CNT_BY_MODULE][module_name] += 1
stats[IN_TOK_BY_MODULE][module_name] += prompt_tokens
stats[OUT_TOK_BY_MODULE][module_name] += completion_tokens
stats[TOTAL_TOK_BY_MODULE][module_name] += total_tokens
stats[COST_BY_MODULE][module_name] += cost
return stats
@router.get("/llm/stats")
async def get_llm_stats(
period_type: Literal[
@@ -179,7 +77,8 @@ async def get_llm_stats(
if start_time is None:
raise HTTPException(status_code=400, detail="无法确定查询的起始时间")
period_stats = await _collect_stats_in_period(start_time, end_time)
stats_data = await StatisticOutputTask._collect_model_request_for_period([("custom", start_time)])
period_stats = stats_data.get("custom", {})
if not period_stats:
return {"period": {"start": start_time.isoformat(), "end": end_time.isoformat()}, "data": {}}

File diff suppressed because it is too large Load Diff