re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
parent ecb02cae31
commit 7923eafef3
263 changed files with 3103 additions and 3123 deletions

View File

@@ -18,4 +18,4 @@ def get_vector_db_service() -> VectorDBBase:
# 全局向量数据库服务实例
vector_db_service: VectorDBBase = get_vector_db_service()
__all__ = ["vector_db_service", "VectorDBBase"]
__all__ = ["VectorDBBase", "vector_db_service"]

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any
class VectorDBBase(ABC):
@@ -36,10 +36,10 @@ class VectorDBBase(ABC):
def add(
self,
collection_name: str,
embeddings: List[List[float]],
documents: Optional[List[str]] = None,
metadatas: Optional[List[Dict[str, Any]]] = None,
ids: Optional[List[str]] = None,
embeddings: list[list[float]],
documents: list[str] | None = None,
metadatas: list[dict[str, Any]] | None = None,
ids: list[str] | None = None,
) -> None:
"""
向指定集合中添加数据。
@@ -57,11 +57,11 @@ class VectorDBBase(ABC):
def query(
self,
collection_name: str,
query_embeddings: List[List[float]],
query_embeddings: list[list[float]],
n_results: int = 1,
where: Optional[Dict[str, Any]] = None,
where: dict[str, Any] | None = None,
**kwargs: Any,
) -> Dict[str, List[Any]]:
) -> dict[str, list[Any]]:
"""
在指定集合中查询相似向量。
@@ -81,8 +81,8 @@ class VectorDBBase(ABC):
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
where: Optional[Dict[str, Any]] = None,
ids: list[str] | None = None,
where: dict[str, Any] | None = None,
) -> None:
"""
从指定集合中删除数据。
@@ -98,13 +98,13 @@ class VectorDBBase(ABC):
def get(
self,
collection_name: str,
ids: Optional[List[str]] = None,
where: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[Dict[str, Any]] = None,
include: Optional[List[str]] = None,
) -> Dict[str, Any]:
ids: list[str] | None = None,
where: dict[str, Any] | None = None,
limit: int | None = None,
offset: int | None = None,
where_document: dict[str, Any] | None = None,
include: list[str] | None = None,
) -> dict[str, Any]:
"""
根据条件从集合中获取数据。

View File

@@ -1,12 +1,13 @@
import threading
from typing import Any, Dict, List, Optional
from typing import Any
import chromadb
from chromadb.config import Settings
from .base import VectorDBBase
from src.common.logger import get_logger
from .base import VectorDBBase
logger = get_logger("chromadb_impl")
@@ -38,7 +39,7 @@ class ChromaDBImpl(VectorDBBase):
self.client = chromadb.PersistentClient(
path=path, settings=Settings(anonymized_telemetry=False)
)
self._collections: Dict[str, Any] = {}
self._collections: dict[str, Any] = {}
self._initialized = True
logger.info(f"ChromaDB 客户端已初始化,数据库路径: {path}")
except Exception as e:
@@ -65,10 +66,10 @@ class ChromaDBImpl(VectorDBBase):
def add(
self,
collection_name: str,
embeddings: List[List[float]],
documents: Optional[List[str]] = None,
metadatas: Optional[List[Dict[str, Any]]] = None,
ids: Optional[List[str]] = None,
embeddings: list[list[float]],
documents: list[str] | None = None,
metadatas: list[dict[str, Any]] | None = None,
ids: list[str] | None = None,
) -> None:
collection = self.get_or_create_collection(collection_name)
if collection:
@@ -85,11 +86,11 @@ class ChromaDBImpl(VectorDBBase):
def query(
self,
collection_name: str,
query_embeddings: List[List[float]],
query_embeddings: list[list[float]],
n_results: int = 1,
where: Optional[Dict[str, Any]] = None,
where: dict[str, Any] | None = None,
**kwargs: Any,
) -> Dict[str, List[Any]]:
) -> dict[str, list[Any]]:
collection = self.get_or_create_collection(collection_name)
if collection:
try:
@@ -120,7 +121,7 @@ class ChromaDBImpl(VectorDBBase):
logger.error(f"回退查询也失败: {fallback_e}")
return {}
def _process_where_condition(self, where: Dict[str, Any]) -> Optional[Dict[str, Any]]:
def _process_where_condition(self, where: dict[str, Any]) -> dict[str, Any] | None:
"""
处理where条件转换为ChromaDB支持的格式
ChromaDB支持的格式
@@ -174,13 +175,13 @@ class ChromaDBImpl(VectorDBBase):
def get(
self,
collection_name: str,
ids: Optional[List[str]] = None,
where: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[Dict[str, Any]] = None,
include: Optional[List[str]] = None,
) -> Dict[str, Any]:
ids: list[str] | None = None,
where: dict[str, Any] | None = None,
limit: int | None = None,
offset: int | None = None,
where_document: dict[str, Any] | None = None,
include: list[str] | None = None,
) -> dict[str, Any]:
"""根据条件从集合中获取数据"""
collection = self.get_or_create_collection(collection_name)
if collection:
@@ -217,8 +218,8 @@ class ChromaDBImpl(VectorDBBase):
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
where: Optional[Dict[str, Any]] = None,
ids: list[str] | None = None,
where: dict[str, Any] | None = None,
) -> None:
collection = self.get_or_create_collection(collection_name)
if collection: