re-style: 格式化代码
This commit is contained in:
@@ -18,4 +18,4 @@ def get_vector_db_service() -> VectorDBBase:
|
||||
# 全局向量数据库服务实例
|
||||
vector_db_service: VectorDBBase = get_vector_db_service()
|
||||
|
||||
__all__ = ["vector_db_service", "VectorDBBase"]
|
||||
__all__ = ["VectorDBBase", "vector_db_service"]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class VectorDBBase(ABC):
|
||||
@@ -36,10 +36,10 @@ class VectorDBBase(ABC):
|
||||
def add(
|
||||
self,
|
||||
collection_name: str,
|
||||
embeddings: List[List[float]],
|
||||
documents: Optional[List[str]] = None,
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
embeddings: list[list[float]],
|
||||
documents: list[str] | None = None,
|
||||
metadatas: list[dict[str, Any]] | None = None,
|
||||
ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
向指定集合中添加数据。
|
||||
@@ -57,11 +57,11 @@ class VectorDBBase(ABC):
|
||||
def query(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_embeddings: List[List[float]],
|
||||
query_embeddings: list[list[float]],
|
||||
n_results: int = 1,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
where: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, List[Any]]:
|
||||
) -> dict[str, list[Any]]:
|
||||
"""
|
||||
在指定集合中查询相似向量。
|
||||
|
||||
@@ -81,8 +81,8 @@ class VectorDBBase(ABC):
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
ids: list[str] | None = None,
|
||||
where: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
从指定集合中删除数据。
|
||||
@@ -98,13 +98,13 @@ class VectorDBBase(ABC):
|
||||
def get(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[Dict[str, Any]] = None,
|
||||
include: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
ids: list[str] | None = None,
|
||||
where: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
where_document: dict[str, Any] | None = None,
|
||||
include: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
根据条件从集合中获取数据。
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
from .base import VectorDBBase
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .base import VectorDBBase
|
||||
|
||||
logger = get_logger("chromadb_impl")
|
||||
|
||||
|
||||
@@ -38,7 +39,7 @@ class ChromaDBImpl(VectorDBBase):
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=path, settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
self._collections: Dict[str, Any] = {}
|
||||
self._collections: dict[str, Any] = {}
|
||||
self._initialized = True
|
||||
logger.info(f"ChromaDB 客户端已初始化,数据库路径: {path}")
|
||||
except Exception as e:
|
||||
@@ -65,10 +66,10 @@ class ChromaDBImpl(VectorDBBase):
|
||||
def add(
|
||||
self,
|
||||
collection_name: str,
|
||||
embeddings: List[List[float]],
|
||||
documents: Optional[List[str]] = None,
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
embeddings: list[list[float]],
|
||||
documents: list[str] | None = None,
|
||||
metadatas: list[dict[str, Any]] | None = None,
|
||||
ids: list[str] | None = None,
|
||||
) -> None:
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
@@ -85,11 +86,11 @@ class ChromaDBImpl(VectorDBBase):
|
||||
def query(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_embeddings: List[List[float]],
|
||||
query_embeddings: list[list[float]],
|
||||
n_results: int = 1,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
where: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, List[Any]]:
|
||||
) -> dict[str, list[Any]]:
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
try:
|
||||
@@ -120,7 +121,7 @@ class ChromaDBImpl(VectorDBBase):
|
||||
logger.error(f"回退查询也失败: {fallback_e}")
|
||||
return {}
|
||||
|
||||
def _process_where_condition(self, where: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
def _process_where_condition(self, where: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""
|
||||
处理where条件,转换为ChromaDB支持的格式
|
||||
ChromaDB支持的格式:
|
||||
@@ -174,13 +175,13 @@ class ChromaDBImpl(VectorDBBase):
|
||||
def get(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[Dict[str, Any]] = None,
|
||||
include: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
ids: list[str] | None = None,
|
||||
where: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
where_document: dict[str, Any] | None = None,
|
||||
include: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""根据条件从集合中获取数据"""
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
@@ -217,8 +218,8 @@ class ChromaDBImpl(VectorDBBase):
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
ids: list[str] | None = None,
|
||||
where: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
|
||||
Reference in New Issue
Block a user