219 lines
6.5 KiB
Python
219 lines
6.5 KiB
Python
"""
|
||
嵌入向量生成器:优先使用配置的 embedding API,失败时跳过向量生成
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import numpy as np
|
||
|
||
from src.common.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class EmbeddingGenerator:
|
||
"""
|
||
嵌入向量生成器
|
||
|
||
策略:
|
||
1. 优先使用配置的 embedding API(通过 LLMRequest)
|
||
2. 如果 API 不可用或失败,跳过向量生成,返回 None 或零向量
|
||
3. 不再使用本地 sentence-transformers 模型,避免向量维度不匹配
|
||
|
||
优点:
|
||
- 完全避免本地运算负载
|
||
- 避免向量维度不匹配问题
|
||
- 简化错误处理逻辑
|
||
- 保持与现有系统的一致性
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
use_api: bool = True,
|
||
):
|
||
"""
|
||
初始化嵌入生成器
|
||
|
||
Args:
|
||
use_api: 是否使用 API(默认 True)
|
||
"""
|
||
self.use_api = use_api
|
||
|
||
# API 相关
|
||
self._llm_request = None
|
||
self._api_available = False
|
||
self._api_dimension = None
|
||
|
||
async def _initialize_api(self):
|
||
"""初始化 embedding API"""
|
||
if self._api_available:
|
||
return
|
||
|
||
try:
|
||
from src.config.config import model_config
|
||
from src.llm_models.utils_model import LLMRequest
|
||
|
||
embedding_config = model_config.model_task_config.embedding
|
||
self._llm_request = LLMRequest(
|
||
model_set=embedding_config,
|
||
request_type="memory_graph.embedding"
|
||
)
|
||
|
||
# 获取嵌入维度
|
||
if hasattr(embedding_config, "embedding_dimension") and embedding_config.embedding_dimension:
|
||
self._api_dimension = embedding_config.embedding_dimension
|
||
|
||
self._api_available = True
|
||
logger.info(f"✅ Embedding API 初始化成功 (维度: {self._api_dimension})")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ Embedding API 初始化失败: {e}")
|
||
self._api_available = False
|
||
|
||
|
||
async def generate(self, text: str) -> np.ndarray | None:
|
||
"""
|
||
生成单个文本的嵌入向量
|
||
|
||
策略:
|
||
1. 使用 API 生成向量
|
||
2. API 失败则返回 None,跳过向量生成
|
||
|
||
Args:
|
||
text: 输入文本
|
||
|
||
Returns:
|
||
嵌入向量,失败时返回 None
|
||
"""
|
||
if not text or not text.strip():
|
||
logger.debug("输入文本为空,返回 None")
|
||
return None
|
||
|
||
try:
|
||
# 使用 API 生成嵌入
|
||
if self.use_api:
|
||
embedding = await self._generate_with_api(text)
|
||
if embedding is not None:
|
||
return embedding
|
||
|
||
# API 失败,记录日志并返回 None
|
||
logger.debug(f"⚠️ 嵌入生成失败,跳过: {text[:30]}...")
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 嵌入生成异常: {e}", exc_info=True)
|
||
return None
|
||
|
||
async def _generate_with_api(self, text: str) -> np.ndarray | None:
|
||
"""使用 API 生成嵌入"""
|
||
try:
|
||
# 初始化 API
|
||
if not self._api_available:
|
||
await self._initialize_api()
|
||
|
||
if not self._api_available or not self._llm_request:
|
||
return None
|
||
|
||
# 调用 API
|
||
embedding_list, model_name = await self._llm_request.get_embedding(text)
|
||
|
||
if embedding_list and len(embedding_list) > 0:
|
||
embedding = np.array(embedding_list, dtype=np.float32)
|
||
logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})")
|
||
return embedding
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.debug(f"API 嵌入生成失败: {e}")
|
||
return None
|
||
|
||
|
||
def _get_dimension(self) -> int:
|
||
"""获取嵌入维度"""
|
||
# 优先使用 API 维度
|
||
if self._api_dimension:
|
||
return self._api_dimension
|
||
|
||
raise ValueError("无法确定嵌入向量维度,请确保已正确配置 embedding API")
|
||
|
||
async def generate_batch(self, texts: list[str]) -> list[np.ndarray | None]:
|
||
"""
|
||
批量生成嵌入向量
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
|
||
Returns:
|
||
嵌入向量列表,失败的项目为 None
|
||
"""
|
||
if not texts:
|
||
return []
|
||
|
||
try:
|
||
# 过滤空文本
|
||
valid_texts = [t for t in texts if t and t.strip()]
|
||
if not valid_texts:
|
||
logger.debug("所有文本为空,返回 None 列表")
|
||
return [None for _ in texts]
|
||
|
||
# 使用 API 批量生成(如果可用)
|
||
if self.use_api:
|
||
results = await self._generate_batch_with_api(valid_texts)
|
||
if results:
|
||
return results
|
||
|
||
# 回退到逐个生成
|
||
results = []
|
||
for text in valid_texts:
|
||
embedding = await self.generate(text)
|
||
results.append(embedding)
|
||
|
||
success_count = sum(1 for r in results if r is not None)
|
||
logger.debug(f"✅ 批量生成嵌入: {success_count}/{len(texts)} 个成功")
|
||
return results
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 批量嵌入生成失败: {e}", exc_info=True)
|
||
return [None for _ in texts]
|
||
|
||
async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray | None] | None:
|
||
"""使用 API 批量生成"""
|
||
try:
|
||
# 对于大多数 API,批量调用就是多次单独调用
|
||
# 这里保持简单,逐个调用
|
||
results = []
|
||
for text in texts:
|
||
embedding = await self._generate_with_api(text)
|
||
results.append(embedding) # 失败的项目为 None,不中断整个批量处理
|
||
return results
|
||
except Exception as e:
|
||
logger.debug(f"API 批量生成失败: {e}")
|
||
return None
|
||
|
||
def get_embedding_dimension(self) -> int:
|
||
"""获取嵌入向量维度"""
|
||
return self._get_dimension()
|
||
|
||
|
||
# 全局单例
|
||
_global_generator: EmbeddingGenerator | None = None
|
||
|
||
|
||
def get_embedding_generator(
|
||
use_api: bool = True,
|
||
) -> EmbeddingGenerator:
|
||
"""
|
||
获取全局嵌入生成器单例
|
||
|
||
Args:
|
||
use_api: 是否使用 API
|
||
|
||
Returns:
|
||
EmbeddingGenerator 实例
|
||
"""
|
||
global _global_generator
|
||
if _global_generator is None:
|
||
_global_generator = EmbeddingGenerator(use_api=use_api)
|
||
return _global_generator
|