feat(core): 集成统一向量数据库服务并重构相关模块
本次提交引入了一个统一的、可扩展的向量数据库服务层,旨在解决代码重复、实现分散以及数据库实例泛滥的问题。 主要变更: 新增向量数据库抽象层: 在 src/common/vector_db/ 目录下创建了 VectorDBBase 抽象基类,定义了标准化的数据库操作接口。 创建了 ChromaDBImpl 作为具体的实现,并采用单例模式确保全局只有一个数据库客户端实例。 重构语义缓存 (CacheManager): 移除了对 chromadb 库的直接依赖。 改为调用统一的 vector_db_service 来进行向量的添加和查询操作。 重构瞬时记忆 (VectorInstantMemoryV2): 彻底解决了为每个 chat_id 创建独立数据库实例的问题。 现在所有记忆数据都存储在统一的 instant_memory 集合中,并通过 metadata 中的 chat_id 进行数据隔离和查询。 新增使用文档: 在 docs/ 目录下添加了 vector_db_usage_guide.md,详细说明了如何使用新的 vector_db_service 代码接口。 带来的好处: 高内聚,低耦合: 业务代码与具体的向量数据库实现解耦。 易于维护和扩展: 未来可以轻松替换或添加新的向量数据库支持。 性能与资源优化: 整个应用共享一个数据库连接,显著减少了文件句柄和内存占用
This commit is contained in:
@@ -4,13 +4,13 @@ import hashlib
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import faiss
|
||||
import chromadb
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.database.sqlalchemy_models import CacheEntries
|
||||
from src.common.database.sqlalchemy_database_api import db_query, db_save
|
||||
from src.common.vector_db import vector_db_service
|
||||
|
||||
logger = get_logger("cache_manager")
|
||||
|
||||
@@ -28,25 +28,23 @@ class CacheManager:
|
||||
cls._instance = super(CacheManager, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, default_ttl: int = 3600, chroma_path: str = "data/chroma_db"):
|
||||
def __init__(self, default_ttl: int = 3600):
|
||||
"""
|
||||
初始化缓存管理器。
|
||||
"""
|
||||
if not hasattr(self, '_initialized'):
|
||||
self.default_ttl = default_ttl
|
||||
|
||||
self.semantic_cache_collection_name = "semantic_cache"
|
||||
|
||||
# L1 缓存 (内存)
|
||||
self.l1_kv_cache: Dict[str, Dict[str, Any]] = {}
|
||||
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
|
||||
self.l1_vector_id_to_key: Dict[int, str] = {}
|
||||
|
||||
# 语义缓存 (ChromaDB)
|
||||
|
||||
self.chroma_client = chromadb.PersistentClient(path=chroma_path)
|
||||
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
|
||||
# L2 向量缓存 (使用新的服务)
|
||||
vector_db_service.get_or_create_collection(self.semantic_cache_collection_name)
|
||||
|
||||
|
||||
# 嵌入模型
|
||||
self.embedding_model = LLMRequest(model_config.model_task_config.embedding)
|
||||
|
||||
@@ -152,18 +150,20 @@ class CacheManager:
|
||||
return self.l1_kv_cache[l1_hit_key]["data"]
|
||||
|
||||
# 步骤 2b: L2 精确缓存 (数据库)
|
||||
cache_results = await db_query(
|
||||
cache_results_obj = await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="get",
|
||||
filters={"cache_key": key},
|
||||
single_result=True
|
||||
)
|
||||
|
||||
if cache_results:
|
||||
expires_at = cache_results["expires_at"]
|
||||
if cache_results_obj:
|
||||
# 使用 getattr 安全访问属性,避免 Pylance 类型检查错误
|
||||
expires_at = getattr(cache_results_obj, "expires_at", 0)
|
||||
if time.time() < expires_at:
|
||||
logger.info(f"命中L2键值缓存: {key}")
|
||||
data = orjson.loads(cache_results["cache_value"])
|
||||
cache_value = getattr(cache_results_obj, "cache_value", "{}")
|
||||
data = orjson.loads(cache_value)
|
||||
|
||||
# 更新访问统计
|
||||
await db_query(
|
||||
@@ -172,7 +172,7 @@ class CacheManager:
|
||||
filters={"cache_key": key},
|
||||
data={
|
||||
"last_accessed": time.time(),
|
||||
"access_count": cache_results["access_count"] + 1
|
||||
"access_count": getattr(cache_results_obj, "access_count", 0) + 1
|
||||
}
|
||||
)
|
||||
|
||||
@@ -187,29 +187,35 @@ class CacheManager:
|
||||
filters={"cache_key": key}
|
||||
)
|
||||
|
||||
# 步骤 2c: L2 语义缓存 (ChromaDB)
|
||||
if query_embedding is not None and self.chroma_collection:
|
||||
# 步骤 2c: L2 语义缓存 (VectorDB Service)
|
||||
if query_embedding is not None:
|
||||
try:
|
||||
results = self.chroma_collection.query(query_embeddings=query_embedding.tolist(), n_results=1)
|
||||
if results and results['ids'] and results['ids'][0]:
|
||||
distance = results['distances'][0][0] if results['distances'] and results['distances'][0] else 'N/A'
|
||||
results = vector_db_service.query(
|
||||
collection_name=self.semantic_cache_collection_name,
|
||||
query_embeddings=query_embedding.tolist(),
|
||||
n_results=1
|
||||
)
|
||||
if results and results.get('ids') and results['ids'][0]:
|
||||
distance = results['distances'][0][0] if results.get('distances') and results['distances'][0] else 'N/A'
|
||||
logger.debug(f"L2语义搜索找到最相似的结果: id={results['ids'][0]}, 距离={distance}")
|
||||
|
||||
if distance != 'N/A' and distance < 0.75:
|
||||
l2_hit_key = results['ids'][0][0] if isinstance(results['ids'][0], list) else results['ids'][0]
|
||||
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
|
||||
|
||||
# 从数据库获取缓存数据
|
||||
semantic_cache_results = await db_query(
|
||||
semantic_cache_results_obj = await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="get",
|
||||
filters={"cache_key": l2_hit_key},
|
||||
single_result=True
|
||||
)
|
||||
|
||||
if semantic_cache_results:
|
||||
expires_at = semantic_cache_results["expires_at"]
|
||||
if semantic_cache_results_obj:
|
||||
expires_at = getattr(semantic_cache_results_obj, "expires_at", 0)
|
||||
if time.time() < expires_at:
|
||||
data = orjson.loads(semantic_cache_results["cache_value"])
|
||||
cache_value = getattr(semantic_cache_results_obj, "cache_value", "{}")
|
||||
data = orjson.loads(cache_value)
|
||||
logger.debug(f"L2语义缓存返回的数据: {data}")
|
||||
|
||||
# 回填 L1
|
||||
@@ -218,13 +224,13 @@ class CacheManager:
|
||||
try:
|
||||
new_id = self.l1_vector_index.ntotal
|
||||
faiss.normalize_L2(query_embedding)
|
||||
self.l1_vector_index.add(x=query_embedding)
|
||||
self.l1_vector_index.add(x=query_embedding) # type: ignore
|
||||
self.l1_vector_id_to_key[new_id] = key
|
||||
except Exception as e:
|
||||
logger.error(f"回填L1向量索引时发生错误: {e}")
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.warning(f"ChromaDB查询失败: {e}")
|
||||
logger.warning(f"VectorDB Service 查询失败: {e}")
|
||||
|
||||
logger.debug(f"缓存未命中: {key}")
|
||||
return None
|
||||
@@ -261,22 +267,27 @@ class CacheManager:
|
||||
)
|
||||
|
||||
# 写入语义缓存
|
||||
if semantic_query and self.embedding_model and self.chroma_collection:
|
||||
if semantic_query and self.embedding_model:
|
||||
try:
|
||||
embedding_result = await self.embedding_model.get_embedding(semantic_query)
|
||||
if embedding_result:
|
||||
# embedding_result是一个元组(embedding_vector, model_name),取第一个元素
|
||||
embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result
|
||||
validated_embedding = self._validate_embedding(embedding_vector)
|
||||
if validated_embedding is not None:
|
||||
embedding = np.array([validated_embedding], dtype='float32')
|
||||
|
||||
# 写入 L1 Vector
|
||||
new_id = self.l1_vector_index.ntotal
|
||||
faiss.normalize_L2(embedding)
|
||||
self.l1_vector_index.add(x=embedding)
|
||||
self.l1_vector_index.add(x=embedding) # type: ignore
|
||||
self.l1_vector_id_to_key[new_id] = key
|
||||
# 写入 L2 Vector
|
||||
self.chroma_collection.add(embeddings=embedding.tolist(), ids=[key])
|
||||
|
||||
# 写入 L2 Vector (使用新的服务)
|
||||
vector_db_service.add(
|
||||
collection_name=self.semantic_cache_collection_name,
|
||||
embeddings=embedding.tolist(),
|
||||
ids=[key]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"语义缓存写入失败: {e}")
|
||||
|
||||
@@ -298,15 +309,14 @@ class CacheManager:
|
||||
filters={} # 删除所有记录
|
||||
)
|
||||
|
||||
# 清空ChromaDB
|
||||
if self.chroma_collection:
|
||||
try:
|
||||
self.chroma_client.delete_collection(name="semantic_cache")
|
||||
self.chroma_collection = self.chroma_client.get_or_create_collection(name="semantic_cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"清空ChromaDB失败: {e}")
|
||||
# 清空 VectorDB
|
||||
try:
|
||||
vector_db_service.delete_collection(name=self.semantic_cache_collection_name)
|
||||
vector_db_service.get_or_create_collection(name=self.semantic_cache_collection_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"清空 VectorDB 集合失败: {e}")
|
||||
|
||||
logger.info("L2 (数据库 & ChromaDB) 缓存已清空。")
|
||||
logger.info("L2 (数据库 & VectorDB) 缓存已清空。")
|
||||
|
||||
async def clear_all(self):
|
||||
"""清空所有缓存。"""
|
||||
|
||||
19
src/common/vector_db/__init__.py
Normal file
19
src/common/vector_db/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from .base import VectorDBBase
|
||||
from .chromadb_impl import ChromaDBImpl
|
||||
|
||||
def get_vector_db_service() -> VectorDBBase:
|
||||
"""
|
||||
工厂函数,初始化并返回向量数据库服务实例。
|
||||
|
||||
目前硬编码为 ChromaDB,未来可以从配置中读取。
|
||||
"""
|
||||
# TODO: 从全局配置中读取数据库类型和路径
|
||||
db_path = "data/chroma_db"
|
||||
|
||||
# ChromaDBImpl 是一个单例,所以这里每次调用都会返回同一个实例
|
||||
return ChromaDBImpl(path=db_path)
|
||||
|
||||
# 全局向量数据库服务实例
|
||||
vector_db_service: VectorDBBase = get_vector_db_service()
|
||||
|
||||
__all__ = ["vector_db_service", "VectorDBBase"]
|
||||
117
src/common/vector_db/base.py
Normal file
117
src/common/vector_db/base.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
class VectorDBBase(ABC):
|
||||
"""
|
||||
向量数据库的抽象基类 (ABC),定义了所有向量数据库实现必须遵循的接口。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, path: str, **kwargs: Any):
|
||||
"""
|
||||
初始化向量数据库客户端。
|
||||
|
||||
Args:
|
||||
path (str): 数据库文件的存储路径。
|
||||
**kwargs: 其他特定于实现的参数。
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_or_create_collection(self, name: str, **kwargs: Any) -> Any:
|
||||
"""
|
||||
获取或创建一个集合 (Collection)。
|
||||
|
||||
Args:
|
||||
name (str): 集合的名称。
|
||||
**kwargs: 其他特定于实现的参数 (例如 metadata)。
|
||||
|
||||
Returns:
|
||||
Any: 代表集合的对象。
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add(
|
||||
self,
|
||||
collection_name: str,
|
||||
embeddings: List[List[float]],
|
||||
documents: Optional[List[str]] = None,
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
向指定集合中添加数据。
|
||||
|
||||
Args:
|
||||
collection_name (str): 目标集合的名称。
|
||||
embeddings (List[List[float]]): 向量列表。
|
||||
documents (Optional[List[str]], optional): 文档列表。Defaults to None.
|
||||
metadatas (Optional[List[Dict[str, Any]]], optional): 元数据列表。Defaults to None.
|
||||
ids (Optional[List[str]], optional): ID 列表。Defaults to None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_embeddings: List[List[float]],
|
||||
n_results: int = 1,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, List[Any]]:
|
||||
"""
|
||||
在指定集合中查询相似向量。
|
||||
|
||||
Args:
|
||||
collection_name (str): 目标集合的名称。
|
||||
query_embeddings (List[List[float]]): 用于查询的向量列表。
|
||||
n_results (int, optional): 返回结果的数量。Defaults to 1.
|
||||
where (Optional[Dict[str, Any]], optional): 元数据过滤条件。Defaults to None.
|
||||
**kwargs: 其他特定于实现的参数。
|
||||
|
||||
Returns:
|
||||
Dict[str, List[Any]]: 查询结果,通常包含 ids, distances, metadatas, documents。
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
从指定集合中删除数据。
|
||||
|
||||
Args:
|
||||
collection_name (str): 目标集合的名称。
|
||||
ids (Optional[List[str]], optional): 要删除的条目的 ID 列表。Defaults to None.
|
||||
where (Optional[Dict[str, Any]], optional): 基于元数据的过滤条件。Defaults to None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def count(self, collection_name: str) -> int:
|
||||
"""
|
||||
获取指定集合中的条目总数。
|
||||
|
||||
Args:
|
||||
collection_name (str): 目标集合的名称。
|
||||
|
||||
Returns:
|
||||
int: 条目总数。
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_collection(self, name: str) -> None:
|
||||
"""
|
||||
删除一个集合。
|
||||
|
||||
Args:
|
||||
name (str): 要删除的集合的名称。
|
||||
"""
|
||||
pass
|
||||
137
src/common/vector_db/chromadb_impl.py
Normal file
137
src/common/vector_db/chromadb_impl.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
from .base import VectorDBBase
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("chromadb_impl")
|
||||
|
||||
class ChromaDBImpl(VectorDBBase):
|
||||
"""
|
||||
ChromaDB 的具体实现,遵循 VectorDBBase 接口。
|
||||
采用单例模式,确保全局只有一个 ChromaDB 客户端实例。
|
||||
"""
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not cls._instance:
|
||||
with cls._lock:
|
||||
if not cls._instance:
|
||||
cls._instance = super(ChromaDBImpl, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, path: str = "data/chroma_db", **kwargs: Any):
|
||||
"""
|
||||
初始化 ChromaDB 客户端。
|
||||
由于是单例,这个初始化只会执行一次。
|
||||
"""
|
||||
if not hasattr(self, '_initialized'):
|
||||
with self._lock:
|
||||
if not hasattr(self, '_initialized'):
|
||||
try:
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=path,
|
||||
settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
self._collections: Dict[str, Any] = {}
|
||||
self._initialized = True
|
||||
logger.info(f"ChromaDB 客户端已初始化,数据库路径: {path}")
|
||||
except Exception as e:
|
||||
logger.error(f"ChromaDB 初始化失败: {e}")
|
||||
self.client = None
|
||||
self._initialized = False
|
||||
|
||||
def get_or_create_collection(self, name: str, **kwargs: Any) -> Any:
|
||||
if not self.client:
|
||||
raise ConnectionError("ChromaDB 客户端未初始化")
|
||||
|
||||
if name in self._collections:
|
||||
return self._collections[name]
|
||||
|
||||
try:
|
||||
collection = self.client.get_or_create_collection(name=name, **kwargs)
|
||||
self._collections[name] = collection
|
||||
logger.info(f"成功获取或创建集合: '{name}'")
|
||||
return collection
|
||||
except Exception as e:
|
||||
logger.error(f"获取或创建集合 '{name}' 失败: {e}")
|
||||
return None
|
||||
|
||||
def add(
|
||||
self,
|
||||
collection_name: str,
|
||||
embeddings: List[List[float]],
|
||||
documents: Optional[List[str]] = None,
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
try:
|
||||
collection.add(
|
||||
embeddings=embeddings,
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"向集合 '{collection_name}' 添加数据失败: {e}")
|
||||
|
||||
def query(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_embeddings: List[List[float]],
|
||||
n_results: int = 1,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, List[Any]]:
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
try:
|
||||
return collection.query(
|
||||
query_embeddings=query_embeddings,
|
||||
n_results=n_results,
|
||||
where=where or {},
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"查询集合 '{collection_name}' 失败: {e}")
|
||||
return {}
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
try:
|
||||
collection.delete(ids=ids, where=where)
|
||||
except Exception as e:
|
||||
logger.error(f"从集合 '{collection_name}' 删除数据失败: {e}")
|
||||
|
||||
def count(self, collection_name: str) -> int:
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
try:
|
||||
return collection.count()
|
||||
except Exception as e:
|
||||
logger.error(f"获取集合 '{collection_name}' 计数失败: {e}")
|
||||
return 0
|
||||
|
||||
def delete_collection(self, name: str) -> None:
|
||||
if not self.client:
|
||||
raise ConnectionError("ChromaDB 客户端未初始化")
|
||||
|
||||
try:
|
||||
self.client.delete_collection(name=name)
|
||||
if name in self._collections:
|
||||
del self._collections[name]
|
||||
logger.info(f"集合 '{name}' 已被删除")
|
||||
except Exception as e:
|
||||
logger.error(f"删除集合 '{name}' 失败: {e}")
|
||||
Reference in New Issue
Block a user