refactor(storage): 优化向量记忆存储的批处理和错误处理机制

- 改进_memory_to_vector_format方法,增强元数据序列化和错误处理
- 重构store_memories方法,实现真正的批处理存储
- 添加详细的日志记录,提升系统可观测性
- 修复memory_id获取和缓存问题
- 增强向量数据库操作的容错能力
- 调整日志级别,优化调试信息输出
This commit is contained in:
Windpicker-owo
2025-10-02 08:23:19 +08:00
parent 96093306e1
commit 4e4aa9fbcf
2 changed files with 144 additions and 76 deletions

View File

@@ -73,9 +73,15 @@ class VectorMemoryStorage:
def __init__(self, config: Optional[VectorStorageConfig] = None): def __init__(self, config: Optional[VectorStorageConfig] = None):
self.config = config or VectorStorageConfig() self.config = config or VectorStorageConfig()
# 从配置中获取批处理大小和集合名称
self.batch_size = self.config.batch_size
self.collection_name = self.config.memory_collection
self.vector_db_service = vector_db_service
# 内存缓存 # 内存缓存
self.memory_cache: Dict[str, MemoryChunk] = {} self.memory_cache: Dict[str, MemoryChunk] = {}
self.cache_timestamps: Dict[str, float] = {} self.cache_timestamps: Dict[str, float] = {}
self._cache = self.memory_cache # 别名,兼容旧代码
# 遗忘引擎 # 遗忘引擎
self.forgetting_engine: Optional[MemoryForgettingEngine] = None self.forgetting_engine: Optional[MemoryForgettingEngine] = None
@@ -180,29 +186,59 @@ class VectorMemoryStorage:
except Exception as e: except Exception as e:
logger.error(f"自动清理失败: {e}") logger.error(f"自动清理失败: {e}")
def _memory_to_vector_format(self, memory: MemoryChunk) -> Tuple[Dict[str, Any], str]: def _memory_to_vector_format(self, memory: MemoryChunk) -> Dict[str, Any]:
"""将MemoryChunk转换为Vector DB格式""" """将MemoryChunk转换为向量存储格式"""
# 选择用于向量化的文本 try:
content = getattr(memory, 'display', None) or getattr(memory, 'text_content', None) or "" # 获取memory_id
memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', None)
# 构建元数据全部从memory.metadata获取 # 生成向量表示的文本
meta = getattr(memory, 'metadata', None) display_text = getattr(memory, 'display', None) or getattr(memory, 'text_content', None) or str(memory.content)
metadata = { if not display_text.strip():
"user_id": getattr(meta, 'user_id', None), logger.warning(f"记忆 {memory_id} 缺少有效的显示文本")
"chat_id": getattr(meta, 'chat_id', 'unknown'), display_text = f"{memory.memory_type.value}: {', '.join(memory.subjects)}"
"memory_type": memory.memory_type.value,
"keywords": orjson.dumps(getattr(memory, 'keywords', [])).decode("utf-8"),
"importance": getattr(meta, 'importance', None),
"timestamp": getattr(meta, 'created_at', None),
"access_count": getattr(meta, 'access_count', 0),
"last_access_time": getattr(meta, 'last_accessed', 0),
"confidence": getattr(meta, 'confidence', None),
"source": "vector_storage_v2",
# 存储完整的记忆数据
"memory_data": orjson.dumps(memory.to_dict()).decode("utf-8")
}
return metadata, content # 构建元数据 - 修复枚举值和列表序列化
metadata = {
"memory_id": memory_id,
"user_id": memory.metadata.user_id or "unknown",
"memory_type": memory.memory_type.value,
"importance": memory.metadata.importance.name, # 使用 .name 而不是枚举对象
"confidence": memory.metadata.confidence.name, # 使用 .name 而不是枚举对象
"created_at": memory.metadata.created_at,
"last_accessed": memory.metadata.last_accessed or memory.metadata.created_at,
"access_count": memory.metadata.access_count,
"subjects": orjson.dumps(memory.subjects).decode("utf-8"), # 列表转JSON字符串
"keywords": orjson.dumps(memory.keywords).decode("utf-8"), # 列表转JSON字符串
"tags": orjson.dumps(memory.tags).decode("utf-8"), # 列表转JSON字符串
"categories": orjson.dumps(memory.categories).decode("utf-8"), # 列表转JSON字符串
"relevance_score": memory.metadata.relevance_score
}
# 添加可选字段
if memory.metadata.source_context:
metadata["source_context"] = str(memory.metadata.source_context)
if memory.content.predicate:
metadata["predicate"] = memory.content.predicate
if memory.content.object:
if isinstance(memory.content.object, (dict, list)):
metadata["object"] = orjson.dumps(memory.content.object).decode()
else:
metadata["object"] = str(memory.content.object)
return {
"id": memory_id,
"embedding": None, # 将由vector_db_service生成
"metadata": metadata,
"document": display_text
}
except Exception as e:
memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown')
logger.error(f"转换记忆 {memory_id} 到向量格式失败: {e}", exc_info=True)
raise
def _vector_result_to_memory(self, document: str, metadata: Dict[str, Any]) -> Optional[MemoryChunk]: def _vector_result_to_memory(self, document: str, metadata: Dict[str, Any]) -> Optional[MemoryChunk]:
"""将Vector DB结果转换为MemoryChunk""" """将Vector DB结果转换为MemoryChunk"""
@@ -262,70 +298,85 @@ class VectorMemoryStorage:
self.memory_cache.pop(oldest_id, None) self.memory_cache.pop(oldest_id, None)
self.cache_timestamps.pop(oldest_id, None) self.cache_timestamps.pop(oldest_id, None)
self.memory_cache[memory.memory_id] = memory memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', None)
self.cache_timestamps[memory.memory_id] = time.time() if memory_id:
self.memory_cache[memory_id] = memory
self.cache_timestamps[memory_id] = time.time()
async def store_memories(self, memories: List[MemoryChunk]) -> int: async def store_memories(self, memories: List[MemoryChunk]) -> int:
"""批量存储记忆""" """批量存储记忆"""
if not memories: if not memories:
return 0 return 0
try: start_time = datetime.now()
# 准备批量数据 success_count = 0
embeddings = []
documents = []
metadatas = []
ids = []
try:
# 转换为向量格式
vector_data_list = []
for memory in memories: for memory in memories:
try: try:
# 转换格式 vector_data = self._memory_to_vector_format(memory)
metadata, content = self._memory_to_vector_format(memory) vector_data_list.append(vector_data)
if not content.strip():
logger.warning(f"记忆 {memory.memory_id} 内容为空,跳过")
continue
# 生成向量
embedding = await get_embedding(content)
if not embedding:
logger.warning(f"生成向量失败,跳过记忆: {memory.memory_id}")
continue
embeddings.append(embedding)
documents.append(content)
metadatas.append(metadata)
ids.append(memory.memory_id)
# 添加到缓存
self._add_to_cache(memory)
except Exception as e: except Exception as e:
logger.error(f"处理记忆 {memory.memory_id} 失败: {e}") memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown')
logger.error(f"处理记忆 {memory_id} 失败: {e}")
continue continue
# 批量插入Vector DB if not vector_data_list:
if embeddings: logger.warning("没有有效的记忆数据可存储")
vector_db_service.add( return 0
collection_name=self.config.memory_collection,
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
ids=ids
)
stored_count = len(embeddings) # 批量存储到向量数据库
self.stats["total_stores"] += stored_count for i in range(0, len(vector_data_list), self.batch_size):
self.stats["total_memories"] += stored_count batch = vector_data_list[i:i + self.batch_size]
logger.info(f"成功存储 {stored_count}/{len(memories)} 条记忆") try:
return stored_count # 生成embeddings
embeddings = []
for item in batch:
try:
embedding = await get_embedding(item["document"])
embeddings.append(embedding)
except Exception as e:
logger.error(f"生成embedding失败: {e}")
# 使用零向量作为后备
embeddings.append([0.0] * 768) # 默认维度
return 0 # vector_db_service.add 需要embeddings参数
self.vector_db_service.add(
collection_name=self.collection_name,
embeddings=embeddings,
ids=[item["id"] for item in batch],
documents=[item["document"] for item in batch],
metadatas=[item["metadata"] for item in batch]
)
success = True
if success:
# 更新缓存
for item in batch:
memory_id = item["id"]
# 从原始 memories 列表中找到对应的 MemoryChunk
memory = next((m for m in memories if (getattr(m.metadata, 'memory_id', None) or getattr(m, 'memory_id', None)) == memory_id), None)
if memory:
self._cache[memory_id] = memory
success_count += 1
else:
logger.warning(f"批次存储失败,跳过 {len(batch)} 条记忆")
except Exception as e:
logger.error(f"批量存储失败: {e}", exc_info=True)
continue
duration = (datetime.now() - start_time).total_seconds()
logger.info(f"成功存储 {success_count}/{len(memories)} 条记忆,耗时 {duration:.2f}")
return success_count
except Exception as e: except Exception as e:
logger.error(f"批量存储记忆失败: {e}") logger.error(f"批量存储记忆失败: {e}", exc_info=True)
return 0 return success_count
async def store_memory(self, memory: MemoryChunk) -> bool: async def store_memory(self, memory: MemoryChunk) -> bool:
"""存储单条记忆""" """存储单条记忆"""
@@ -371,6 +422,7 @@ class VectorMemoryStorage:
metadatas = results.get("metadatas", [[]])[0] metadatas = results.get("metadatas", [[]])[0]
ids = results.get("ids", [[]])[0] ids = results.get("ids", [[]])[0]
logger.info(f"向量检索返回原始结果documents={len(documents)}, ids={len(ids)}, metadatas={len(metadatas)}")
for i, (doc, metadata, memory_id) in enumerate(zip(documents, metadatas, ids)): for i, (doc, metadata, memory_id) in enumerate(zip(documents, metadatas, ids)):
# 计算相似度 # 计算相似度
distance = distances[i] if i < len(distances) else 1.0 distance = distances[i] if i < len(distances) else 1.0
@@ -390,12 +442,19 @@ class VectorMemoryStorage:
if memory: if memory:
similar_memories.append((memory, similarity)) similar_memories.append((memory, similarity))
# 记录单条结果的关键日志id相似度简短文本
try:
short_text = (str(memory.content)[:120]) if hasattr(memory, 'content') else (doc[:120] if isinstance(doc, str) else '')
except Exception:
short_text = ''
logger.info(f"检索结果 - id={memory_id}, similarity={similarity:.4f}, summary={short_text}")
# 按相似度排序 # 按相似度排序
similar_memories.sort(key=lambda x: x[1], reverse=True) similar_memories.sort(key=lambda x: x[1], reverse=True)
self.stats["total_searches"] += 1 self.stats["total_searches"] += 1
logger.debug(f"搜索相似记忆: 查询='{query_text[:30]}...', 结果数={len(similar_memories)}") logger.info(f"搜索相似记忆: query='{query_text[:60]}...', limit={limit}, threshold={threshold}, filters={where_conditions}, 返回数={len(similar_memories)}")
logger.debug(f"搜索相似记忆 详细结果数={len(similar_memories)}")
return similar_memories return similar_memories
@@ -451,6 +510,7 @@ class VectorMemoryStorage:
metadatas = results.get("metadatas", [{}] * len(documents)) metadatas = results.get("metadatas", [{}] * len(documents))
ids = results.get("ids", []) ids = results.get("ids", [])
logger.info(f"按过滤条件获取返回: docs={len(documents)}, ids={len(ids)}")
for i, (doc, metadata) in enumerate(zip(documents, metadatas)): for i, (doc, metadata) in enumerate(zip(documents, metadatas)):
memory_id = ids[i] if i < len(ids) else None memory_id = ids[i] if i < len(ids) else None
@@ -459,6 +519,7 @@ class VectorMemoryStorage:
memory = self._get_from_cache(memory_id) memory = self._get_from_cache(memory_id)
if memory: if memory:
memories.append(memory) memories.append(memory)
logger.debug(f"过滤获取命中缓存: id={memory_id}")
continue continue
# 从Vector结果重建 # 从Vector结果重建
@@ -467,6 +528,7 @@ class VectorMemoryStorage:
memories.append(memory) memories.append(memory)
if memory_id: if memory_id:
self._add_to_cache(memory) self._add_to_cache(memory)
logger.debug(f"过滤获取结果: id={memory_id}, meta_keys={list(metadata.keys())}")
return memories return memories
@@ -477,14 +539,20 @@ class VectorMemoryStorage:
async def update_memory(self, memory: MemoryChunk) -> bool: async def update_memory(self, memory: MemoryChunk) -> bool:
"""更新记忆""" """更新记忆"""
try: try:
memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', None)
if not memory_id:
logger.error("无法更新记忆缺少memory_id")
return False
# 先删除旧记忆 # 先删除旧记忆
await self.delete_memory(memory.memory_id) await self.delete_memory(memory_id)
# 重新存储更新后的记忆 # 重新存储更新后的记忆
return await self.store_memory(memory) return await self.store_memory(memory)
except Exception as e: except Exception as e:
logger.error(f"更新记忆 {memory.memory_id} 失败: {e}") memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown')
logger.error(f"更新记忆 {memory_id} 失败: {e}")
return False return False
async def delete_memory(self, memory_id: str) -> bool: async def delete_memory(self, memory_id: str) -> bool:
@@ -658,7 +726,7 @@ class VectorMemoryStorageAdapter:
query_text, limit, filters=filters query_text, limit, filters=filters
) )
# 转换为原格式:(memory_id, similarity) # 转换为原格式:(memory_id, similarity)
return [(memory.memory_id, similarity) for memory, similarity in results] return [(getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown'), similarity) for memory, similarity in results]
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
return self.storage.get_storage_stats() return self.storage.get_storage_stats()

View File

@@ -127,7 +127,7 @@ class ChatterActionPlanner:
} }
) )
logger.debug( logger.info(
f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}" f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}"
) )