From e7cb04bfdde573096d32060c72a4b643bf2e66da Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 4 Dec 2025 18:58:07 +0800 Subject: [PATCH] =?UTF-8?q?feat(chromadb):=20=E6=B7=BB=E5=8A=A0=E5=85=A8?= =?UTF-8?q?=E5=B1=80=E9=94=81=E4=BB=A5=E4=BF=9D=E6=8A=A4=20ChromaDB=20?= =?UTF-8?q?=E6=93=8D=E4=BD=9C=EF=BC=8C=E7=A1=AE=E4=BF=9D=E7=BA=BF=E7=A8=8B?= =?UTF-8?q?=E5=AE=89=E5=85=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/vector_db/chromadb_impl.py | 72 +++++++------ src/memory_graph/storage/vector_store.py | 123 ++++++++++++++--------- 2 files changed, 118 insertions(+), 77 deletions(-) diff --git a/src/common/vector_db/chromadb_impl.py b/src/common/vector_db/chromadb_impl.py index e4a2911ae..487bf6246 100644 --- a/src/common/vector_db/chromadb_impl.py +++ b/src/common/vector_db/chromadb_impl.py @@ -10,11 +10,17 @@ from .base import VectorDBBase logger = get_logger("chromadb_impl") +# 全局操作锁,用于保护 ChromaDB 的所有操作 +# ChromaDB 的 Rust 后端在 Windows 上多线程并发访问时可能导致 access violation +_operation_lock = threading.Lock() + class ChromaDBImpl(VectorDBBase): """ ChromaDB 的具体实现,遵循 VectorDBBase 接口。 采用单例模式,确保全局只有一个 ChromaDB 客户端实例。 + + 注意:所有操作都使用 _operation_lock 保护,以避免 Windows 上的并发访问崩溃。 """ _instance = None @@ -36,9 +42,10 @@ class ChromaDBImpl(VectorDBBase): with self._lock: if not hasattr(self, "_initialized"): try: - self.client = chromadb.PersistentClient( - path=path, settings=Settings(anonymized_telemetry=False) - ) + with _operation_lock: + self.client = chromadb.PersistentClient( + path=path, settings=Settings(anonymized_telemetry=False) + ) self._collections: dict[str, Any] = {} self._initialized = True logger.info(f"ChromaDB 客户端已初始化,数据库路径: {path}") @@ -56,7 +63,8 @@ class ChromaDBImpl(VectorDBBase): return self._collections[name] try: - collection = self.client.get_or_create_collection(name=name, **kwargs) + with _operation_lock: + collection = self.client.get_or_create_collection(name=name, **kwargs) self._collections[name] = collection logger.info(f"成功获取或创建集合: '{name}'") return collection @@ -75,12 +83,13 @@ class ChromaDBImpl(VectorDBBase): collection = self.get_or_create_collection(collection_name) if collection: try: - collection.add( - embeddings=embeddings, - documents=documents, - metadatas=metadatas, - ids=ids, - ) + with _operation_lock: + collection.add( + embeddings=embeddings, + documents=documents, + metadatas=metadatas, + ids=ids, + ) except Exception as e: logger.error(f"向集合 '{collection_name}' 添加数据失败: {e}") @@ -107,7 +116,8 @@ class ChromaDBImpl(VectorDBBase): if processed_where: query_params["where"] = processed_where - return collection.query(**query_params) + with _operation_lock: + return collection.query(**query_params) except Exception as e: logger.error(f"查询集合 '{collection_name}' 失败: {e}") # 如果查询失败,尝试不使用where条件重新查询 @@ -117,7 +127,8 @@ class ChromaDBImpl(VectorDBBase): "n_results": n_results, } logger.warning("使用回退查询模式(无where条件)") - return collection.query(**fallback_params) + with _operation_lock: + return collection.query(**fallback_params) except Exception as fallback_e: logger.error(f"回退查询也失败: {fallback_e}") return {} @@ -192,26 +203,28 @@ class ChromaDBImpl(VectorDBBase): if where: processed_where = self._process_where_condition(where) - return collection.get( - ids=ids, - where=processed_where, - limit=limit, - offset=offset, - where_document=where_document, - include=include or ["documents", "metadatas", "embeddings"], - ) - except Exception as e: - logger.error(f"从集合 '{collection_name}' 获取数据失败: {e}") - # 如果获取失败,尝试不使用where条件重新获取 - try: - logger.warning("使用回退获取模式(无where条件)") + with _operation_lock: return collection.get( ids=ids, + where=processed_where, limit=limit, offset=offset, where_document=where_document, include=include or ["documents", "metadatas", "embeddings"], ) + except Exception as e: + logger.error(f"从集合 '{collection_name}' 获取数据失败: {e}") + # 如果获取失败,尝试不使用where条件重新获取 + try: + logger.warning("使用回退获取模式(无where条件)") + with _operation_lock: + return collection.get( + ids=ids, + limit=limit, + offset=offset, + where_document=where_document, + include=include or ["documents", "metadatas", "embeddings"], + ) except Exception as fallback_e: logger.error(f"回退获取也失败: {fallback_e}") return {} @@ -225,7 +238,8 @@ class ChromaDBImpl(VectorDBBase): collection = self.get_or_create_collection(collection_name) if collection: try: - collection.delete(ids=ids, where=where) + with _operation_lock: + collection.delete(ids=ids, where=where) except Exception as e: logger.error(f"从集合 '{collection_name}' 删除数据失败: {e}") @@ -233,7 +247,8 @@ class ChromaDBImpl(VectorDBBase): collection = self.get_or_create_collection(collection_name) if collection: try: - return collection.count() + with _operation_lock: + return collection.count() except Exception as e: logger.error(f"获取集合 '{collection_name}' 计数失败: {e}") return 0 @@ -243,7 +258,8 @@ class ChromaDBImpl(VectorDBBase): raise ConnectionError("ChromaDB 客户端未初始化") try: - self.client.delete_collection(name=name) + with _operation_lock: + self.client.delete_collection(name=name) if name in self._collections: del self._collections[name] logger.info(f"集合 '{name}' 已被删除") diff --git a/src/memory_graph/storage/vector_store.py b/src/memory_graph/storage/vector_store.py index b59ea1b83..6e602a732 100644 --- a/src/memory_graph/storage/vector_store.py +++ b/src/memory_graph/storage/vector_store.py @@ -3,11 +3,15 @@ 注意:ChromaDB 是同步库,所有操作都必须使用 asyncio.to_thread() 包装 以避免阻塞 asyncio 事件循环导致死锁。 + +重要:ChromaDB 的 Rust 后端在 Windows 上多线程并发访问时可能导致 access violation, +因此所有操作都需要通过全局锁保护以确保串行执行。 """ from __future__ import annotations import asyncio +import threading from pathlib import Path from typing import Any @@ -18,6 +22,10 @@ from src.memory_graph.models import MemoryNode, NodeType logger = get_logger(__name__) +# 全局锁,用于保护 ChromaDB 的所有操作 +# ChromaDB 的 Rust 后端在 Windows 上多线程并发访问时可能导致 access violation +_chromadb_lock = threading.Lock() + class VectorStore: """ @@ -57,29 +65,35 @@ class VectorStore: import chromadb from chromadb.config import Settings - # 创建持久化客户端 - 同步操作需要在线程中执行 + # 创建持久化客户端 - 同步操作需要在线程中执行,并使用锁保护 def _create_client(): - return chromadb.PersistentClient( - path=str(self.data_dir / "chroma"), - settings=Settings( - anonymized_telemetry=False, - allow_reset=True, - ), - ) + with _chromadb_lock: + return chromadb.PersistentClient( + path=str(self.data_dir / "chroma"), + settings=Settings( + anonymized_telemetry=False, + allow_reset=True, + ), + ) self.client = await asyncio.to_thread(_create_client) - # 获取或创建集合 - 同步操作需要在线程中执行 + # 获取或创建集合 - 同步操作需要在线程中执行,并使用锁保护 def _get_or_create_collection(): - return self.client.get_or_create_collection( - name=self.collection_name, - metadata={"description": "Memory graph node embeddings"}, - ) + with _chromadb_lock: + return self.client.get_or_create_collection( + name=self.collection_name, + metadata={"description": "Memory graph node embeddings"}, + ) self.collection = await asyncio.to_thread(_get_or_create_collection) - # count() 也是同步操作 - count = await asyncio.to_thread(self.collection.count) + # count() 也是同步操作,使用锁保护 + def _count(): + with _chromadb_lock: + return self.collection.count() + + count = await asyncio.to_thread(_count) logger.debug(f"ChromaDB 初始化完成,集合包含 {count} 个节点") except Exception as e: @@ -118,14 +132,15 @@ class VectorStore: else: metadata[key] = str(value) - # ChromaDB add() 是同步阻塞操作,必须在线程中执行 + # ChromaDB add() 是同步阻塞操作,必须在线程中执行,使用锁保护 def _add_node(): - self.collection.add( - ids=[node.id], - embeddings=[node.embedding.tolist()], - metadatas=[metadata], - documents=[node.content], - ) + with _chromadb_lock: + self.collection.add( + ids=[node.id], + embeddings=[node.embedding.tolist()], + metadatas=[metadata], + documents=[node.content], + ) await asyncio.to_thread(_add_node) @@ -171,14 +186,15 @@ class VectorStore: metadata[key] = str(value) metadatas.append(metadata) - # ChromaDB add() 是同步阻塞操作,必须在线程中执行 + # ChromaDB add() 是同步阻塞操作,必须在线程中执行,使用锁保护 def _add_batch(): - self.collection.add( - ids=[n.id for n in valid_nodes], - embeddings=[n.embedding.tolist() for n in valid_nodes], # type: ignore - metadatas=metadatas, - documents=[n.content for n in valid_nodes], - ) + with _chromadb_lock: + self.collection.add( + ids=[n.id for n in valid_nodes], + embeddings=[n.embedding.tolist() for n in valid_nodes], # type: ignore + metadatas=metadatas, + documents=[n.content for n in valid_nodes], + ) await asyncio.to_thread(_add_batch) @@ -214,13 +230,14 @@ class VectorStore: if node_types: where_filter = {"node_type": {"$in": [nt.value for nt in node_types]}} - # ChromaDB query() 是同步阻塞操作,必须在线程中执行 + # ChromaDB query() 是同步阻塞操作,必须在线程中执行,使用锁保护 def _query(): - return self.collection.query( - query_embeddings=[query_embedding.tolist()], - n_results=limit, - where=where_filter, - ) + with _chromadb_lock: + return self.collection.query( + query_embeddings=[query_embedding.tolist()], + n_results=limit, + where=where_filter, + ) results = await asyncio.to_thread(_query) @@ -383,9 +400,10 @@ class VectorStore: raise RuntimeError("向量存储未初始化") try: - # ChromaDB get() 是同步阻塞操作,必须在线程中执行 + # ChromaDB get() 是同步阻塞操作,必须在线程中执行,使用锁保护 def _get(): - return self.collection.get(ids=[node_id], include=["metadatas", "embeddings"]) + with _chromadb_lock: + return self.collection.get(ids=[node_id], include=["metadatas", "embeddings"]) result = await asyncio.to_thread(_get) @@ -420,9 +438,10 @@ class VectorStore: raise RuntimeError("向量存储未初始化") try: - # ChromaDB delete() 是同步阻塞操作,必须在线程中执行 + # ChromaDB delete() 是同步阻塞操作,必须在线程中执行,使用锁保护 def _delete(): - self.collection.delete(ids=[node_id]) + with _chromadb_lock: + self.collection.delete(ids=[node_id]) await asyncio.to_thread(_delete) logger.debug(f"删除节点: {node_id}") @@ -443,9 +462,10 @@ class VectorStore: raise RuntimeError("向量存储未初始化") try: - # ChromaDB update() 是同步阻塞操作,必须在线程中执行 + # ChromaDB update() 是同步阻塞操作,必须在线程中执行,使用锁保护 def _update(): - self.collection.update(ids=[node_id], embeddings=[embedding.tolist()]) + with _chromadb_lock: + self.collection.update(ids=[node_id], embeddings=[embedding.tolist()]) await asyncio.to_thread(_update) logger.debug(f"更新节点 embedding: {node_id}") @@ -458,13 +478,17 @@ class VectorStore: """获取向量存储中的节点总数(同步方法,谨慎在 async 上下文中使用)""" if not self.collection: return 0 - return self.collection.count() + with _chromadb_lock: + return self.collection.count() async def get_total_count_async(self) -> int: """异步获取向量存储中的节点总数""" if not self.collection: return 0 - return await asyncio.to_thread(self.collection.count) + def _count(): + with _chromadb_lock: + return self.collection.count() + return await asyncio.to_thread(_count) async def clear(self) -> None: """清空向量存储(危险操作,仅用于测试)""" @@ -472,13 +496,14 @@ class VectorStore: return try: - # ChromaDB delete_collection 和 get_or_create_collection 都是同步阻塞操作 + # ChromaDB delete_collection 和 get_or_create_collection 都是同步阻塞操作,使用锁保护 def _clear(): - self.client.delete_collection(self.collection_name) - return self.client.get_or_create_collection( - name=self.collection_name, - metadata={"description": "Memory graph node embeddings"}, - ) + with _chromadb_lock: + self.client.delete_collection(self.collection_name) + return self.client.get_or_create_collection( + name=self.collection_name, + metadata={"description": "Memory graph node embeddings"}, + ) self.collection = await asyncio.to_thread(_clear) logger.warning(f"向量存储已清空: {self.collection_name}")