style: 统一代码风格并采用现代化类型注解

对整个代码库进行了一次全面的代码风格清理和现代化改造,主要包括:

- 移除了所有文件中多余的行尾空格。
- 将类型提示更新为 PEP 585 和 PEP 604 引入的现代语法(例如,使用 `list` 代替 `List`,使用 `|` 代替 `Optional`)。
- 清理了多个模块中未被使用的导入语句。
- 移除了不含插值变量的冗余 f-string。
- 调整了部分 `__init__.py` 文件中的 `__all__` 导出顺序,以保持一致性。

这些改动旨在提升代码的可读性和可维护性,使其与现代 Python 最佳实践保持一致,但未修改任何核心逻辑。
This commit is contained in:
minecraft1024a
2025-11-12 12:49:40 +08:00
parent daf8ea7e6a
commit 0e1e9935b2
33 changed files with 227 additions and 229 deletions

View File

@@ -4,13 +4,13 @@
提供 Web API 用于可视化记忆图数据
"""
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from collections import defaultdict
from typing import Any
import orjson
from fastapi import APIRouter, HTTPException, Request, Query
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
@@ -29,7 +29,7 @@ router = APIRouter()
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
def find_available_data_files() -> List[Path]:
def find_available_data_files() -> list[Path]:
"""查找所有可用的记忆图数据文件"""
files = []
if not data_dir.exists():
@@ -62,7 +62,7 @@ def find_available_data_files() -> List[Path]:
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any]:
def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]:
"""从磁盘加载图数据"""
global graph_data_cache, current_data_file
@@ -85,7 +85,7 @@ def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any
if not graph_file.exists():
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
with open(graph_file, "r", encoding="utf-8") as f:
with open(graph_file, encoding="utf-8") as f:
data = orjson.loads(f.read())
nodes = data.get("nodes", [])
@@ -150,7 +150,7 @@ async def index(request: Request):
return templates.TemplateResponse("visualizer.html", {"request": request})
def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]:
def _format_graph_data_from_manager(memory_manager) -> dict[str, Any]:
"""从 MemoryManager 提取并格式化图数据"""
if not memory_manager.graph_store:
return {"nodes": [], "edges": [], "memories": [], "stats": {}}
@@ -188,7 +188,7 @@ def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]:
"arrows": "to",
"memory_id": memory.id,
}
edges_list = list(edges_dict.values())
stats = memory_manager.get_statistics()
@@ -261,7 +261,7 @@ async def get_paginated_graph(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(500, ge=100, le=2000, description="每页节点数"),
min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"),
node_types: Optional[str] = Query(None, description="节点类型过滤,逗号分隔"),
node_types: str | None = Query(None, description="节点类型过滤,逗号分隔"),
):
"""分页获取图数据,支持重要性过滤"""
try:
@@ -301,13 +301,13 @@ async def get_paginated_graph(
total_pages = (total_nodes + page_size - 1) // page_size
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_nodes)
paginated_nodes = nodes_with_importance[start_idx:end_idx]
node_ids = set(n["id"] for n in paginated_nodes)
# 只保留连接分页节点的边
paginated_edges = [
e for e in edges
e for e in edges
if e.get("from") in node_ids and e.get("to") in node_ids
]
@@ -383,7 +383,7 @@ async def get_clustered_graph(
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cluster_threshold: int) -> Dict:
def _cluster_graph_data(nodes: list[dict], edges: list[dict], max_nodes: int, cluster_threshold: int) -> dict:
"""简单的图聚类算法:按类型和连接度聚类"""
# 构建邻接表
adjacency = defaultdict(set)
@@ -412,21 +412,21 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl
for node in type_nodes:
importance = len(adjacency[node["id"]])
node_importance.append((node, importance))
node_importance.sort(key=lambda x: x[1], reverse=True)
# 保留前N个重要节点
keep_count = min(len(type_nodes), max_nodes // len(type_groups))
for node, importance in node_importance[:keep_count]:
clustered_nodes.append(node)
node_mapping[node["id"]] = node["id"]
# 其余节点聚合为一个超级节点
if len(node_importance) > keep_count:
clustered_node_ids = [n["id"] for n, _ in node_importance[keep_count:]]
cluster_id = f"cluster_{node_type}_{len(clustered_nodes)}"
cluster_label = f"{node_type} 集群 ({len(clustered_node_ids)}个节点)"
clustered_nodes.append({
"id": cluster_id,
"label": cluster_label,
@@ -436,7 +436,7 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl
"cluster_size": len(clustered_node_ids),
"clustered_nodes": clustered_node_ids[:10], # 只保留前10个用于展示
})
for node_id in clustered_node_ids:
node_mapping[node_id] = cluster_id
@@ -445,7 +445,7 @@ def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cl
for edge in edges:
from_id = node_mapping.get(edge["from"])
to_id = node_mapping.get(edge["to"])
if from_id and to_id and from_id != to_id:
edge_key = tuple(sorted([from_id, to_id]))
if edge_key not in edge_set:

View File

@@ -1,6 +1,5 @@
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Literal
from typing import Literal
from fastapi import APIRouter, HTTPException, Query