feat(chromadb): 添加全局锁以保护 ChromaDB 操作,确保线程安全

This commit is contained in:
Windpicker-owo
2025-12-04 18:58:07 +08:00
parent 0949e7fa3f
commit e7cb04bfdd
2 changed files with 118 additions and 77 deletions

View File

@@ -10,11 +10,17 @@ from .base import VectorDBBase
logger = get_logger("chromadb_impl")
# 全局操作锁,用于保护 ChromaDB 的所有操作
# ChromaDB 的 Rust 后端在 Windows 上多线程并发访问时可能导致 access violation
_operation_lock = threading.Lock()
class ChromaDBImpl(VectorDBBase):
"""
ChromaDB 的具体实现,遵循 VectorDBBase 接口。
采用单例模式,确保全局只有一个 ChromaDB 客户端实例。
注意:所有操作都使用 _operation_lock 保护,以避免 Windows 上的并发访问崩溃。
"""
_instance = None
@@ -36,9 +42,10 @@ class ChromaDBImpl(VectorDBBase):
with self._lock:
if not hasattr(self, "_initialized"):
try:
self.client = chromadb.PersistentClient(
path=path, settings=Settings(anonymized_telemetry=False)
)
with _operation_lock:
self.client = chromadb.PersistentClient(
path=path, settings=Settings(anonymized_telemetry=False)
)
self._collections: dict[str, Any] = {}
self._initialized = True
logger.info(f"ChromaDB 客户端已初始化,数据库路径: {path}")
@@ -56,7 +63,8 @@ class ChromaDBImpl(VectorDBBase):
return self._collections[name]
try:
collection = self.client.get_or_create_collection(name=name, **kwargs)
with _operation_lock:
collection = self.client.get_or_create_collection(name=name, **kwargs)
self._collections[name] = collection
logger.info(f"成功获取或创建集合: '{name}'")
return collection
@@ -75,12 +83,13 @@ class ChromaDBImpl(VectorDBBase):
collection = self.get_or_create_collection(collection_name)
if collection:
try:
collection.add(
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
ids=ids,
)
with _operation_lock:
collection.add(
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
ids=ids,
)
except Exception as e:
logger.error(f"向集合 '{collection_name}' 添加数据失败: {e}")
@@ -107,7 +116,8 @@ class ChromaDBImpl(VectorDBBase):
if processed_where:
query_params["where"] = processed_where
return collection.query(**query_params)
with _operation_lock:
return collection.query(**query_params)
except Exception as e:
logger.error(f"查询集合 '{collection_name}' 失败: {e}")
# 如果查询失败尝试不使用where条件重新查询
@@ -117,7 +127,8 @@ class ChromaDBImpl(VectorDBBase):
"n_results": n_results,
}
logger.warning("使用回退查询模式无where条件")
return collection.query(**fallback_params)
with _operation_lock:
return collection.query(**fallback_params)
except Exception as fallback_e:
logger.error(f"回退查询也失败: {fallback_e}")
return {}
@@ -192,26 +203,28 @@ class ChromaDBImpl(VectorDBBase):
if where:
processed_where = self._process_where_condition(where)
return collection.get(
ids=ids,
where=processed_where,
limit=limit,
offset=offset,
where_document=where_document,
include=include or ["documents", "metadatas", "embeddings"],
)
except Exception as e:
logger.error(f"从集合 '{collection_name}' 获取数据失败: {e}")
# 如果获取失败尝试不使用where条件重新获取
try:
logger.warning("使用回退获取模式无where条件")
with _operation_lock:
return collection.get(
ids=ids,
where=processed_where,
limit=limit,
offset=offset,
where_document=where_document,
include=include or ["documents", "metadatas", "embeddings"],
)
except Exception as e:
logger.error(f"从集合 '{collection_name}' 获取数据失败: {e}")
# 如果获取失败尝试不使用where条件重新获取
try:
logger.warning("使用回退获取模式无where条件")
with _operation_lock:
return collection.get(
ids=ids,
limit=limit,
offset=offset,
where_document=where_document,
include=include or ["documents", "metadatas", "embeddings"],
)
except Exception as fallback_e:
logger.error(f"回退获取也失败: {fallback_e}")
return {}
@@ -225,7 +238,8 @@ class ChromaDBImpl(VectorDBBase):
collection = self.get_or_create_collection(collection_name)
if collection:
try:
collection.delete(ids=ids, where=where)
with _operation_lock:
collection.delete(ids=ids, where=where)
except Exception as e:
logger.error(f"从集合 '{collection_name}' 删除数据失败: {e}")
@@ -233,7 +247,8 @@ class ChromaDBImpl(VectorDBBase):
collection = self.get_or_create_collection(collection_name)
if collection:
try:
return collection.count()
with _operation_lock:
return collection.count()
except Exception as e:
logger.error(f"获取集合 '{collection_name}' 计数失败: {e}")
return 0
@@ -243,7 +258,8 @@ class ChromaDBImpl(VectorDBBase):
raise ConnectionError("ChromaDB 客户端未初始化")
try:
self.client.delete_collection(name=name)
with _operation_lock:
self.client.delete_collection(name=name)
if name in self._collections:
del self._collections[name]
logger.info(f"集合 '{name}' 已被删除")

View File

@@ -3,11 +3,15 @@
注意ChromaDB 是同步库,所有操作都必须使用 asyncio.to_thread() 包装
以避免阻塞 asyncio 事件循环导致死锁。
重要ChromaDB 的 Rust 后端在 Windows 上多线程并发访问时可能导致 access violation
因此所有操作都需要通过全局锁保护以确保串行执行。
"""
from __future__ import annotations
import asyncio
import threading
from pathlib import Path
from typing import Any
@@ -18,6 +22,10 @@ from src.memory_graph.models import MemoryNode, NodeType
logger = get_logger(__name__)
# 全局锁,用于保护 ChromaDB 的所有操作
# ChromaDB 的 Rust 后端在 Windows 上多线程并发访问时可能导致 access violation
_chromadb_lock = threading.Lock()
class VectorStore:
"""
@@ -57,29 +65,35 @@ class VectorStore:
import chromadb
from chromadb.config import Settings
# 创建持久化客户端 - 同步操作需要在线程中执行
# 创建持久化客户端 - 同步操作需要在线程中执行,并使用锁保护
def _create_client():
return chromadb.PersistentClient(
path=str(self.data_dir / "chroma"),
settings=Settings(
anonymized_telemetry=False,
allow_reset=True,
),
)
with _chromadb_lock:
return chromadb.PersistentClient(
path=str(self.data_dir / "chroma"),
settings=Settings(
anonymized_telemetry=False,
allow_reset=True,
),
)
self.client = await asyncio.to_thread(_create_client)
# 获取或创建集合 - 同步操作需要在线程中执行
# 获取或创建集合 - 同步操作需要在线程中执行,并使用锁保护
def _get_or_create_collection():
return self.client.get_or_create_collection(
name=self.collection_name,
metadata={"description": "Memory graph node embeddings"},
)
with _chromadb_lock:
return self.client.get_or_create_collection(
name=self.collection_name,
metadata={"description": "Memory graph node embeddings"},
)
self.collection = await asyncio.to_thread(_get_or_create_collection)
# count() 也是同步操作
count = await asyncio.to_thread(self.collection.count)
# count() 也是同步操作,使用锁保护
def _count():
with _chromadb_lock:
return self.collection.count()
count = await asyncio.to_thread(_count)
logger.debug(f"ChromaDB 初始化完成,集合包含 {count} 个节点")
except Exception as e:
@@ -118,14 +132,15 @@ class VectorStore:
else:
metadata[key] = str(value)
# ChromaDB add() 是同步阻塞操作,必须在线程中执行
# ChromaDB add() 是同步阻塞操作,必须在线程中执行,使用锁保护
def _add_node():
self.collection.add(
ids=[node.id],
embeddings=[node.embedding.tolist()],
metadatas=[metadata],
documents=[node.content],
)
with _chromadb_lock:
self.collection.add(
ids=[node.id],
embeddings=[node.embedding.tolist()],
metadatas=[metadata],
documents=[node.content],
)
await asyncio.to_thread(_add_node)
@@ -171,14 +186,15 @@ class VectorStore:
metadata[key] = str(value)
metadatas.append(metadata)
# ChromaDB add() 是同步阻塞操作,必须在线程中执行
# ChromaDB add() 是同步阻塞操作,必须在线程中执行,使用锁保护
def _add_batch():
self.collection.add(
ids=[n.id for n in valid_nodes],
embeddings=[n.embedding.tolist() for n in valid_nodes], # type: ignore
metadatas=metadatas,
documents=[n.content for n in valid_nodes],
)
with _chromadb_lock:
self.collection.add(
ids=[n.id for n in valid_nodes],
embeddings=[n.embedding.tolist() for n in valid_nodes], # type: ignore
metadatas=metadatas,
documents=[n.content for n in valid_nodes],
)
await asyncio.to_thread(_add_batch)
@@ -214,13 +230,14 @@ class VectorStore:
if node_types:
where_filter = {"node_type": {"$in": [nt.value for nt in node_types]}}
# ChromaDB query() 是同步阻塞操作,必须在线程中执行
# ChromaDB query() 是同步阻塞操作,必须在线程中执行,使用锁保护
def _query():
return self.collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=limit,
where=where_filter,
)
with _chromadb_lock:
return self.collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=limit,
where=where_filter,
)
results = await asyncio.to_thread(_query)
@@ -383,9 +400,10 @@ class VectorStore:
raise RuntimeError("向量存储未初始化")
try:
# ChromaDB get() 是同步阻塞操作,必须在线程中执行
# ChromaDB get() 是同步阻塞操作,必须在线程中执行,使用锁保护
def _get():
return self.collection.get(ids=[node_id], include=["metadatas", "embeddings"])
with _chromadb_lock:
return self.collection.get(ids=[node_id], include=["metadatas", "embeddings"])
result = await asyncio.to_thread(_get)
@@ -420,9 +438,10 @@ class VectorStore:
raise RuntimeError("向量存储未初始化")
try:
# ChromaDB delete() 是同步阻塞操作,必须在线程中执行
# ChromaDB delete() 是同步阻塞操作,必须在线程中执行,使用锁保护
def _delete():
self.collection.delete(ids=[node_id])
with _chromadb_lock:
self.collection.delete(ids=[node_id])
await asyncio.to_thread(_delete)
logger.debug(f"删除节点: {node_id}")
@@ -443,9 +462,10 @@ class VectorStore:
raise RuntimeError("向量存储未初始化")
try:
# ChromaDB update() 是同步阻塞操作,必须在线程中执行
# ChromaDB update() 是同步阻塞操作,必须在线程中执行,使用锁保护
def _update():
self.collection.update(ids=[node_id], embeddings=[embedding.tolist()])
with _chromadb_lock:
self.collection.update(ids=[node_id], embeddings=[embedding.tolist()])
await asyncio.to_thread(_update)
logger.debug(f"更新节点 embedding: {node_id}")
@@ -458,13 +478,17 @@ class VectorStore:
"""获取向量存储中的节点总数(同步方法,谨慎在 async 上下文中使用)"""
if not self.collection:
return 0
return self.collection.count()
with _chromadb_lock:
return self.collection.count()
async def get_total_count_async(self) -> int:
"""异步获取向量存储中的节点总数"""
if not self.collection:
return 0
return await asyncio.to_thread(self.collection.count)
def _count():
with _chromadb_lock:
return self.collection.count()
return await asyncio.to_thread(_count)
async def clear(self) -> None:
"""清空向量存储(危险操作,仅用于测试)"""
@@ -472,13 +496,14 @@ class VectorStore:
return
try:
# ChromaDB delete_collection 和 get_or_create_collection 都是同步阻塞操作
# ChromaDB delete_collection 和 get_or_create_collection 都是同步阻塞操作,使用锁保护
def _clear():
self.client.delete_collection(self.collection_name)
return self.client.get_or_create_collection(
name=self.collection_name,
metadata={"description": "Memory graph node embeddings"},
)
with _chromadb_lock:
self.client.delete_collection(self.collection_name)
return self.client.get_or_create_collection(
name=self.collection_name,
metadata={"description": "Memory graph node embeddings"},
)
self.collection = await asyncio.to_thread(_clear)
logger.warning(f"向量存储已清空: {self.collection_name}")