style: 统一代码风格并采用现代化类型注解

对整个代码库进行了一次全面的代码风格清理和现代化改造,主要包括:

- 移除了所有文件中多余的行尾空格。
- 将类型提示更新为 PEP 585 和 PEP 604 引入的现代语法(例如,使用 `list` 代替 `List`,使用 `|` 代替 `Optional`)。
- 清理了多个模块中未被使用的导入语句。
- 移除了不含插值变量的冗余 f-string。
- 调整了部分 `__init__.py` 文件中的 `__all__` 导出顺序,以保持一致性。

这些改动旨在提升代码的可读性和可维护性,使其与现代 Python 最佳实践保持一致,但未修改任何核心逻辑。
This commit is contained in:
minecraft1024a
2025-11-12 12:49:40 +08:00
parent daf8ea7e6a
commit 0e1e9935b2
33 changed files with 227 additions and 229 deletions

View File

@@ -19,14 +19,13 @@
import asyncio
import sys
from pathlib import Path
from typing import List
# 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent))
async def generate_missing_embeddings(
target_node_types: List[str] = None,
target_node_types: list[str] = None,
batch_size: int = 50,
):
"""
@@ -46,13 +45,13 @@ async def generate_missing_embeddings(
target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value]
print(f"\n{'='*80}")
print(f"🔧 为节点生成嵌入向量")
print("🔧 为节点生成嵌入向量")
print(f"{'='*80}\n")
print(f"目标节点类型: {', '.join(target_node_types)}")
print(f"批处理大小: {batch_size}\n")
# 1. 初始化记忆管理器
print(f"🔧 正在初始化记忆管理器...")
print("🔧 正在初始化记忆管理器...")
await initialize_memory_manager()
manager = get_memory_manager()
@@ -60,10 +59,10 @@ async def generate_missing_embeddings(
print("❌ 记忆管理器初始化失败")
return
print(f"✅ 记忆管理器已初始化\n")
print("✅ 记忆管理器已初始化\n")
# 2. 获取已索引的节点ID
print(f"🔍 检查现有向量索引...")
print("🔍 检查现有向量索引...")
existing_node_ids = set()
try:
vector_count = manager.vector_store.collection.count()
@@ -82,10 +81,10 @@ async def generate_missing_embeddings(
print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n")
except Exception as e:
logger.warning(f"获取已索引节点ID失败: {e}")
print(f"⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
print("⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
# 3. 收集需要生成嵌入的节点
print(f"🔍 扫描需要生成嵌入的节点...")
print("🔍 扫描需要生成嵌入的节点...")
all_memories = manager.graph_store.get_all_memories()
nodes_to_process = []
@@ -110,7 +109,7 @@ async def generate_missing_embeddings(
})
type_stats[node.node_type.value]["need_emb"] += 1
print(f"\n📊 扫描结果:")
print("\n📊 扫描结果:")
for node_type in target_node_types:
stats = type_stats[node_type]
already_ok = stats["already_indexed"]
@@ -121,11 +120,11 @@ async def generate_missing_embeddings(
print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n")
if len(nodes_to_process) == 0:
print(f"✅ 所有节点已有嵌入向量,无需生成")
print("✅ 所有节点已有嵌入向量,无需生成")
return
# 3. 批量生成嵌入
print(f"🚀 开始生成嵌入向量...\n")
print("🚀 开始生成嵌入向量...\n")
total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size
success_count = 0
@@ -193,22 +192,22 @@ async def generate_missing_embeddings(
print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n")
# 4. 保存图数据(更新节点的 embedding 字段)
print(f"💾 保存图数据...")
print("💾 保存图数据...")
try:
await manager.persistence.save_graph_store(manager.graph_store)
print(f"✅ 图数据已保存\n")
print("✅ 图数据已保存\n")
except Exception as e:
logger.error(f"保存图数据失败", exc_info=True)
logger.error("保存图数据失败", exc_info=True)
print(f"❌ 保存失败: {e}\n")
# 5. 验证结果
print(f"🔍 验证向量索引...")
print("🔍 验证向量索引...")
final_vector_count = manager.vector_store.collection.count()
stats = manager.graph_store.get_statistics()
total_nodes = stats["total_nodes"]
print(f"\n{'='*80}")
print(f"📊 生成完成")
print("📊 生成完成")
print(f"{'='*80}")
print(f"处理节点数: {len(nodes_to_process)}")
print(f"成功生成: {success_count}")
@@ -219,7 +218,7 @@ async def generate_missing_embeddings(
print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n")
# 6. 测试搜索
print(f"🧪 测试搜索功能...")
print("🧪 测试搜索功能...")
test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"]
for query in test_queries:

View File

@@ -4,13 +4,13 @@
提供 Web API 用于可视化记忆图数据
"""
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from collections import defaultdict
from typing import Any
import orjson
from fastapi import APIRouter, HTTPException, Request, Query
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
@@ -29,7 +29,7 @@ router = APIRouter()
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
def find_available_data_files() -> List[Path]:
def find_available_data_files() -> list[Path]:
"""查找所有可用的记忆图数据文件"""
files = []
if not data_dir.exists():
@@ -62,7 +62,7 @@ def find_available_data_files() -> List[Path]:
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any]:
def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]:
"""从磁盘加载图数据"""
global graph_data_cache, current_data_file
@@ -85,7 +85,7 @@ def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any
if not graph_file.exists():
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
with open(graph_file, "r", encoding="utf-8") as f:
with open(graph_file, encoding="utf-8") as f:
data = orjson.loads(f.read())
nodes = data.get("nodes", [])
@@ -150,7 +150,7 @@ async def index(request: Request):
return templates.TemplateResponse("visualizer.html", {"request": request})
def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]:
def _format_graph_data_from_manager(memory_manager) -> dict[str, Any]:
"""从 MemoryManager 提取并格式化图数据"""
if not memory_manager.graph_store:
return {"nodes": [], "edges": [], "memories": [], "stats": {}}
@@ -261,7 +261,7 @@ async def get_paginated_graph(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(500, ge=100, le=2000, description="每页节点数"),
min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"),
node_types: Optional[str] = Query(None, description="节点类型过滤,逗号分隔"),
node_types: str | None = Query(None, description="节点类型过滤,逗号分隔"),
):
"""分页获取图数据,支持重要性过滤"""
try:
@@ -383,7 +383,7 @@ async def get_clustered_graph(
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cluster_threshold: int) -> Dict:
def _cluster_graph_data(nodes: list[dict], edges: list[dict], max_nodes: int, cluster_threshold: int) -> dict:
"""简单的图聚类算法:按类型和连接度聚类"""
# 构建邻接表
adjacency = defaultdict(set)

View File

@@ -1,6 +1,5 @@
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Literal
from typing import Literal
from fastapi import APIRouter, HTTPException, Query

View File

@@ -481,7 +481,7 @@ class MessageManager:
try:
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'):
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
chat_stream.context_manager.context.is_chatter_processing = is_processing
logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}")
except Exception as e:
@@ -517,7 +517,7 @@ class MessageManager:
try:
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'):
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
return chat_stream.context_manager.context.is_chatter_processing
except Exception:
pass

View File

@@ -1177,10 +1177,10 @@ class DefaultReplyer:
if unread_messages:
# 使用最后一条未读消息作为参考
last_msg = unread_messages[-1]
platform = last_msg.chat_info.platform if hasattr(last_msg, 'chat_info') else chat_stream.platform
user_id = last_msg.user_info.user_id if hasattr(last_msg, 'user_info') else ""
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, 'user_info') else ""
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, 'user_info') else ""
platform = last_msg.chat_info.platform if hasattr(last_msg, "chat_info") else chat_stream.platform
user_id = last_msg.user_info.user_id if hasattr(last_msg, "user_info") else ""
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, "user_info") else ""
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, "user_info") else ""
processed_plain_text = last_msg.processed_plain_text or ""
else:
# 没有未读消息,使用默认值

View File

@@ -5,12 +5,12 @@
插件可以通过实现这些接口来扩展安全功能。
"""
from .interfaces import SecurityCheckResult, SecurityChecker
from .interfaces import SecurityChecker, SecurityCheckResult
from .manager import SecurityManager, get_security_manager
__all__ = [
"SecurityChecker",
"SecurityCheckResult",
"SecurityChecker",
"SecurityManager",
"get_security_manager",
]

View File

@@ -10,7 +10,7 @@ from typing import Any
from src.common.logger import get_logger
from .interfaces import SecurityAction, SecurityCheckResult, SecurityChecker, SecurityLevel
from .interfaces import SecurityAction, SecurityChecker, SecurityCheckResult, SecurityLevel
logger = get_logger("security.manager")

View File

@@ -9,11 +9,12 @@
"""
import asyncio
import builtins
import time
from collections import OrderedDict
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union
from typing import Any, Generic, TypeVar
from src.common.logger import get_logger
from src.common.memory_utils import estimate_size_smart
@@ -96,7 +97,7 @@ class LRUCache(Generic[T]):
self._lock = asyncio.Lock()
self._stats = CacheStats()
async def get(self, key: str) -> Optional[T]:
async def get(self, key: str) -> T | None:
"""获取缓存值
Args:
@@ -137,8 +138,8 @@ class LRUCache(Generic[T]):
self,
key: str,
value: T,
size: Optional[int] = None,
ttl: Optional[float] = None,
size: int | None = None,
ttl: float | None = None,
) -> None:
"""设置缓存值
@@ -287,8 +288,8 @@ class MultiLevelCache:
async def get(
self,
key: str,
loader: Optional[Callable[[], Any]] = None,
) -> Optional[Any]:
loader: Callable[[], Any] | None = None,
) -> Any | None:
"""从缓存获取数据
查询顺序L1 -> L2 -> loader
@@ -329,8 +330,8 @@ class MultiLevelCache:
self,
key: str,
value: Any,
size: Optional[int] = None,
ttl: Optional[float] = None,
size: int | None = None,
ttl: float | None = None,
) -> None:
"""设置缓存值
@@ -390,7 +391,7 @@ class MultiLevelCache:
await self.l2_cache.clear()
logger.info("所有缓存已清空")
async def get_stats(self) -> Dict[str, Any]:
async def get_stats(self) -> dict[str, Any]:
"""获取所有缓存层的统计信息(修复版:避免锁嵌套,使用超时)"""
# 🔧 修复:并行获取统计信息,避免锁嵌套
l1_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l1_cache, "L1"))
@@ -492,7 +493,7 @@ class MultiLevelCache:
logger.error(f"{cache_name}统计获取异常: {e}")
return CacheStats()
async def _get_cache_keys_safe(self, cache) -> Set[str]:
async def _get_cache_keys_safe(self, cache) -> builtins.set[str]:
"""安全获取缓存键集合(带超时)"""
try:
# 快速获取键集合,使用超时避免死锁
@@ -507,12 +508,12 @@ class MultiLevelCache:
logger.error(f"缓存键获取异常: {e}")
return set()
async def _extract_keys_with_lock(self, cache) -> Set[str]:
async def _extract_keys_with_lock(self, cache) -> builtins.set[str]:
"""在锁保护下提取键集合"""
async with cache._lock:
return set(cache._cache.keys())
async def _calculate_memory_usage_safe(self, cache, keys: Set[str]) -> int:
async def _calculate_memory_usage_safe(self, cache, keys: builtins.set[str]) -> int:
"""安全计算内存使用(带超时)"""
if not keys:
return 0
@@ -529,7 +530,7 @@ class MultiLevelCache:
logger.error(f"内存计算异常: {e}")
return 0
async def _calc_memory_with_lock(self, cache, keys: Set[str]) -> int:
async def _calc_memory_with_lock(self, cache, keys: builtins.set[str]) -> int:
"""在锁保护下计算内存使用"""
total_size = 0
async with cache._lock:
@@ -749,7 +750,7 @@ class MultiLevelCache:
# 全局缓存实例
_global_cache: Optional[MultiLevelCache] = None
_global_cache: MultiLevelCache | None = None
_cache_lock = asyncio.Lock()

View File

@@ -3,7 +3,6 @@ import socket
from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from rich.traceback import install
from uvicorn import Config
from uvicorn import Server as UvicornServer

View File

@@ -1095,7 +1095,7 @@ class MemoryManager:
# 2. 清理孤立边(指向已删除节点的边)
edges_to_remove = []
for source, target, edge_id in self.graph_store.graph.edges(data='edge_id'):
for source, target, edge_id in self.graph_store.graph.edges(data="edge_id"):
# 检查边的源节点和目标节点是否还存在于node_to_memories中
if source not in self.graph_store.node_to_memories or \
target not in self.graph_store.node_to_memories:
@@ -2298,7 +2298,7 @@ class MemoryManager:
# 使用 asyncio.wait_for 来支持取消
await asyncio.wait_for(
asyncio.sleep(initial_delay),
timeout=float('inf') # 允许随时取消
timeout=float("inf") # 允许随时取消
)
# 检查是否仍然需要运行

View File

@@ -554,8 +554,8 @@ class MemoryTools:
)
if len(initial_memory_ids) == 0:
logger.warning(
f"⚠️ 向量搜索未找到任何记忆!"
f"可能原因1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大"
"⚠️ 向量搜索未找到任何记忆!"
"可能原因1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大"
)
# 输出相似节点的详细信息用于调试
if similar_nodes:
@@ -659,12 +659,12 @@ class MemoryTools:
activation_score = memory.activation
# 🆕 动态权重计算:使用配置的基础权重 + 根据记忆类型微调
memory_type = memory.memory_type.value if hasattr(memory.memory_type, 'value') else str(memory.memory_type)
memory_type = memory.memory_type.value if hasattr(memory.memory_type, "value") else str(memory.memory_type)
# 检测记忆的主要节点类型
node_types_count = {}
for node in memory.nodes:
nt = node.node_type.value if hasattr(node.node_type, 'value') else str(node.node_type)
nt = node.node_type.value if hasattr(node.node_type, "value") else str(node.node_type)
node_types_count[nt] = node_types_count.get(nt, 0) + 1
dominant_node_type = max(node_types_count.items(), key=lambda x: x[1])[0] if node_types_count else "unknown"
@@ -1000,6 +1000,7 @@ class MemoryTools:
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300)
import re
import orjson
# 清理Markdown代码块

View File

@@ -97,7 +97,7 @@ async def expand_memories_with_semantic_filter(
source_node_id = edge.source_id
# 🆕 根据边类型设置权重优先扩展REFERENCE、ATTRIBUTE相关的边
edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, "value") else str(edge.edge_type)
if edge_type_str == "REFERENCE":
edge_weight = 1.3 # REFERENCE边权重最高引用关系
elif edge_type_str in ["ATTRIBUTE", "HAS_PROPERTY"]:

View File

@@ -31,4 +31,4 @@ __plugin_meta__ = PluginMetadata(
# 导入插件主类
from .plugin import AntiInjectionPlugin
__all__ = ["__plugin_meta__", "AntiInjectionPlugin"]
__all__ = ["AntiInjectionPlugin", "__plugin_meta__"]

View File

@@ -8,8 +8,8 @@ import time
from src.chat.security.interfaces import (
SecurityAction,
SecurityCheckResult,
SecurityChecker,
SecurityCheckResult,
SecurityLevel,
)
from src.common.logger import get_logger

View File

@@ -4,7 +4,7 @@
处理检测结果,执行相应的动作(允许/监控/加盾/阻止/反击)。
"""
from src.chat.security.interfaces import SecurityAction, SecurityCheckResult
from src.chat.security.interfaces import SecurityCheckResult
from src.common.logger import get_logger
from .counter_attack import CounterAttackGenerator

View File

@@ -6,10 +6,10 @@
import asyncio
import base64
import datetime
import filetype
from collections.abc import Callable
import aiohttp
import filetype
from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager

View File

@@ -17,7 +17,6 @@ import uuid
import weakref
from collections import defaultdict
from collections.abc import Awaitable, Callable
from contextlib import suppress
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum