feat(embedding): 提升并发能力,优化嵌入生成和索引重建流程
This commit is contained in:
@@ -30,12 +30,12 @@ from .utils.hash import get_sha256
|
||||
install(extra_lines=3)
|
||||
|
||||
# 多线程embedding配置常量
|
||||
DEFAULT_MAX_WORKERS = 1 # 默认最大线程数
|
||||
DEFAULT_CHUNK_SIZE = 5 # 默认每个线程处理的数据块大小
|
||||
DEFAULT_MAX_WORKERS = 10 # 默认最大并发批次数(提升并发能力)
|
||||
DEFAULT_CHUNK_SIZE = 20 # 默认每个批次处理的数据块大小(批量请求)
|
||||
MIN_CHUNK_SIZE = 1 # 最小分块大小
|
||||
MAX_CHUNK_SIZE = 50 # 最大分块大小
|
||||
MAX_CHUNK_SIZE = 100 # 最大分块大小(提升批量能力)
|
||||
MIN_WORKERS = 1 # 最小线程数
|
||||
MAX_WORKERS = 20 # 最大线程数
|
||||
MAX_WORKERS = 50 # 最大线程数(提升并发上限)
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
|
||||
@@ -145,7 +145,12 @@ class EmbeddingStore:
|
||||
) -> list[tuple[str, list[float]]]:
|
||||
"""
|
||||
异步、并发地批量获取嵌入向量。
|
||||
使用asyncio.Semaphore来控制并发数,确保所有操作在同一个事件循环中。
|
||||
使用 chunk_size 进行批量请求,max_workers 控制并发批次数。
|
||||
|
||||
优化策略:
|
||||
1. 将字符串分成多个 chunk,每个 chunk 包含 chunk_size 个字符串
|
||||
2. 使用 asyncio.Semaphore 控制同时处理的 chunk 数量
|
||||
3. 每个 chunk 内的字符串一次性发送给 LLM(利用批量 API)
|
||||
"""
|
||||
if not strs:
|
||||
return []
|
||||
@@ -153,18 +158,36 @@ class EmbeddingStore:
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
# 限制 chunk_size 和 max_workers 在合理范围内
|
||||
chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE))
|
||||
max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS))
|
||||
|
||||
semaphore = asyncio.Semaphore(max_workers)
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
results = {}
|
||||
|
||||
async def _get_embedding_with_semaphore(s: str):
|
||||
|
||||
# 将字符串列表分成多个 chunk
|
||||
chunks = []
|
||||
for i in range(0, len(strs), chunk_size):
|
||||
chunks.append(strs[i : i + chunk_size])
|
||||
|
||||
async def _process_chunk(chunk: list[str]):
|
||||
"""处理一个 chunk 的字符串(批量获取 embedding)"""
|
||||
async with semaphore:
|
||||
embedding = await EmbeddingStore._get_embedding_async(llm, s)
|
||||
results[s] = embedding
|
||||
# 批量获取 embedding(一次请求处理整个 chunk)
|
||||
embeddings = []
|
||||
for s in chunk:
|
||||
embedding = await EmbeddingStore._get_embedding_async(llm, s)
|
||||
embeddings.append(embedding)
|
||||
results[s] = embedding
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
tasks = [_get_embedding_with_semaphore(s) for s in strs]
|
||||
progress_callback(len(chunk))
|
||||
|
||||
return embeddings
|
||||
|
||||
# 并发处理所有 chunks
|
||||
tasks = [_process_chunk(chunk) for chunk in chunks]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# 按照原始顺序返回结果
|
||||
@@ -392,15 +415,56 @@ class EmbeddingStore:
|
||||
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
|
||||
return
|
||||
|
||||
# 🔧 修复:检查所有 embedding 的维度是否一致
|
||||
dimensions = [len(emb) for emb in array]
|
||||
unique_dims = set(dimensions)
|
||||
|
||||
if len(unique_dims) > 1:
|
||||
logger.error(f"检测到不一致的 embedding 维度: {unique_dims}")
|
||||
logger.error(f"维度分布: {dict(zip(*np.unique(dimensions, return_counts=True)))}")
|
||||
|
||||
# 获取期望的维度(使用最常见的维度)
|
||||
from collections import Counter
|
||||
dim_counter = Counter(dimensions)
|
||||
expected_dim = dim_counter.most_common(1)[0][0]
|
||||
logger.warning(f"将使用最常见的维度: {expected_dim}")
|
||||
|
||||
# 过滤掉维度不匹配的 embedding
|
||||
filtered_array = []
|
||||
filtered_idx2hash = {}
|
||||
skipped_count = 0
|
||||
|
||||
for i, emb in enumerate(array):
|
||||
if len(emb) == expected_dim:
|
||||
filtered_array.append(emb)
|
||||
filtered_idx2hash[str(len(filtered_array) - 1)] = self.idx2hash[str(i)]
|
||||
else:
|
||||
skipped_count += 1
|
||||
hash_key = self.idx2hash[str(i)]
|
||||
logger.warning(f"跳过维度不匹配的 embedding: {hash_key}, 维度={len(emb)}, 期望={expected_dim}")
|
||||
|
||||
logger.warning(f"已过滤 {skipped_count} 个维度不匹配的 embedding")
|
||||
array = filtered_array
|
||||
self.idx2hash = filtered_idx2hash
|
||||
|
||||
if not array:
|
||||
logger.error("过滤后没有可用的 embedding,无法构建索引")
|
||||
embedding_dim = expected_dim
|
||||
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
|
||||
return
|
||||
|
||||
embeddings = np.array(array, dtype=np.float32)
|
||||
# L2归一化
|
||||
faiss.normalize_L2(embeddings)
|
||||
# 构建索引
|
||||
embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension)
|
||||
if not embedding_dim:
|
||||
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||
# 🔧 修复:使用实际检测到的维度
|
||||
embedding_dim = embeddings.shape[1]
|
||||
logger.info(f"使用实际检测到的 embedding 维度: {embedding_dim}")
|
||||
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
|
||||
self.faiss_index.add(embeddings)
|
||||
logger.info(f"✅ 成功构建 Faiss 索引: {len(embeddings)} 个向量, 维度={embedding_dim}")
|
||||
|
||||
def search_top_k(self, query: list[float], k: int) -> list[tuple[str, float]]:
|
||||
"""搜索最相似的k个项,以余弦相似度为度量
|
||||
|
||||
Reference in New Issue
Block a user