feat(chromadb): 添加全局锁以保护 ChromaDB 操作,确保线程安全

This commit is contained in:
Windpicker-owo
2025-12-04 18:58:07 +08:00
parent 0949e7fa3f
commit e7cb04bfdd
2 changed files with 118 additions and 77 deletions

View File

@@ -10,11 +10,17 @@ from .base import VectorDBBase
logger = get_logger("chromadb_impl") logger = get_logger("chromadb_impl")
# 全局操作锁,用于保护 ChromaDB 的所有操作
# ChromaDB 的 Rust 后端在 Windows 上多线程并发访问时可能导致 access violation
_operation_lock = threading.Lock()
class ChromaDBImpl(VectorDBBase): class ChromaDBImpl(VectorDBBase):
""" """
ChromaDB 的具体实现,遵循 VectorDBBase 接口。 ChromaDB 的具体实现,遵循 VectorDBBase 接口。
采用单例模式,确保全局只有一个 ChromaDB 客户端实例。 采用单例模式,确保全局只有一个 ChromaDB 客户端实例。
注意:所有操作都使用 _operation_lock 保护,以避免 Windows 上的并发访问崩溃。
""" """
_instance = None _instance = None
@@ -36,9 +42,10 @@ class ChromaDBImpl(VectorDBBase):
with self._lock: with self._lock:
if not hasattr(self, "_initialized"): if not hasattr(self, "_initialized"):
try: try:
self.client = chromadb.PersistentClient( with _operation_lock:
path=path, settings=Settings(anonymized_telemetry=False) self.client = chromadb.PersistentClient(
) path=path, settings=Settings(anonymized_telemetry=False)
)
self._collections: dict[str, Any] = {} self._collections: dict[str, Any] = {}
self._initialized = True self._initialized = True
logger.info(f"ChromaDB 客户端已初始化,数据库路径: {path}") logger.info(f"ChromaDB 客户端已初始化,数据库路径: {path}")
@@ -56,7 +63,8 @@ class ChromaDBImpl(VectorDBBase):
return self._collections[name] return self._collections[name]
try: try:
collection = self.client.get_or_create_collection(name=name, **kwargs) with _operation_lock:
collection = self.client.get_or_create_collection(name=name, **kwargs)
self._collections[name] = collection self._collections[name] = collection
logger.info(f"成功获取或创建集合: '{name}'") logger.info(f"成功获取或创建集合: '{name}'")
return collection return collection
@@ -75,12 +83,13 @@ class ChromaDBImpl(VectorDBBase):
collection = self.get_or_create_collection(collection_name) collection = self.get_or_create_collection(collection_name)
if collection: if collection:
try: try:
collection.add( with _operation_lock:
embeddings=embeddings, collection.add(
documents=documents, embeddings=embeddings,
metadatas=metadatas, documents=documents,
ids=ids, metadatas=metadatas,
) ids=ids,
)
except Exception as e: except Exception as e:
logger.error(f"向集合 '{collection_name}' 添加数据失败: {e}") logger.error(f"向集合 '{collection_name}' 添加数据失败: {e}")
@@ -107,7 +116,8 @@ class ChromaDBImpl(VectorDBBase):
if processed_where: if processed_where:
query_params["where"] = processed_where query_params["where"] = processed_where
return collection.query(**query_params) with _operation_lock:
return collection.query(**query_params)
except Exception as e: except Exception as e:
logger.error(f"查询集合 '{collection_name}' 失败: {e}") logger.error(f"查询集合 '{collection_name}' 失败: {e}")
# 如果查询失败尝试不使用where条件重新查询 # 如果查询失败尝试不使用where条件重新查询
@@ -117,7 +127,8 @@ class ChromaDBImpl(VectorDBBase):
"n_results": n_results, "n_results": n_results,
} }
logger.warning("使用回退查询模式无where条件") logger.warning("使用回退查询模式无where条件")
return collection.query(**fallback_params) with _operation_lock:
return collection.query(**fallback_params)
except Exception as fallback_e: except Exception as fallback_e:
logger.error(f"回退查询也失败: {fallback_e}") logger.error(f"回退查询也失败: {fallback_e}")
return {} return {}
@@ -192,26 +203,28 @@ class ChromaDBImpl(VectorDBBase):
if where: if where:
processed_where = self._process_where_condition(where) processed_where = self._process_where_condition(where)
return collection.get( with _operation_lock:
ids=ids,
where=processed_where,
limit=limit,
offset=offset,
where_document=where_document,
include=include or ["documents", "metadatas", "embeddings"],
)
except Exception as e:
logger.error(f"从集合 '{collection_name}' 获取数据失败: {e}")
# 如果获取失败尝试不使用where条件重新获取
try:
logger.warning("使用回退获取模式无where条件")
return collection.get( return collection.get(
ids=ids, ids=ids,
where=processed_where,
limit=limit, limit=limit,
offset=offset, offset=offset,
where_document=where_document, where_document=where_document,
include=include or ["documents", "metadatas", "embeddings"], include=include or ["documents", "metadatas", "embeddings"],
) )
except Exception as e:
logger.error(f"从集合 '{collection_name}' 获取数据失败: {e}")
# 如果获取失败尝试不使用where条件重新获取
try:
logger.warning("使用回退获取模式无where条件")
with _operation_lock:
return collection.get(
ids=ids,
limit=limit,
offset=offset,
where_document=where_document,
include=include or ["documents", "metadatas", "embeddings"],
)
except Exception as fallback_e: except Exception as fallback_e:
logger.error(f"回退获取也失败: {fallback_e}") logger.error(f"回退获取也失败: {fallback_e}")
return {} return {}
@@ -225,7 +238,8 @@ class ChromaDBImpl(VectorDBBase):
collection = self.get_or_create_collection(collection_name) collection = self.get_or_create_collection(collection_name)
if collection: if collection:
try: try:
collection.delete(ids=ids, where=where) with _operation_lock:
collection.delete(ids=ids, where=where)
except Exception as e: except Exception as e:
logger.error(f"从集合 '{collection_name}' 删除数据失败: {e}") logger.error(f"从集合 '{collection_name}' 删除数据失败: {e}")
@@ -233,7 +247,8 @@ class ChromaDBImpl(VectorDBBase):
collection = self.get_or_create_collection(collection_name) collection = self.get_or_create_collection(collection_name)
if collection: if collection:
try: try:
return collection.count() with _operation_lock:
return collection.count()
except Exception as e: except Exception as e:
logger.error(f"获取集合 '{collection_name}' 计数失败: {e}") logger.error(f"获取集合 '{collection_name}' 计数失败: {e}")
return 0 return 0
@@ -243,7 +258,8 @@ class ChromaDBImpl(VectorDBBase):
raise ConnectionError("ChromaDB 客户端未初始化") raise ConnectionError("ChromaDB 客户端未初始化")
try: try:
self.client.delete_collection(name=name) with _operation_lock:
self.client.delete_collection(name=name)
if name in self._collections: if name in self._collections:
del self._collections[name] del self._collections[name]
logger.info(f"集合 '{name}' 已被删除") logger.info(f"集合 '{name}' 已被删除")

View File

@@ -3,11 +3,15 @@
注意ChromaDB 是同步库,所有操作都必须使用 asyncio.to_thread() 包装 注意ChromaDB 是同步库,所有操作都必须使用 asyncio.to_thread() 包装
以避免阻塞 asyncio 事件循环导致死锁。 以避免阻塞 asyncio 事件循环导致死锁。
重要ChromaDB 的 Rust 后端在 Windows 上多线程并发访问时可能导致 access violation
因此所有操作都需要通过全局锁保护以确保串行执行。
""" """
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import threading
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -18,6 +22,10 @@ from src.memory_graph.models import MemoryNode, NodeType
logger = get_logger(__name__) logger = get_logger(__name__)
# 全局锁,用于保护 ChromaDB 的所有操作
# ChromaDB 的 Rust 后端在 Windows 上多线程并发访问时可能导致 access violation
_chromadb_lock = threading.Lock()
class VectorStore: class VectorStore:
""" """
@@ -57,29 +65,35 @@ class VectorStore:
import chromadb import chromadb
from chromadb.config import Settings from chromadb.config import Settings
# 创建持久化客户端 - 同步操作需要在线程中执行 # 创建持久化客户端 - 同步操作需要在线程中执行,并使用锁保护
def _create_client(): def _create_client():
return chromadb.PersistentClient( with _chromadb_lock:
path=str(self.data_dir / "chroma"), return chromadb.PersistentClient(
settings=Settings( path=str(self.data_dir / "chroma"),
anonymized_telemetry=False, settings=Settings(
allow_reset=True, anonymized_telemetry=False,
), allow_reset=True,
) ),
)
self.client = await asyncio.to_thread(_create_client) self.client = await asyncio.to_thread(_create_client)
# 获取或创建集合 - 同步操作需要在线程中执行 # 获取或创建集合 - 同步操作需要在线程中执行,并使用锁保护
def _get_or_create_collection(): def _get_or_create_collection():
return self.client.get_or_create_collection( with _chromadb_lock:
name=self.collection_name, return self.client.get_or_create_collection(
metadata={"description": "Memory graph node embeddings"}, name=self.collection_name,
) metadata={"description": "Memory graph node embeddings"},
)
self.collection = await asyncio.to_thread(_get_or_create_collection) self.collection = await asyncio.to_thread(_get_or_create_collection)
# count() 也是同步操作 # count() 也是同步操作,使用锁保护
count = await asyncio.to_thread(self.collection.count) def _count():
with _chromadb_lock:
return self.collection.count()
count = await asyncio.to_thread(_count)
logger.debug(f"ChromaDB 初始化完成,集合包含 {count} 个节点") logger.debug(f"ChromaDB 初始化完成,集合包含 {count} 个节点")
except Exception as e: except Exception as e:
@@ -118,14 +132,15 @@ class VectorStore:
else: else:
metadata[key] = str(value) metadata[key] = str(value)
# ChromaDB add() 是同步阻塞操作,必须在线程中执行 # ChromaDB add() 是同步阻塞操作,必须在线程中执行,使用锁保护
def _add_node(): def _add_node():
self.collection.add( with _chromadb_lock:
ids=[node.id], self.collection.add(
embeddings=[node.embedding.tolist()], ids=[node.id],
metadatas=[metadata], embeddings=[node.embedding.tolist()],
documents=[node.content], metadatas=[metadata],
) documents=[node.content],
)
await asyncio.to_thread(_add_node) await asyncio.to_thread(_add_node)
@@ -171,14 +186,15 @@ class VectorStore:
metadata[key] = str(value) metadata[key] = str(value)
metadatas.append(metadata) metadatas.append(metadata)
# ChromaDB add() 是同步阻塞操作,必须在线程中执行 # ChromaDB add() 是同步阻塞操作,必须在线程中执行,使用锁保护
def _add_batch(): def _add_batch():
self.collection.add( with _chromadb_lock:
ids=[n.id for n in valid_nodes], self.collection.add(
embeddings=[n.embedding.tolist() for n in valid_nodes], # type: ignore ids=[n.id for n in valid_nodes],
metadatas=metadatas, embeddings=[n.embedding.tolist() for n in valid_nodes], # type: ignore
documents=[n.content for n in valid_nodes], metadatas=metadatas,
) documents=[n.content for n in valid_nodes],
)
await asyncio.to_thread(_add_batch) await asyncio.to_thread(_add_batch)
@@ -214,13 +230,14 @@ class VectorStore:
if node_types: if node_types:
where_filter = {"node_type": {"$in": [nt.value for nt in node_types]}} where_filter = {"node_type": {"$in": [nt.value for nt in node_types]}}
# ChromaDB query() 是同步阻塞操作,必须在线程中执行 # ChromaDB query() 是同步阻塞操作,必须在线程中执行,使用锁保护
def _query(): def _query():
return self.collection.query( with _chromadb_lock:
query_embeddings=[query_embedding.tolist()], return self.collection.query(
n_results=limit, query_embeddings=[query_embedding.tolist()],
where=where_filter, n_results=limit,
) where=where_filter,
)
results = await asyncio.to_thread(_query) results = await asyncio.to_thread(_query)
@@ -383,9 +400,10 @@ class VectorStore:
raise RuntimeError("向量存储未初始化") raise RuntimeError("向量存储未初始化")
try: try:
# ChromaDB get() 是同步阻塞操作,必须在线程中执行 # ChromaDB get() 是同步阻塞操作,必须在线程中执行,使用锁保护
def _get(): def _get():
return self.collection.get(ids=[node_id], include=["metadatas", "embeddings"]) with _chromadb_lock:
return self.collection.get(ids=[node_id], include=["metadatas", "embeddings"])
result = await asyncio.to_thread(_get) result = await asyncio.to_thread(_get)
@@ -420,9 +438,10 @@ class VectorStore:
raise RuntimeError("向量存储未初始化") raise RuntimeError("向量存储未初始化")
try: try:
# ChromaDB delete() 是同步阻塞操作,必须在线程中执行 # ChromaDB delete() 是同步阻塞操作,必须在线程中执行,使用锁保护
def _delete(): def _delete():
self.collection.delete(ids=[node_id]) with _chromadb_lock:
self.collection.delete(ids=[node_id])
await asyncio.to_thread(_delete) await asyncio.to_thread(_delete)
logger.debug(f"删除节点: {node_id}") logger.debug(f"删除节点: {node_id}")
@@ -443,9 +462,10 @@ class VectorStore:
raise RuntimeError("向量存储未初始化") raise RuntimeError("向量存储未初始化")
try: try:
# ChromaDB update() 是同步阻塞操作,必须在线程中执行 # ChromaDB update() 是同步阻塞操作,必须在线程中执行,使用锁保护
def _update(): def _update():
self.collection.update(ids=[node_id], embeddings=[embedding.tolist()]) with _chromadb_lock:
self.collection.update(ids=[node_id], embeddings=[embedding.tolist()])
await asyncio.to_thread(_update) await asyncio.to_thread(_update)
logger.debug(f"更新节点 embedding: {node_id}") logger.debug(f"更新节点 embedding: {node_id}")
@@ -458,13 +478,17 @@ class VectorStore:
"""获取向量存储中的节点总数(同步方法,谨慎在 async 上下文中使用)""" """获取向量存储中的节点总数(同步方法,谨慎在 async 上下文中使用)"""
if not self.collection: if not self.collection:
return 0 return 0
return self.collection.count() with _chromadb_lock:
return self.collection.count()
async def get_total_count_async(self) -> int: async def get_total_count_async(self) -> int:
"""异步获取向量存储中的节点总数""" """异步获取向量存储中的节点总数"""
if not self.collection: if not self.collection:
return 0 return 0
return await asyncio.to_thread(self.collection.count) def _count():
with _chromadb_lock:
return self.collection.count()
return await asyncio.to_thread(_count)
async def clear(self) -> None: async def clear(self) -> None:
"""清空向量存储(危险操作,仅用于测试)""" """清空向量存储(危险操作,仅用于测试)"""
@@ -472,13 +496,14 @@ class VectorStore:
return return
try: try:
# ChromaDB delete_collection 和 get_or_create_collection 都是同步阻塞操作 # ChromaDB delete_collection 和 get_or_create_collection 都是同步阻塞操作,使用锁保护
def _clear(): def _clear():
self.client.delete_collection(self.collection_name) with _chromadb_lock:
return self.client.get_or_create_collection( self.client.delete_collection(self.collection_name)
name=self.collection_name, return self.client.get_or_create_collection(
metadata={"description": "Memory graph node embeddings"}, name=self.collection_name,
) metadata={"description": "Memory graph node embeddings"},
)
self.collection = await asyncio.to_thread(_clear) self.collection = await asyncio.to_thread(_clear)
logger.warning(f"向量存储已清空: {self.collection_name}") logger.warning(f"向量存储已清空: {self.collection_name}")