feat(embedding): 提升并发能力,优化嵌入生成和索引重建流程

This commit is contained in:
Windpicker-owo
2025-11-09 22:30:21 +08:00
parent eb16508fb5
commit 6d727eeda9
2 changed files with 101 additions and 15 deletions

View File

@@ -297,7 +297,9 @@ async def import_data(openie_obj: OpenIE | None = None):
默认为 None. 默认为 None.
""" """
logger.info("--- 步骤 3: 开始数据导入 ---") logger.info("--- 步骤 3: 开始数据导入 ---")
embed_manager, kg_manager = EmbeddingManager(), KGManager() # 使用更高的并发参数以加速 embedding 生成
# max_workers: 并发批次数chunk_size: 每批次处理的字符串数
embed_manager, kg_manager = EmbeddingManager(max_workers=20, chunk_size=30), KGManager()
logger.info("正在加载现有的 Embedding 库...") logger.info("正在加载现有的 Embedding 库...")
try: try:
@@ -374,6 +376,23 @@ def import_from_specific_file():
# --- 主函数 --- # --- 主函数 ---
def rebuild_faiss_only():
"""仅重建 FAISS 索引,不重新导入数据"""
logger.info("--- 重建 FAISS 索引 ---")
# 重建索引不需要并发参数(不涉及 embedding 生成)
embed_manager = EmbeddingManager()
logger.info("正在加载现有的 Embedding 库...")
try:
embed_manager.load_from_file()
logger.info("开始重建 FAISS 索引...")
embed_manager.rebuild_faiss_index()
embed_manager.save_to_file()
logger.info("✅ FAISS 索引重建完成!")
except Exception as e:
logger.error(f"重建 FAISS 索引时发生错误: {e}", exc_info=True)
def main(): def main():
# 使用 os.path.relpath 创建相对于项目根目录的友好路径 # 使用 os.path.relpath 创建相对于项目根目录的友好路径
raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, "..")) raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, ".."))
@@ -386,9 +405,10 @@ def main():
print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3") print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3")
print("5. [指定导入] -> 从特定的 openie.json 文件导入知识") print("5. [指定导入] -> 从特定的 openie.json 文件导入知识")
print("6. [清理缓存] -> 删除所有已提取信息的缓存") print("6. [清理缓存] -> 删除所有已提取信息的缓存")
print("7. [重建索引] -> 仅重建 FAISS 索引(数据已导入时使用)")
print("0. [退出]") print("0. [退出]")
print("-" * 30) print("-" * 30)
choice = input("请输入你的选择 (0-6): ").strip() choice = input("请输入你的选择 (0-7): ").strip()
if choice == "1": if choice == "1":
preprocess_raw_data() preprocess_raw_data()
@@ -409,6 +429,8 @@ def main():
import_from_specific_file() import_from_specific_file()
elif choice == "6": elif choice == "6":
clear_cache() clear_cache()
elif choice == "7":
rebuild_faiss_only()
elif choice == "0": elif choice == "0":
sys.exit(0) sys.exit(0)
else: else:

View File

@@ -30,12 +30,12 @@ from .utils.hash import get_sha256
install(extra_lines=3) install(extra_lines=3)
# 多线程embedding配置常量 # 多线程embedding配置常量
DEFAULT_MAX_WORKERS = 1 # 默认最大线程数 DEFAULT_MAX_WORKERS = 10 # 默认最大并发批次数(提升并发能力)
DEFAULT_CHUNK_SIZE = 5 # 默认每个线程处理的数据块大小 DEFAULT_CHUNK_SIZE = 20 # 默认每个批次处理的数据块大小(批量请求)
MIN_CHUNK_SIZE = 1 # 最小分块大小 MIN_CHUNK_SIZE = 1 # 最小分块大小
MAX_CHUNK_SIZE = 50 # 最大分块大小 MAX_CHUNK_SIZE = 100 # 最大分块大小(提升批量能力)
MIN_WORKERS = 1 # 最小线程数 MIN_WORKERS = 1 # 最小线程数
MAX_WORKERS = 20 # 最大线程数 MAX_WORKERS = 50 # 最大线程数(提升并发上限)
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding") EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
@@ -145,7 +145,12 @@ class EmbeddingStore:
) -> list[tuple[str, list[float]]]: ) -> list[tuple[str, list[float]]]:
""" """
异步、并发地批量获取嵌入向量。 异步、并发地批量获取嵌入向量。
使用asyncio.Semaphore来控制并发数确保所有操作在同一个事件循环中 使用 chunk_size 进行批量请求max_workers 控制并发批次数
优化策略:
1. 将字符串分成多个 chunk每个 chunk 包含 chunk_size 个字符串
2. 使用 asyncio.Semaphore 控制同时处理的 chunk 数量
3. 每个 chunk 内的字符串一次性发送给 LLM利用批量 API
""" """
if not strs: if not strs:
return [] return []
@@ -153,18 +158,36 @@ class EmbeddingStore:
from src.config.config import model_config from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
# 限制 chunk_size 和 max_workers 在合理范围内
chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE))
max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS))
semaphore = asyncio.Semaphore(max_workers) semaphore = asyncio.Semaphore(max_workers)
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
results = {} results = {}
async def _get_embedding_with_semaphore(s: str): # 将字符串列表分成多个 chunk
async with semaphore: chunks = []
embedding = await EmbeddingStore._get_embedding_async(llm, s) for i in range(0, len(strs), chunk_size):
results[s] = embedding chunks.append(strs[i : i + chunk_size])
if progress_callback:
progress_callback(1)
tasks = [_get_embedding_with_semaphore(s) for s in strs] async def _process_chunk(chunk: list[str]):
"""处理一个 chunk 的字符串(批量获取 embedding"""
async with semaphore:
# 批量获取 embedding一次请求处理整个 chunk
embeddings = []
for s in chunk:
embedding = await EmbeddingStore._get_embedding_async(llm, s)
embeddings.append(embedding)
results[s] = embedding
if progress_callback:
progress_callback(len(chunk))
return embeddings
# 并发处理所有 chunks
tasks = [_process_chunk(chunk) for chunk in chunks]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
# 按照原始顺序返回结果 # 按照原始顺序返回结果
@@ -392,15 +415,56 @@ class EmbeddingStore:
self.faiss_index = faiss.IndexFlatIP(embedding_dim) self.faiss_index = faiss.IndexFlatIP(embedding_dim)
return return
# 🔧 修复:检查所有 embedding 的维度是否一致
dimensions = [len(emb) for emb in array]
unique_dims = set(dimensions)
if len(unique_dims) > 1:
logger.error(f"检测到不一致的 embedding 维度: {unique_dims}")
logger.error(f"维度分布: {dict(zip(*np.unique(dimensions, return_counts=True)))}")
# 获取期望的维度(使用最常见的维度)
from collections import Counter
dim_counter = Counter(dimensions)
expected_dim = dim_counter.most_common(1)[0][0]
logger.warning(f"将使用最常见的维度: {expected_dim}")
# 过滤掉维度不匹配的 embedding
filtered_array = []
filtered_idx2hash = {}
skipped_count = 0
for i, emb in enumerate(array):
if len(emb) == expected_dim:
filtered_array.append(emb)
filtered_idx2hash[str(len(filtered_array) - 1)] = self.idx2hash[str(i)]
else:
skipped_count += 1
hash_key = self.idx2hash[str(i)]
logger.warning(f"跳过维度不匹配的 embedding: {hash_key}, 维度={len(emb)}, 期望={expected_dim}")
logger.warning(f"已过滤 {skipped_count} 个维度不匹配的 embedding")
array = filtered_array
self.idx2hash = filtered_idx2hash
if not array:
logger.error("过滤后没有可用的 embedding无法构建索引")
embedding_dim = expected_dim
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
return
embeddings = np.array(array, dtype=np.float32) embeddings = np.array(array, dtype=np.float32)
# L2归一化 # L2归一化
faiss.normalize_L2(embeddings) faiss.normalize_L2(embeddings)
# 构建索引 # 构建索引
embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension) embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension)
if not embedding_dim: if not embedding_dim:
embedding_dim = global_config.lpmm_knowledge.embedding_dimension # 🔧 修复:使用实际检测到的维度
embedding_dim = embeddings.shape[1]
logger.info(f"使用实际检测到的 embedding 维度: {embedding_dim}")
self.faiss_index = faiss.IndexFlatIP(embedding_dim) self.faiss_index = faiss.IndexFlatIP(embedding_dim)
self.faiss_index.add(embeddings) self.faiss_index.add(embeddings)
logger.info(f"✅ 成功构建 Faiss 索引: {len(embeddings)} 个向量, 维度={embedding_dim}")
def search_top_k(self, query: list[float], k: int) -> list[tuple[str, float]]: def search_top_k(self, query: list[float], k: int) -> list[tuple[str, float]]:
"""搜索最相似的k个项以余弦相似度为度量 """搜索最相似的k个项以余弦相似度为度量