Files
Mofox-Core/src/memory_graph/utils/embeddings.py
明天好像没什么 26d22e5dd8 ruff
2025-11-19 23:36:06 +08:00

219 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
嵌入向量生成器:优先使用配置的 embedding API失败时跳过向量生成
"""
from __future__ import annotations
import numpy as np
from src.common.logger import get_logger
logger = get_logger(__name__)
class EmbeddingGenerator:
"""
嵌入向量生成器
策略:
1. 优先使用配置的 embedding API通过 LLMRequest
2. 如果 API 不可用或失败,跳过向量生成,返回 None 或零向量
3. 不再使用本地 sentence-transformers 模型,避免向量维度不匹配
优点:
- 完全避免本地运算负载
- 避免向量维度不匹配问题
- 简化错误处理逻辑
- 保持与现有系统的一致性
"""
def __init__(
self,
use_api: bool = True,
):
"""
初始化嵌入生成器
Args:
use_api: 是否使用 API默认 True
"""
self.use_api = use_api
# API 相关
self._llm_request = None
self._api_available = False
self._api_dimension = None
async def _initialize_api(self):
"""初始化 embedding API"""
if self._api_available:
return
try:
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
embedding_config = model_config.model_task_config.embedding
self._llm_request = LLMRequest(
model_set=embedding_config,
request_type="memory_graph.embedding"
)
# 获取嵌入维度
if hasattr(embedding_config, "embedding_dimension") and embedding_config.embedding_dimension:
self._api_dimension = embedding_config.embedding_dimension
self._api_available = True
logger.info(f"✅ Embedding API 初始化成功 (维度: {self._api_dimension})")
except Exception as e:
logger.warning(f"⚠️ Embedding API 初始化失败: {e}")
self._api_available = False
async def generate(self, text: str) -> np.ndarray | None:
"""
生成单个文本的嵌入向量
策略:
1. 使用 API 生成向量
2. API 失败则返回 None跳过向量生成
Args:
text: 输入文本
Returns:
嵌入向量,失败时返回 None
"""
if not text or not text.strip():
logger.debug("输入文本为空,返回 None")
return None
try:
# 使用 API 生成嵌入
if self.use_api:
embedding = await self._generate_with_api(text)
if embedding is not None:
return embedding
# API 失败,记录日志并返回 None
logger.debug(f"⚠️ 嵌入生成失败,跳过: {text[:30]}...")
return None
except Exception as e:
logger.error(f"❌ 嵌入生成异常: {e}", exc_info=True)
return None
async def _generate_with_api(self, text: str) -> np.ndarray | None:
"""使用 API 生成嵌入"""
try:
# 初始化 API
if not self._api_available:
await self._initialize_api()
if not self._api_available or not self._llm_request:
return None
# 调用 API
embedding_list, model_name = await self._llm_request.get_embedding(text)
if embedding_list and len(embedding_list) > 0:
embedding = np.array(embedding_list, dtype=np.float32)
logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})")
return embedding
return None
except Exception as e:
logger.debug(f"API 嵌入生成失败: {e}")
return None
def _get_dimension(self) -> int:
"""获取嵌入维度"""
# 优先使用 API 维度
if self._api_dimension:
return self._api_dimension
raise ValueError("无法确定嵌入向量维度,请确保已正确配置 embedding API")
async def generate_batch(self, texts: list[str]) -> list[np.ndarray | None]:
"""
批量生成嵌入向量
Args:
texts: 文本列表
Returns:
嵌入向量列表,失败的项目为 None
"""
if not texts:
return []
try:
# 过滤空文本
valid_texts = [t for t in texts if t and t.strip()]
if not valid_texts:
logger.debug("所有文本为空,返回 None 列表")
return [None for _ in texts]
# 使用 API 批量生成(如果可用)
if self.use_api:
results = await self._generate_batch_with_api(valid_texts)
if results:
return results
# 回退到逐个生成
results = []
for text in valid_texts:
embedding = await self.generate(text)
results.append(embedding)
success_count = sum(1 for r in results if r is not None)
logger.debug(f"✅ 批量生成嵌入: {success_count}/{len(texts)} 个成功")
return results
except Exception as e:
logger.error(f"❌ 批量嵌入生成失败: {e}", exc_info=True)
return [None for _ in texts]
async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray | None] | None:
"""使用 API 批量生成"""
try:
# 对于大多数 API批量调用就是多次单独调用
# 这里保持简单,逐个调用
results = []
for text in texts:
embedding = await self._generate_with_api(text)
results.append(embedding) # 失败的项目为 None不中断整个批量处理
return results
except Exception as e:
logger.debug(f"API 批量生成失败: {e}")
return None
def get_embedding_dimension(self) -> int:
"""获取嵌入向量维度"""
return self._get_dimension()
# 全局单例
_global_generator: EmbeddingGenerator | None = None
def get_embedding_generator(
use_api: bool = True,
) -> EmbeddingGenerator:
"""
获取全局嵌入生成器单例
Args:
use_api: 是否使用 API
Returns:
EmbeddingGenerator 实例
"""
global _global_generator
if _global_generator is None:
_global_generator = EmbeddingGenerator(use_api=use_api)
return _global_generator