refactor: 优化嵌入生成逻辑,失败时返回 None,简化错误处理;更新调度器任务管理逻辑
This commit is contained in:
@@ -417,8 +417,7 @@ class SchedulerDispatcher:
|
|||||||
stream_id: 流ID
|
stream_id: 流ID
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 从追踪中移除(因为是一次性任务)
|
old_schedule_id = self.stream_schedules.get(stream_id)
|
||||||
old_schedule_id = self.stream_schedules.pop(stream_id, None)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"⏰ Schedule 触发: 流={stream_id[:8]}..., "
|
f"⏰ Schedule 触发: 流={stream_id[:8]}..., "
|
||||||
@@ -445,14 +444,8 @@ class SchedulerDispatcher:
|
|||||||
if not success:
|
if not success:
|
||||||
self.stats["total_failures"] += 1
|
self.stats["total_failures"] += 1
|
||||||
|
|
||||||
# 处理完成后,检查是否需要创建新的 schedule
|
self.stream_schedules.pop(stream_id, None)
|
||||||
if stream_id in self.stream_schedules:
|
|
||||||
logger.info(
|
|
||||||
f"⚠️ 处理完成时发现已有新 schedule: 流={stream_id[:8]}..., "
|
|
||||||
f"可能是打断创建的,跳过创建新 schedule"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# 检查缓存中是否有待处理的消息
|
# 检查缓存中是否有待处理的消息
|
||||||
from src.chat.message_manager.message_manager import message_manager
|
from src.chat.message_manager.message_manager import message_manager
|
||||||
|
|
||||||
|
|||||||
@@ -318,7 +318,7 @@ class MemoryBuilder:
|
|||||||
|
|
||||||
return nodes, edges
|
return nodes, edges
|
||||||
|
|
||||||
async def _generate_embedding(self, text: str) -> np.ndarray:
|
async def _generate_embedding(self, text: str) -> np.ndarray | None:
|
||||||
"""
|
"""
|
||||||
生成文本的嵌入向量
|
生成文本的嵌入向量
|
||||||
|
|
||||||
@@ -326,17 +326,17 @@ class MemoryBuilder:
|
|||||||
text: 文本内容
|
text: 文本内容
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
嵌入向量
|
嵌入向量,失败时返回 None
|
||||||
"""
|
"""
|
||||||
if self.embedding_generator:
|
if self.embedding_generator:
|
||||||
try:
|
try:
|
||||||
embedding = await self.embedding_generator.generate(text)
|
embedding = await self.embedding_generator.generate(text)
|
||||||
return embedding
|
return embedding
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"嵌入生成失败,使用随机向量: {e}")
|
logger.warning(f"嵌入生成失败,跳过: {e}")
|
||||||
|
|
||||||
# 回退:生成随机向量(仅用于测试)
|
# 嵌入生成失败,返回 None
|
||||||
return np.random.rand(384).astype(np.float32)
|
return None
|
||||||
|
|
||||||
async def _find_existing_node(
|
async def _find_existing_node(
|
||||||
self, content: str, node_type: NodeType
|
self, content: str, node_type: NodeType
|
||||||
@@ -367,7 +367,7 @@ class MemoryBuilder:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def _find_similar_topic(
|
async def _find_similar_topic(
|
||||||
self, content: str, embedding: np.ndarray
|
self, content: str, embedding: np.ndarray | None
|
||||||
) -> MemoryNode | None:
|
) -> MemoryNode | None:
|
||||||
"""
|
"""
|
||||||
查找相似的主题节点(基于语义相似度)
|
查找相似的主题节点(基于语义相似度)
|
||||||
@@ -379,6 +379,11 @@ class MemoryBuilder:
|
|||||||
Returns:
|
Returns:
|
||||||
相似节点,如果没有则返回 None
|
相似节点,如果没有则返回 None
|
||||||
"""
|
"""
|
||||||
|
# 如果嵌入为空,无法进行相似性搜索
|
||||||
|
if embedding is None:
|
||||||
|
logger.debug("嵌入向量为空,跳过相似节点搜索")
|
||||||
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 搜索相似节点(阈值 0.95)
|
# 搜索相似节点(阈值 0.95)
|
||||||
similar_nodes = await self.vector_store.search_similar_nodes(
|
similar_nodes = await self.vector_store.search_similar_nodes(
|
||||||
@@ -412,7 +417,7 @@ class MemoryBuilder:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def _find_similar_object(
|
async def _find_similar_object(
|
||||||
self, content: str, embedding: np.ndarray
|
self, content: str, embedding: np.ndarray | None
|
||||||
) -> MemoryNode | None:
|
) -> MemoryNode | None:
|
||||||
"""
|
"""
|
||||||
查找相似的客体节点(基于语义相似度)
|
查找相似的客体节点(基于语义相似度)
|
||||||
@@ -424,6 +429,11 @@ class MemoryBuilder:
|
|||||||
Returns:
|
Returns:
|
||||||
相似节点,如果没有则返回 None
|
相似节点,如果没有则返回 None
|
||||||
"""
|
"""
|
||||||
|
# 如果嵌入为空,无法进行相似性搜索
|
||||||
|
if embedding is None:
|
||||||
|
logger.debug("嵌入向量为空,跳过相似节点搜索")
|
||||||
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 搜索相似节点(阈值 0.95)
|
# 搜索相似节点(阈值 0.95)
|
||||||
similar_nodes = await self.vector_store.search_similar_nodes(
|
similar_nodes = await self.vector_store.search_similar_nodes(
|
||||||
|
|||||||
@@ -506,16 +506,18 @@ class MemoryTools:
|
|||||||
try:
|
try:
|
||||||
query_embedding = await self.builder.embedding_generator.generate(query)
|
query_embedding = await self.builder.embedding_generator.generate(query)
|
||||||
|
|
||||||
# 使用共享的图扩展工具函数
|
# 只有在嵌入生成成功时才进行语义扩展
|
||||||
expanded_results = await expand_memories_with_semantic_filter(
|
if query_embedding is not None:
|
||||||
graph_store=self.graph_store,
|
# 使用共享的图扩展工具函数
|
||||||
vector_store=self.vector_store,
|
expanded_results = await expand_memories_with_semantic_filter(
|
||||||
initial_memory_ids=list(initial_memory_ids),
|
graph_store=self.graph_store,
|
||||||
query_embedding=query_embedding,
|
vector_store=self.vector_store,
|
||||||
max_depth=expand_depth,
|
initial_memory_ids=list(initial_memory_ids),
|
||||||
semantic_threshold=self.expand_semantic_threshold, # 使用配置的阈值
|
query_embedding=query_embedding,
|
||||||
max_expanded=top_k * 2
|
max_depth=expand_depth,
|
||||||
)
|
semantic_threshold=self.expand_semantic_threshold, # 使用配置的阈值
|
||||||
|
max_expanded=top_k * 2
|
||||||
|
)
|
||||||
|
|
||||||
# 合并扩展结果
|
# 合并扩展结果
|
||||||
expanded_memory_scores.update(dict(expanded_results))
|
expanded_memory_scores.update(dict(expanded_results))
|
||||||
@@ -714,12 +716,14 @@ class MemoryTools:
|
|||||||
相似节点列表 [(node_id, similarity, metadata), ...]
|
相似节点列表 [(node_id, similarity, metadata), ...]
|
||||||
"""
|
"""
|
||||||
# 生成查询嵌入
|
# 生成查询嵌入
|
||||||
|
query_embedding = None
|
||||||
if self.builder.embedding_generator:
|
if self.builder.embedding_generator:
|
||||||
query_embedding = await self.builder.embedding_generator.generate(query)
|
query_embedding = await self.builder.embedding_generator.generate(query)
|
||||||
else:
|
|
||||||
logger.warning("未配置嵌入生成器,使用随机向量")
|
# 如果嵌入生成失败,无法进行向量搜索
|
||||||
import numpy as np
|
if query_embedding is None:
|
||||||
query_embedding = np.random.rand(384).astype(np.float32)
|
logger.warning("嵌入生成失败,跳过节点搜索")
|
||||||
|
return []
|
||||||
|
|
||||||
# 向量搜索
|
# 向量搜索
|
||||||
similar_nodes = await self.vector_store.search_similar_nodes(
|
similar_nodes = await self.vector_store.search_similar_nodes(
|
||||||
@@ -766,8 +770,14 @@ class MemoryTools:
|
|||||||
|
|
||||||
for sub_query, weight in multi_queries:
|
for sub_query, weight in multi_queries:
|
||||||
embedding = await self.builder.embedding_generator.generate(sub_query)
|
embedding = await self.builder.embedding_generator.generate(sub_query)
|
||||||
query_embeddings.append(embedding)
|
if embedding is not None:
|
||||||
query_weights.append(weight)
|
query_embeddings.append(embedding)
|
||||||
|
query_weights.append(weight)
|
||||||
|
|
||||||
|
# 如果所有嵌入都生成失败,回退到单查询模式
|
||||||
|
if not query_embeddings:
|
||||||
|
logger.warning("所有查询嵌入生成失败,回退到单查询模式")
|
||||||
|
return await self._single_query_search(query, top_k)
|
||||||
|
|
||||||
# 3. 多查询融合搜索
|
# 3. 多查询融合搜索
|
||||||
similar_nodes = await self.vector_store.search_with_multiple_queries(
|
similar_nodes = await self.vector_store.search_with_multiple_queries(
|
||||||
@@ -806,11 +816,14 @@ class MemoryTools:
|
|||||||
找到的记忆,如果没有则返回 None
|
找到的记忆,如果没有则返回 None
|
||||||
"""
|
"""
|
||||||
# 使用语义搜索查找最相关的记忆
|
# 使用语义搜索查找最相关的记忆
|
||||||
|
query_embedding = None
|
||||||
if self.builder.embedding_generator:
|
if self.builder.embedding_generator:
|
||||||
query_embedding = await self.builder.embedding_generator.generate(description)
|
query_embedding = await self.builder.embedding_generator.generate(description)
|
||||||
else:
|
|
||||||
import numpy as np
|
# 如果嵌入生成失败,无法进行语义搜索
|
||||||
query_embedding = np.random.rand(384).astype(np.float32)
|
if query_embedding is None:
|
||||||
|
logger.debug("嵌入生成失败,跳过描述搜索")
|
||||||
|
return None
|
||||||
|
|
||||||
# 搜索相似节点
|
# 搜索相似节点
|
||||||
similar_nodes = await self.vector_store.search_similar_nodes(
|
similar_nodes = await self.vector_store.search_similar_nodes(
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
嵌入向量生成器:优先使用配置的 embedding API,sentence-transformers 作为备选
|
嵌入向量生成器:优先使用配置的 embedding API,失败时跳过向量生成
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -19,39 +19,33 @@ class EmbeddingGenerator:
|
|||||||
|
|
||||||
策略:
|
策略:
|
||||||
1. 优先使用配置的 embedding API(通过 LLMRequest)
|
1. 优先使用配置的 embedding API(通过 LLMRequest)
|
||||||
2. 如果 API 不可用,回退到本地 sentence-transformers
|
2. 如果 API 不可用或失败,跳过向量生成,返回 None 或零向量
|
||||||
3. 如果 sentence-transformers 未安装,使用随机向量(仅测试)
|
3. 不再使用本地 sentence-transformers 模型,避免向量维度不匹配
|
||||||
|
|
||||||
优点:
|
优点:
|
||||||
- 降低本地运算负载
|
- 完全避免本地运算负载
|
||||||
- 即使未安装 sentence-transformers 也可正常运行
|
- 避免向量维度不匹配问题
|
||||||
|
- 简化错误处理逻辑
|
||||||
- 保持与现有系统的一致性
|
- 保持与现有系统的一致性
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
use_api: bool = True,
|
use_api: bool = True,
|
||||||
fallback_model_name: str = "paraphrase-multilingual-MiniLM-L12-v2",
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化嵌入生成器
|
初始化嵌入生成器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
use_api: 是否优先使用 API(默认 True)
|
use_api: 是否使用 API(默认 True)
|
||||||
fallback_model_name: 回退本地模型名称
|
|
||||||
"""
|
"""
|
||||||
self.use_api = use_api
|
self.use_api = use_api
|
||||||
self.fallback_model_name = fallback_model_name
|
|
||||||
|
|
||||||
# API 相关
|
# API 相关
|
||||||
self._llm_request = None
|
self._llm_request = None
|
||||||
self._api_available = False
|
self._api_available = False
|
||||||
self._api_dimension = None
|
self._api_dimension = None
|
||||||
|
|
||||||
# 本地模型相关
|
|
||||||
self._local_model = None
|
|
||||||
self._local_model_loaded = False
|
|
||||||
|
|
||||||
async def _initialize_api(self):
|
async def _initialize_api(self):
|
||||||
"""初始化 embedding API"""
|
"""初始化 embedding API"""
|
||||||
if self._api_available:
|
if self._api_available:
|
||||||
@@ -78,67 +72,39 @@ class EmbeddingGenerator:
|
|||||||
logger.warning(f"⚠️ Embedding API 初始化失败: {e}")
|
logger.warning(f"⚠️ Embedding API 初始化失败: {e}")
|
||||||
self._api_available = False
|
self._api_available = False
|
||||||
|
|
||||||
def _load_local_model(self):
|
|
||||||
"""延迟加载本地模型"""
|
async def generate(self, text: str) -> np.ndarray | None:
|
||||||
if not self._local_model_loaded:
|
|
||||||
try:
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
|
|
||||||
logger.info(f"📦 加载本地嵌入模型: {self.fallback_model_name}")
|
|
||||||
self._local_model = SentenceTransformer(self.fallback_model_name)
|
|
||||||
self._local_model_loaded = True
|
|
||||||
logger.info("✅ 本地嵌入模型加载成功")
|
|
||||||
except ImportError:
|
|
||||||
logger.warning(
|
|
||||||
"⚠️ sentence-transformers 未安装,将使用随机向量(仅测试用)\n"
|
|
||||||
" 安装方法: pip install sentence-transformers"
|
|
||||||
)
|
|
||||||
self._local_model_loaded = False
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"⚠️ 本地模型加载失败: {e}")
|
|
||||||
self._local_model_loaded = False
|
|
||||||
|
|
||||||
async def generate(self, text: str) -> np.ndarray:
|
|
||||||
"""
|
"""
|
||||||
生成单个文本的嵌入向量
|
生成单个文本的嵌入向量
|
||||||
|
|
||||||
策略:
|
策略:
|
||||||
1. 优先使用 API
|
1. 使用 API 生成向量
|
||||||
2. API 失败则使用本地模型
|
2. API 失败则返回 None,跳过向量生成
|
||||||
3. 本地模型不可用则使用随机向量
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: 输入文本
|
text: 输入文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
嵌入向量
|
嵌入向量,失败时返回 None
|
||||||
"""
|
"""
|
||||||
if not text or not text.strip():
|
if not text or not text.strip():
|
||||||
logger.warning("输入文本为空,返回零向量")
|
logger.debug("输入文本为空,返回 None")
|
||||||
dim = self._get_dimension()
|
return None
|
||||||
return np.zeros(dim, dtype=np.float32)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 策略 1: 使用 API
|
# 使用 API 生成嵌入
|
||||||
if self.use_api:
|
if self.use_api:
|
||||||
embedding = await self._generate_with_api(text)
|
embedding = await self._generate_with_api(text)
|
||||||
if embedding is not None:
|
if embedding is not None:
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
# 策略 2: 使用本地模型
|
# API 失败,记录日志并返回 None
|
||||||
embedding = await self._generate_with_local_model(text)
|
logger.debug(f"⚠️ 嵌入生成失败,跳过: {text[:30]}...")
|
||||||
if embedding is not None:
|
return None
|
||||||
return embedding
|
|
||||||
|
|
||||||
# 策略 3: 随机向量(仅测试)
|
|
||||||
logger.warning(f"⚠️ 所有嵌入策略失败,使用随机向量: {text[:30]}...")
|
|
||||||
dim = self._get_dimension()
|
|
||||||
return np.random.rand(dim).astype(np.float32)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"❌ 嵌入生成失败: {e}", exc_info=True)
|
logger.error(f"❌ 嵌入生成异常: {e}", exc_info=True)
|
||||||
dim = self._get_dimension()
|
return None
|
||||||
return np.random.rand(dim).astype(np.float32)
|
|
||||||
|
|
||||||
async def _generate_with_api(self, text: str) -> np.ndarray | None:
|
async def _generate_with_api(self, text: str) -> np.ndarray | None:
|
||||||
"""使用 API 生成嵌入"""
|
"""使用 API 生成嵌入"""
|
||||||
@@ -164,51 +130,16 @@ class EmbeddingGenerator:
|
|||||||
logger.debug(f"API 嵌入生成失败: {e}")
|
logger.debug(f"API 嵌入生成失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _generate_with_local_model(self, text: str) -> np.ndarray | None:
|
|
||||||
"""使用本地模型生成嵌入"""
|
|
||||||
try:
|
|
||||||
# 加载本地模型
|
|
||||||
if not self._local_model_loaded:
|
|
||||||
self._load_local_model()
|
|
||||||
|
|
||||||
if not self._local_model_loaded or not self._local_model:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 在线程池中运行
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
embedding = await loop.run_in_executor(None, self._encode_single_local, text)
|
|
||||||
|
|
||||||
logger.debug(f"💻 本地生成嵌入: {text[:30]}... -> {len(embedding)}维")
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"本地模型嵌入生成失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _encode_single_local(self, text: str) -> np.ndarray:
|
|
||||||
"""同步编码单个文本(本地模型)"""
|
|
||||||
if self._local_model is None:
|
|
||||||
raise RuntimeError("本地模型未加载")
|
|
||||||
embedding = self._local_model.encode(text, convert_to_numpy=True) # type: ignore
|
|
||||||
return embedding.astype(np.float32)
|
|
||||||
|
|
||||||
def _get_dimension(self) -> int:
|
def _get_dimension(self) -> int:
|
||||||
"""获取嵌入维度"""
|
"""获取嵌入维度"""
|
||||||
# 优先使用 API 维度
|
# 优先使用 API 维度
|
||||||
if self._api_dimension:
|
if self._api_dimension:
|
||||||
return self._api_dimension
|
return self._api_dimension
|
||||||
|
|
||||||
# 其次使用本地模型维度
|
raise ValueError("无法确定嵌入向量维度,请确保已正确配置 embedding API")
|
||||||
if self._local_model_loaded and self._local_model:
|
|
||||||
try:
|
|
||||||
return self._local_model.get_sentence_embedding_dimension()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 默认 384(sentence-transformers 常用维度)
|
async def generate_batch(self, texts: list[str]) -> list[np.ndarray | None]:
|
||||||
return 384
|
|
||||||
|
|
||||||
async def generate_batch(self, texts: list[str]) -> list[np.ndarray]:
|
|
||||||
"""
|
"""
|
||||||
批量生成嵌入向量
|
批量生成嵌入向量
|
||||||
|
|
||||||
@@ -216,7 +147,7 @@ class EmbeddingGenerator:
|
|||||||
texts: 文本列表
|
texts: 文本列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
嵌入向量列表
|
嵌入向量列表,失败的项目为 None
|
||||||
"""
|
"""
|
||||||
if not texts:
|
if not texts:
|
||||||
return []
|
return []
|
||||||
@@ -225,9 +156,8 @@ class EmbeddingGenerator:
|
|||||||
# 过滤空文本
|
# 过滤空文本
|
||||||
valid_texts = [t for t in texts if t and t.strip()]
|
valid_texts = [t for t in texts if t and t.strip()]
|
||||||
if not valid_texts:
|
if not valid_texts:
|
||||||
logger.warning("所有文本为空,返回零向量列表")
|
logger.debug("所有文本为空,返回 None 列表")
|
||||||
dim = self._get_dimension()
|
return [None for _ in texts]
|
||||||
return [np.zeros(dim, dtype=np.float32) for _ in texts]
|
|
||||||
|
|
||||||
# 使用 API 批量生成(如果可用)
|
# 使用 API 批量生成(如果可用)
|
||||||
if self.use_api:
|
if self.use_api:
|
||||||
@@ -241,15 +171,15 @@ class EmbeddingGenerator:
|
|||||||
embedding = await self.generate(text)
|
embedding = await self.generate(text)
|
||||||
results.append(embedding)
|
results.append(embedding)
|
||||||
|
|
||||||
logger.info(f"✅ 批量生成嵌入: {len(texts)} 个文本")
|
success_count = sum(1 for r in results if r is not None)
|
||||||
|
logger.debug(f"✅ 批量生成嵌入: {success_count}/{len(texts)} 个成功")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"❌ 批量嵌入生成失败: {e}", exc_info=True)
|
logger.error(f"❌ 批量嵌入生成失败: {e}", exc_info=True)
|
||||||
dim = self._get_dimension()
|
return [None for _ in texts]
|
||||||
return [np.random.rand(dim).astype(np.float32) for _ in texts]
|
|
||||||
|
|
||||||
async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray] | None:
|
async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray | None] | None:
|
||||||
"""使用 API 批量生成"""
|
"""使用 API 批量生成"""
|
||||||
try:
|
try:
|
||||||
# 对于大多数 API,批量调用就是多次单独调用
|
# 对于大多数 API,批量调用就是多次单独调用
|
||||||
@@ -257,9 +187,7 @@ class EmbeddingGenerator:
|
|||||||
results = []
|
results = []
|
||||||
for text in texts:
|
for text in texts:
|
||||||
embedding = await self._generate_with_api(text)
|
embedding = await self._generate_with_api(text)
|
||||||
if embedding is None:
|
results.append(embedding) # 失败的项目为 None,不中断整个批量处理
|
||||||
return None # 如果任何一个失败,返回 None 触发回退
|
|
||||||
results.append(embedding)
|
|
||||||
return results
|
return results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"API 批量生成失败: {e}")
|
logger.debug(f"API 批量生成失败: {e}")
|
||||||
@@ -276,22 +204,17 @@ _global_generator: EmbeddingGenerator | None = None
|
|||||||
|
|
||||||
def get_embedding_generator(
|
def get_embedding_generator(
|
||||||
use_api: bool = True,
|
use_api: bool = True,
|
||||||
fallback_model_name: str = "paraphrase-multilingual-MiniLM-L12-v2",
|
|
||||||
) -> EmbeddingGenerator:
|
) -> EmbeddingGenerator:
|
||||||
"""
|
"""
|
||||||
获取全局嵌入生成器单例
|
获取全局嵌入生成器单例
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
use_api: 是否优先使用 API
|
use_api: 是否使用 API
|
||||||
fallback_model_name: 回退本地模型名称
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
EmbeddingGenerator 实例
|
EmbeddingGenerator 实例
|
||||||
"""
|
"""
|
||||||
global _global_generator
|
global _global_generator
|
||||||
if _global_generator is None:
|
if _global_generator is None:
|
||||||
_global_generator = EmbeddingGenerator(
|
_global_generator = EmbeddingGenerator(use_api=use_api)
|
||||||
use_api=use_api,
|
|
||||||
fallback_model_name=fallback_model_name
|
|
||||||
)
|
|
||||||
return _global_generator
|
return _global_generator
|
||||||
|
|||||||
Reference in New Issue
Block a user