feat(embedding): 优化embedding缓存管理,避免无关标签的缓存影响内存使用
This commit is contained in:
@@ -331,7 +331,12 @@ class BotInterestManager:
|
|||||||
# 尝试从文件加载缓存
|
# 尝试从文件加载缓存
|
||||||
file_cache = await self._load_embedding_cache_from_file(interests.personality_id)
|
file_cache = await self._load_embedding_cache_from_file(interests.personality_id)
|
||||||
if file_cache:
|
if file_cache:
|
||||||
self.embedding_cache.update(file_cache)
|
allowed_keys = {tag.tag_name for tag in interests.interest_tags}
|
||||||
|
filtered_cache = {key: value for key, value in file_cache.items() if key in allowed_keys}
|
||||||
|
dropped_cache = len(file_cache) - len(filtered_cache)
|
||||||
|
if dropped_cache > 0:
|
||||||
|
logger.debug(f"🧹 跳过 {dropped_cache} 个与当前兴趣标签无关的缓存embedding")
|
||||||
|
self.embedding_cache.update(filtered_cache)
|
||||||
|
|
||||||
memory_cached_count = 0
|
memory_cached_count = 0
|
||||||
file_cached_count = 0
|
file_cached_count = 0
|
||||||
@@ -371,13 +376,16 @@ class BotInterestManager:
|
|||||||
|
|
||||||
interests.last_updated = datetime.now()
|
interests.last_updated = datetime.now()
|
||||||
|
|
||||||
async def _get_embedding(self, text: str) -> list[float]:
|
async def _get_embedding(self, text: str, cache: bool = True) -> list[float]:
|
||||||
"""获取文本的embedding向量"""
|
"""获取文本的embedding向量
|
||||||
|
|
||||||
|
cache=False 用于消息内容,避免在 embedding_cache 中长期保留大文本导致内存膨胀。
|
||||||
|
"""
|
||||||
if not hasattr(self, "embedding_request"):
|
if not hasattr(self, "embedding_request"):
|
||||||
raise RuntimeError("❌ Embedding请求客户端未初始化")
|
raise RuntimeError("❌ Embedding请求客户端未初始化")
|
||||||
|
|
||||||
# 检查缓存
|
# 检查缓存
|
||||||
if text in self.embedding_cache:
|
if cache and text in self.embedding_cache:
|
||||||
return self.embedding_cache[text]
|
return self.embedding_cache[text]
|
||||||
|
|
||||||
# 使用LLMRequest获取embedding
|
# 使用LLMRequest获取embedding
|
||||||
@@ -389,10 +397,12 @@ class BotInterestManager:
|
|||||||
if isinstance(embedding[0], list):
|
if isinstance(embedding[0], list):
|
||||||
# If it's a list of lists, take the first one (though get_embedding(str) should return list[float])
|
# If it's a list of lists, take the first one (though get_embedding(str) should return list[float])
|
||||||
embedding = embedding[0]
|
embedding = embedding[0]
|
||||||
|
|
||||||
# Now we can safely cast to list[float] as we've handled the nested list case
|
# Now we can safely cast to list[float] as we've handled the nested list case
|
||||||
embedding_float = cast(list[float], embedding)
|
embedding_float = cast(list[float], embedding)
|
||||||
self.embedding_cache[text] = embedding_float
|
|
||||||
|
if cache:
|
||||||
|
self.embedding_cache[text] = embedding_float
|
||||||
|
|
||||||
current_dim = len(embedding_float)
|
current_dim = len(embedding_float)
|
||||||
if self._detected_embedding_dimension is None:
|
if self._detected_embedding_dimension is None:
|
||||||
@@ -424,7 +434,7 @@ class BotInterestManager:
|
|||||||
combined_text = message_text
|
combined_text = message_text
|
||||||
|
|
||||||
# 生成embedding
|
# 生成embedding
|
||||||
embedding = await self._get_embedding(combined_text)
|
embedding = await self._get_embedding(combined_text, cache=False)
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
async def generate_embeddings_for_texts(
|
async def generate_embeddings_for_texts(
|
||||||
@@ -531,7 +541,8 @@ class BotInterestManager:
|
|||||||
# 生成消息的embedding
|
# 生成消息的embedding
|
||||||
logger.debug("正在生成消息 embedding...")
|
logger.debug("正在生成消息 embedding...")
|
||||||
if not message_embedding:
|
if not message_embedding:
|
||||||
message_embedding = await self._get_embedding(message_text)
|
# 消息文本embedding不入全局缓存,避免缓存随着对话历史无限增长
|
||||||
|
message_embedding = await self._get_embedding(message_text, cache=False)
|
||||||
logger.debug(f"消息 embedding 生成成功, 维度: {len(message_embedding)}")
|
logger.debug(f"消息 embedding 生成成功, 维度: {len(message_embedding)}")
|
||||||
|
|
||||||
# 计算与每个兴趣标签的相似度(使用扩展标签)
|
# 计算与每个兴趣标签的相似度(使用扩展标签)
|
||||||
@@ -1104,12 +1115,17 @@ class BotInterestManager:
|
|||||||
if self.embedding_config and hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list:
|
if self.embedding_config and hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list:
|
||||||
current_embedding_model = self.embedding_config.model_list[0]
|
current_embedding_model = self.embedding_config.model_list[0]
|
||||||
|
|
||||||
|
tag_embeddings = self.embedding_cache
|
||||||
|
if self.current_interests:
|
||||||
|
allowed_keys = {tag.tag_name for tag in self.current_interests.interest_tags}
|
||||||
|
tag_embeddings = {key: value for key, value in self.embedding_cache.items() if key in allowed_keys}
|
||||||
|
|
||||||
cache_data = {
|
cache_data = {
|
||||||
"version": 1,
|
"version": 1,
|
||||||
"personality_id": personality_id,
|
"personality_id": personality_id,
|
||||||
"embedding_model": current_embedding_model,
|
"embedding_model": current_embedding_model,
|
||||||
"last_updated": datetime.now().isoformat(),
|
"last_updated": datetime.now().isoformat(),
|
||||||
"embeddings": self.embedding_cache,
|
"embeddings": tag_embeddings,
|
||||||
"expanded_embeddings": self.expanded_embedding_cache, # 同时保存扩展标签的embedding
|
"expanded_embeddings": self.expanded_embedding_cache, # 同时保存扩展标签的embedding
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user