fix: 修复代码质量问题 - 更正异常处理和导入语句
Co-authored-by: Windpicker-owo <221029311+Windpicker-owo@users.noreply.github.com>
This commit is contained in:
committed by
Windpicker-owo
parent
ea724eb5d4
commit
f8e58ef229
@@ -5,8 +5,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -18,12 +16,12 @@ logger = get_logger(__name__)
|
||||
class EmbeddingGenerator:
|
||||
"""
|
||||
嵌入向量生成器
|
||||
|
||||
|
||||
策略:
|
||||
1. 优先使用配置的 embedding API(通过 LLMRequest)
|
||||
2. 如果 API 不可用,回退到本地 sentence-transformers
|
||||
3. 如果 sentence-transformers 未安装,使用随机向量(仅测试)
|
||||
|
||||
|
||||
优点:
|
||||
- 降低本地运算负载
|
||||
- 即使未安装 sentence-transformers 也可正常运行
|
||||
@@ -37,19 +35,19 @@ class EmbeddingGenerator:
|
||||
):
|
||||
"""
|
||||
初始化嵌入生成器
|
||||
|
||||
|
||||
Args:
|
||||
use_api: 是否优先使用 API(默认 True)
|
||||
fallback_model_name: 回退本地模型名称
|
||||
"""
|
||||
self.use_api = use_api
|
||||
self.fallback_model_name = fallback_model_name
|
||||
|
||||
|
||||
# API 相关
|
||||
self._llm_request = None
|
||||
self._api_available = False
|
||||
self._api_dimension = None
|
||||
|
||||
|
||||
# 本地模型相关
|
||||
self._local_model = None
|
||||
self._local_model_loaded = False
|
||||
@@ -58,24 +56,24 @@ class EmbeddingGenerator:
|
||||
"""初始化 embedding API"""
|
||||
if self._api_available:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
|
||||
embedding_config = model_config.model_task_config.embedding
|
||||
self._llm_request = LLMRequest(
|
||||
model_set=embedding_config,
|
||||
request_type="memory_graph.embedding"
|
||||
)
|
||||
|
||||
|
||||
# 获取嵌入维度
|
||||
if hasattr(embedding_config, "embedding_dimension") and embedding_config.embedding_dimension:
|
||||
self._api_dimension = embedding_config.embedding_dimension
|
||||
|
||||
|
||||
self._api_available = True
|
||||
logger.info(f"✅ Embedding API 初始化成功 (维度: {self._api_dimension})")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ Embedding API 初始化失败: {e}")
|
||||
self._api_available = False
|
||||
@@ -103,15 +101,15 @@ class EmbeddingGenerator:
|
||||
async def generate(self, text: str) -> np.ndarray:
|
||||
"""
|
||||
生成单个文本的嵌入向量
|
||||
|
||||
|
||||
策略:
|
||||
1. 优先使用 API
|
||||
2. API 失败则使用本地模型
|
||||
3. 本地模型不可用则使用随机向量
|
||||
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
|
||||
Returns:
|
||||
嵌入向量
|
||||
"""
|
||||
@@ -126,12 +124,12 @@ class EmbeddingGenerator:
|
||||
embedding = await self._generate_with_api(text)
|
||||
if embedding is not None:
|
||||
return embedding
|
||||
|
||||
|
||||
# 策略 2: 使用本地模型
|
||||
embedding = await self._generate_with_local_model(text)
|
||||
if embedding is not None:
|
||||
return embedding
|
||||
|
||||
|
||||
# 策略 3: 随机向量(仅测试)
|
||||
logger.warning(f"⚠️ 所有嵌入策略失败,使用随机向量: {text[:30]}...")
|
||||
dim = self._get_dimension()
|
||||
@@ -142,47 +140,47 @@ class EmbeddingGenerator:
|
||||
dim = self._get_dimension()
|
||||
return np.random.rand(dim).astype(np.float32)
|
||||
|
||||
async def _generate_with_api(self, text: str) -> Optional[np.ndarray]:
|
||||
async def _generate_with_api(self, text: str) -> np.ndarray | None:
|
||||
"""使用 API 生成嵌入"""
|
||||
try:
|
||||
# 初始化 API
|
||||
if not self._api_available:
|
||||
await self._initialize_api()
|
||||
|
||||
|
||||
if not self._api_available or not self._llm_request:
|
||||
return None
|
||||
|
||||
|
||||
# 调用 API
|
||||
embedding_list, model_name = await self._llm_request.get_embedding(text)
|
||||
|
||||
|
||||
if embedding_list and len(embedding_list) > 0:
|
||||
embedding = np.array(embedding_list, dtype=np.float32)
|
||||
logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})")
|
||||
return embedding
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"API 嵌入生成失败: {e}")
|
||||
return None
|
||||
|
||||
async def _generate_with_local_model(self, text: str) -> Optional[np.ndarray]:
|
||||
async def _generate_with_local_model(self, text: str) -> np.ndarray | None:
|
||||
"""使用本地模型生成嵌入"""
|
||||
try:
|
||||
# 加载本地模型
|
||||
if not self._local_model_loaded:
|
||||
self._load_local_model()
|
||||
|
||||
|
||||
if not self._local_model_loaded or not self._local_model:
|
||||
return None
|
||||
|
||||
|
||||
# 在线程池中运行
|
||||
loop = asyncio.get_event_loop()
|
||||
embedding = await loop.run_in_executor(None, self._encode_single_local, text)
|
||||
|
||||
|
||||
logger.debug(f"💻 本地生成嵌入: {text[:30]}... -> {len(embedding)}维")
|
||||
return embedding
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"本地模型嵌入生成失败: {e}")
|
||||
return None
|
||||
@@ -199,24 +197,24 @@ class EmbeddingGenerator:
|
||||
# 优先使用 API 维度
|
||||
if self._api_dimension:
|
||||
return self._api_dimension
|
||||
|
||||
|
||||
# 其次使用本地模型维度
|
||||
if self._local_model_loaded and self._local_model:
|
||||
try:
|
||||
return self._local_model.get_sentence_embedding_dimension()
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# 默认 384(sentence-transformers 常用维度)
|
||||
return 384
|
||||
|
||||
async def generate_batch(self, texts: List[str]) -> List[np.ndarray]:
|
||||
async def generate_batch(self, texts: list[str]) -> list[np.ndarray]:
|
||||
"""
|
||||
批量生成嵌入向量
|
||||
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
|
||||
|
||||
Returns:
|
||||
嵌入向量列表
|
||||
"""
|
||||
@@ -236,13 +234,13 @@ class EmbeddingGenerator:
|
||||
results = await self._generate_batch_with_api(valid_texts)
|
||||
if results:
|
||||
return results
|
||||
|
||||
|
||||
# 回退到逐个生成
|
||||
results = []
|
||||
for text in valid_texts:
|
||||
embedding = await self.generate(text)
|
||||
results.append(embedding)
|
||||
|
||||
|
||||
logger.info(f"✅ 批量生成嵌入: {len(texts)} 个文本")
|
||||
return results
|
||||
|
||||
@@ -251,7 +249,7 @@ class EmbeddingGenerator:
|
||||
dim = self._get_dimension()
|
||||
return [np.random.rand(dim).astype(np.float32) for _ in texts]
|
||||
|
||||
async def _generate_batch_with_api(self, texts: List[str]) -> Optional[List[np.ndarray]]:
|
||||
async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray] | None:
|
||||
"""使用 API 批量生成"""
|
||||
try:
|
||||
# 对于大多数 API,批量调用就是多次单独调用
|
||||
@@ -273,7 +271,7 @@ class EmbeddingGenerator:
|
||||
|
||||
|
||||
# 全局单例
|
||||
_global_generator: Optional[EmbeddingGenerator] = None
|
||||
_global_generator: EmbeddingGenerator | None = None
|
||||
|
||||
|
||||
def get_embedding_generator(
|
||||
@@ -282,11 +280,11 @@ def get_embedding_generator(
|
||||
) -> EmbeddingGenerator:
|
||||
"""
|
||||
获取全局嵌入生成器单例
|
||||
|
||||
|
||||
Args:
|
||||
use_api: 是否优先使用 API
|
||||
fallback_model_name: 回退本地模型名称
|
||||
|
||||
|
||||
Returns:
|
||||
EmbeddingGenerator 实例
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user