From 8a8d2ed57417e82ae319c907c9b136f6d85daa0d Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Wed, 27 Aug 2025 21:40:03 +0800 Subject: [PATCH] =?UTF-8?q?refactor(memory):=20=E9=87=8D=E6=9E=84=E5=90=91?= =?UTF-8?q?=E9=87=8F=E8=AE=B0=E5=BF=86=E6=B8=85=E7=90=86=E9=80=BB=E8=BE=91?= =?UTF-8?q?=E4=BB=A5=E6=8F=90=E9=AB=98=E7=A8=B3=E5=AE=9A=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 原有的清理逻辑直接使用 delete 和 where 条件(timestamp: {"$lt": ...})来删除过期记录。然而,ChromaDB 对元数据中复杂的查询操作符(如 $lt)的支持并不可靠。 为确保过期记录能被稳定地清除,本次提交将清理策略修改为更稳健的“获取-过滤-删除”模式: 1. 为向量数据库抽象层新增 `get` 方法,并为 ChromaDB 提供具体实现。 2. 在 `VectorInstantMemoryV2` 中,先获取指定聊天的所有记录。 3. 在应用代码中根据时间戳筛选出过期的记录ID。 4. 最后根据ID列表精确删除过期记录,确保了清理操作的准确性。 --- .../memory_system/vector_instant_memory.py | 44 +++++++++++++------ src/common/vector_db/base.py | 28 ++++++++++++ src/common/vector_db/chromadb_impl.py | 26 +++++++++++ 3 files changed, 84 insertions(+), 14 deletions(-) diff --git a/src/chat/memory_system/vector_instant_memory.py b/src/chat/memory_system/vector_instant_memory.py index 360d52bec..9c7824d9a 100644 --- a/src/chat/memory_system/vector_instant_memory.py +++ b/src/chat/memory_system/vector_instant_memory.py @@ -87,22 +87,38 @@ class VectorInstantMemoryV2: """清理过期的聊天记录""" try: expire_time = time.time() - (self.retention_hours * 3600) - - # 使用 where 条件来删除过期记录 - # 注意: ChromaDB 的 where 过滤器目前对 timestamp 的 $lt 操作支持可能有限 - # 一个更可靠的方法是 get() -> filter -> delete() - # 但为了简化,我们先尝试直接 delete - - # TODO: 确认 ChromaDB 对 $lt 在 metadata 上的支持。如果不支持,需要实现 get-filter-delete 模式。 - vector_db_service.delete( + + # 采用 get -> filter -> delete 模式,避免复杂的 where 查询 + # 1. 获取当前 chat_id 的所有文档 + results = vector_db_service.get( collection_name=self.collection_name, - where={ - "chat_id": self.chat_id, - "timestamp": {"$lt": expire_time} - } + where={"chat_id": self.chat_id}, + include=["metadatas"] ) - logger.info(f"已为 chat_id '{self.chat_id}' 触发过期记录清理") - + + if not results or not results.get('ids'): + logger.info(f"chat_id '{self.chat_id}' 没有找到任何记录,无需清理") + return + + # 2. 在内存中过滤出过期的文档 + expired_ids = [] + metadatas = results.get('metadatas', []) + ids = results.get('ids', []) + + for i, metadata in enumerate(metadatas): + if metadata and metadata.get('timestamp', float('inf')) < expire_time: + expired_ids.append(ids[i]) + + # 3. 如果有过期文档,根据 ID 进行删除 + if expired_ids: + vector_db_service.delete( + collection_name=self.collection_name, + ids=expired_ids + ) + logger.info(f"为 chat_id '{self.chat_id}' 清理了 {len(expired_ids)} 条过期记录") + else: + logger.info(f"chat_id '{self.chat_id}' 没有需要清理的过期记录") + except Exception as e: logger.error(f"清理过期记录失败: {e}") diff --git a/src/common/vector_db/base.py b/src/common/vector_db/base.py index 56494ab73..e94b74cba 100644 --- a/src/common/vector_db/base.py +++ b/src/common/vector_db/base.py @@ -93,6 +93,34 @@ class VectorDBBase(ABC): """ pass + @abstractmethod + def get( + self, + collection_name: str, + ids: Optional[List[str]] = None, + where: Optional[Dict[str, Any]] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + where_document: Optional[Dict[str, Any]] = None, + include: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """ + 根据条件从集合中获取数据。 + + Args: + collection_name (str): 目标集合的名称。 + ids (Optional[List[str]], optional): 要获取的条目的 ID 列表。Defaults to None. + where (Optional[Dict[str, Any]], optional): 基于元数据的过滤条件。Defaults to None. + limit (Optional[int], optional): 返回结果的数量限制。Defaults to None. + offset (Optional[int], optional): 返回结果的偏移量。Defaults to None. + where_document (Optional[Dict[str, Any]], optional): 基于文档内容的过滤条件。Defaults to None. + include (Optional[List[str]], optional): 指定返回的数据字段 (e.g., ["metadatas", "documents"])。Defaults to None. + + Returns: + Dict[str, Any]: 获取到的数据。 + """ + pass + @abstractmethod def count(self, collection_name: str) -> int: """ diff --git a/src/common/vector_db/chromadb_impl.py b/src/common/vector_db/chromadb_impl.py index a43c2a6d8..8e9313b3b 100644 --- a/src/common/vector_db/chromadb_impl.py +++ b/src/common/vector_db/chromadb_impl.py @@ -102,6 +102,32 @@ class ChromaDBImpl(VectorDBBase): logger.error(f"查询集合 '{collection_name}' 失败: {e}") return {} + def get( + self, + collection_name: str, + ids: Optional[List[str]] = None, + where: Optional[Dict[str, Any]] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + where_document: Optional[Dict[str, Any]] = None, + include: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """根据条件从集合中获取数据""" + collection = self.get_or_create_collection(collection_name) + if collection: + try: + return collection.get( + ids=ids, + where=where, + limit=limit, + offset=offset, + where_document=where_document, + include=include, + ) + except Exception as e: + logger.error(f"从集合 '{collection_name}' 获取数据失败: {e}") + return {} + def delete( self, collection_name: str,