fix(short_term_manager): 缓存和相似度计算优化
This commit is contained in:
@@ -210,8 +210,8 @@ perceptual_activation_threshold = 3 # 转移激活阈值
|
||||
short_term_max_memories = 30 # 容量上限
|
||||
short_term_transfer_threshold = 0.6 # 转移重要性阈值
|
||||
short_term_overflow_strategy = "transfer_all" # 溢出策略(transfer_all/selective_cleanup)
|
||||
short_term_enable_force_cleanup = true # 启用泄压
|
||||
short_term_cleanup_keep_ratio = 0.9 # 泄压保留比例
|
||||
short_term_enable_force_cleanup = true # 启用泄压(已弃用)
|
||||
short_term_cleanup_keep_ratio = 0.9 # 泄压保留比例(已弃用)
|
||||
|
||||
# 长期记忆
|
||||
long_term_batch_size = 10 # 批量转移大小
|
||||
|
||||
@@ -79,6 +79,8 @@ class ShortTermMemoryManager:
|
||||
# 优化:快速查找索引
|
||||
self._memory_id_index: dict[str, ShortTermMemory] = {} # ID 快速查找
|
||||
self._similarity_cache: dict[str, dict[str, float]] = {} # 相似度缓存 {query_id: {target_id: sim}}
|
||||
self._emb_matrix: np.ndarray | None = None
|
||||
self._emb_matrix_mem_ids: list[str] | None = None
|
||||
|
||||
# 状态
|
||||
self._initialized = False
|
||||
@@ -384,6 +386,7 @@ class ShortTermMemoryManager:
|
||||
# 创建新记忆
|
||||
self.memories.append(new_memory)
|
||||
self._memory_id_index[new_memory.id] = new_memory # 更新索引
|
||||
self._invalidate_matrix_cache()
|
||||
logger.debug(f"创建新短期记忆: {new_memory.id}")
|
||||
return new_memory
|
||||
|
||||
@@ -410,6 +413,7 @@ class ShortTermMemoryManager:
|
||||
|
||||
# 清除此记忆的缓存
|
||||
self._similarity_cache.pop(target.id, None)
|
||||
self._invalidate_matrix_cache()
|
||||
|
||||
logger.debug(f"合并记忆到: {target.id}")
|
||||
return target
|
||||
@@ -437,6 +441,7 @@ class ShortTermMemoryManager:
|
||||
|
||||
# 清除此记忆的缓存
|
||||
self._similarity_cache.pop(target.id, None)
|
||||
self._invalidate_matrix_cache()
|
||||
|
||||
logger.debug(f"更新记忆: {target.id}")
|
||||
return target
|
||||
@@ -450,6 +455,7 @@ class ShortTermMemoryManager:
|
||||
# 保持独立
|
||||
self.memories.append(new_memory)
|
||||
self._memory_id_index[new_memory.id] = new_memory # 更新索引
|
||||
self._invalidate_matrix_cache()
|
||||
logger.debug(f"保持独立记忆: {new_memory.id}")
|
||||
return new_memory
|
||||
|
||||
@@ -489,24 +495,21 @@ class ShortTermMemoryManager:
|
||||
scored.sort(key=lambda x: x[1], reverse=True)
|
||||
return scored[:top_k]
|
||||
|
||||
# 并发计算所有相似度
|
||||
tasks = []
|
||||
for existing_mem in self.memories:
|
||||
if existing_mem.embedding is None:
|
||||
continue
|
||||
tasks.append(cosine_similarity_async(memory.embedding, existing_mem.embedding))
|
||||
|
||||
if not tasks:
|
||||
valid_memories, matrix = await self._ensure_embeddings_matrix()
|
||||
if not valid_memories or matrix is None:
|
||||
return []
|
||||
|
||||
similarities = await asyncio.gather(*tasks)
|
||||
q = memory.embedding.astype(np.float32)
|
||||
sims = await self._compute_cosine_similarities_vectorized(q, matrix)
|
||||
if sims is None or len(sims) == 0:
|
||||
return []
|
||||
|
||||
# 构建结果并缓存
|
||||
scored = []
|
||||
cache_entry = {}
|
||||
for existing_mem, similarity in zip([m for m in self.memories if m.embedding is not None], similarities):
|
||||
scored.append((existing_mem, similarity))
|
||||
cache_entry[existing_mem.id] = similarity
|
||||
for existing_mem, similarity in zip(valid_memories, sims):
|
||||
scored.append((existing_mem, float(similarity)))
|
||||
cache_entry[existing_mem.id] = float(similarity)
|
||||
|
||||
self._similarity_cache[memory.id] = cache_entry
|
||||
|
||||
@@ -535,7 +538,7 @@ class ShortTermMemoryManager:
|
||||
return None
|
||||
|
||||
embedding = await self.embedding_generator.generate(text)
|
||||
return embedding
|
||||
return self._normalize_embedding(embedding)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成向量失败: {e}")
|
||||
@@ -557,7 +560,7 @@ class ShortTermMemoryManager:
|
||||
return [None] * len(texts)
|
||||
|
||||
embeddings = await self.embedding_generator.generate_batch(texts)
|
||||
return embeddings
|
||||
return [self._normalize_embedding(e) for e in embeddings]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量生成向量失败: {e}")
|
||||
@@ -608,25 +611,18 @@ class ShortTermMemoryManager:
|
||||
if query_embedding is None or len(query_embedding) == 0:
|
||||
return []
|
||||
|
||||
# 并发计算所有相似度
|
||||
tasks = []
|
||||
valid_memories = []
|
||||
for memory in self.memories:
|
||||
if memory.embedding is None:
|
||||
continue
|
||||
valid_memories.append(memory)
|
||||
tasks.append(cosine_similarity_async(query_embedding, memory.embedding))
|
||||
|
||||
if not tasks:
|
||||
valid_memories, matrix = await self._ensure_embeddings_matrix()
|
||||
if not valid_memories or matrix is None:
|
||||
return []
|
||||
|
||||
similarities = await asyncio.gather(*tasks)
|
||||
q = query_embedding.astype(np.float32)
|
||||
sims = await self._compute_cosine_similarities_vectorized(q, matrix)
|
||||
|
||||
# 构建结果
|
||||
scored = []
|
||||
for memory, similarity in zip(valid_memories, similarities):
|
||||
if similarity >= similarity_threshold:
|
||||
scored.append((memory, similarity))
|
||||
for memory, similarity in zip(valid_memories, sims):
|
||||
if float(similarity) >= similarity_threshold:
|
||||
scored.append((memory, float(similarity)))
|
||||
|
||||
# 排序并取 TopK
|
||||
scored.sort(key=lambda x: x[1], reverse=True)
|
||||
@@ -695,6 +691,7 @@ class ShortTermMemoryManager:
|
||||
for mem_id in to_remove:
|
||||
self._memory_id_index.pop(mem_id, None)
|
||||
self._similarity_cache.pop(mem_id, None)
|
||||
self._invalidate_matrix_cache()
|
||||
|
||||
# 异步保存即可,不阻塞主流程
|
||||
asyncio.create_task(self._save_to_disk())
|
||||
@@ -753,6 +750,7 @@ class ShortTermMemoryManager:
|
||||
|
||||
# 异步保存
|
||||
asyncio.create_task(self._save_to_disk())
|
||||
self._invalidate_matrix_cache()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清除已转移记忆失败: {e}")
|
||||
@@ -815,6 +813,8 @@ class ShortTermMemoryManager:
|
||||
# 批量重新生成向量
|
||||
await self._reload_embeddings()
|
||||
|
||||
self._invalidate_matrix_cache()
|
||||
|
||||
logger.info(f"短期记忆已从 {load_path} 加载 ({len(self.memories)} 条)")
|
||||
|
||||
except Exception as e:
|
||||
@@ -848,6 +848,51 @@ class ShortTermMemoryManager:
|
||||
success_count += 1
|
||||
|
||||
logger.info(f"向量重新生成完成(成功: {success_count}/{len(memories_to_process)})")
|
||||
self._invalidate_matrix_cache()
|
||||
|
||||
def _normalize_embedding(self, emb: np.ndarray | None) -> np.ndarray | None:
|
||||
if emb is None:
|
||||
return None
|
||||
v = emb.astype(np.float32)
|
||||
n = float(np.linalg.norm(v))
|
||||
if n == 0.0:
|
||||
return v
|
||||
return v / n
|
||||
|
||||
def _invalidate_matrix_cache(self) -> None:
|
||||
self._emb_matrix = None
|
||||
self._emb_matrix_mem_ids = None
|
||||
self._similarity_cache.clear()
|
||||
|
||||
async def _ensure_embeddings_matrix(self) -> tuple[list[ShortTermMemory], np.ndarray | None]:
|
||||
if self._emb_matrix is not None and self._emb_matrix_mem_ids is not None:
|
||||
mems = [self._memory_id_index[mid] for mid in self._emb_matrix_mem_ids if mid in self._memory_id_index]
|
||||
return mems, self._emb_matrix
|
||||
|
||||
valid_memories = [m for m in self.memories if m.embedding is not None]
|
||||
if not valid_memories:
|
||||
self._emb_matrix = None
|
||||
self._emb_matrix_mem_ids = None
|
||||
return [], None
|
||||
|
||||
matrix = np.array([m.embedding for m in valid_memories], dtype=np.float32)
|
||||
self._emb_matrix = matrix
|
||||
self._emb_matrix_mem_ids = [m.id for m in valid_memories]
|
||||
return valid_memories, matrix
|
||||
|
||||
async def _compute_cosine_similarities_vectorized(
|
||||
self, query_embedding: np.ndarray, matrix: np.ndarray
|
||||
) -> np.ndarray | None:
|
||||
try:
|
||||
if query_embedding is None or len(query_embedding) == 0 or matrix is None or matrix.ndim != 2:
|
||||
return None
|
||||
return await asyncio.to_thread(self._compute_cosine_similarities_np, query_embedding, matrix)
|
||||
except Exception as e:
|
||||
logger.error(f"向量化相似度计算失败: {e}")
|
||||
return None
|
||||
|
||||
def _compute_cosine_similarities_np(self, q: np.ndarray, matrix: np.ndarray) -> np.ndarray:
|
||||
return matrix @ q
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""关闭管理器"""
|
||||
|
||||
Reference in New Issue
Block a user