refactor(storage): 迁移记忆系统架构至现代化Vector DB方案
重构记忆存储层,采用ChromaDB作为核心向量数据库引擎,提升系统可扩展性和查询性能。通过引入VectorMemoryStorage替代原有的UnifiedMemoryStorage实现,实现了更高效的向量索引和检索机制。 主要变更包括: - 架构层面:完全重构记忆存储抽象层,移除577行旧存储实现代码 - 配置系统:新增41项Vector DB专用配置参数,支持细粒度性能调优 - 查询优化:增强ChromaDB查询条件处理器,添加智能回退机制 - 系统集成:更新记忆系统初始化流程,适配新的存储接口 - 类型安全:修复异步调用类型不匹配问题 BREAKING CHANGE: 记忆存储API接口发生重大变更,UnifiedMemoryStorage相关类已废弃
This commit is contained in:
@@ -22,12 +22,11 @@ from .memory_forgetting_engine import (
|
||||
get_memory_forgetting_engine
|
||||
)
|
||||
|
||||
# 统一存储系统
|
||||
from .unified_memory_storage import (
|
||||
UnifiedMemoryStorage,
|
||||
UnifiedStorageConfig,
|
||||
get_unified_memory_storage,
|
||||
initialize_unified_memory_storage
|
||||
# Vector DB存储系统
|
||||
from .vector_memory_storage_v2 import (
|
||||
VectorMemoryStorage,
|
||||
VectorStorageConfig,
|
||||
get_vector_memory_storage
|
||||
)
|
||||
|
||||
# 记忆核心系统
|
||||
@@ -79,11 +78,10 @@ __all__ = [
|
||||
"ForgettingConfig",
|
||||
"get_memory_forgetting_engine",
|
||||
|
||||
# 统一存储
|
||||
"UnifiedMemoryStorage",
|
||||
"UnifiedStorageConfig",
|
||||
"get_unified_memory_storage",
|
||||
"initialize_unified_memory_storage",
|
||||
# Vector DB存储
|
||||
"VectorMemoryStorage",
|
||||
"VectorStorageConfig",
|
||||
"get_vector_memory_storage",
|
||||
|
||||
# 记忆系统
|
||||
"MemorySystem",
|
||||
|
||||
@@ -191,27 +191,27 @@ class MemorySystem:
|
||||
self.memory_builder = MemoryBuilder(self.memory_extraction_model)
|
||||
self.fusion_engine = MemoryFusionEngine(self.config.fusion_similarity_threshold)
|
||||
|
||||
# 初始化统一存储系统
|
||||
from src.chat.memory_system.unified_memory_storage import initialize_unified_memory_storage, UnifiedStorageConfig
|
||||
# 初始化Vector DB存储系统(替代旧的unified_memory_storage)
|
||||
from src.chat.memory_system.vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig
|
||||
|
||||
storage_config = UnifiedStorageConfig(
|
||||
dimension=self.config.vector_dimension,
|
||||
storage_config = VectorStorageConfig(
|
||||
memory_collection="unified_memory_v2",
|
||||
metadata_collection="memory_metadata_v2",
|
||||
similarity_threshold=self.config.similarity_threshold,
|
||||
storage_path=getattr(global_config.memory, 'unified_storage_path', 'data/unified_memory'),
|
||||
cache_size_limit=getattr(global_config.memory, 'unified_storage_cache_limit', 10000),
|
||||
auto_save_interval=getattr(global_config.memory, 'unified_storage_auto_save_interval', 50),
|
||||
enable_compression=getattr(global_config.memory, 'unified_storage_enable_compression', True),
|
||||
search_limit=getattr(global_config.memory, 'unified_storage_search_limit', 20),
|
||||
batch_size=getattr(global_config.memory, 'unified_storage_batch_size', 100),
|
||||
enable_caching=getattr(global_config.memory, 'unified_storage_enable_caching', True),
|
||||
cache_size_limit=getattr(global_config.memory, 'unified_storage_cache_limit', 1000),
|
||||
auto_cleanup_interval=getattr(global_config.memory, 'unified_storage_auto_cleanup_interval', 3600),
|
||||
enable_forgetting=getattr(global_config.memory, 'enable_memory_forgetting', True),
|
||||
forgetting_check_interval=getattr(global_config.memory, 'forgetting_check_interval_hours', 24)
|
||||
retention_hours=getattr(global_config.memory, 'memory_retention_hours', 720) # 30天
|
||||
)
|
||||
|
||||
try:
|
||||
self.unified_storage = await initialize_unified_memory_storage(storage_config)
|
||||
if self.unified_storage is None:
|
||||
raise RuntimeError("统一存储系统初始化返回None")
|
||||
logger.info("✅ 统一存储系统初始化成功")
|
||||
self.unified_storage = VectorMemoryStorage(storage_config)
|
||||
logger.info("✅ Vector DB存储系统初始化成功")
|
||||
except Exception as storage_error:
|
||||
logger.error(f"❌ 统一存储系统初始化失败: {storage_error}", exc_info=True)
|
||||
logger.error(f"❌ Vector DB存储系统初始化失败: {storage_error}", exc_info=True)
|
||||
raise
|
||||
|
||||
# 初始化遗忘引擎
|
||||
@@ -647,20 +647,18 @@ class MemorySystem:
|
||||
except Exception as plan_exc:
|
||||
logger.warning("查询规划失败,使用默认检索策略: %s", plan_exc, exc_info=True)
|
||||
|
||||
# 使用统一存储搜索
|
||||
# 使用Vector DB存储搜索
|
||||
search_results = await self.unified_storage.search_similar_memories(
|
||||
query_text=raw_query,
|
||||
limit=effective_limit,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
# 转换为记忆对象
|
||||
# 转换为记忆对象 - search_results 返回 List[Tuple[MemoryChunk, float]]
|
||||
final_memories = []
|
||||
for memory_id, similarity_score in search_results:
|
||||
memory = self.unified_storage.get_memory_by_id(memory_id)
|
||||
if memory:
|
||||
memory.update_access()
|
||||
final_memories.append(memory)
|
||||
for memory, similarity_score in search_results:
|
||||
memory.update_access()
|
||||
final_memories.append(memory)
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
|
||||
|
||||
@@ -1,577 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
统一记忆存储系统
|
||||
简化后的记忆存储,整合向量存储和元数据索引
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import orjson
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Dict, List, Optional, Tuple, Set, Any
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 尝试导入FAISS
|
||||
try:
|
||||
import faiss
|
||||
FAISS_AVAILABLE = True
|
||||
except ImportError:
|
||||
FAISS_AVAILABLE = False
|
||||
logger.warning("FAISS not available, using simple vector storage")
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedStorageConfig:
|
||||
"""统一存储配置"""
|
||||
# 向量存储配置
|
||||
dimension: int = 1024
|
||||
similarity_threshold: float = 0.8
|
||||
storage_path: str = "data/unified_memory"
|
||||
|
||||
# 性能配置
|
||||
cache_size_limit: int = 10000
|
||||
auto_save_interval: int = 50
|
||||
search_limit: int = 20
|
||||
enable_compression: bool = True
|
||||
|
||||
# 遗忘配置
|
||||
enable_forgetting: bool = True
|
||||
forgetting_check_interval: int = 24 # 小时
|
||||
|
||||
|
||||
class UnifiedMemoryStorage:
|
||||
"""统一记忆存储系统"""
|
||||
|
||||
def __init__(self, config: Optional[UnifiedStorageConfig] = None):
|
||||
self.config = config or UnifiedStorageConfig()
|
||||
|
||||
# 存储路径
|
||||
self.storage_path = Path(self.config.storage_path)
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 向量索引
|
||||
self.vector_index = None
|
||||
self.memory_id_to_index: Dict[str, int] = {}
|
||||
self.index_to_memory_id: Dict[int, str] = {}
|
||||
|
||||
# 内存缓存
|
||||
self.memory_cache: Dict[str, MemoryChunk] = {}
|
||||
self.vector_cache: Dict[str, np.ndarray] = {}
|
||||
|
||||
# 元数据索引(简化版)
|
||||
self.keyword_index: Dict[str, Set[str]] = {} # keyword -> memory_id set
|
||||
self.type_index: Dict[str, Set[str]] = {} # type -> memory_id set
|
||||
self.user_index: Dict[str, Set[str]] = {} # user_id -> memory_id set
|
||||
|
||||
# 遗忘引擎
|
||||
self.forgetting_engine: Optional[MemoryForgettingEngine] = None
|
||||
if self.config.enable_forgetting:
|
||||
self.forgetting_engine = MemoryForgettingEngine()
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
"total_memories": 0,
|
||||
"total_vectors": 0,
|
||||
"cache_size": 0,
|
||||
"last_save_time": 0.0,
|
||||
"total_searches": 0,
|
||||
"total_stores": 0,
|
||||
"forgetting_stats": {}
|
||||
}
|
||||
|
||||
# 线程锁
|
||||
self._lock = threading.RLock()
|
||||
self._operation_count = 0
|
||||
|
||||
# 嵌入模型
|
||||
self.embedding_model: Optional[LLMRequest] = None
|
||||
|
||||
# 初始化
|
||||
self._initialize_storage()
|
||||
|
||||
def _initialize_storage(self):
|
||||
"""初始化存储系统"""
|
||||
try:
|
||||
# 初始化向量索引
|
||||
if FAISS_AVAILABLE:
|
||||
self.vector_index = faiss.IndexFlatIP(self.config.dimension)
|
||||
logger.info(f"FAISS向量索引初始化完成,维度: {self.config.dimension}")
|
||||
else:
|
||||
# 简单向量存储
|
||||
self.vector_index = {}
|
||||
logger.info("使用简单向量存储(FAISS不可用)")
|
||||
|
||||
# 尝试加载现有数据
|
||||
self._load_storage()
|
||||
|
||||
logger.info(f"统一记忆存储初始化完成,当前记忆数: {len(self.memory_cache)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储系统初始化失败: {e}", exc_info=True)
|
||||
|
||||
def set_embedding_model(self, model: LLMRequest):
|
||||
"""设置嵌入模型"""
|
||||
self.embedding_model = model
|
||||
|
||||
async def _generate_embedding(self, text: str) -> Optional[np.ndarray]:
|
||||
"""生成文本的向量表示"""
|
||||
if not self.embedding_model:
|
||||
logger.warning("未设置嵌入模型,无法生成向量")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 使用嵌入模型生成向量
|
||||
embedding, _ = await self.embedding_model.get_embedding(text)
|
||||
|
||||
if embedding is None:
|
||||
logger.warning(f"嵌入模型返回空向量,文本: {text[:50]}...")
|
||||
return None
|
||||
|
||||
# 转换为numpy数组
|
||||
embedding_array = np.array(embedding, dtype=np.float32)
|
||||
|
||||
# 归一化向量
|
||||
norm = np.linalg.norm(embedding_array)
|
||||
if norm > 0:
|
||||
embedding_array = embedding_array / norm
|
||||
|
||||
return embedding_array
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成向量失败: {e}")
|
||||
return None
|
||||
|
||||
def _add_to_keyword_index(self, memory: MemoryChunk):
|
||||
"""添加到关键词索引"""
|
||||
for keyword in memory.keywords:
|
||||
if keyword not in self.keyword_index:
|
||||
self.keyword_index[keyword] = set()
|
||||
self.keyword_index[keyword].add(memory.memory_id)
|
||||
|
||||
def _add_to_type_index(self, memory: MemoryChunk):
|
||||
"""添加到类型索引"""
|
||||
memory_type = memory.memory_type.value
|
||||
if memory_type not in self.type_index:
|
||||
self.type_index[memory_type] = set()
|
||||
self.type_index[memory_type].add(memory.memory_id)
|
||||
|
||||
def _add_to_user_index(self, memory: MemoryChunk):
|
||||
"""添加到用户索引"""
|
||||
user_id = memory.user_id
|
||||
if user_id not in self.user_index:
|
||||
self.user_index[user_id] = set()
|
||||
self.user_index[user_id].add(memory.memory_id)
|
||||
|
||||
def _remove_from_indexes(self, memory: MemoryChunk):
|
||||
"""从所有索引中移除记忆"""
|
||||
memory_id = memory.memory_id
|
||||
|
||||
# 从关键词索引移除
|
||||
for keyword, memory_ids in self.keyword_index.items():
|
||||
memory_ids.discard(memory_id)
|
||||
if not memory_ids:
|
||||
del self.keyword_index[keyword]
|
||||
|
||||
# 从类型索引移除
|
||||
memory_type = memory.memory_type.value
|
||||
if memory_type in self.type_index:
|
||||
self.type_index[memory_type].discard(memory_id)
|
||||
if not self.type_index[memory_type]:
|
||||
del self.type_index[memory_type]
|
||||
|
||||
# 从用户索引移除
|
||||
if memory.user_id in self.user_index:
|
||||
self.user_index[memory.user_id].discard(memory_id)
|
||||
if not self.user_index[memory.user_id]:
|
||||
del self.user_index[memory.user_id]
|
||||
|
||||
async def store_memories(self, memories: List[MemoryChunk]) -> int:
|
||||
"""存储记忆列表"""
|
||||
if not memories:
|
||||
return 0
|
||||
|
||||
stored_count = 0
|
||||
|
||||
with self._lock:
|
||||
for memory in memories:
|
||||
try:
|
||||
# 生成向量
|
||||
vector = None
|
||||
if memory.display and memory.display.strip():
|
||||
vector = await self._generate_embedding(memory.display)
|
||||
elif memory.text_content and memory.text_content.strip():
|
||||
vector = await self._generate_embedding(memory.text_content)
|
||||
|
||||
# 存储到缓存
|
||||
self.memory_cache[memory.memory_id] = memory
|
||||
if vector is not None:
|
||||
self.vector_cache[memory.memory_id] = vector
|
||||
|
||||
# 添加到向量索引
|
||||
if FAISS_AVAILABLE:
|
||||
index_id = self.vector_index.ntotal
|
||||
self.vector_index.add(vector.reshape(1, -1))
|
||||
self.memory_id_to_index[memory.memory_id] = index_id
|
||||
self.index_to_memory_id[index_id] = memory.memory_id
|
||||
else:
|
||||
# 简单存储
|
||||
self.vector_index[memory.memory_id] = vector
|
||||
|
||||
# 更新元数据索引
|
||||
self._add_to_keyword_index(memory)
|
||||
self._add_to_type_index(memory)
|
||||
self._add_to_user_index(memory)
|
||||
|
||||
stored_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储记忆 {memory.memory_id[:8]} 失败: {e}")
|
||||
continue
|
||||
|
||||
# 更新统计
|
||||
self.stats["total_memories"] = len(self.memory_cache)
|
||||
self.stats["total_vectors"] = len(self.vector_cache)
|
||||
self.stats["total_stores"] += stored_count
|
||||
|
||||
# 自动保存
|
||||
self._operation_count += stored_count
|
||||
if self._operation_count >= self.config.auto_save_interval:
|
||||
await self._save_storage()
|
||||
self._operation_count = 0
|
||||
|
||||
logger.debug(f"成功存储 {stored_count}/{len(memories)} 条记忆")
|
||||
return stored_count
|
||||
|
||||
async def search_similar_memories(
|
||||
self,
|
||||
query_text: str,
|
||||
limit: int = 10,
|
||||
scope_id: Optional[str] = None,
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""搜索相似记忆"""
|
||||
if not query_text or not self.vector_cache:
|
||||
return []
|
||||
|
||||
# 生成查询向量
|
||||
query_vector = await self._generate_embedding(query_text)
|
||||
if query_vector is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
results = []
|
||||
|
||||
if FAISS_AVAILABLE and self.vector_index.ntotal > 0:
|
||||
# 使用FAISS搜索
|
||||
query_vector = query_vector.reshape(1, -1)
|
||||
scores, indices = self.vector_index.search(
|
||||
query_vector,
|
||||
min(limit, self.vector_index.ntotal)
|
||||
)
|
||||
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx >= 0 and score >= self.config.similarity_threshold:
|
||||
memory_id = self.index_to_memory_id.get(idx)
|
||||
if memory_id and memory_id in self.memory_cache:
|
||||
# 应用过滤器
|
||||
if self._apply_filters(self.memory_cache[memory_id], filters):
|
||||
results.append((memory_id, float(score)))
|
||||
|
||||
else:
|
||||
# 简单余弦相似度搜索
|
||||
for memory_id, vector in self.vector_cache.items():
|
||||
if memory_id not in self.memory_cache:
|
||||
continue
|
||||
|
||||
# 计算余弦相似度
|
||||
similarity = np.dot(query_vector, vector)
|
||||
if similarity >= self.config.similarity_threshold:
|
||||
# 应用过滤器
|
||||
if self._apply_filters(self.memory_cache[memory_id], filters):
|
||||
results.append((memory_id, float(similarity)))
|
||||
|
||||
# 排序并限制结果
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
results = results[:limit]
|
||||
|
||||
self.stats["total_searches"] += 1
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜索相似记忆失败: {e}")
|
||||
return []
|
||||
|
||||
def _apply_filters(self, memory: MemoryChunk, filters: Optional[Dict[str, Any]]) -> bool:
|
||||
"""应用搜索过滤器"""
|
||||
if not filters:
|
||||
return True
|
||||
|
||||
# 用户过滤器
|
||||
if "user_id" in filters and memory.user_id != filters["user_id"]:
|
||||
return False
|
||||
|
||||
# 类型过滤器
|
||||
if "memory_types" in filters and memory.memory_type.value not in filters["memory_types"]:
|
||||
return False
|
||||
|
||||
# 关键词过滤器
|
||||
if "keywords" in filters:
|
||||
memory_keywords = set(k.lower() for k in memory.keywords)
|
||||
filter_keywords = set(k.lower() for k in filters["keywords"])
|
||||
if not memory_keywords.intersection(filter_keywords):
|
||||
return False
|
||||
|
||||
# 重要性过滤器
|
||||
if "min_importance" in filters and memory.metadata.importance.value < filters["min_importance"]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]:
|
||||
"""根据ID获取记忆"""
|
||||
return self.memory_cache.get(memory_id)
|
||||
|
||||
def get_memories_by_filters(self, filters: Dict[str, Any], limit: int = 50) -> List[MemoryChunk]:
|
||||
"""根据过滤器获取记忆"""
|
||||
results = []
|
||||
|
||||
for memory in self.memory_cache.values():
|
||||
if self._apply_filters(memory, filters):
|
||||
results.append(memory)
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
async def forget_memories(self, memory_ids: List[str]) -> int:
|
||||
"""遗忘指定的记忆"""
|
||||
if not memory_ids:
|
||||
return 0
|
||||
|
||||
forgotten_count = 0
|
||||
|
||||
with self._lock:
|
||||
for memory_id in memory_ids:
|
||||
try:
|
||||
memory = self.memory_cache.get(memory_id)
|
||||
if not memory:
|
||||
continue
|
||||
|
||||
# 从向量索引移除
|
||||
if FAISS_AVAILABLE and memory_id in self.memory_id_to_index:
|
||||
# FAISS不支持直接删除,这里简化处理
|
||||
# 在实际使用中,可能需要重建索引
|
||||
logger.debug(f"FAISS索引删除 {memory_id} (需要重建索引)")
|
||||
elif memory_id in self.vector_index:
|
||||
del self.vector_index[memory_id]
|
||||
|
||||
# 从缓存移除
|
||||
self.memory_cache.pop(memory_id, None)
|
||||
self.vector_cache.pop(memory_id, None)
|
||||
|
||||
# 从索引移除
|
||||
self._remove_from_indexes(memory)
|
||||
|
||||
forgotten_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"遗忘记忆 {memory_id[:8]} 失败: {e}")
|
||||
continue
|
||||
|
||||
# 更新统计
|
||||
self.stats["total_memories"] = len(self.memory_cache)
|
||||
self.stats["total_vectors"] = len(self.vector_cache)
|
||||
|
||||
logger.info(f"成功遗忘 {forgotten_count}/{len(memory_ids)} 条记忆")
|
||||
return forgotten_count
|
||||
|
||||
async def perform_forgetting_check(self) -> Dict[str, Any]:
|
||||
"""执行遗忘检查"""
|
||||
if not self.forgetting_engine:
|
||||
return {"error": "遗忘引擎未启用"}
|
||||
|
||||
try:
|
||||
# 执行遗忘检查
|
||||
result = await self.forgetting_engine.perform_forgetting_check(list(self.memory_cache.values()))
|
||||
|
||||
# 遗忘标记的记忆
|
||||
forgetting_ids = result["normal_forgetting"] + result["force_forgetting"]
|
||||
if forgetting_ids:
|
||||
forgotten_count = await self.forget_memories(forgetting_ids)
|
||||
result["forgotten_count"] = forgotten_count
|
||||
|
||||
# 更新统计
|
||||
self.stats["forgetting_stats"] = self.forgetting_engine.get_forgetting_stats()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行遗忘检查失败: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def _load_storage(self):
|
||||
"""加载存储数据"""
|
||||
try:
|
||||
# 加载记忆缓存
|
||||
memory_file = self.storage_path / "memory_cache.json"
|
||||
if memory_file.exists():
|
||||
with open(memory_file, 'rb') as f:
|
||||
memory_data = orjson.loads(f.read())
|
||||
for memory_id, memory_dict in memory_data.items():
|
||||
self.memory_cache[memory_id] = MemoryChunk.from_dict(memory_dict)
|
||||
|
||||
# 加载向量缓存(如果启用压缩)
|
||||
if not self.config.enable_compression:
|
||||
vector_file = self.storage_path / "vectors.npz"
|
||||
if vector_file.exists():
|
||||
vectors = np.load(vector_file)
|
||||
self.vector_cache = {
|
||||
memory_id: vectors[memory_id]
|
||||
for memory_id in vectors.files
|
||||
if memory_id in self.memory_cache
|
||||
}
|
||||
|
||||
# 重建向量索引
|
||||
if FAISS_AVAILABLE and self.vector_cache:
|
||||
logger.info("重建FAISS向量索引...")
|
||||
vectors = []
|
||||
memory_ids = []
|
||||
|
||||
for memory_id, vector in self.vector_cache.items():
|
||||
vectors.append(vector)
|
||||
memory_ids.append(memory_id)
|
||||
|
||||
if vectors:
|
||||
vectors_array = np.vstack(vectors)
|
||||
self.vector_index.reset()
|
||||
self.vector_index.add(vectors_array)
|
||||
|
||||
# 重建映射
|
||||
for idx, memory_id in enumerate(memory_ids):
|
||||
self.memory_id_to_index[memory_id] = idx
|
||||
self.index_to_memory_id[idx] = memory_id
|
||||
|
||||
logger.info(f"存储数据加载完成,记忆数: {len(self.memory_cache)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"加载存储数据失败: {e}")
|
||||
|
||||
async def _save_storage(self):
|
||||
"""保存存储数据"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 保存记忆缓存
|
||||
memory_data = {
|
||||
memory_id: memory.to_dict()
|
||||
for memory_id, memory in self.memory_cache.items()
|
||||
}
|
||||
|
||||
memory_file = self.storage_path / "memory_cache.json"
|
||||
with open(memory_file, 'wb') as f:
|
||||
f.write(orjson.dumps(memory_data, option=orjson.OPT_INDENT_2))
|
||||
|
||||
# 保存向量缓存(如果启用压缩)
|
||||
if not self.config.enable_compression and self.vector_cache:
|
||||
vector_file = self.storage_path / "vectors.npz"
|
||||
np.savez_compressed(vector_file, **self.vector_cache)
|
||||
|
||||
save_time = time.time() - start_time
|
||||
self.stats["last_save_time"] = time.time()
|
||||
|
||||
logger.debug(f"存储数据保存完成,耗时: {save_time:.3f}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存存储数据失败: {e}")
|
||||
|
||||
async def save_storage(self):
|
||||
"""手动保存存储数据"""
|
||||
await self._save_storage()
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
"""获取存储统计信息"""
|
||||
stats = self.stats.copy()
|
||||
stats.update({
|
||||
"cache_size": len(self.memory_cache),
|
||||
"vector_count": len(self.vector_cache),
|
||||
"keyword_index_size": len(self.keyword_index),
|
||||
"type_index_size": len(self.type_index),
|
||||
"user_index_size": len(self.user_index),
|
||||
"config": {
|
||||
"dimension": self.config.dimension,
|
||||
"similarity_threshold": self.config.similarity_threshold,
|
||||
"enable_forgetting": self.config.enable_forgetting
|
||||
}
|
||||
})
|
||||
return stats
|
||||
|
||||
async def cleanup(self):
|
||||
"""清理存储系统"""
|
||||
try:
|
||||
logger.info("开始清理统一记忆存储...")
|
||||
|
||||
# 保存数据
|
||||
await self._save_storage()
|
||||
|
||||
# 清空缓存
|
||||
self.memory_cache.clear()
|
||||
self.vector_cache.clear()
|
||||
self.keyword_index.clear()
|
||||
self.type_index.clear()
|
||||
self.user_index.clear()
|
||||
|
||||
# 重置索引
|
||||
if FAISS_AVAILABLE:
|
||||
self.vector_index.reset()
|
||||
|
||||
self.memory_id_to_index.clear()
|
||||
self.index_to_memory_id.clear()
|
||||
|
||||
logger.info("统一记忆存储清理完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理存储系统失败: {e}")
|
||||
|
||||
|
||||
# 创建全局存储实例
|
||||
unified_memory_storage: Optional[UnifiedMemoryStorage] = None
|
||||
|
||||
|
||||
def get_unified_memory_storage() -> Optional[UnifiedMemoryStorage]:
|
||||
"""获取统一存储实例"""
|
||||
return unified_memory_storage
|
||||
|
||||
|
||||
async def initialize_unified_memory_storage(config: Optional[UnifiedStorageConfig] = None) -> UnifiedMemoryStorage:
|
||||
"""初始化统一记忆存储"""
|
||||
global unified_memory_storage
|
||||
|
||||
if unified_memory_storage is None:
|
||||
unified_memory_storage = UnifiedMemoryStorage(config)
|
||||
|
||||
# 设置嵌入模型
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
try:
|
||||
embedding_task = getattr(model_config.model_task_config, "embedding", None)
|
||||
if embedding_task:
|
||||
unified_memory_storage.set_embedding_model(
|
||||
LLMRequest(model_set=embedding_task, request_type="memory.embedding")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"设置嵌入模型失败: {e}")
|
||||
|
||||
return unified_memory_storage
|
||||
700
src/chat/memory_system/vector_memory_storage_v2.py
Normal file
700
src/chat/memory_system/vector_memory_storage_v2.py
Normal file
@@ -0,0 +1,700 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
基于Vector DB的统一记忆存储系统 V2
|
||||
使用ChromaDB作为底层存储,替代JSON存储方式
|
||||
|
||||
主要特性:
|
||||
- 统一的向量存储接口
|
||||
- 高效的语义检索
|
||||
- 元数据过滤支持
|
||||
- 批量操作优化
|
||||
- 自动清理过期记忆
|
||||
"""
|
||||
|
||||
import time
|
||||
import orjson
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Dict, List, Optional, Tuple, Set, Any, Union
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from src.common.logger import get_logger
|
||||
from src.common.vector_db import vector_db_service
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorStorageConfig:
|
||||
"""Vector存储配置"""
|
||||
# 集合配置
|
||||
memory_collection: str = "unified_memory_v2"
|
||||
metadata_collection: str = "memory_metadata_v2"
|
||||
|
||||
# 检索配置
|
||||
similarity_threshold: float = 0.8
|
||||
search_limit: int = 20
|
||||
batch_size: int = 100
|
||||
|
||||
# 性能配置
|
||||
enable_caching: bool = True
|
||||
cache_size_limit: int = 1000
|
||||
auto_cleanup_interval: int = 3600 # 1小时
|
||||
|
||||
# 遗忘配置
|
||||
enable_forgetting: bool = True
|
||||
retention_hours: int = 24 * 30 # 30天
|
||||
|
||||
|
||||
class VectorMemoryStorage:
|
||||
@property
|
||||
def keyword_index(self) -> dict:
|
||||
"""
|
||||
动态构建关键词倒排索引(仅兼容旧接口,基于当前缓存)
|
||||
返回: {keyword: [memory_id, ...]}
|
||||
"""
|
||||
index = {}
|
||||
for memory in self.memory_cache.values():
|
||||
for kw in getattr(memory, 'keywords', []):
|
||||
if not kw:
|
||||
continue
|
||||
kw_norm = kw.strip().lower()
|
||||
if kw_norm:
|
||||
index.setdefault(kw_norm, []).append(getattr(memory.metadata, 'memory_id', None))
|
||||
return index
|
||||
"""基于Vector DB的记忆存储系统"""
|
||||
|
||||
def __init__(self, config: Optional[VectorStorageConfig] = None):
|
||||
self.config = config or VectorStorageConfig()
|
||||
|
||||
# 内存缓存
|
||||
self.memory_cache: Dict[str, MemoryChunk] = {}
|
||||
self.cache_timestamps: Dict[str, float] = {}
|
||||
|
||||
# 遗忘引擎
|
||||
self.forgetting_engine: Optional[MemoryForgettingEngine] = None
|
||||
if self.config.enable_forgetting:
|
||||
self.forgetting_engine = MemoryForgettingEngine()
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
"total_memories": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
"total_searches": 0,
|
||||
"total_stores": 0,
|
||||
"last_cleanup_time": 0.0,
|
||||
"forgetting_stats": {}
|
||||
}
|
||||
|
||||
# 线程锁
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# 定时清理任务
|
||||
self._cleanup_task = None
|
||||
self._stop_cleanup = False
|
||||
|
||||
# 初始化系统
|
||||
self._initialize_storage()
|
||||
self._start_cleanup_task()
|
||||
|
||||
def _initialize_storage(self):
|
||||
"""初始化Vector DB存储"""
|
||||
try:
|
||||
# 创建记忆集合
|
||||
vector_db_service.get_or_create_collection(
|
||||
name=self.config.memory_collection,
|
||||
metadata={
|
||||
"description": "统一记忆存储V2",
|
||||
"hnsw:space": "cosine",
|
||||
"version": "2.0"
|
||||
}
|
||||
)
|
||||
|
||||
# 创建元数据集合(用于复杂查询)
|
||||
vector_db_service.get_or_create_collection(
|
||||
name=self.config.metadata_collection,
|
||||
metadata={
|
||||
"description": "记忆元数据索引",
|
||||
"hnsw:space": "cosine",
|
||||
"version": "2.0"
|
||||
}
|
||||
)
|
||||
|
||||
# 获取当前记忆总数
|
||||
self.stats["total_memories"] = vector_db_service.count(self.config.memory_collection)
|
||||
|
||||
logger.info(f"Vector记忆存储初始化完成,当前记忆数: {self.stats['total_memories']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Vector存储系统初始化失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""启动定时清理任务"""
|
||||
if self.config.auto_cleanup_interval > 0:
|
||||
def cleanup_worker():
|
||||
while not self._stop_cleanup:
|
||||
try:
|
||||
time.sleep(self.config.auto_cleanup_interval)
|
||||
if not self._stop_cleanup:
|
||||
asyncio.create_task(self._perform_auto_cleanup())
|
||||
except Exception as e:
|
||||
logger.error(f"定时清理任务出错: {e}")
|
||||
|
||||
self._cleanup_task = threading.Thread(target=cleanup_worker, daemon=True)
|
||||
self._cleanup_task.start()
|
||||
logger.info(f"定时清理任务已启动,间隔: {self.config.auto_cleanup_interval}秒")
|
||||
|
||||
async def _perform_auto_cleanup(self):
|
||||
"""执行自动清理"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
# 清理过期缓存
|
||||
if self.config.enable_caching:
|
||||
expired_keys = [
|
||||
memory_id for memory_id, timestamp in self.cache_timestamps.items()
|
||||
if current_time - timestamp > 3600 # 1小时过期
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
self.memory_cache.pop(key, None)
|
||||
self.cache_timestamps.pop(key, None)
|
||||
|
||||
if expired_keys:
|
||||
logger.debug(f"清理了 {len(expired_keys)} 个过期缓存项")
|
||||
|
||||
# 执行遗忘检查
|
||||
if self.forgetting_engine:
|
||||
await self.perform_forgetting_check()
|
||||
|
||||
self.stats["last_cleanup_time"] = current_time
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"自动清理失败: {e}")
|
||||
|
||||
def _memory_to_vector_format(self, memory: MemoryChunk) -> Tuple[Dict[str, Any], str]:
|
||||
"""将MemoryChunk转换为Vector DB格式"""
|
||||
# 选择用于向量化的文本
|
||||
content = getattr(memory, 'display', None) or getattr(memory, 'text_content', None) or ""
|
||||
|
||||
# 构建元数据(全部从memory.metadata获取)
|
||||
meta = getattr(memory, 'metadata', None)
|
||||
metadata = {
|
||||
"user_id": getattr(meta, 'user_id', None),
|
||||
"chat_id": getattr(meta, 'chat_id', 'unknown'),
|
||||
"memory_type": memory.memory_type.value,
|
||||
"keywords": orjson.dumps(getattr(memory, 'keywords', [])).decode("utf-8"),
|
||||
"importance": getattr(meta, 'importance', None),
|
||||
"timestamp": getattr(meta, 'created_at', None),
|
||||
"access_count": getattr(meta, 'access_count', 0),
|
||||
"last_access_time": getattr(meta, 'last_accessed', 0),
|
||||
"confidence": getattr(meta, 'confidence', None),
|
||||
"source": "vector_storage_v2",
|
||||
# 存储完整的记忆数据
|
||||
"memory_data": orjson.dumps(memory.to_dict()).decode("utf-8")
|
||||
}
|
||||
|
||||
return metadata, content
|
||||
|
||||
def _vector_result_to_memory(self, document: str, metadata: Dict[str, Any]) -> Optional[MemoryChunk]:
|
||||
"""将Vector DB结果转换为MemoryChunk"""
|
||||
try:
|
||||
# 从元数据中恢复完整记忆
|
||||
if "memory_data" in metadata:
|
||||
memory_dict = orjson.loads(metadata["memory_data"])
|
||||
return MemoryChunk.from_dict(memory_dict)
|
||||
|
||||
# 兜底:从基础字段重建
|
||||
memory_dict = {
|
||||
"memory_id": metadata.get("memory_id", f"recovered_{int(time.time())}"),
|
||||
"user_id": metadata.get("user_id", "unknown"),
|
||||
"text_content": document,
|
||||
"display": document,
|
||||
"memory_type": metadata.get("memory_type", "general"),
|
||||
"keywords": orjson.loads(metadata.get("keywords", "[]")),
|
||||
"importance": metadata.get("importance", 0.5),
|
||||
"timestamp": metadata.get("timestamp", time.time()),
|
||||
"access_count": metadata.get("access_count", 0),
|
||||
"last_access_time": metadata.get("last_access_time", 0),
|
||||
"confidence": metadata.get("confidence", 0.8),
|
||||
"metadata": {}
|
||||
}
|
||||
|
||||
return MemoryChunk.from_dict(memory_dict)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"转换Vector结果到MemoryChunk失败: {e}")
|
||||
return None
|
||||
|
||||
def _get_from_cache(self, memory_id: str) -> Optional[MemoryChunk]:
|
||||
"""从缓存获取记忆"""
|
||||
if not self.config.enable_caching:
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
if memory_id in self.memory_cache:
|
||||
self.cache_timestamps[memory_id] = time.time()
|
||||
self.stats["cache_hits"] += 1
|
||||
return self.memory_cache[memory_id]
|
||||
|
||||
self.stats["cache_misses"] += 1
|
||||
return None
|
||||
|
||||
def _add_to_cache(self, memory: MemoryChunk):
|
||||
"""添加记忆到缓存"""
|
||||
if not self.config.enable_caching:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
# 检查缓存大小限制
|
||||
if len(self.memory_cache) >= self.config.cache_size_limit:
|
||||
# 移除最老的缓存项
|
||||
oldest_id = min(self.cache_timestamps.keys(),
|
||||
key=lambda k: self.cache_timestamps[k])
|
||||
self.memory_cache.pop(oldest_id, None)
|
||||
self.cache_timestamps.pop(oldest_id, None)
|
||||
|
||||
self.memory_cache[memory.memory_id] = memory
|
||||
self.cache_timestamps[memory.memory_id] = time.time()
|
||||
|
||||
async def store_memories(self, memories: List[MemoryChunk]) -> int:
|
||||
"""批量存储记忆"""
|
||||
if not memories:
|
||||
return 0
|
||||
|
||||
try:
|
||||
# 准备批量数据
|
||||
embeddings = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
ids = []
|
||||
|
||||
for memory in memories:
|
||||
try:
|
||||
# 转换格式
|
||||
metadata, content = self._memory_to_vector_format(memory)
|
||||
|
||||
if not content.strip():
|
||||
logger.warning(f"记忆 {memory.memory_id} 内容为空,跳过")
|
||||
continue
|
||||
|
||||
# 生成向量
|
||||
embedding = await get_embedding(content)
|
||||
if not embedding:
|
||||
logger.warning(f"生成向量失败,跳过记忆: {memory.memory_id}")
|
||||
continue
|
||||
|
||||
embeddings.append(embedding)
|
||||
documents.append(content)
|
||||
metadatas.append(metadata)
|
||||
ids.append(memory.memory_id)
|
||||
|
||||
# 添加到缓存
|
||||
self._add_to_cache(memory)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理记忆 {memory.memory_id} 失败: {e}")
|
||||
continue
|
||||
|
||||
# 批量插入Vector DB
|
||||
if embeddings:
|
||||
vector_db_service.add(
|
||||
collection_name=self.config.memory_collection,
|
||||
embeddings=embeddings,
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
ids=ids
|
||||
)
|
||||
|
||||
stored_count = len(embeddings)
|
||||
self.stats["total_stores"] += stored_count
|
||||
self.stats["total_memories"] += stored_count
|
||||
|
||||
logger.info(f"成功存储 {stored_count}/{len(memories)} 条记忆")
|
||||
return stored_count
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量存储记忆失败: {e}")
|
||||
return 0
|
||||
|
||||
async def store_memory(self, memory: MemoryChunk) -> bool:
|
||||
"""存储单条记忆"""
|
||||
result = await self.store_memories([memory])
|
||||
return result > 0
|
||||
|
||||
async def search_similar_memories(
|
||||
self,
|
||||
query_text: str,
|
||||
limit: int = 10,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Tuple[MemoryChunk, float]]:
|
||||
"""搜索相似记忆"""
|
||||
if not query_text.strip():
|
||||
return []
|
||||
|
||||
try:
|
||||
# 生成查询向量
|
||||
query_embedding = await get_embedding(query_text)
|
||||
if not query_embedding:
|
||||
return []
|
||||
|
||||
threshold = similarity_threshold or self.config.similarity_threshold
|
||||
|
||||
# 构建where条件
|
||||
where_conditions = filters or {}
|
||||
|
||||
# 查询Vector DB
|
||||
results = vector_db_service.query(
|
||||
collection_name=self.config.memory_collection,
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=min(limit, self.config.search_limit),
|
||||
where=where_conditions if where_conditions else None
|
||||
)
|
||||
|
||||
# 处理结果
|
||||
similar_memories = []
|
||||
|
||||
if results.get("documents") and results["documents"][0]:
|
||||
documents = results["documents"][0]
|
||||
distances = results.get("distances", [[]])[0]
|
||||
metadatas = results.get("metadatas", [[]])[0]
|
||||
ids = results.get("ids", [[]])[0]
|
||||
|
||||
for i, (doc, metadata, memory_id) in enumerate(zip(documents, metadatas, ids)):
|
||||
# 计算相似度
|
||||
distance = distances[i] if i < len(distances) else 1.0
|
||||
similarity = 1 - distance # ChromaDB返回距离,转换为相似度
|
||||
|
||||
if similarity < threshold:
|
||||
continue
|
||||
|
||||
# 首先尝试从缓存获取
|
||||
memory = self._get_from_cache(memory_id)
|
||||
|
||||
if not memory:
|
||||
# 从Vector结果重建
|
||||
memory = self._vector_result_to_memory(doc, metadata)
|
||||
if memory:
|
||||
self._add_to_cache(memory)
|
||||
|
||||
if memory:
|
||||
similar_memories.append((memory, similarity))
|
||||
|
||||
# 按相似度排序
|
||||
similar_memories.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
self.stats["total_searches"] += 1
|
||||
logger.debug(f"搜索相似记忆: 查询='{query_text[:30]}...', 结果数={len(similar_memories)}")
|
||||
|
||||
return similar_memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜索相似记忆失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]:
|
||||
"""根据ID获取记忆"""
|
||||
# 首先尝试从缓存获取
|
||||
memory = self._get_from_cache(memory_id)
|
||||
if memory:
|
||||
return memory
|
||||
|
||||
try:
|
||||
# 从Vector DB获取
|
||||
results = vector_db_service.get(
|
||||
collection_name=self.config.memory_collection,
|
||||
ids=[memory_id]
|
||||
)
|
||||
|
||||
if results.get("documents") and results["documents"]:
|
||||
document = results["documents"][0]
|
||||
metadata = results["metadatas"][0] if results.get("metadatas") else {}
|
||||
|
||||
memory = self._vector_result_to_memory(document, metadata)
|
||||
if memory:
|
||||
self._add_to_cache(memory)
|
||||
|
||||
return memory
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取记忆 {memory_id} 失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def get_memories_by_filters(
|
||||
self,
|
||||
filters: Dict[str, Any],
|
||||
limit: int = 100
|
||||
) -> List[MemoryChunk]:
|
||||
"""根据过滤条件获取记忆"""
|
||||
try:
|
||||
results = vector_db_service.get(
|
||||
collection_name=self.config.memory_collection,
|
||||
where=filters,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
memories = []
|
||||
if results.get("documents"):
|
||||
documents = results["documents"]
|
||||
metadatas = results.get("metadatas", [{}] * len(documents))
|
||||
ids = results.get("ids", [])
|
||||
|
||||
for i, (doc, metadata) in enumerate(zip(documents, metadatas)):
|
||||
memory_id = ids[i] if i < len(ids) else None
|
||||
|
||||
# 首先尝试从缓存获取
|
||||
if memory_id:
|
||||
memory = self._get_from_cache(memory_id)
|
||||
if memory:
|
||||
memories.append(memory)
|
||||
continue
|
||||
|
||||
# 从Vector结果重建
|
||||
memory = self._vector_result_to_memory(doc, metadata)
|
||||
if memory:
|
||||
memories.append(memory)
|
||||
if memory_id:
|
||||
self._add_to_cache(memory)
|
||||
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"根据过滤条件获取记忆失败: {e}")
|
||||
return []
|
||||
|
||||
async def update_memory(self, memory: MemoryChunk) -> bool:
|
||||
"""更新记忆"""
|
||||
try:
|
||||
# 先删除旧记忆
|
||||
await self.delete_memory(memory.memory_id)
|
||||
|
||||
# 重新存储更新后的记忆
|
||||
return await self.store_memory(memory)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新记忆 {memory.memory_id} 失败: {e}")
|
||||
return False
|
||||
|
||||
async def delete_memory(self, memory_id: str) -> bool:
|
||||
"""删除记忆"""
|
||||
try:
|
||||
# 从Vector DB删除
|
||||
vector_db_service.delete(
|
||||
collection_name=self.config.memory_collection,
|
||||
ids=[memory_id]
|
||||
)
|
||||
|
||||
# 从缓存删除
|
||||
with self._lock:
|
||||
self.memory_cache.pop(memory_id, None)
|
||||
self.cache_timestamps.pop(memory_id, None)
|
||||
|
||||
self.stats["total_memories"] = max(0, self.stats["total_memories"] - 1)
|
||||
logger.debug(f"删除记忆: {memory_id}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除记忆 {memory_id} 失败: {e}")
|
||||
return False
|
||||
|
||||
async def delete_memories_by_filters(self, filters: Dict[str, Any]) -> int:
|
||||
"""根据过滤条件批量删除记忆"""
|
||||
try:
|
||||
# 先获取要删除的记忆ID
|
||||
results = vector_db_service.get(
|
||||
collection_name=self.config.memory_collection,
|
||||
where=filters,
|
||||
include=["metadatas"]
|
||||
)
|
||||
|
||||
if not results.get("ids"):
|
||||
return 0
|
||||
|
||||
memory_ids = results["ids"]
|
||||
|
||||
# 批量删除
|
||||
vector_db_service.delete(
|
||||
collection_name=self.config.memory_collection,
|
||||
where=filters
|
||||
)
|
||||
|
||||
# 从缓存删除
|
||||
with self._lock:
|
||||
for memory_id in memory_ids:
|
||||
self.memory_cache.pop(memory_id, None)
|
||||
self.cache_timestamps.pop(memory_id, None)
|
||||
|
||||
deleted_count = len(memory_ids)
|
||||
self.stats["total_memories"] = max(0, self.stats["total_memories"] - deleted_count)
|
||||
logger.info(f"批量删除记忆: {deleted_count} 条")
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量删除记忆失败: {e}")
|
||||
return 0
|
||||
|
||||
async def perform_forgetting_check(self) -> Dict[str, Any]:
|
||||
"""执行遗忘检查"""
|
||||
if not self.forgetting_engine:
|
||||
return {"error": "遗忘引擎未启用"}
|
||||
|
||||
try:
|
||||
# 获取所有记忆进行遗忘检查
|
||||
# 注意:对于大型数据集,这里应该分批处理
|
||||
current_time = time.time()
|
||||
cutoff_time = current_time - (self.config.retention_hours * 3600)
|
||||
|
||||
# 先删除明显过期的记忆
|
||||
expired_filters = {"timestamp": {"$lt": cutoff_time}}
|
||||
expired_count = await self.delete_memories_by_filters(expired_filters)
|
||||
|
||||
# 对剩余记忆执行智能遗忘检查
|
||||
# 这里为了性能考虑,只检查一部分记忆
|
||||
sample_memories = await self.get_memories_by_filters({}, limit=500)
|
||||
|
||||
if sample_memories:
|
||||
result = await self.forgetting_engine.perform_forgetting_check(sample_memories)
|
||||
|
||||
# 遗忘标记的记忆
|
||||
forgetting_ids = result.get("normal_forgetting", []) + result.get("force_forgetting", [])
|
||||
forgotten_count = 0
|
||||
|
||||
for memory_id in forgetting_ids:
|
||||
if await self.delete_memory(memory_id):
|
||||
forgotten_count += 1
|
||||
|
||||
result["forgotten_count"] = forgotten_count
|
||||
result["expired_count"] = expired_count
|
||||
|
||||
# 更新统计
|
||||
self.stats["forgetting_stats"] = self.forgetting_engine.get_forgetting_stats()
|
||||
|
||||
logger.info(f"遗忘检查完成: 过期删除 {expired_count}, 智能遗忘 {forgotten_count}")
|
||||
return result
|
||||
|
||||
return {"expired_count": expired_count, "forgotten_count": 0}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行遗忘检查失败: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
"""获取存储统计信息"""
|
||||
try:
|
||||
current_total = vector_db_service.count(self.config.memory_collection)
|
||||
self.stats["total_memories"] = current_total
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
**self.stats,
|
||||
"cache_size": len(self.memory_cache),
|
||||
"collection_name": self.config.memory_collection,
|
||||
"storage_type": "vector_db_v2",
|
||||
"uptime": time.time() - self.stats.get("start_time", time.time())
|
||||
}
|
||||
|
||||
def stop(self):
|
||||
"""停止存储系统"""
|
||||
self._stop_cleanup = True
|
||||
|
||||
if self._cleanup_task and self._cleanup_task.is_alive():
|
||||
logger.info("正在停止定时清理任务...")
|
||||
|
||||
# 清空缓存
|
||||
with self._lock:
|
||||
self.memory_cache.clear()
|
||||
self.cache_timestamps.clear()
|
||||
|
||||
logger.info("Vector记忆存储系统已停止")
|
||||
|
||||
|
||||
# 全局实例(可选)
|
||||
_global_vector_storage = None
|
||||
|
||||
|
||||
def get_vector_memory_storage(config: Optional[VectorStorageConfig] = None) -> VectorMemoryStorage:
|
||||
"""获取全局Vector记忆存储实例"""
|
||||
global _global_vector_storage
|
||||
|
||||
if _global_vector_storage is None:
|
||||
_global_vector_storage = VectorMemoryStorage(config)
|
||||
|
||||
return _global_vector_storage
|
||||
|
||||
|
||||
# 兼容性接口
|
||||
class VectorMemoryStorageAdapter:
|
||||
"""适配器类,提供与原UnifiedMemoryStorage兼容的接口"""
|
||||
|
||||
def __init__(self, config: Optional[VectorStorageConfig] = None):
|
||||
self.storage = VectorMemoryStorage(config)
|
||||
|
||||
async def store_memories(self, memories: List[MemoryChunk]) -> int:
|
||||
return await self.storage.store_memories(memories)
|
||||
|
||||
async def search_similar_memories(
|
||||
self,
|
||||
query_text: str,
|
||||
limit: int = 10,
|
||||
scope_id: Optional[str] = None,
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Tuple[str, float]]:
|
||||
results = await self.storage.search_similar_memories(
|
||||
query_text, limit, filters=filters
|
||||
)
|
||||
# 转换为原格式:(memory_id, similarity)
|
||||
return [(memory.memory_id, similarity) for memory, similarity in results]
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
return self.storage.get_storage_stats()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 简单测试
|
||||
async def test_vector_storage():
|
||||
storage = VectorMemoryStorage()
|
||||
|
||||
# 创建测试记忆
|
||||
from src.chat.memory_system.memory_chunk import MemoryType
|
||||
test_memory = MemoryChunk(
|
||||
memory_id="test_001",
|
||||
user_id="test_user",
|
||||
text_content="今天天气很好,适合出门散步",
|
||||
memory_type=MemoryType.FACT,
|
||||
keywords=["天气", "散步"],
|
||||
importance=0.7
|
||||
)
|
||||
|
||||
# 存储记忆
|
||||
success = await storage.store_memory(test_memory)
|
||||
print(f"存储结果: {success}")
|
||||
|
||||
# 搜索记忆
|
||||
results = await storage.search_similar_memories("天气怎么样", limit=5)
|
||||
print(f"搜索结果: {len(results)} 条")
|
||||
|
||||
for memory, similarity in results:
|
||||
print(f" - {memory.text_content[:50]}... (相似度: {similarity:.3f})")
|
||||
|
||||
# 获取统计信息
|
||||
stats = storage.get_storage_stats()
|
||||
print(f"存储统计: {stats}")
|
||||
|
||||
storage.stop()
|
||||
|
||||
asyncio.run(test_vector_storage())
|
||||
@@ -98,13 +98,79 @@ class ChromaDBImpl(VectorDBBase):
|
||||
"n_results": n_results,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# 修复ChromaDB的where条件格式
|
||||
if where:
|
||||
query_params["where"] = where
|
||||
processed_where = self._process_where_condition(where)
|
||||
if processed_where:
|
||||
query_params["where"] = processed_where
|
||||
|
||||
return collection.query(**query_params)
|
||||
except Exception as e:
|
||||
logger.error(f"查询集合 '{collection_name}' 失败: {e}")
|
||||
# 如果查询失败,尝试不使用where条件重新查询
|
||||
try:
|
||||
fallback_params = {
|
||||
"query_embeddings": query_embeddings,
|
||||
"n_results": n_results,
|
||||
}
|
||||
logger.warning(f"使用回退查询模式(无where条件)")
|
||||
return collection.query(**fallback_params)
|
||||
except Exception as fallback_e:
|
||||
logger.error(f"回退查询也失败: {fallback_e}")
|
||||
return {}
|
||||
|
||||
def _process_where_condition(self, where: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
处理where条件,转换为ChromaDB支持的格式
|
||||
ChromaDB支持的格式:
|
||||
- 简单条件: {"field": "value"}
|
||||
- 操作符条件: {"field": {"$op": "value"}}
|
||||
- AND条件: {"$and": [condition1, condition2]}
|
||||
- OR条件: {"$or": [condition1, condition2]}
|
||||
"""
|
||||
if not where:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 如果只有一个字段,直接返回
|
||||
if len(where) == 1:
|
||||
key, value = next(iter(where.items()))
|
||||
|
||||
# 处理列表值(如memory_types)
|
||||
if isinstance(value, list):
|
||||
if len(value) == 1:
|
||||
return {key: value[0]}
|
||||
else:
|
||||
# 多个值使用 $in 操作符
|
||||
return {key: {"$in": value}}
|
||||
else:
|
||||
return {key: value}
|
||||
|
||||
# 多个字段使用 $and 操作符
|
||||
conditions = []
|
||||
for key, value in where.items():
|
||||
if isinstance(value, list):
|
||||
if len(value) == 1:
|
||||
conditions.append({key: value[0]})
|
||||
else:
|
||||
conditions.append({key: {"$in": value}})
|
||||
else:
|
||||
conditions.append({key: value})
|
||||
|
||||
return {"$and": conditions}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"处理where条件失败: {e}, 使用简化条件")
|
||||
# 回退到只使用第一个条件
|
||||
if where:
|
||||
key, value = next(iter(where.items()))
|
||||
if isinstance(value, list) and value:
|
||||
return {key: value[0]}
|
||||
elif not isinstance(value, list):
|
||||
return {key: value}
|
||||
return None
|
||||
|
||||
def get(
|
||||
self,
|
||||
collection_name: str,
|
||||
@@ -119,16 +185,33 @@ class ChromaDBImpl(VectorDBBase):
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
try:
|
||||
# 处理where条件
|
||||
processed_where = None
|
||||
if where:
|
||||
processed_where = self._process_where_condition(where)
|
||||
|
||||
return collection.get(
|
||||
ids=ids,
|
||||
where=where,
|
||||
where=processed_where,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
where_document=where_document,
|
||||
include=include,
|
||||
include=include or ["documents", "metadatas", "embeddings"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"从集合 '{collection_name}' 获取数据失败: {e}")
|
||||
# 如果获取失败,尝试不使用where条件重新获取
|
||||
try:
|
||||
logger.warning(f"使用回退获取模式(无where条件)")
|
||||
return collection.get(
|
||||
ids=ids,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
where_document=where_document,
|
||||
include=include or ["documents", "metadatas", "embeddings"],
|
||||
)
|
||||
except Exception as fallback_e:
|
||||
logger.error(f"回退获取也失败: {fallback_e}")
|
||||
return {}
|
||||
|
||||
def delete(
|
||||
|
||||
@@ -474,6 +474,47 @@ class MemoryConfig(ValidatedConfigBase):
|
||||
cache_ttl_seconds: int = Field(default=300, description="缓存生存时间(秒)")
|
||||
max_cache_size: int = Field(default=1000, description="最大缓存大小")
|
||||
|
||||
# Vector DB记忆存储配置 (替代JSON存储)
|
||||
enable_vector_memory_storage: bool = Field(default=True, description="启用Vector DB记忆存储")
|
||||
enable_llm_instant_memory: bool = Field(default=True, description="启用基于LLM的瞬时记忆")
|
||||
enable_vector_instant_memory: bool = Field(default=True, description="启用基于向量的瞬时记忆")
|
||||
|
||||
# Vector DB配置
|
||||
vector_db_memory_collection: str = Field(default="unified_memory_v2", description="Vector DB集合名称")
|
||||
vector_db_similarity_threshold: float = Field(default=0.8, description="Vector DB相似度阈值")
|
||||
vector_db_search_limit: int = Field(default=20, description="Vector DB搜索限制")
|
||||
vector_db_batch_size: int = Field(default=100, description="批处理大小")
|
||||
vector_db_enable_caching: bool = Field(default=True, description="启用缓存")
|
||||
vector_db_cache_size_limit: int = Field(default=1000, description="缓存大小限制")
|
||||
vector_db_auto_cleanup_interval: int = Field(default=3600, description="自动清理间隔(秒)")
|
||||
vector_db_retention_hours: int = Field(default=720, description="记忆保留时间(小时,默认30天)")
|
||||
|
||||
# 遗忘引擎配置
|
||||
enable_memory_forgetting: bool = Field(default=True, description="启用智能遗忘机制")
|
||||
forgetting_check_interval_hours: int = Field(default=24, description="遗忘检查间隔(小时)")
|
||||
base_forgetting_days: float = Field(default=30.0, description="基础遗忘天数")
|
||||
min_forgetting_days: float = Field(default=7.0, description="最小遗忘天数")
|
||||
max_forgetting_days: float = Field(default=365.0, description="最大遗忘天数")
|
||||
|
||||
# 重要程度权重
|
||||
critical_importance_bonus: float = Field(default=45.0, description="关键重要性额外天数")
|
||||
high_importance_bonus: float = Field(default=30.0, description="高重要性额外天数")
|
||||
normal_importance_bonus: float = Field(default=15.0, description="一般重要性额外天数")
|
||||
low_importance_bonus: float = Field(default=0.0, description="低重要性额外天数")
|
||||
|
||||
# 置信度权重
|
||||
verified_confidence_bonus: float = Field(default=30.0, description="已验证置信度额外天数")
|
||||
high_confidence_bonus: float = Field(default=20.0, description="高置信度额外天数")
|
||||
medium_confidence_bonus: float = Field(default=10.0, description="中等置信度额外天数")
|
||||
low_confidence_bonus: float = Field(default=0.0, description="低置信度额外天数")
|
||||
|
||||
# 激活频率权重
|
||||
activation_frequency_weight: float = Field(default=0.5, description="每次激活增加的天数权重")
|
||||
max_frequency_bonus: float = Field(default=10.0, description="最大激活频率奖励天数")
|
||||
|
||||
# 休眠机制
|
||||
dormant_threshold_days: int = Field(default=90, description="休眠状态判定天数")
|
||||
|
||||
|
||||
class MoodConfig(ValidatedConfigBase):
|
||||
"""情绪配置类"""
|
||||
|
||||
@@ -795,7 +795,7 @@ class LLMRequest:
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
|
||||
await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
|
||||
content, reasoning, _ = self._prompt_processor.process_response(response.content or "", False)
|
||||
reasoning = response.reasoning_content or reasoning
|
||||
|
||||
|
||||
Reference in New Issue
Block a user