refactor(storage): 优化向量记忆存储的批处理和错误处理机制

- 改进_memory_to_vector_format方法,增强元数据序列化和错误处理
- 重构store_memories方法,实现真正的批处理存储
- 添加详细的日志记录,提升系统可观测性
- 修复memory_id获取和缓存问题
- 增强向量数据库操作的容错能力
- 调整日志级别,优化调试信息输出
This commit is contained in:
Windpicker-owo
2025-10-02 08:23:19 +08:00
parent 89f007fa33
commit c4aa34bc0c
2 changed files with 144 additions and 76 deletions

View File

@@ -73,9 +73,15 @@ class VectorMemoryStorage:
def __init__(self, config: Optional[VectorStorageConfig] = None):
self.config = config or VectorStorageConfig()
# 从配置中获取批处理大小和集合名称
self.batch_size = self.config.batch_size
self.collection_name = self.config.memory_collection
self.vector_db_service = vector_db_service
# 内存缓存
self.memory_cache: Dict[str, MemoryChunk] = {}
self.cache_timestamps: Dict[str, float] = {}
self._cache = self.memory_cache # 别名,兼容旧代码
# 遗忘引擎
self.forgetting_engine: Optional[MemoryForgettingEngine] = None
@@ -180,29 +186,59 @@ class VectorMemoryStorage:
except Exception as e:
logger.error(f"自动清理失败: {e}")
def _memory_to_vector_format(self, memory: MemoryChunk) -> Tuple[Dict[str, Any], str]:
"""将MemoryChunk转换为Vector DB格式"""
# 选择用于向量化的文本
content = getattr(memory, 'display', None) or getattr(memory, 'text_content', None) or ""
def _memory_to_vector_format(self, memory: MemoryChunk) -> Dict[str, Any]:
"""将MemoryChunk转换为向量存储格式"""
try:
# 获取memory_id
memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', None)
# 生成向量表示的文本
display_text = getattr(memory, 'display', None) or getattr(memory, 'text_content', None) or str(memory.content)
if not display_text.strip():
logger.warning(f"记忆 {memory_id} 缺少有效的显示文本")
display_text = f"{memory.memory_type.value}: {', '.join(memory.subjects)}"
# 构建元数据全部从memory.metadata获取
meta = getattr(memory, 'metadata', None)
metadata = {
"user_id": getattr(meta, 'user_id', None),
"chat_id": getattr(meta, 'chat_id', 'unknown'),
"memory_type": memory.memory_type.value,
"keywords": orjson.dumps(getattr(memory, 'keywords', [])).decode("utf-8"),
"importance": getattr(meta, 'importance', None),
"timestamp": getattr(meta, 'created_at', None),
"access_count": getattr(meta, 'access_count', 0),
"last_access_time": getattr(meta, 'last_accessed', 0),
"confidence": getattr(meta, 'confidence', None),
"source": "vector_storage_v2",
# 存储完整的记忆数据
"memory_data": orjson.dumps(memory.to_dict()).decode("utf-8")
}
# 构建元数据 - 修复枚举值和列表序列化
metadata = {
"memory_id": memory_id,
"user_id": memory.metadata.user_id or "unknown",
"memory_type": memory.memory_type.value,
"importance": memory.metadata.importance.name, # 使用 .name 而不是枚举对象
"confidence": memory.metadata.confidence.name, # 使用 .name 而不是枚举对象
"created_at": memory.metadata.created_at,
"last_accessed": memory.metadata.last_accessed or memory.metadata.created_at,
"access_count": memory.metadata.access_count,
"subjects": orjson.dumps(memory.subjects).decode("utf-8"), # 列表转JSON字符串
"keywords": orjson.dumps(memory.keywords).decode("utf-8"), # 列表转JSON字符串
"tags": orjson.dumps(memory.tags).decode("utf-8"), # 列表转JSON字符串
"categories": orjson.dumps(memory.categories).decode("utf-8"), # 列表转JSON字符串
"relevance_score": memory.metadata.relevance_score
}
return metadata, content
# 添加可选字段
if memory.metadata.source_context:
metadata["source_context"] = str(memory.metadata.source_context)
if memory.content.predicate:
metadata["predicate"] = memory.content.predicate
if memory.content.object:
if isinstance(memory.content.object, (dict, list)):
metadata["object"] = orjson.dumps(memory.content.object).decode()
else:
metadata["object"] = str(memory.content.object)
return {
"id": memory_id,
"embedding": None, # 将由vector_db_service生成
"metadata": metadata,
"document": display_text
}
except Exception as e:
memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown')
logger.error(f"转换记忆 {memory_id} 到向量格式失败: {e}", exc_info=True)
raise
def _vector_result_to_memory(self, document: str, metadata: Dict[str, Any]) -> Optional[MemoryChunk]:
"""将Vector DB结果转换为MemoryChunk"""
@@ -262,70 +298,85 @@ class VectorMemoryStorage:
self.memory_cache.pop(oldest_id, None)
self.cache_timestamps.pop(oldest_id, None)
self.memory_cache[memory.memory_id] = memory
self.cache_timestamps[memory.memory_id] = time.time()
memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', None)
if memory_id:
self.memory_cache[memory_id] = memory
self.cache_timestamps[memory_id] = time.time()
async def store_memories(self, memories: List[MemoryChunk]) -> int:
"""批量存储记忆"""
if not memories:
return 0
start_time = datetime.now()
success_count = 0
try:
# 准备批量数据
embeddings = []
documents = []
metadatas = []
ids = []
# 转换为向量格式
vector_data_list = []
for memory in memories:
try:
# 转换格式
metadata, content = self._memory_to_vector_format(memory)
if not content.strip():
logger.warning(f"记忆 {memory.memory_id} 内容为空,跳过")
continue
# 生成向量
embedding = await get_embedding(content)
if not embedding:
logger.warning(f"生成向量失败,跳过记忆: {memory.memory_id}")
continue
embeddings.append(embedding)
documents.append(content)
metadatas.append(metadata)
ids.append(memory.memory_id)
# 添加到缓存
self._add_to_cache(memory)
vector_data = self._memory_to_vector_format(memory)
vector_data_list.append(vector_data)
except Exception as e:
logger.error(f"处理记忆 {memory.memory_id} 失败: {e}")
memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown')
logger.error(f"处理记忆 {memory_id} 失败: {e}")
continue
# 批量插入Vector DB
if embeddings:
vector_db_service.add(
collection_name=self.config.memory_collection,
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
ids=ids
)
if not vector_data_list:
logger.warning("没有有效的记忆数据可存储")
return 0
# 批量存储到向量数据库
for i in range(0, len(vector_data_list), self.batch_size):
batch = vector_data_list[i:i + self.batch_size]
stored_count = len(embeddings)
self.stats["total_stores"] += stored_count
self.stats["total_memories"] += stored_count
logger.info(f"成功存储 {stored_count}/{len(memories)} 条记忆")
return stored_count
return 0
try:
# 生成embeddings
embeddings = []
for item in batch:
try:
embedding = await get_embedding(item["document"])
embeddings.append(embedding)
except Exception as e:
logger.error(f"生成embedding失败: {e}")
# 使用零向量作为后备
embeddings.append([0.0] * 768) # 默认维度
# vector_db_service.add 需要embeddings参数
self.vector_db_service.add(
collection_name=self.collection_name,
embeddings=embeddings,
ids=[item["id"] for item in batch],
documents=[item["document"] for item in batch],
metadatas=[item["metadata"] for item in batch]
)
success = True
if success:
# 更新缓存
for item in batch:
memory_id = item["id"]
# 从原始 memories 列表中找到对应的 MemoryChunk
memory = next((m for m in memories if (getattr(m.metadata, 'memory_id', None) or getattr(m, 'memory_id', None)) == memory_id), None)
if memory:
self._cache[memory_id] = memory
success_count += 1
else:
logger.warning(f"批次存储失败,跳过 {len(batch)} 条记忆")
except Exception as e:
logger.error(f"批量存储失败: {e}", exc_info=True)
continue
duration = (datetime.now() - start_time).total_seconds()
logger.info(f"成功存储 {success_count}/{len(memories)} 条记忆,耗时 {duration:.2f}")
return success_count
except Exception as e:
logger.error(f"批量存储记忆失败: {e}")
return 0
logger.error(f"批量存储记忆失败: {e}", exc_info=True)
return success_count
async def store_memory(self, memory: MemoryChunk) -> bool:
"""存储单条记忆"""
@@ -371,6 +422,7 @@ class VectorMemoryStorage:
metadatas = results.get("metadatas", [[]])[0]
ids = results.get("ids", [[]])[0]
logger.info(f"向量检索返回原始结果documents={len(documents)}, ids={len(ids)}, metadatas={len(metadatas)}")
for i, (doc, metadata, memory_id) in enumerate(zip(documents, metadatas, ids)):
# 计算相似度
distance = distances[i] if i < len(distances) else 1.0
@@ -390,12 +442,19 @@ class VectorMemoryStorage:
if memory:
similar_memories.append((memory, similarity))
# 记录单条结果的关键日志id相似度简短文本
try:
short_text = (str(memory.content)[:120]) if hasattr(memory, 'content') else (doc[:120] if isinstance(doc, str) else '')
except Exception:
short_text = ''
logger.info(f"检索结果 - id={memory_id}, similarity={similarity:.4f}, summary={short_text}")
# 按相似度排序
similar_memories.sort(key=lambda x: x[1], reverse=True)
self.stats["total_searches"] += 1
logger.debug(f"搜索相似记忆: 查询='{query_text[:30]}...', 结果数={len(similar_memories)}")
logger.info(f"搜索相似记忆: query='{query_text[:60]}...', limit={limit}, threshold={threshold}, filters={where_conditions}, 返回数={len(similar_memories)}")
logger.debug(f"搜索相似记忆 详细结果数={len(similar_memories)}")
return similar_memories
@@ -451,6 +510,7 @@ class VectorMemoryStorage:
metadatas = results.get("metadatas", [{}] * len(documents))
ids = results.get("ids", [])
logger.info(f"按过滤条件获取返回: docs={len(documents)}, ids={len(ids)}")
for i, (doc, metadata) in enumerate(zip(documents, metadatas)):
memory_id = ids[i] if i < len(ids) else None
@@ -459,6 +519,7 @@ class VectorMemoryStorage:
memory = self._get_from_cache(memory_id)
if memory:
memories.append(memory)
logger.debug(f"过滤获取命中缓存: id={memory_id}")
continue
# 从Vector结果重建
@@ -467,6 +528,7 @@ class VectorMemoryStorage:
memories.append(memory)
if memory_id:
self._add_to_cache(memory)
logger.debug(f"过滤获取结果: id={memory_id}, meta_keys={list(metadata.keys())}")
return memories
@@ -477,14 +539,20 @@ class VectorMemoryStorage:
async def update_memory(self, memory: MemoryChunk) -> bool:
"""更新记忆"""
try:
memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', None)
if not memory_id:
logger.error("无法更新记忆缺少memory_id")
return False
# 先删除旧记忆
await self.delete_memory(memory.memory_id)
await self.delete_memory(memory_id)
# 重新存储更新后的记忆
return await self.store_memory(memory)
except Exception as e:
logger.error(f"更新记忆 {memory.memory_id} 失败: {e}")
memory_id = getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown')
logger.error(f"更新记忆 {memory_id} 失败: {e}")
return False
async def delete_memory(self, memory_id: str) -> bool:
@@ -658,7 +726,7 @@ class VectorMemoryStorageAdapter:
query_text, limit, filters=filters
)
# 转换为原格式:(memory_id, similarity)
return [(memory.memory_id, similarity) for memory, similarity in results]
return [(getattr(memory.metadata, 'memory_id', None) or getattr(memory, 'memory_id', 'unknown'), similarity) for memory, similarity in results]
def get_stats(self) -> Dict[str, Any]:
return self.storage.get_storage_stats()

View File

@@ -127,7 +127,7 @@ class ChatterActionPlanner:
}
)
logger.debug(
logger.info(
f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}"
)