diff --git a/src/chat/memory_system/vector_instant_memory.py b/src/chat/memory_system/vector_instant_memory.py index 360d52bec..9c7824d9a 100644 --- a/src/chat/memory_system/vector_instant_memory.py +++ b/src/chat/memory_system/vector_instant_memory.py @@ -87,22 +87,38 @@ class VectorInstantMemoryV2: """清理过期的聊天记录""" try: expire_time = time.time() - (self.retention_hours * 3600) - - # 使用 where 条件来删除过期记录 - # 注意: ChromaDB 的 where 过滤器目前对 timestamp 的 $lt 操作支持可能有限 - # 一个更可靠的方法是 get() -> filter -> delete() - # 但为了简化,我们先尝试直接 delete - - # TODO: 确认 ChromaDB 对 $lt 在 metadata 上的支持。如果不支持,需要实现 get-filter-delete 模式。 - vector_db_service.delete( + + # 采用 get -> filter -> delete 模式,避免复杂的 where 查询 + # 1. 获取当前 chat_id 的所有文档 + results = vector_db_service.get( collection_name=self.collection_name, - where={ - "chat_id": self.chat_id, - "timestamp": {"$lt": expire_time} - } + where={"chat_id": self.chat_id}, + include=["metadatas"] ) - logger.info(f"已为 chat_id '{self.chat_id}' 触发过期记录清理") - + + if not results or not results.get('ids'): + logger.info(f"chat_id '{self.chat_id}' 没有找到任何记录,无需清理") + return + + # 2. 在内存中过滤出过期的文档 + expired_ids = [] + metadatas = results.get('metadatas', []) + ids = results.get('ids', []) + + for i, metadata in enumerate(metadatas): + if metadata and metadata.get('timestamp', float('inf')) < expire_time: + expired_ids.append(ids[i]) + + # 3. 如果有过期文档,根据 ID 进行删除 + if expired_ids: + vector_db_service.delete( + collection_name=self.collection_name, + ids=expired_ids + ) + logger.info(f"为 chat_id '{self.chat_id}' 清理了 {len(expired_ids)} 条过期记录") + else: + logger.info(f"chat_id '{self.chat_id}' 没有需要清理的过期记录") + except Exception as e: logger.error(f"清理过期记录失败: {e}") diff --git a/src/common/vector_db/base.py b/src/common/vector_db/base.py index 56494ab73..e94b74cba 100644 --- a/src/common/vector_db/base.py +++ b/src/common/vector_db/base.py @@ -93,6 +93,34 @@ class VectorDBBase(ABC): """ pass + @abstractmethod + def get( + self, + collection_name: str, + ids: Optional[List[str]] = None, + where: Optional[Dict[str, Any]] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + where_document: Optional[Dict[str, Any]] = None, + include: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """ + 根据条件从集合中获取数据。 + + Args: + collection_name (str): 目标集合的名称。 + ids (Optional[List[str]], optional): 要获取的条目的 ID 列表。Defaults to None. + where (Optional[Dict[str, Any]], optional): 基于元数据的过滤条件。Defaults to None. + limit (Optional[int], optional): 返回结果的数量限制。Defaults to None. + offset (Optional[int], optional): 返回结果的偏移量。Defaults to None. + where_document (Optional[Dict[str, Any]], optional): 基于文档内容的过滤条件。Defaults to None. + include (Optional[List[str]], optional): 指定返回的数据字段 (e.g., ["metadatas", "documents"])。Defaults to None. + + Returns: + Dict[str, Any]: 获取到的数据。 + """ + pass + @abstractmethod def count(self, collection_name: str) -> int: """ diff --git a/src/common/vector_db/chromadb_impl.py b/src/common/vector_db/chromadb_impl.py index a43c2a6d8..8e9313b3b 100644 --- a/src/common/vector_db/chromadb_impl.py +++ b/src/common/vector_db/chromadb_impl.py @@ -102,6 +102,32 @@ class ChromaDBImpl(VectorDBBase): logger.error(f"查询集合 '{collection_name}' 失败: {e}") return {} + def get( + self, + collection_name: str, + ids: Optional[List[str]] = None, + where: Optional[Dict[str, Any]] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + where_document: Optional[Dict[str, Any]] = None, + include: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """根据条件从集合中获取数据""" + collection = self.get_or_create_collection(collection_name) + if collection: + try: + return collection.get( + ids=ids, + where=where, + limit=limit, + offset=offset, + where_document=where_document, + include=include, + ) + except Exception as e: + logger.error(f"从集合 '{collection_name}' 获取数据失败: {e}") + return {} + def delete( self, collection_name: str,