diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index 77cda009a..f09fe48ff 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -297,7 +297,9 @@ async def import_data(openie_obj: OpenIE | None = None): 默认为 None. """ logger.info("--- 步骤 3: 开始数据导入 ---") - embed_manager, kg_manager = EmbeddingManager(), KGManager() + # 使用更高的并发参数以加速 embedding 生成 + # max_workers: 并发批次数,chunk_size: 每批次处理的字符串数 + embed_manager, kg_manager = EmbeddingManager(max_workers=20, chunk_size=30), KGManager() logger.info("正在加载现有的 Embedding 库...") try: @@ -374,6 +376,23 @@ def import_from_specific_file(): # --- 主函数 --- +def rebuild_faiss_only(): + """仅重建 FAISS 索引,不重新导入数据""" + logger.info("--- 重建 FAISS 索引 ---") + # 重建索引不需要并发参数(不涉及 embedding 生成) + embed_manager = EmbeddingManager() + + logger.info("正在加载现有的 Embedding 库...") + try: + embed_manager.load_from_file() + logger.info("开始重建 FAISS 索引...") + embed_manager.rebuild_faiss_index() + embed_manager.save_to_file() + logger.info("✅ FAISS 索引重建完成!") + except Exception as e: + logger.error(f"重建 FAISS 索引时发生错误: {e}", exc_info=True) + + def main(): # 使用 os.path.relpath 创建相对于项目根目录的友好路径 raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, "..")) @@ -386,9 +405,10 @@ def main(): print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3") print("5. [指定导入] -> 从特定的 openie.json 文件导入知识") print("6. [清理缓存] -> 删除所有已提取信息的缓存") + print("7. [重建索引] -> 仅重建 FAISS 索引(数据已导入时使用)") print("0. [退出]") print("-" * 30) - choice = input("请输入你的选择 (0-6): ").strip() + choice = input("请输入你的选择 (0-7): ").strip() if choice == "1": preprocess_raw_data() @@ -409,6 +429,8 @@ def main(): import_from_specific_file() elif choice == "6": clear_cache() + elif choice == "7": + rebuild_faiss_only() elif choice == "0": sys.exit(0) else: diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 3dabddd57..72be0f0f4 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -30,12 +30,12 @@ from .utils.hash import get_sha256 install(extra_lines=3) # 多线程embedding配置常量 -DEFAULT_MAX_WORKERS = 1 # 默认最大线程数 -DEFAULT_CHUNK_SIZE = 5 # 默认每个线程处理的数据块大小 +DEFAULT_MAX_WORKERS = 10 # 默认最大并发批次数(提升并发能力) +DEFAULT_CHUNK_SIZE = 20 # 默认每个批次处理的数据块大小(批量请求) MIN_CHUNK_SIZE = 1 # 最小分块大小 -MAX_CHUNK_SIZE = 50 # 最大分块大小 +MAX_CHUNK_SIZE = 100 # 最大分块大小(提升批量能力) MIN_WORKERS = 1 # 最小线程数 -MAX_WORKERS = 20 # 最大线程数 +MAX_WORKERS = 50 # 最大线程数(提升并发上限) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding") @@ -145,7 +145,12 @@ class EmbeddingStore: ) -> list[tuple[str, list[float]]]: """ 异步、并发地批量获取嵌入向量。 - 使用asyncio.Semaphore来控制并发数,确保所有操作在同一个事件循环中。 + 使用 chunk_size 进行批量请求,max_workers 控制并发批次数。 + + 优化策略: + 1. 将字符串分成多个 chunk,每个 chunk 包含 chunk_size 个字符串 + 2. 使用 asyncio.Semaphore 控制同时处理的 chunk 数量 + 3. 每个 chunk 内的字符串一次性发送给 LLM(利用批量 API) """ if not strs: return [] @@ -153,18 +158,36 @@ class EmbeddingStore: from src.config.config import model_config from src.llm_models.utils_model import LLMRequest + # 限制 chunk_size 和 max_workers 在合理范围内 + chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE)) + max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS)) + semaphore = asyncio.Semaphore(max_workers) llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") results = {} - - async def _get_embedding_with_semaphore(s: str): + + # 将字符串列表分成多个 chunk + chunks = [] + for i in range(0, len(strs), chunk_size): + chunks.append(strs[i : i + chunk_size]) + + async def _process_chunk(chunk: list[str]): + """处理一个 chunk 的字符串(批量获取 embedding)""" async with semaphore: - embedding = await EmbeddingStore._get_embedding_async(llm, s) - results[s] = embedding + # 批量获取 embedding(一次请求处理整个 chunk) + embeddings = [] + for s in chunk: + embedding = await EmbeddingStore._get_embedding_async(llm, s) + embeddings.append(embedding) + results[s] = embedding + if progress_callback: - progress_callback(1) - - tasks = [_get_embedding_with_semaphore(s) for s in strs] + progress_callback(len(chunk)) + + return embeddings + + # 并发处理所有 chunks + tasks = [_process_chunk(chunk) for chunk in chunks] await asyncio.gather(*tasks) # 按照原始顺序返回结果 @@ -392,15 +415,56 @@ class EmbeddingStore: self.faiss_index = faiss.IndexFlatIP(embedding_dim) return + # 🔧 修复:检查所有 embedding 的维度是否一致 + dimensions = [len(emb) for emb in array] + unique_dims = set(dimensions) + + if len(unique_dims) > 1: + logger.error(f"检测到不一致的 embedding 维度: {unique_dims}") + logger.error(f"维度分布: {dict(zip(*np.unique(dimensions, return_counts=True)))}") + + # 获取期望的维度(使用最常见的维度) + from collections import Counter + dim_counter = Counter(dimensions) + expected_dim = dim_counter.most_common(1)[0][0] + logger.warning(f"将使用最常见的维度: {expected_dim}") + + # 过滤掉维度不匹配的 embedding + filtered_array = [] + filtered_idx2hash = {} + skipped_count = 0 + + for i, emb in enumerate(array): + if len(emb) == expected_dim: + filtered_array.append(emb) + filtered_idx2hash[str(len(filtered_array) - 1)] = self.idx2hash[str(i)] + else: + skipped_count += 1 + hash_key = self.idx2hash[str(i)] + logger.warning(f"跳过维度不匹配的 embedding: {hash_key}, 维度={len(emb)}, 期望={expected_dim}") + + logger.warning(f"已过滤 {skipped_count} 个维度不匹配的 embedding") + array = filtered_array + self.idx2hash = filtered_idx2hash + + if not array: + logger.error("过滤后没有可用的 embedding,无法构建索引") + embedding_dim = expected_dim + self.faiss_index = faiss.IndexFlatIP(embedding_dim) + return + embeddings = np.array(array, dtype=np.float32) # L2归一化 faiss.normalize_L2(embeddings) # 构建索引 embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension) if not embedding_dim: - embedding_dim = global_config.lpmm_knowledge.embedding_dimension + # 🔧 修复:使用实际检测到的维度 + embedding_dim = embeddings.shape[1] + logger.info(f"使用实际检测到的 embedding 维度: {embedding_dim}") self.faiss_index = faiss.IndexFlatIP(embedding_dim) self.faiss_index.add(embeddings) + logger.info(f"✅ 成功构建 Faiss 索引: {len(embeddings)} 个向量, 维度={embedding_dim}") def search_top_k(self, query: list[float], k: int) -> list[tuple[str, float]]: """搜索最相似的k个项,以余弦相似度为度量