feat(chromadb): 添加全局锁以保护 ChromaDB 操作,确保线程安全
This commit is contained in:
@@ -10,11 +10,17 @@ from .base import VectorDBBase
|
||||
|
||||
logger = get_logger("chromadb_impl")
|
||||
|
||||
# 全局操作锁,用于保护 ChromaDB 的所有操作
|
||||
# ChromaDB 的 Rust 后端在 Windows 上多线程并发访问时可能导致 access violation
|
||||
_operation_lock = threading.Lock()
|
||||
|
||||
|
||||
class ChromaDBImpl(VectorDBBase):
|
||||
"""
|
||||
ChromaDB 的具体实现,遵循 VectorDBBase 接口。
|
||||
采用单例模式,确保全局只有一个 ChromaDB 客户端实例。
|
||||
|
||||
注意:所有操作都使用 _operation_lock 保护,以避免 Windows 上的并发访问崩溃。
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
@@ -36,9 +42,10 @@ class ChromaDBImpl(VectorDBBase):
|
||||
with self._lock:
|
||||
if not hasattr(self, "_initialized"):
|
||||
try:
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=path, settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
with _operation_lock:
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=path, settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
self._collections: dict[str, Any] = {}
|
||||
self._initialized = True
|
||||
logger.info(f"ChromaDB 客户端已初始化,数据库路径: {path}")
|
||||
@@ -56,7 +63,8 @@ class ChromaDBImpl(VectorDBBase):
|
||||
return self._collections[name]
|
||||
|
||||
try:
|
||||
collection = self.client.get_or_create_collection(name=name, **kwargs)
|
||||
with _operation_lock:
|
||||
collection = self.client.get_or_create_collection(name=name, **kwargs)
|
||||
self._collections[name] = collection
|
||||
logger.info(f"成功获取或创建集合: '{name}'")
|
||||
return collection
|
||||
@@ -75,12 +83,13 @@ class ChromaDBImpl(VectorDBBase):
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
try:
|
||||
collection.add(
|
||||
embeddings=embeddings,
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
)
|
||||
with _operation_lock:
|
||||
collection.add(
|
||||
embeddings=embeddings,
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"向集合 '{collection_name}' 添加数据失败: {e}")
|
||||
|
||||
@@ -107,7 +116,8 @@ class ChromaDBImpl(VectorDBBase):
|
||||
if processed_where:
|
||||
query_params["where"] = processed_where
|
||||
|
||||
return collection.query(**query_params)
|
||||
with _operation_lock:
|
||||
return collection.query(**query_params)
|
||||
except Exception as e:
|
||||
logger.error(f"查询集合 '{collection_name}' 失败: {e}")
|
||||
# 如果查询失败,尝试不使用where条件重新查询
|
||||
@@ -117,7 +127,8 @@ class ChromaDBImpl(VectorDBBase):
|
||||
"n_results": n_results,
|
||||
}
|
||||
logger.warning("使用回退查询模式(无where条件)")
|
||||
return collection.query(**fallback_params)
|
||||
with _operation_lock:
|
||||
return collection.query(**fallback_params)
|
||||
except Exception as fallback_e:
|
||||
logger.error(f"回退查询也失败: {fallback_e}")
|
||||
return {}
|
||||
@@ -192,26 +203,28 @@ class ChromaDBImpl(VectorDBBase):
|
||||
if where:
|
||||
processed_where = self._process_where_condition(where)
|
||||
|
||||
return collection.get(
|
||||
ids=ids,
|
||||
where=processed_where,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
where_document=where_document,
|
||||
include=include or ["documents", "metadatas", "embeddings"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"从集合 '{collection_name}' 获取数据失败: {e}")
|
||||
# 如果获取失败,尝试不使用where条件重新获取
|
||||
try:
|
||||
logger.warning("使用回退获取模式(无where条件)")
|
||||
with _operation_lock:
|
||||
return collection.get(
|
||||
ids=ids,
|
||||
where=processed_where,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
where_document=where_document,
|
||||
include=include or ["documents", "metadatas", "embeddings"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"从集合 '{collection_name}' 获取数据失败: {e}")
|
||||
# 如果获取失败,尝试不使用where条件重新获取
|
||||
try:
|
||||
logger.warning("使用回退获取模式(无where条件)")
|
||||
with _operation_lock:
|
||||
return collection.get(
|
||||
ids=ids,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
where_document=where_document,
|
||||
include=include or ["documents", "metadatas", "embeddings"],
|
||||
)
|
||||
except Exception as fallback_e:
|
||||
logger.error(f"回退获取也失败: {fallback_e}")
|
||||
return {}
|
||||
@@ -225,7 +238,8 @@ class ChromaDBImpl(VectorDBBase):
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
try:
|
||||
collection.delete(ids=ids, where=where)
|
||||
with _operation_lock:
|
||||
collection.delete(ids=ids, where=where)
|
||||
except Exception as e:
|
||||
logger.error(f"从集合 '{collection_name}' 删除数据失败: {e}")
|
||||
|
||||
@@ -233,7 +247,8 @@ class ChromaDBImpl(VectorDBBase):
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
try:
|
||||
return collection.count()
|
||||
with _operation_lock:
|
||||
return collection.count()
|
||||
except Exception as e:
|
||||
logger.error(f"获取集合 '{collection_name}' 计数失败: {e}")
|
||||
return 0
|
||||
@@ -243,7 +258,8 @@ class ChromaDBImpl(VectorDBBase):
|
||||
raise ConnectionError("ChromaDB 客户端未初始化")
|
||||
|
||||
try:
|
||||
self.client.delete_collection(name=name)
|
||||
with _operation_lock:
|
||||
self.client.delete_collection(name=name)
|
||||
if name in self._collections:
|
||||
del self._collections[name]
|
||||
logger.info(f"集合 '{name}' 已被删除")
|
||||
|
||||
@@ -3,11 +3,15 @@
|
||||
|
||||
注意:ChromaDB 是同步库,所有操作都必须使用 asyncio.to_thread() 包装
|
||||
以避免阻塞 asyncio 事件循环导致死锁。
|
||||
|
||||
重要:ChromaDB 的 Rust 后端在 Windows 上多线程并发访问时可能导致 access violation,
|
||||
因此所有操作都需要通过全局锁保护以确保串行执行。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -18,6 +22,10 @@ from src.memory_graph.models import MemoryNode, NodeType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 全局锁,用于保护 ChromaDB 的所有操作
|
||||
# ChromaDB 的 Rust 后端在 Windows 上多线程并发访问时可能导致 access violation
|
||||
_chromadb_lock = threading.Lock()
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""
|
||||
@@ -57,29 +65,35 @@ class VectorStore:
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
# 创建持久化客户端 - 同步操作需要在线程中执行
|
||||
# 创建持久化客户端 - 同步操作需要在线程中执行,并使用锁保护
|
||||
def _create_client():
|
||||
return chromadb.PersistentClient(
|
||||
path=str(self.data_dir / "chroma"),
|
||||
settings=Settings(
|
||||
anonymized_telemetry=False,
|
||||
allow_reset=True,
|
||||
),
|
||||
)
|
||||
with _chromadb_lock:
|
||||
return chromadb.PersistentClient(
|
||||
path=str(self.data_dir / "chroma"),
|
||||
settings=Settings(
|
||||
anonymized_telemetry=False,
|
||||
allow_reset=True,
|
||||
),
|
||||
)
|
||||
|
||||
self.client = await asyncio.to_thread(_create_client)
|
||||
|
||||
# 获取或创建集合 - 同步操作需要在线程中执行
|
||||
# 获取或创建集合 - 同步操作需要在线程中执行,并使用锁保护
|
||||
def _get_or_create_collection():
|
||||
return self.client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={"description": "Memory graph node embeddings"},
|
||||
)
|
||||
with _chromadb_lock:
|
||||
return self.client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={"description": "Memory graph node embeddings"},
|
||||
)
|
||||
|
||||
self.collection = await asyncio.to_thread(_get_or_create_collection)
|
||||
|
||||
# count() 也是同步操作
|
||||
count = await asyncio.to_thread(self.collection.count)
|
||||
# count() 也是同步操作,使用锁保护
|
||||
def _count():
|
||||
with _chromadb_lock:
|
||||
return self.collection.count()
|
||||
|
||||
count = await asyncio.to_thread(_count)
|
||||
logger.debug(f"ChromaDB 初始化完成,集合包含 {count} 个节点")
|
||||
|
||||
except Exception as e:
|
||||
@@ -118,14 +132,15 @@ class VectorStore:
|
||||
else:
|
||||
metadata[key] = str(value)
|
||||
|
||||
# ChromaDB add() 是同步阻塞操作,必须在线程中执行
|
||||
# ChromaDB add() 是同步阻塞操作,必须在线程中执行,使用锁保护
|
||||
def _add_node():
|
||||
self.collection.add(
|
||||
ids=[node.id],
|
||||
embeddings=[node.embedding.tolist()],
|
||||
metadatas=[metadata],
|
||||
documents=[node.content],
|
||||
)
|
||||
with _chromadb_lock:
|
||||
self.collection.add(
|
||||
ids=[node.id],
|
||||
embeddings=[node.embedding.tolist()],
|
||||
metadatas=[metadata],
|
||||
documents=[node.content],
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_add_node)
|
||||
|
||||
@@ -171,14 +186,15 @@ class VectorStore:
|
||||
metadata[key] = str(value)
|
||||
metadatas.append(metadata)
|
||||
|
||||
# ChromaDB add() 是同步阻塞操作,必须在线程中执行
|
||||
# ChromaDB add() 是同步阻塞操作,必须在线程中执行,使用锁保护
|
||||
def _add_batch():
|
||||
self.collection.add(
|
||||
ids=[n.id for n in valid_nodes],
|
||||
embeddings=[n.embedding.tolist() for n in valid_nodes], # type: ignore
|
||||
metadatas=metadatas,
|
||||
documents=[n.content for n in valid_nodes],
|
||||
)
|
||||
with _chromadb_lock:
|
||||
self.collection.add(
|
||||
ids=[n.id for n in valid_nodes],
|
||||
embeddings=[n.embedding.tolist() for n in valid_nodes], # type: ignore
|
||||
metadatas=metadatas,
|
||||
documents=[n.content for n in valid_nodes],
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_add_batch)
|
||||
|
||||
@@ -214,13 +230,14 @@ class VectorStore:
|
||||
if node_types:
|
||||
where_filter = {"node_type": {"$in": [nt.value for nt in node_types]}}
|
||||
|
||||
# ChromaDB query() 是同步阻塞操作,必须在线程中执行
|
||||
# ChromaDB query() 是同步阻塞操作,必须在线程中执行,使用锁保护
|
||||
def _query():
|
||||
return self.collection.query(
|
||||
query_embeddings=[query_embedding.tolist()],
|
||||
n_results=limit,
|
||||
where=where_filter,
|
||||
)
|
||||
with _chromadb_lock:
|
||||
return self.collection.query(
|
||||
query_embeddings=[query_embedding.tolist()],
|
||||
n_results=limit,
|
||||
where=where_filter,
|
||||
)
|
||||
|
||||
results = await asyncio.to_thread(_query)
|
||||
|
||||
@@ -383,9 +400,10 @@ class VectorStore:
|
||||
raise RuntimeError("向量存储未初始化")
|
||||
|
||||
try:
|
||||
# ChromaDB get() 是同步阻塞操作,必须在线程中执行
|
||||
# ChromaDB get() 是同步阻塞操作,必须在线程中执行,使用锁保护
|
||||
def _get():
|
||||
return self.collection.get(ids=[node_id], include=["metadatas", "embeddings"])
|
||||
with _chromadb_lock:
|
||||
return self.collection.get(ids=[node_id], include=["metadatas", "embeddings"])
|
||||
|
||||
result = await asyncio.to_thread(_get)
|
||||
|
||||
@@ -420,9 +438,10 @@ class VectorStore:
|
||||
raise RuntimeError("向量存储未初始化")
|
||||
|
||||
try:
|
||||
# ChromaDB delete() 是同步阻塞操作,必须在线程中执行
|
||||
# ChromaDB delete() 是同步阻塞操作,必须在线程中执行,使用锁保护
|
||||
def _delete():
|
||||
self.collection.delete(ids=[node_id])
|
||||
with _chromadb_lock:
|
||||
self.collection.delete(ids=[node_id])
|
||||
|
||||
await asyncio.to_thread(_delete)
|
||||
logger.debug(f"删除节点: {node_id}")
|
||||
@@ -443,9 +462,10 @@ class VectorStore:
|
||||
raise RuntimeError("向量存储未初始化")
|
||||
|
||||
try:
|
||||
# ChromaDB update() 是同步阻塞操作,必须在线程中执行
|
||||
# ChromaDB update() 是同步阻塞操作,必须在线程中执行,使用锁保护
|
||||
def _update():
|
||||
self.collection.update(ids=[node_id], embeddings=[embedding.tolist()])
|
||||
with _chromadb_lock:
|
||||
self.collection.update(ids=[node_id], embeddings=[embedding.tolist()])
|
||||
|
||||
await asyncio.to_thread(_update)
|
||||
logger.debug(f"更新节点 embedding: {node_id}")
|
||||
@@ -458,13 +478,17 @@ class VectorStore:
|
||||
"""获取向量存储中的节点总数(同步方法,谨慎在 async 上下文中使用)"""
|
||||
if not self.collection:
|
||||
return 0
|
||||
return self.collection.count()
|
||||
with _chromadb_lock:
|
||||
return self.collection.count()
|
||||
|
||||
async def get_total_count_async(self) -> int:
|
||||
"""异步获取向量存储中的节点总数"""
|
||||
if not self.collection:
|
||||
return 0
|
||||
return await asyncio.to_thread(self.collection.count)
|
||||
def _count():
|
||||
with _chromadb_lock:
|
||||
return self.collection.count()
|
||||
return await asyncio.to_thread(_count)
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""清空向量存储(危险操作,仅用于测试)"""
|
||||
@@ -472,13 +496,14 @@ class VectorStore:
|
||||
return
|
||||
|
||||
try:
|
||||
# ChromaDB delete_collection 和 get_or_create_collection 都是同步阻塞操作
|
||||
# ChromaDB delete_collection 和 get_or_create_collection 都是同步阻塞操作,使用锁保护
|
||||
def _clear():
|
||||
self.client.delete_collection(self.collection_name)
|
||||
return self.client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={"description": "Memory graph node embeddings"},
|
||||
)
|
||||
with _chromadb_lock:
|
||||
self.client.delete_collection(self.collection_name)
|
||||
return self.client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={"description": "Memory graph node embeddings"},
|
||||
)
|
||||
|
||||
self.collection = await asyncio.to_thread(_clear)
|
||||
logger.warning(f"向量存储已清空: {self.collection_name}")
|
||||
|
||||
Reference in New Issue
Block a user