优化缓存条目大小估算,添加向量存储标记,清理待处理消息逻辑
This commit is contained in:
@@ -17,7 +17,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, Generic, TypeVar
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.memory_utils import estimate_size_smart
|
from src.common.memory_utils import estimate_cache_item_size
|
||||||
|
|
||||||
logger = get_logger("cache_manager")
|
logger = get_logger("cache_manager")
|
||||||
|
|
||||||
@@ -237,7 +237,7 @@ class LRUCache(Generic[T]):
|
|||||||
使用深度递归估算,比 sys.getsizeof() 更准确
|
使用深度递归估算,比 sys.getsizeof() 更准确
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return estimate_size_smart(value)
|
return estimate_cache_item_size(value)
|
||||||
except (TypeError, AttributeError):
|
except (TypeError, AttributeError):
|
||||||
# 无法获取大小,返回默认值
|
# 无法获取大小,返回默认值
|
||||||
return 1024
|
return 1024
|
||||||
@@ -345,7 +345,7 @@ class MultiLevelCache:
|
|||||||
"""
|
"""
|
||||||
# 估算数据大小(如果未提供)
|
# 估算数据大小(如果未提供)
|
||||||
if size is None:
|
if size is None:
|
||||||
size = estimate_size_smart(value)
|
size = estimate_cache_item_size(value)
|
||||||
|
|
||||||
# 检查单个条目大小是否超过限制
|
# 检查单个条目大小是否超过限制
|
||||||
if size > self.max_item_size_bytes:
|
if size > self.max_item_size_bytes:
|
||||||
|
|||||||
@@ -169,6 +169,30 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
|
|||||||
return size
|
return size
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_cache_item_size(obj: Any) -> int:
|
||||||
|
"""
|
||||||
|
估算缓存条目的大小。
|
||||||
|
|
||||||
|
结合深度递归和 pickle 大小,选择更保守的估值,
|
||||||
|
以避免大量嵌套对象被低估。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
smart_size = estimate_size_smart(obj, max_depth=10, sample_large=False)
|
||||||
|
except Exception:
|
||||||
|
smart_size = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
deep_size = get_accurate_size(obj)
|
||||||
|
except Exception:
|
||||||
|
deep_size = 0
|
||||||
|
|
||||||
|
pickle_size = get_pickle_size(obj)
|
||||||
|
|
||||||
|
best = max(smart_size, deep_size, pickle_size)
|
||||||
|
# 至少返回基础大小,避免 0
|
||||||
|
return best or sys.getsizeof(obj)
|
||||||
|
|
||||||
|
|
||||||
def format_size(size_bytes: int) -> str:
|
def format_size(size_bytes: int) -> str:
|
||||||
"""
|
"""
|
||||||
格式化字节数为人类可读的格式
|
格式化字节数为人类可读的格式
|
||||||
|
|||||||
@@ -379,6 +379,7 @@ class MemoryBuilder:
|
|||||||
node_type=NodeType(node_data["node_type"]),
|
node_type=NodeType(node_data["node_type"]),
|
||||||
embedding=None, # 图存储不包含 embedding,需要从向量数据库获取
|
embedding=None, # 图存储不包含 embedding,需要从向量数据库获取
|
||||||
metadata=node_data.get("metadata", {}),
|
metadata=node_data.get("metadata", {}),
|
||||||
|
has_vector=node_data.get("has_vector", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@@ -424,6 +425,7 @@ class MemoryBuilder:
|
|||||||
node_type=NodeType(node_data["node_type"]),
|
node_type=NodeType(node_data["node_type"]),
|
||||||
embedding=None, # 图存储不包含 embedding,需要从向量数据库获取
|
embedding=None, # 图存储不包含 embedding,需要从向量数据库获取
|
||||||
metadata=node_data.get("metadata", {}),
|
metadata=node_data.get("metadata", {}),
|
||||||
|
has_vector=node_data.get("has_vector", False),
|
||||||
)
|
)
|
||||||
# 添加当前记忆ID到元数据
|
# 添加当前记忆ID到元数据
|
||||||
return existing_node
|
return existing_node
|
||||||
@@ -474,6 +476,7 @@ class MemoryBuilder:
|
|||||||
node_type=NodeType(node_data["node_type"]),
|
node_type=NodeType(node_data["node_type"]),
|
||||||
embedding=None, # 图存储不包含 embedding,需要从向量数据库获取
|
embedding=None, # 图存储不包含 embedding,需要从向量数据库获取
|
||||||
metadata=node_data.get("metadata", {}),
|
metadata=node_data.get("metadata", {}),
|
||||||
|
has_vector=node_data.get("has_vector", False),
|
||||||
)
|
)
|
||||||
return existing_node
|
return existing_node
|
||||||
|
|
||||||
|
|||||||
@@ -922,6 +922,9 @@ class LongTermMemoryManager:
|
|||||||
embedding=embedding
|
embedding=embedding
|
||||||
)
|
)
|
||||||
await self.memory_manager.vector_store.add_node(node)
|
await self.memory_manager.vector_store.add_node(node)
|
||||||
|
node.mark_vector_stored()
|
||||||
|
if self.memory_manager.graph_store.graph.has_node(node_id):
|
||||||
|
self.memory_manager.graph_store.graph.nodes[node_id]["has_vector"] = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"生成节点 embedding 失败: {e}")
|
logger.warning(f"生成节点 embedding 失败: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -359,9 +359,13 @@ class MemoryManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# 从向量存储删除节点
|
# 从向量存储删除节点
|
||||||
for node in memory.nodes:
|
if self.vector_store:
|
||||||
if node.embedding is not None:
|
for node in memory.nodes:
|
||||||
await self.vector_store.delete_node(node.id)
|
if getattr(node, "has_vector", False):
|
||||||
|
await self.vector_store.delete_node(node.id)
|
||||||
|
node.has_vector = False
|
||||||
|
if self.graph_store.graph.has_node(node.id):
|
||||||
|
self.graph_store.graph.nodes[node.id]["has_vector"] = False
|
||||||
|
|
||||||
# 从图存储删除记忆
|
# 从图存储删除记忆
|
||||||
self.graph_store.remove_memory(memory_id)
|
self.graph_store.remove_memory(memory_id)
|
||||||
@@ -900,13 +904,17 @@ class MemoryManager:
|
|||||||
|
|
||||||
# 1. 从向量存储删除节点的嵌入向量
|
# 1. 从向量存储删除节点的嵌入向量
|
||||||
deleted_vectors = 0
|
deleted_vectors = 0
|
||||||
for node in memory.nodes:
|
if self.vector_store:
|
||||||
if node.embedding is not None:
|
for node in memory.nodes:
|
||||||
try:
|
if getattr(node, "has_vector", False):
|
||||||
await self.vector_store.delete_node(node.id)
|
try:
|
||||||
deleted_vectors += 1
|
await self.vector_store.delete_node(node.id)
|
||||||
except Exception as e:
|
deleted_vectors += 1
|
||||||
logger.warning(f"删除节点向量失败 {node.id}: {e}")
|
node.has_vector = False
|
||||||
|
if self.graph_store.graph.has_node(node.id):
|
||||||
|
self.graph_store.graph.nodes[node.id]["has_vector"] = False
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"删除节点向量失败 {node.id}: {e}")
|
||||||
|
|
||||||
# 2. 从图存储删除记忆
|
# 2. 从图存储删除记忆
|
||||||
success = self.graph_store.remove_memory(memory_id, cleanup_orphans=False)
|
success = self.graph_store.remove_memory(memory_id, cleanup_orphans=False)
|
||||||
|
|||||||
@@ -121,6 +121,7 @@ class MemoryNode:
|
|||||||
node_type: NodeType # 节点类型
|
node_type: NodeType # 节点类型
|
||||||
embedding: np.ndarray | None = None # 语义向量(仅主题/客体需要)
|
embedding: np.ndarray | None = None # 语义向量(仅主题/客体需要)
|
||||||
metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
||||||
|
has_vector: bool = False # 是否已写入向量存储
|
||||||
created_at: datetime = field(default_factory=datetime.now)
|
created_at: datetime = field(default_factory=datetime.now)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -137,6 +138,7 @@ class MemoryNode:
|
|||||||
"node_type": self.node_type.value,
|
"node_type": self.node_type.value,
|
||||||
"metadata": self.metadata,
|
"metadata": self.metadata,
|
||||||
"created_at": self.created_at.isoformat(),
|
"created_at": self.created_at.isoformat(),
|
||||||
|
"has_vector": self.has_vector,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -150,12 +152,18 @@ class MemoryNode:
|
|||||||
embedding=None, # 向量数据需要从向量数据库中单独加载
|
embedding=None, # 向量数据需要从向量数据库中单独加载
|
||||||
metadata=data.get("metadata", {}),
|
metadata=data.get("metadata", {}),
|
||||||
created_at=datetime.fromisoformat(data["created_at"]),
|
created_at=datetime.fromisoformat(data["created_at"]),
|
||||||
|
has_vector=data.get("has_vector", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
def has_embedding(self) -> bool:
|
def has_embedding(self) -> bool:
|
||||||
"""是否有语义向量"""
|
"""是否持有可用的语义向量数据"""
|
||||||
return self.embedding is not None
|
return self.embedding is not None
|
||||||
|
|
||||||
|
def mark_vector_stored(self) -> None:
|
||||||
|
"""标记该节点已写入向量存储,并清理内存中的 embedding 数据。"""
|
||||||
|
self.has_vector = True
|
||||||
|
self.embedding = None
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"Node({self.node_type.value}: {self.content})"
|
return f"Node({self.node_type.value}: {self.content})"
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -40,6 +41,9 @@ class PerceptualMemoryManager:
|
|||||||
activation_threshold: int = 3,
|
activation_threshold: int = 3,
|
||||||
recall_top_k: int = 5,
|
recall_top_k: int = 5,
|
||||||
recall_similarity_threshold: float = 0.55,
|
recall_similarity_threshold: float = 0.55,
|
||||||
|
pending_message_ttl: int = 600,
|
||||||
|
max_pending_per_stream: int = 50,
|
||||||
|
max_pending_messages: int = 2000,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化感知记忆层管理器
|
初始化感知记忆层管理器
|
||||||
@@ -51,6 +55,9 @@ class PerceptualMemoryManager:
|
|||||||
activation_threshold: 激活阈值(召回次数)
|
activation_threshold: 激活阈值(召回次数)
|
||||||
recall_top_k: 召回时返回的最大块数
|
recall_top_k: 召回时返回的最大块数
|
||||||
recall_similarity_threshold: 召回的相似度阈值
|
recall_similarity_threshold: 召回的相似度阈值
|
||||||
|
pending_message_ttl: 待组块消息最大保留时间(秒)
|
||||||
|
max_pending_per_stream: 单个流允许的待组块消息上限
|
||||||
|
max_pending_messages: 全部流的待组块消息总上限
|
||||||
"""
|
"""
|
||||||
self.data_dir = data_dir or Path("data/memory_graph")
|
self.data_dir = data_dir or Path("data/memory_graph")
|
||||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -61,6 +68,9 @@ class PerceptualMemoryManager:
|
|||||||
self.activation_threshold = activation_threshold
|
self.activation_threshold = activation_threshold
|
||||||
self.recall_top_k = recall_top_k
|
self.recall_top_k = recall_top_k
|
||||||
self.recall_similarity_threshold = recall_similarity_threshold
|
self.recall_similarity_threshold = recall_similarity_threshold
|
||||||
|
self.pending_message_ttl = max(0, pending_message_ttl)
|
||||||
|
self.max_pending_per_stream = max(0, max_pending_per_stream)
|
||||||
|
self.max_pending_messages = max(0, max_pending_messages)
|
||||||
|
|
||||||
# 核心数据
|
# 核心数据
|
||||||
self.perceptual_memory: PerceptualMemory | None = None
|
self.perceptual_memory: PerceptualMemory | None = None
|
||||||
@@ -104,6 +114,8 @@ class PerceptualMemoryManager:
|
|||||||
max_blocks=self.max_blocks,
|
max_blocks=self.max_blocks,
|
||||||
block_size=self.block_size,
|
block_size=self.block_size,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self._cleanup_pending_messages()
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -138,18 +150,28 @@ class PerceptualMemoryManager:
|
|||||||
await self.initialize()
|
await self.initialize()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 添加到待处理消息队列
|
if not hasattr(self.perceptual_memory, "pending_messages"):
|
||||||
self.perceptual_memory.pending_messages.append(message)
|
self.perceptual_memory.pending_messages = []
|
||||||
|
|
||||||
|
self._cleanup_pending_messages()
|
||||||
|
|
||||||
stream_id = message.get("stream_id", "unknown")
|
stream_id = message.get("stream_id", "unknown")
|
||||||
|
self._normalize_message_timestamp(message)
|
||||||
|
self.perceptual_memory.pending_messages.append(message)
|
||||||
|
self._enforce_pending_limits(stream_id)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"消息已添加到待处理队列 (stream={stream_id[:8]}, "
|
f"消息已添加到待处理队列 (stream={stream_id[:8]}, "
|
||||||
f"总数={len(self.perceptual_memory.pending_messages)})"
|
f"总数={len(self.perceptual_memory.pending_messages)})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 按 stream_id 检查是否达到创建块的条件
|
# 按 stream_id 检查是否达到创建块的条件
|
||||||
stream_messages = [msg for msg in self.perceptual_memory.pending_messages if msg.get("stream_id") == stream_id]
|
stream_messages = [
|
||||||
|
msg
|
||||||
|
for msg in self.perceptual_memory.pending_messages
|
||||||
|
if msg.get("stream_id") == stream_id
|
||||||
|
]
|
||||||
|
|
||||||
if len(stream_messages) >= self.block_size:
|
if len(stream_messages) >= self.block_size:
|
||||||
new_block = await self._create_memory_block(stream_id)
|
new_block = await self._create_memory_block(stream_id)
|
||||||
return new_block
|
return new_block
|
||||||
@@ -171,6 +193,7 @@ class PerceptualMemoryManager:
|
|||||||
新创建的记忆块,失败返回 None
|
新创建的记忆块,失败返回 None
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
self._cleanup_pending_messages()
|
||||||
# 只取出指定 stream_id 的 block_size 条消息
|
# 只取出指定 stream_id 的 block_size 条消息
|
||||||
stream_messages = [msg for msg in self.perceptual_memory.pending_messages if msg.get("stream_id") == stream_id]
|
stream_messages = [msg for msg in self.perceptual_memory.pending_messages if msg.get("stream_id") == stream_id]
|
||||||
|
|
||||||
@@ -227,6 +250,82 @@ class PerceptualMemoryManager:
|
|||||||
logger.error(f"创建记忆块失败: {e}", exc_info=True)
|
logger.error(f"创建记忆块失败: {e}", exc_info=True)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _normalize_message_timestamp(self, message: dict[str, Any]) -> float:
|
||||||
|
"""确保消息包含 timestamp 字段并返回其值。"""
|
||||||
|
raw_ts = message.get("timestamp", message.get("time"))
|
||||||
|
try:
|
||||||
|
timestamp = float(raw_ts)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
timestamp = time.time()
|
||||||
|
message["timestamp"] = timestamp
|
||||||
|
return timestamp
|
||||||
|
|
||||||
|
def _cleanup_pending_messages(self) -> None:
|
||||||
|
"""移除过期/超限的待组块消息,避免内存无限增长。"""
|
||||||
|
if not self.perceptual_memory or not getattr(self.perceptual_memory, "pending_messages", None):
|
||||||
|
return
|
||||||
|
|
||||||
|
pending = self.perceptual_memory.pending_messages
|
||||||
|
now = time.time()
|
||||||
|
removed = 0
|
||||||
|
|
||||||
|
if self.pending_message_ttl > 0:
|
||||||
|
filtered: list[dict[str, Any]] = []
|
||||||
|
ttl = float(self.pending_message_ttl)
|
||||||
|
for msg in pending:
|
||||||
|
ts = msg.get("timestamp") or msg.get("time")
|
||||||
|
try:
|
||||||
|
ts_value = float(ts)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
ts_value = time.time()
|
||||||
|
msg["timestamp"] = ts_value
|
||||||
|
if now - ts_value <= ttl:
|
||||||
|
filtered.append(msg)
|
||||||
|
else:
|
||||||
|
removed += 1
|
||||||
|
|
||||||
|
if removed:
|
||||||
|
pending[:] = filtered
|
||||||
|
|
||||||
|
# 全局上限,按 FIFO 丢弃最旧的消息
|
||||||
|
if self.max_pending_messages > 0 and len(pending) > self.max_pending_messages:
|
||||||
|
overflow = len(pending) - self.max_pending_messages
|
||||||
|
del pending[:overflow]
|
||||||
|
removed += overflow
|
||||||
|
|
||||||
|
if removed:
|
||||||
|
logger.debug(f"清理待组块消息 {removed} 条 (剩余 {len(pending)})")
|
||||||
|
|
||||||
|
def _enforce_pending_limits(self, stream_id: str) -> None:
|
||||||
|
"""保证单个 stream 的待组块消息不超过限制。"""
|
||||||
|
if (
|
||||||
|
not self.perceptual_memory
|
||||||
|
or not getattr(self.perceptual_memory, "pending_messages", None)
|
||||||
|
or self.max_pending_per_stream <= 0
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
pending = self.perceptual_memory.pending_messages
|
||||||
|
indexes = [
|
||||||
|
idx
|
||||||
|
for idx, msg in enumerate(pending)
|
||||||
|
if msg.get("stream_id") == stream_id
|
||||||
|
]
|
||||||
|
|
||||||
|
overflow = len(indexes) - self.max_pending_per_stream
|
||||||
|
if overflow <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
for idx in reversed(indexes[:overflow]):
|
||||||
|
pending.pop(idx)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"stream %s 待组块消息过多,丢弃 %d 条旧消息 (保留 %d 条)",
|
||||||
|
stream_id,
|
||||||
|
overflow,
|
||||||
|
self.max_pending_per_stream,
|
||||||
|
)
|
||||||
|
|
||||||
def _combine_messages(self, messages: list[dict[str, Any]]) -> str:
|
def _combine_messages(self, messages: list[dict[str, Any]]) -> str:
|
||||||
"""
|
"""
|
||||||
合并多条消息为单一文本
|
合并多条消息为单一文本
|
||||||
@@ -508,6 +607,8 @@ class PerceptualMemoryManager:
|
|||||||
if not self.perceptual_memory:
|
if not self.perceptual_memory:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
self._cleanup_pending_messages()
|
||||||
|
|
||||||
# 保存到 JSON 文件
|
# 保存到 JSON 文件
|
||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ class GraphStore:
|
|||||||
node_type=node.node_type.value,
|
node_type=node.node_type.value,
|
||||||
created_at=node.created_at.isoformat(),
|
created_at=node.created_at.isoformat(),
|
||||||
metadata=node.metadata,
|
metadata=node.metadata,
|
||||||
|
has_vector=node.has_vector,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 更新节点到记忆的映射
|
# 更新节点到记忆的映射
|
||||||
@@ -120,6 +121,7 @@ class GraphStore:
|
|||||||
node_type=node_type,
|
node_type=node_type,
|
||||||
created_at=datetime.now().isoformat(),
|
created_at=datetime.now().isoformat(),
|
||||||
metadata=metadata or {},
|
metadata=metadata or {},
|
||||||
|
has_vector=(metadata or {}).get("has_vector", False),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 如果节点已存在,更新内容(可选)
|
# 如果节点已存在,更新内容(可选)
|
||||||
@@ -144,7 +146,8 @@ class GraphStore:
|
|||||||
id=node_id,
|
id=node_id,
|
||||||
content=content,
|
content=content,
|
||||||
node_type=node_type_enum,
|
node_type=node_type_enum,
|
||||||
metadata=metadata or {}
|
metadata=metadata or {},
|
||||||
|
has_vector=(metadata or {}).get("has_vector", False)
|
||||||
)
|
)
|
||||||
memory.nodes.append(new_node)
|
memory.nodes.append(new_node)
|
||||||
|
|
||||||
|
|||||||
@@ -1211,6 +1211,9 @@ class MemoryTools:
|
|||||||
for node in memory.nodes:
|
for node in memory.nodes:
|
||||||
if node.embedding is not None:
|
if node.embedding is not None:
|
||||||
await self.vector_store.add_node(node)
|
await self.vector_store.add_node(node)
|
||||||
|
node.mark_vector_stored()
|
||||||
|
if self.graph_store.graph.has_node(node.id):
|
||||||
|
self.graph_store.graph.nodes[node.id]["has_vector"] = True
|
||||||
|
|
||||||
async def _find_memory_by_description(self, description: str) -> Memory | None:
|
async def _find_memory_by_description(self, description: str) -> Memory | None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user