style: 统一代码风格并采用现代化类型注解
对整个代码库进行了一次全面的代码风格清理和现代化改造,主要包括: - 移除了所有文件中多余的行尾空格。 - 将类型提示更新为 PEP 585 和 PEP 604 引入的现代语法(例如,使用 `list` 代替 `List`,使用 `|` 代替 `Optional`)。 - 清理了多个模块中未被使用的导入语句。 - 移除了不含插值变量的冗余 f-string。 - 调整了部分 `__init__.py` 文件中的 `__all__` 导出顺序,以保持一致性。 这些改动旨在提升代码的可读性和可维护性,使其与现代 Python 最佳实践保持一致,但未修改任何核心逻辑。
This commit is contained in:
committed by
Windpicker-owo
parent
5fa004503c
commit
f44ece0b29
@@ -19,14 +19,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
|
||||||
|
|
||||||
# 添加项目根目录到路径
|
# 添加项目根目录到路径
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
|
||||||
async def generate_missing_embeddings(
|
async def generate_missing_embeddings(
|
||||||
target_node_types: List[str] = None,
|
target_node_types: list[str] = None,
|
||||||
batch_size: int = 50,
|
batch_size: int = 50,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -46,13 +45,13 @@ async def generate_missing_embeddings(
|
|||||||
target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value]
|
target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value]
|
||||||
|
|
||||||
print(f"\n{'='*80}")
|
print(f"\n{'='*80}")
|
||||||
print(f"🔧 为节点生成嵌入向量")
|
print("🔧 为节点生成嵌入向量")
|
||||||
print(f"{'='*80}\n")
|
print(f"{'='*80}\n")
|
||||||
print(f"目标节点类型: {', '.join(target_node_types)}")
|
print(f"目标节点类型: {', '.join(target_node_types)}")
|
||||||
print(f"批处理大小: {batch_size}\n")
|
print(f"批处理大小: {batch_size}\n")
|
||||||
|
|
||||||
# 1. 初始化记忆管理器
|
# 1. 初始化记忆管理器
|
||||||
print(f"🔧 正在初始化记忆管理器...")
|
print("🔧 正在初始化记忆管理器...")
|
||||||
await initialize_memory_manager()
|
await initialize_memory_manager()
|
||||||
manager = get_memory_manager()
|
manager = get_memory_manager()
|
||||||
|
|
||||||
@@ -60,10 +59,10 @@ async def generate_missing_embeddings(
|
|||||||
print("❌ 记忆管理器初始化失败")
|
print("❌ 记忆管理器初始化失败")
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"✅ 记忆管理器已初始化\n")
|
print("✅ 记忆管理器已初始化\n")
|
||||||
|
|
||||||
# 2. 获取已索引的节点ID
|
# 2. 获取已索引的节点ID
|
||||||
print(f"🔍 检查现有向量索引...")
|
print("🔍 检查现有向量索引...")
|
||||||
existing_node_ids = set()
|
existing_node_ids = set()
|
||||||
try:
|
try:
|
||||||
vector_count = manager.vector_store.collection.count()
|
vector_count = manager.vector_store.collection.count()
|
||||||
@@ -82,10 +81,10 @@ async def generate_missing_embeddings(
|
|||||||
print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n")
|
print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"获取已索引节点ID失败: {e}")
|
logger.warning(f"获取已索引节点ID失败: {e}")
|
||||||
print(f"⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
|
print("⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
|
||||||
|
|
||||||
# 3. 收集需要生成嵌入的节点
|
# 3. 收集需要生成嵌入的节点
|
||||||
print(f"🔍 扫描需要生成嵌入的节点...")
|
print("🔍 扫描需要生成嵌入的节点...")
|
||||||
all_memories = manager.graph_store.get_all_memories()
|
all_memories = manager.graph_store.get_all_memories()
|
||||||
|
|
||||||
nodes_to_process = []
|
nodes_to_process = []
|
||||||
@@ -110,7 +109,7 @@ async def generate_missing_embeddings(
|
|||||||
})
|
})
|
||||||
type_stats[node.node_type.value]["need_emb"] += 1
|
type_stats[node.node_type.value]["need_emb"] += 1
|
||||||
|
|
||||||
print(f"\n📊 扫描结果:")
|
print("\n📊 扫描结果:")
|
||||||
for node_type in target_node_types:
|
for node_type in target_node_types:
|
||||||
stats = type_stats[node_type]
|
stats = type_stats[node_type]
|
||||||
already_ok = stats["already_indexed"]
|
already_ok = stats["already_indexed"]
|
||||||
@@ -121,11 +120,11 @@ async def generate_missing_embeddings(
|
|||||||
print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n")
|
print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n")
|
||||||
|
|
||||||
if len(nodes_to_process) == 0:
|
if len(nodes_to_process) == 0:
|
||||||
print(f"✅ 所有节点已有嵌入向量,无需生成")
|
print("✅ 所有节点已有嵌入向量,无需生成")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 3. 批量生成嵌入
|
# 3. 批量生成嵌入
|
||||||
print(f"🚀 开始生成嵌入向量...\n")
|
print("🚀 开始生成嵌入向量...\n")
|
||||||
|
|
||||||
total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size
|
total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size
|
||||||
success_count = 0
|
success_count = 0
|
||||||
@@ -193,22 +192,22 @@ async def generate_missing_embeddings(
|
|||||||
print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n")
|
print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n")
|
||||||
|
|
||||||
# 4. 保存图数据(更新节点的 embedding 字段)
|
# 4. 保存图数据(更新节点的 embedding 字段)
|
||||||
print(f"💾 保存图数据...")
|
print("💾 保存图数据...")
|
||||||
try:
|
try:
|
||||||
await manager.persistence.save_graph_store(manager.graph_store)
|
await manager.persistence.save_graph_store(manager.graph_store)
|
||||||
print(f"✅ 图数据已保存\n")
|
print("✅ 图数据已保存\n")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存图数据失败", exc_info=True)
|
logger.error("保存图数据失败", exc_info=True)
|
||||||
print(f"❌ 保存失败: {e}\n")
|
print(f"❌ 保存失败: {e}\n")
|
||||||
|
|
||||||
# 5. 验证结果
|
# 5. 验证结果
|
||||||
print(f"🔍 验证向量索引...")
|
print("🔍 验证向量索引...")
|
||||||
final_vector_count = manager.vector_store.collection.count()
|
final_vector_count = manager.vector_store.collection.count()
|
||||||
stats = manager.graph_store.get_statistics()
|
stats = manager.graph_store.get_statistics()
|
||||||
total_nodes = stats["total_nodes"]
|
total_nodes = stats["total_nodes"]
|
||||||
|
|
||||||
print(f"\n{'='*80}")
|
print(f"\n{'='*80}")
|
||||||
print(f"📊 生成完成")
|
print("📊 生成完成")
|
||||||
print(f"{'='*80}")
|
print(f"{'='*80}")
|
||||||
print(f"处理节点数: {len(nodes_to_process)}")
|
print(f"处理节点数: {len(nodes_to_process)}")
|
||||||
print(f"成功生成: {success_count}")
|
print(f"成功生成: {success_count}")
|
||||||
@@ -219,7 +218,7 @@ async def generate_missing_embeddings(
|
|||||||
print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n")
|
print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n")
|
||||||
|
|
||||||
# 6. 测试搜索
|
# 6. 测试搜索
|
||||||
print(f"🧪 测试搜索功能...")
|
print("🧪 测试搜索功能...")
|
||||||
test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"]
|
test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"]
|
||||||
|
|
||||||
for query in test_queries:
|
for query in test_queries:
|
||||||
|
|||||||
@@ -4,13 +4,13 @@
|
|||||||
提供 Web API 用于可视化记忆图数据
|
提供 Web API 用于可视化记忆图数据
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from fastapi import APIRouter, HTTPException, Request, Query
|
from fastapi import APIRouter, HTTPException, Query, Request
|
||||||
from fastapi.responses import HTMLResponse, JSONResponse
|
from fastapi.responses import HTMLResponse, JSONResponse
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ router = APIRouter()
|
|||||||
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
|
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
|
||||||
|
|
||||||
|
|
||||||
def find_available_data_files() -> List[Path]:
|
def find_available_data_files() -> list[Path]:
|
||||||
"""查找所有可用的记忆图数据文件"""
|
"""查找所有可用的记忆图数据文件"""
|
||||||
files = []
|
files = []
|
||||||
if not data_dir.exists():
|
if not data_dir.exists():
|
||||||
@@ -62,7 +62,7 @@ def find_available_data_files() -> List[Path]:
|
|||||||
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
|
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
|
||||||
|
|
||||||
|
|
||||||
def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any]:
|
def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]:
|
||||||
"""从磁盘加载图数据"""
|
"""从磁盘加载图数据"""
|
||||||
global graph_data_cache, current_data_file
|
global graph_data_cache, current_data_file
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any
|
|||||||
if not graph_file.exists():
|
if not graph_file.exists():
|
||||||
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
|
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
|
||||||
|
|
||||||
with open(graph_file, "r", encoding="utf-8") as f:
|
with open(graph_file, encoding="utf-8") as f:
|
||||||
data = orjson.loads(f.read())
|
data = orjson.loads(f.read())
|
||||||
|
|
||||||
nodes = data.get("nodes", [])
|
nodes = data.get("nodes", [])
|
||||||
@@ -150,7 +150,7 @@ async def index(request: Request):
|
|||||||
return templates.TemplateResponse("visualizer.html", {"request": request})
|
return templates.TemplateResponse("visualizer.html", {"request": request})
|
||||||
|
|
||||||
|
|
||||||
def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]:
|
def _format_graph_data_from_manager(memory_manager) -> dict[str, Any]:
|
||||||
"""从 MemoryManager 提取并格式化图数据"""
|
"""从 MemoryManager 提取并格式化图数据"""
|
||||||
if not memory_manager.graph_store:
|
if not memory_manager.graph_store:
|
||||||
return {"nodes": [], "edges": [], "memories": [], "stats": {}}
|
return {"nodes": [], "edges": [], "memories": [], "stats": {}}
|
||||||
@@ -261,7 +261,7 @@ async def get_paginated_graph(
|
|||||||
page: int = Query(1, ge=1, description="页码"),
|
page: int = Query(1, ge=1, description="页码"),
|
||||||
page_size: int = Query(500, ge=100, le=2000, description="每页节点数"),
|
page_size: int = Query(500, ge=100, le=2000, description="每页节点数"),
|
||||||
min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"),
|
min_importance: float = Query(0.0, ge=0.0, le=1.0, description="最小重要性阈值"),
|
||||||
node_types: Optional[str] = Query(None, description="节点类型过滤,逗号分隔"),
|
node_types: str | None = Query(None, description="节点类型过滤,逗号分隔"),
|
||||||
):
|
):
|
||||||
"""分页获取图数据,支持重要性过滤"""
|
"""分页获取图数据,支持重要性过滤"""
|
||||||
try:
|
try:
|
||||||
@@ -383,7 +383,7 @@ async def get_clustered_graph(
|
|||||||
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
|
||||||
|
|
||||||
|
|
||||||
def _cluster_graph_data(nodes: List[Dict], edges: List[Dict], max_nodes: int, cluster_threshold: int) -> Dict:
|
def _cluster_graph_data(nodes: list[dict], edges: list[dict], max_nodes: int, cluster_threshold: int) -> dict:
|
||||||
"""简单的图聚类算法:按类型和连接度聚类"""
|
"""简单的图聚类算法:按类型和连接度聚类"""
|
||||||
# 构建邻接表
|
# 构建邻接表
|
||||||
adjacency = defaultdict(set)
|
adjacency = defaultdict(set)
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from collections import defaultdict
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Literal
|
from typing import Literal
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
|
||||||
|
|||||||
@@ -481,7 +481,7 @@ class MessageManager:
|
|||||||
try:
|
try:
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
chat_stream = await chat_manager.get_stream(stream_id)
|
chat_stream = await chat_manager.get_stream(stream_id)
|
||||||
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'):
|
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
|
||||||
chat_stream.context_manager.context.is_chatter_processing = is_processing
|
chat_stream.context_manager.context.is_chatter_processing = is_processing
|
||||||
logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}")
|
logger.debug(f"设置StreamContext处理状态: stream={stream_id}, processing={is_processing}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -517,7 +517,7 @@ class MessageManager:
|
|||||||
try:
|
try:
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
chat_stream = await chat_manager.get_stream(stream_id)
|
chat_stream = await chat_manager.get_stream(stream_id)
|
||||||
if chat_stream and hasattr(chat_stream.context_manager.context, 'is_chatter_processing'):
|
if chat_stream and hasattr(chat_stream.context_manager.context, "is_chatter_processing"):
|
||||||
return chat_stream.context_manager.context.is_chatter_processing
|
return chat_stream.context_manager.context.is_chatter_processing
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1172,10 +1172,10 @@ class DefaultReplyer:
|
|||||||
if unread_messages:
|
if unread_messages:
|
||||||
# 使用最后一条未读消息作为参考
|
# 使用最后一条未读消息作为参考
|
||||||
last_msg = unread_messages[-1]
|
last_msg = unread_messages[-1]
|
||||||
platform = last_msg.chat_info.platform if hasattr(last_msg, 'chat_info') else chat_stream.platform
|
platform = last_msg.chat_info.platform if hasattr(last_msg, "chat_info") else chat_stream.platform
|
||||||
user_id = last_msg.user_info.user_id if hasattr(last_msg, 'user_info') else ""
|
user_id = last_msg.user_info.user_id if hasattr(last_msg, "user_info") else ""
|
||||||
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, 'user_info') else ""
|
user_nickname = last_msg.user_info.user_nickname if hasattr(last_msg, "user_info") else ""
|
||||||
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, 'user_info') else ""
|
user_cardname = last_msg.user_info.user_cardname if hasattr(last_msg, "user_info") else ""
|
||||||
processed_plain_text = last_msg.processed_plain_text or ""
|
processed_plain_text = last_msg.processed_plain_text or ""
|
||||||
else:
|
else:
|
||||||
# 没有未读消息,使用默认值
|
# 没有未读消息,使用默认值
|
||||||
|
|||||||
@@ -5,12 +5,12 @@
|
|||||||
插件可以通过实现这些接口来扩展安全功能。
|
插件可以通过实现这些接口来扩展安全功能。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .interfaces import SecurityCheckResult, SecurityChecker
|
from .interfaces import SecurityChecker, SecurityCheckResult
|
||||||
from .manager import SecurityManager, get_security_manager
|
from .manager import SecurityManager, get_security_manager
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SecurityChecker",
|
|
||||||
"SecurityCheckResult",
|
"SecurityCheckResult",
|
||||||
|
"SecurityChecker",
|
||||||
"SecurityManager",
|
"SecurityManager",
|
||||||
"get_security_manager",
|
"get_security_manager",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from typing import Any
|
|||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
from .interfaces import SecurityAction, SecurityCheckResult, SecurityChecker, SecurityLevel
|
from .interfaces import SecurityAction, SecurityChecker, SecurityCheckResult, SecurityLevel
|
||||||
|
|
||||||
logger = get_logger("security.manager")
|
logger = get_logger("security.manager")
|
||||||
|
|
||||||
|
|||||||
@@ -9,11 +9,12 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import builtins
|
||||||
import time
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.memory_utils import estimate_size_smart
|
from src.common.memory_utils import estimate_size_smart
|
||||||
@@ -96,7 +97,7 @@ class LRUCache(Generic[T]):
|
|||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
self._stats = CacheStats()
|
self._stats = CacheStats()
|
||||||
|
|
||||||
async def get(self, key: str) -> Optional[T]:
|
async def get(self, key: str) -> T | None:
|
||||||
"""获取缓存值
|
"""获取缓存值
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -137,8 +138,8 @@ class LRUCache(Generic[T]):
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
value: T,
|
value: T,
|
||||||
size: Optional[int] = None,
|
size: int | None = None,
|
||||||
ttl: Optional[float] = None,
|
ttl: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""设置缓存值
|
"""设置缓存值
|
||||||
|
|
||||||
@@ -287,8 +288,8 @@ class MultiLevelCache:
|
|||||||
async def get(
|
async def get(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
loader: Optional[Callable[[], Any]] = None,
|
loader: Callable[[], Any] | None = None,
|
||||||
) -> Optional[Any]:
|
) -> Any | None:
|
||||||
"""从缓存获取数据
|
"""从缓存获取数据
|
||||||
|
|
||||||
查询顺序:L1 -> L2 -> loader
|
查询顺序:L1 -> L2 -> loader
|
||||||
@@ -329,8 +330,8 @@ class MultiLevelCache:
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
value: Any,
|
value: Any,
|
||||||
size: Optional[int] = None,
|
size: int | None = None,
|
||||||
ttl: Optional[float] = None,
|
ttl: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""设置缓存值
|
"""设置缓存值
|
||||||
|
|
||||||
@@ -390,7 +391,7 @@ class MultiLevelCache:
|
|||||||
await self.l2_cache.clear()
|
await self.l2_cache.clear()
|
||||||
logger.info("所有缓存已清空")
|
logger.info("所有缓存已清空")
|
||||||
|
|
||||||
async def get_stats(self) -> Dict[str, Any]:
|
async def get_stats(self) -> dict[str, Any]:
|
||||||
"""获取所有缓存层的统计信息(修复版:避免锁嵌套,使用超时)"""
|
"""获取所有缓存层的统计信息(修复版:避免锁嵌套,使用超时)"""
|
||||||
# 🔧 修复:并行获取统计信息,避免锁嵌套
|
# 🔧 修复:并行获取统计信息,避免锁嵌套
|
||||||
l1_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l1_cache, "L1"))
|
l1_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l1_cache, "L1"))
|
||||||
@@ -492,7 +493,7 @@ class MultiLevelCache:
|
|||||||
logger.error(f"{cache_name}统计获取异常: {e}")
|
logger.error(f"{cache_name}统计获取异常: {e}")
|
||||||
return CacheStats()
|
return CacheStats()
|
||||||
|
|
||||||
async def _get_cache_keys_safe(self, cache) -> Set[str]:
|
async def _get_cache_keys_safe(self, cache) -> builtins.set[str]:
|
||||||
"""安全获取缓存键集合(带超时)"""
|
"""安全获取缓存键集合(带超时)"""
|
||||||
try:
|
try:
|
||||||
# 快速获取键集合,使用超时避免死锁
|
# 快速获取键集合,使用超时避免死锁
|
||||||
@@ -507,12 +508,12 @@ class MultiLevelCache:
|
|||||||
logger.error(f"缓存键获取异常: {e}")
|
logger.error(f"缓存键获取异常: {e}")
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
async def _extract_keys_with_lock(self, cache) -> Set[str]:
|
async def _extract_keys_with_lock(self, cache) -> builtins.set[str]:
|
||||||
"""在锁保护下提取键集合"""
|
"""在锁保护下提取键集合"""
|
||||||
async with cache._lock:
|
async with cache._lock:
|
||||||
return set(cache._cache.keys())
|
return set(cache._cache.keys())
|
||||||
|
|
||||||
async def _calculate_memory_usage_safe(self, cache, keys: Set[str]) -> int:
|
async def _calculate_memory_usage_safe(self, cache, keys: builtins.set[str]) -> int:
|
||||||
"""安全计算内存使用(带超时)"""
|
"""安全计算内存使用(带超时)"""
|
||||||
if not keys:
|
if not keys:
|
||||||
return 0
|
return 0
|
||||||
@@ -529,7 +530,7 @@ class MultiLevelCache:
|
|||||||
logger.error(f"内存计算异常: {e}")
|
logger.error(f"内存计算异常: {e}")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
async def _calc_memory_with_lock(self, cache, keys: Set[str]) -> int:
|
async def _calc_memory_with_lock(self, cache, keys: builtins.set[str]) -> int:
|
||||||
"""在锁保护下计算内存使用"""
|
"""在锁保护下计算内存使用"""
|
||||||
total_size = 0
|
total_size = 0
|
||||||
async with cache._lock:
|
async with cache._lock:
|
||||||
@@ -749,7 +750,7 @@ class MultiLevelCache:
|
|||||||
|
|
||||||
|
|
||||||
# 全局缓存实例
|
# 全局缓存实例
|
||||||
_global_cache: Optional[MultiLevelCache] = None
|
_global_cache: MultiLevelCache | None = None
|
||||||
_cache_lock = asyncio.Lock()
|
_cache_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import socket
|
|||||||
|
|
||||||
from fastapi import APIRouter, FastAPI
|
from fastapi import APIRouter, FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.staticfiles import StaticFiles
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from uvicorn import Config
|
from uvicorn import Config
|
||||||
from uvicorn import Server as UvicornServer
|
from uvicorn import Server as UvicornServer
|
||||||
|
|||||||
@@ -1098,7 +1098,7 @@ class MemoryManager:
|
|||||||
# 2. 清理孤立边(指向已删除节点的边)
|
# 2. 清理孤立边(指向已删除节点的边)
|
||||||
edges_to_remove = []
|
edges_to_remove = []
|
||||||
|
|
||||||
for source, target, edge_id in self.graph_store.graph.edges(data='edge_id'):
|
for source, target, edge_id in self.graph_store.graph.edges(data="edge_id"):
|
||||||
# 检查边的源节点和目标节点是否还存在于node_to_memories中
|
# 检查边的源节点和目标节点是否还存在于node_to_memories中
|
||||||
if source not in self.graph_store.node_to_memories or \
|
if source not in self.graph_store.node_to_memories or \
|
||||||
target not in self.graph_store.node_to_memories:
|
target not in self.graph_store.node_to_memories:
|
||||||
@@ -2301,7 +2301,7 @@ class MemoryManager:
|
|||||||
# 使用 asyncio.wait_for 来支持取消
|
# 使用 asyncio.wait_for 来支持取消
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
asyncio.sleep(initial_delay),
|
asyncio.sleep(initial_delay),
|
||||||
timeout=float('inf') # 允许随时取消
|
timeout=float("inf") # 允许随时取消
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查是否仍然需要运行
|
# 检查是否仍然需要运行
|
||||||
|
|||||||
@@ -559,8 +559,8 @@ class MemoryTools:
|
|||||||
)
|
)
|
||||||
if len(initial_memory_ids) == 0:
|
if len(initial_memory_ids) == 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"⚠️ 向量搜索未找到任何记忆!"
|
"⚠️ 向量搜索未找到任何记忆!"
|
||||||
f"可能原因:1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大"
|
"可能原因:1) 嵌入模型理解问题 2) 记忆节点未建立索引 3) 查询表达与存储内容差异过大"
|
||||||
)
|
)
|
||||||
# 输出相似节点的详细信息用于调试
|
# 输出相似节点的详细信息用于调试
|
||||||
if similar_nodes:
|
if similar_nodes:
|
||||||
@@ -738,12 +738,12 @@ class MemoryTools:
|
|||||||
activation_score = memory.activation
|
activation_score = memory.activation
|
||||||
|
|
||||||
# 🆕 动态权重计算:使用配置的基础权重 + 根据记忆类型微调
|
# 🆕 动态权重计算:使用配置的基础权重 + 根据记忆类型微调
|
||||||
memory_type = memory.memory_type.value if hasattr(memory.memory_type, 'value') else str(memory.memory_type)
|
memory_type = memory.memory_type.value if hasattr(memory.memory_type, "value") else str(memory.memory_type)
|
||||||
|
|
||||||
# 检测记忆的主要节点类型
|
# 检测记忆的主要节点类型
|
||||||
node_types_count = {}
|
node_types_count = {}
|
||||||
for node in memory.nodes:
|
for node in memory.nodes:
|
||||||
nt = node.node_type.value if hasattr(node.node_type, 'value') else str(node.node_type)
|
nt = node.node_type.value if hasattr(node.node_type, "value") else str(node.node_type)
|
||||||
node_types_count[nt] = node_types_count.get(nt, 0) + 1
|
node_types_count[nt] = node_types_count.get(nt, 0) + 1
|
||||||
|
|
||||||
dominant_node_type = max(node_types_count.items(), key=lambda x: x[1])[0] if node_types_count else "unknown"
|
dominant_node_type = max(node_types_count.items(), key=lambda x: x[1])[0] if node_types_count else "unknown"
|
||||||
@@ -1092,6 +1092,7 @@ class MemoryTools:
|
|||||||
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300)
|
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300)
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
# 清理Markdown代码块
|
# 清理Markdown代码块
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ async def expand_memories_with_semantic_filter(
|
|||||||
source_node_id = edge.source_id
|
source_node_id = edge.source_id
|
||||||
|
|
||||||
# 🆕 根据边类型设置权重(优先扩展REFERENCE、ATTRIBUTE相关的边)
|
# 🆕 根据边类型设置权重(优先扩展REFERENCE、ATTRIBUTE相关的边)
|
||||||
edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
|
edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, "value") else str(edge.edge_type)
|
||||||
if edge_type_str == "REFERENCE":
|
if edge_type_str == "REFERENCE":
|
||||||
edge_weight = 1.3 # REFERENCE边权重最高(引用关系)
|
edge_weight = 1.3 # REFERENCE边权重最高(引用关系)
|
||||||
elif edge_type_str in ["ATTRIBUTE", "HAS_PROPERTY"]:
|
elif edge_type_str in ["ATTRIBUTE", "HAS_PROPERTY"]:
|
||||||
|
|||||||
@@ -31,4 +31,4 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
# 导入插件主类
|
# 导入插件主类
|
||||||
from .plugin import AntiInjectionPlugin
|
from .plugin import AntiInjectionPlugin
|
||||||
|
|
||||||
__all__ = ["__plugin_meta__", "AntiInjectionPlugin"]
|
__all__ = ["AntiInjectionPlugin", "__plugin_meta__"]
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import time
|
|||||||
|
|
||||||
from src.chat.security.interfaces import (
|
from src.chat.security.interfaces import (
|
||||||
SecurityAction,
|
SecurityAction,
|
||||||
SecurityCheckResult,
|
|
||||||
SecurityChecker,
|
SecurityChecker,
|
||||||
|
SecurityCheckResult,
|
||||||
SecurityLevel,
|
SecurityLevel,
|
||||||
)
|
)
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
处理检测结果,执行相应的动作(允许/监控/加盾/阻止/反击)。
|
处理检测结果,执行相应的动作(允许/监控/加盾/阻止/反击)。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from src.chat.security.interfaces import SecurityAction, SecurityCheckResult
|
from src.chat.security.interfaces import SecurityCheckResult
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
from .counter_attack import CounterAttackGenerator
|
from .counter_attack import CounterAttackGenerator
|
||||||
|
|||||||
@@ -6,10 +6,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
import filetype
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import filetype
|
||||||
from maim_message import UserInfo
|
from maim_message import UserInfo
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import uuid
|
|||||||
import weakref
|
import weakref
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from contextlib import suppress
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|||||||
Reference in New Issue
Block a user