feat(short_term_manager): 优化短期记忆管理器,增加哈希索引和相似度缓存,提升查找和计算性能
This commit is contained in:
@@ -14,6 +14,7 @@ import uuid
|
||||
import json_repair
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -64,6 +65,10 @@ class ShortTermMemoryManager:
|
||||
# 核心数据
|
||||
self.memories: list[ShortTermMemory] = []
|
||||
self.embedding_generator: EmbeddingGenerator | None = None
|
||||
|
||||
# 优化:快速查找索引
|
||||
self._memory_id_index: dict[str, ShortTermMemory] = {} # ID 快速查找
|
||||
self._similarity_cache: dict[str, dict[str, float]] = {} # 相似度缓存 {query_id: {target_id: sim}}
|
||||
|
||||
# 状态
|
||||
self._initialized = False
|
||||
@@ -366,6 +371,7 @@ class ShortTermMemoryManager:
|
||||
if decision.operation == ShortTermOperation.CREATE_NEW:
|
||||
# 创建新记忆
|
||||
self.memories.append(new_memory)
|
||||
self._memory_id_index[new_memory.id] = new_memory # 更新索引
|
||||
logger.debug(f"创建新短期记忆: {new_memory.id}")
|
||||
return new_memory
|
||||
|
||||
@@ -375,6 +381,7 @@ class ShortTermMemoryManager:
|
||||
if not target:
|
||||
logger.warning(f"目标记忆不存在,改为创建新记忆: {decision.target_memory_id}")
|
||||
self.memories.append(new_memory)
|
||||
self._memory_id_index[new_memory.id] = new_memory
|
||||
return new_memory
|
||||
|
||||
# 更新内容
|
||||
@@ -388,6 +395,9 @@ class ShortTermMemoryManager:
|
||||
# 重新生成向量
|
||||
target.embedding = await self._generate_embedding(target.content)
|
||||
target.update_access()
|
||||
|
||||
# 清除此记忆的缓存
|
||||
self._similarity_cache.pop(target.id, None)
|
||||
|
||||
logger.debug(f"合并记忆到: {target.id}")
|
||||
return target
|
||||
@@ -398,6 +408,7 @@ class ShortTermMemoryManager:
|
||||
if not target:
|
||||
logger.warning(f"目标记忆不存在,改为创建新记忆: {decision.target_memory_id}")
|
||||
self.memories.append(new_memory)
|
||||
self._memory_id_index[new_memory.id] = new_memory
|
||||
return new_memory
|
||||
|
||||
# 更新内容
|
||||
@@ -411,6 +422,9 @@ class ShortTermMemoryManager:
|
||||
|
||||
target.source_block_ids.extend(new_memory.source_block_ids)
|
||||
target.update_access()
|
||||
|
||||
# 清除此记忆的缓存
|
||||
self._similarity_cache.pop(target.id, None)
|
||||
|
||||
logger.debug(f"更新记忆: {target.id}")
|
||||
return target
|
||||
@@ -423,12 +437,14 @@ class ShortTermMemoryManager:
|
||||
elif decision.operation == ShortTermOperation.KEEP_SEPARATE:
|
||||
# 保持独立
|
||||
self.memories.append(new_memory)
|
||||
self._memory_id_index[new_memory.id] = new_memory # 更新索引
|
||||
logger.debug(f"保持独立记忆: {new_memory.id}")
|
||||
return new_memory
|
||||
|
||||
else:
|
||||
logger.warning(f"未知操作类型: {decision.operation},默认创建新记忆")
|
||||
self.memories.append(new_memory)
|
||||
self._memory_id_index[new_memory.id] = new_memory
|
||||
return new_memory
|
||||
|
||||
except Exception as e:
|
||||
@@ -439,7 +455,7 @@ class ShortTermMemoryManager:
|
||||
self, memory: ShortTermMemory, top_k: int = 5
|
||||
) -> list[tuple[ShortTermMemory, float]]:
|
||||
"""
|
||||
查找与给定记忆相似的现有记忆
|
||||
查找与给定记忆相似的现有记忆(优化版:并发计算 + 缓存)
|
||||
|
||||
Args:
|
||||
memory: 目标记忆
|
||||
@@ -452,13 +468,35 @@ class ShortTermMemoryManager:
|
||||
return []
|
||||
|
||||
try:
|
||||
scored = []
|
||||
# 检查缓存
|
||||
if memory.id in self._similarity_cache:
|
||||
cached = self._similarity_cache[memory.id]
|
||||
scored = [(self._memory_id_index[mid], sim)
|
||||
for mid, sim in cached.items()
|
||||
if mid in self._memory_id_index]
|
||||
scored.sort(key=lambda x: x[1], reverse=True)
|
||||
return scored[:top_k]
|
||||
|
||||
# 并发计算所有相似度
|
||||
tasks = []
|
||||
for existing_mem in self.memories:
|
||||
if existing_mem.embedding is None:
|
||||
continue
|
||||
tasks.append(cosine_similarity_async(memory.embedding, existing_mem.embedding))
|
||||
|
||||
similarity = await cosine_similarity_async(memory.embedding, existing_mem.embedding)
|
||||
if not tasks:
|
||||
return []
|
||||
|
||||
similarities = await asyncio.gather(*tasks)
|
||||
|
||||
# 构建结果并缓存
|
||||
scored = []
|
||||
cache_entry = {}
|
||||
for existing_mem, similarity in zip([m for m in self.memories if m.embedding is not None], similarities):
|
||||
scored.append((existing_mem, similarity))
|
||||
cache_entry[existing_mem.id] = similarity
|
||||
|
||||
self._similarity_cache[memory.id] = cache_entry
|
||||
|
||||
# 按相似度降序排序
|
||||
scored.sort(key=lambda x: x[1], reverse=True)
|
||||
@@ -470,15 +508,12 @@ class ShortTermMemoryManager:
|
||||
return []
|
||||
|
||||
def _find_memory_by_id(self, memory_id: str | None) -> ShortTermMemory | None:
|
||||
"""根据ID查找记忆"""
|
||||
"""根据ID查找记忆(优化版:O(1) 哈希表查找)"""
|
||||
if not memory_id:
|
||||
return None
|
||||
|
||||
for mem in self.memories:
|
||||
if mem.id == memory_id:
|
||||
return mem
|
||||
|
||||
return None
|
||||
|
||||
# 使用索引进行 O(1) 查找
|
||||
return self._memory_id_index.get(memory_id)
|
||||
|
||||
async def _generate_embedding(self, text: str) -> np.ndarray | None:
|
||||
"""生成文本向量"""
|
||||
@@ -542,7 +577,7 @@ class ShortTermMemoryManager:
|
||||
self, query_text: str, top_k: int = 5, similarity_threshold: float = 0.5
|
||||
) -> list[ShortTermMemory]:
|
||||
"""
|
||||
检索相关的短期记忆
|
||||
检索相关的短期记忆(优化版:并发计算相似度)
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
@@ -561,13 +596,23 @@ class ShortTermMemoryManager:
|
||||
if query_embedding is None or len(query_embedding) == 0:
|
||||
return []
|
||||
|
||||
# 计算相似度
|
||||
scored = []
|
||||
# 并发计算所有相似度
|
||||
tasks = []
|
||||
valid_memories = []
|
||||
for memory in self.memories:
|
||||
if memory.embedding is None:
|
||||
continue
|
||||
valid_memories.append(memory)
|
||||
tasks.append(cosine_similarity_async(query_embedding, memory.embedding))
|
||||
|
||||
similarity = await cosine_similarity_async(query_embedding, memory.embedding)
|
||||
if not tasks:
|
||||
return []
|
||||
|
||||
similarities = await asyncio.gather(*tasks)
|
||||
|
||||
# 构建结果
|
||||
scored = []
|
||||
for memory, similarity in zip(valid_memories, similarities):
|
||||
if similarity >= similarity_threshold:
|
||||
scored.append((memory, similarity))
|
||||
|
||||
@@ -575,7 +620,7 @@ class ShortTermMemoryManager:
|
||||
scored.sort(key=lambda x: x[1], reverse=True)
|
||||
results = [mem for mem, _ in scored[:top_k]]
|
||||
|
||||
# 更新访问记录
|
||||
# 批量更新访问记录
|
||||
for mem in results:
|
||||
mem.update_access()
|
||||
|
||||
@@ -588,19 +633,21 @@ class ShortTermMemoryManager:
|
||||
|
||||
def get_memories_for_transfer(self) -> list[ShortTermMemory]:
|
||||
"""
|
||||
获取需要转移到长期记忆的记忆
|
||||
获取需要转移到长期记忆的记忆(优化版:单次遍历)
|
||||
|
||||
逻辑:
|
||||
1. 优先选择重要性 >= 阈值的记忆
|
||||
2. 如果剩余记忆数量仍超过 max_memories,直接清理最早的低重要性记忆直到低于上限
|
||||
"""
|
||||
# 1. 正常筛选:重要性达标的记忆
|
||||
candidates = [mem for mem in self.memories if mem.importance >= self.transfer_importance_threshold]
|
||||
candidate_ids = {mem.id for mem in candidates}
|
||||
# 单次遍历:同时分类高重要性和低重要性记忆
|
||||
candidates = []
|
||||
low_importance_memories = []
|
||||
|
||||
# 2. 检查低重要性记忆是否积压
|
||||
# 剩余的都是低重要性记忆
|
||||
low_importance_memories = [mem for mem in self.memories if mem.id not in candidate_ids]
|
||||
for mem in self.memories:
|
||||
if mem.importance >= self.transfer_importance_threshold:
|
||||
candidates.append(mem)
|
||||
else:
|
||||
low_importance_memories.append(mem)
|
||||
|
||||
# 如果低重要性记忆数量超过了上限(说明积压严重)
|
||||
# 我们需要清理掉一部分,而不是转移它们
|
||||
@@ -614,9 +661,12 @@ class ShortTermMemoryManager:
|
||||
low_importance_memories.sort(key=lambda x: x.created_at)
|
||||
to_remove = low_importance_memories[:num_to_remove]
|
||||
|
||||
for mem in to_remove:
|
||||
if mem in self.memories:
|
||||
self.memories.remove(mem)
|
||||
# 批量删除并更新索引
|
||||
remove_ids = {mem.id for mem in to_remove}
|
||||
self.memories = [mem for mem in self.memories if mem.id not in remove_ids]
|
||||
for mem_id in remove_ids:
|
||||
del self._memory_id_index[mem_id]
|
||||
self._similarity_cache.pop(mem_id, None)
|
||||
|
||||
logger.info(
|
||||
f"短期记忆清理: 移除了 {len(to_remove)} 条低重要性记忆 "
|
||||
@@ -636,7 +686,14 @@ class ShortTermMemoryManager:
|
||||
memory_ids: 已转移的记忆ID列表
|
||||
"""
|
||||
try:
|
||||
self.memories = [mem for mem in self.memories if mem.id not in memory_ids]
|
||||
remove_ids = set(memory_ids)
|
||||
self.memories = [mem for mem in self.memories if mem.id not in remove_ids]
|
||||
|
||||
# 更新索引
|
||||
for mem_id in remove_ids:
|
||||
self._memory_id_index.pop(mem_id, None)
|
||||
self._similarity_cache.pop(mem_id, None)
|
||||
|
||||
logger.info(f"清除 {len(memory_ids)} 条已转移的短期记忆")
|
||||
|
||||
# 异步保存
|
||||
@@ -696,7 +753,11 @@ class ShortTermMemoryManager:
|
||||
data = orjson.loads(load_path.read_bytes())
|
||||
self.memories = [ShortTermMemory.from_dict(m) for m in data.get("memories", [])]
|
||||
|
||||
# 重新生成向量
|
||||
# 重建索引
|
||||
for mem in self.memories:
|
||||
self._memory_id_index[mem.id] = mem
|
||||
|
||||
# 批量重新生成向量
|
||||
await self._reload_embeddings()
|
||||
|
||||
logger.info(f"短期记忆已从 {load_path} 加载 ({len(self.memories)} 条)")
|
||||
@@ -705,7 +766,7 @@ class ShortTermMemoryManager:
|
||||
logger.error(f"加载短期记忆失败: {e}")
|
||||
|
||||
async def _reload_embeddings(self) -> None:
|
||||
"""重新生成记忆的向量"""
|
||||
"""重新生成记忆的向量(优化版:并发处理)"""
|
||||
logger.info("重新生成短期记忆向量...")
|
||||
|
||||
memories_to_process = []
|
||||
@@ -722,6 +783,7 @@ class ShortTermMemoryManager:
|
||||
|
||||
logger.info(f"开始批量生成 {len(memories_to_process)} 条短期记忆的向量...")
|
||||
|
||||
# 使用 gather 并发生成向量
|
||||
embeddings = await self._generate_embeddings_batch(texts_to_process)
|
||||
|
||||
success_count = 0
|
||||
|
||||
Reference in New Issue
Block a user