From 02dfed673f55bb4187c79ccfe30186e2f2002838 Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Thu, 23 Oct 2025 11:42:35 +0800 Subject: [PATCH] =?UTF-8?q?fix(knowledge):=20=E4=BF=AE=E5=A4=8D=E7=9F=A5?= =?UTF-8?q?=E8=AF=86=E5=BA=93=E5=B5=8C=E5=85=A5=E7=94=9F=E6=88=90=E4=B8=AD?= =?UTF-8?q?=E7=9A=84=E5=B9=B6=E5=8F=91=E5=A4=84=E7=90=86=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 原有的多线程 (`ThreadPoolExecutor`) 嵌入生成方式已被重构为完全异步的并发模型。 旧的实现为每个线程创建新的 `asyncio` 事件循环来运行协程,这种模式效率低下且不稳定,容易引发难以调试的并发问题。 新的实现统一在单个事件循环中处理所有异步任务,使用 `asyncio.Semaphore` 控制并发等级,并通过 `asyncio.gather` 高效地执行批量嵌入请求。此更改显著提高了代码的稳定性、性能和可维护性。 BREAKING CHANGE: `EmbeddingStore` 和 `EmbeddingManager` 中的多个核心方法(如 `store_new_data_set`, `check_embedding_model_consistency`, `batch_insert_strs` 等)已从同步方法更改为异步方法。所有对这些方法的调用现在都必须使用 `await`。 --- scripts/lpmm_learning_tool.py | 2 +- src/chat/knowledge/embedding_store.py | 235 +++++++------------------- 2 files changed, 66 insertions(+), 171 deletions(-) diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index 1b7ebb2b1..c09139939 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -302,7 +302,7 @@ async def import_data(openie_obj: OpenIE | None = None): else: logger.info(f"去重完成,发现 {len(new_raw_paragraphs)} 个新段落。") logger.info("开始生成 Embedding...") - embed_manager.store_new_data_set(new_raw_paragraphs, new_triple_list_data) + await embed_manager.store_new_data_set(new_raw_paragraphs, new_triple_list_data) embed_manager.rebuild_faiss_index() embed_manager.save_to_file() logger.info("Embedding 处理完成!") diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index dd9251230..8131415e8 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -30,7 +30,7 @@ from .utils.hash import get_sha256 install(extra_lines=3) # 多线程embedding配置常量 -DEFAULT_MAX_WORKERS = 3 # 默认最大线程数 +DEFAULT_MAX_WORKERS = 1 # 默认最大线程数 DEFAULT_CHUNK_SIZE = 5 # 默认每个线程处理的数据块大小 MIN_CHUNK_SIZE = 1 # 最小分块大小 MAX_CHUNK_SIZE = 50 # 最大分块大小 @@ -125,160 +125,63 @@ class EmbeddingStore: self.idx2hash = None @staticmethod - def _get_embedding(s: str) -> list[float]: - """获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题""" - # 创建新的事件循环并在完成后立即关闭 - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - + async def _get_embedding_async(llm, s: str) -> list[float]: + """异步、安全地获取单个字符串的嵌入向量""" try: - # 创建新的LLMRequest实例 - from src.config.config import model_config - from src.llm_models.utils_model import LLMRequest - - llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") - - # 使用新的事件循环运行异步方法 - embedding, _ = loop.run_until_complete(llm.get_embedding(s)) - + embedding, _ = await llm.get_embedding(s) if embedding and len(embedding) > 0: return embedding else: logger.error(f"获取嵌入失败: {s}") return [] - except Exception as e: logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}") return [] - finally: - # 确保事件循环被正确关闭 - try: - loop.close() - except Exception: - ... @staticmethod - def _get_embeddings_batch_threaded( - strs: list[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None + @staticmethod + async def _get_embeddings_batch_async( + strs: list[str], chunk_size: int = 10, max_workers: int = 4, progress_callback=None ) -> list[tuple[str, list[float]]]: - """使用多线程批量获取嵌入向量 - - Args: - strs: 要获取嵌入的字符串列表 - chunk_size: 每个线程处理的数据块大小 - max_workers: 最大线程数 - progress_callback: 进度回调函数,接收一个参数表示完成的数量 - - Returns: - 包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致 + """ + 异步、并发地批量获取嵌入向量。 + 使用asyncio.Semaphore来控制并发数,确保所有操作在同一个事件循环中。 """ if not strs: return [] - # 分块 - chunks = [] - for i in range(0, len(strs), chunk_size): - chunk = strs[i : i + chunk_size] - chunks.append((i, chunk)) # 保存起始索引以维持顺序 + from src.config.config import model_config + from src.llm_models.utils_model import LLMRequest - # 结果存储,使用字典按索引存储以保证顺序 + semaphore = asyncio.Semaphore(max_workers) + llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") results = {} - def process_chunk(chunk_data): - """处理单个数据块的函数""" - start_idx, chunk_strs = chunk_data - chunk_results = [] + async def _get_embedding_with_semaphore(s: str): + async with semaphore: + embedding = await EmbeddingStore._get_embedding_async(llm, s) + results[s] = embedding + if progress_callback: + progress_callback(1) - # 为每个线程创建独立的LLMRequest实例 - from src.config.config import model_config - from src.llm_models.utils_model import LLMRequest + tasks = [_get_embedding_with_semaphore(s) for s in strs] + await asyncio.gather(*tasks) - try: - # 创建线程专用的LLM实例 - llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") - - for i, s in enumerate(chunk_strs): - try: - # 在线程中创建独立的事件循环 - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - embedding = loop.run_until_complete(llm.get_embedding(s)) - finally: - loop.close() - - if embedding and len(embedding) > 0: - chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量 - else: - logger.error(f"获取嵌入失败: {s}") - chunk_results.append((start_idx + i, s, [])) - - # 每完成一个嵌入立即更新进度 - if progress_callback: - progress_callback(1) - - except Exception as e: - logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}") - chunk_results.append((start_idx + i, s, [])) - - # 即使失败也要更新进度 - if progress_callback: - progress_callback(1) - - except Exception as e: - logger.error(f"创建LLM实例失败: {e}") - # 如果创建LLM实例失败,返回空结果 - for i, s in enumerate(chunk_strs): - chunk_results.append((start_idx + i, s, [])) - # 即使失败也要更新进度 - if progress_callback: - progress_callback(1) - - return chunk_results - - # 使用线程池处理 - with ThreadPoolExecutor(max_workers=max_workers) as executor: - # 提交所有任务 - future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks} - - # 收集结果(进度已在process_chunk中实时更新) - for future in as_completed(future_to_chunk): - try: - chunk_results = future.result() - for idx, s, embedding in chunk_results: - results[idx] = (s, embedding) - except Exception as e: - chunk = future_to_chunk[future] - logger.error(f"处理数据块时发生异常: {chunk}, 错误: {e}") - # 为失败的块添加空结果 - start_idx, chunk_strs = chunk - for i, s in enumerate(chunk_strs): - results[start_idx + i] = (s, []) - - # 按原始顺序返回结果 - ordered_results = [] - for i in range(len(strs)): - if i in results: - ordered_results.append(results[i]) - else: - # 防止遗漏 - ordered_results.append((strs[i], [])) - - return ordered_results + # 按照原始顺序返回结果 + return [(s, results.get(s, [])) for s in strs] @staticmethod def get_test_file_path(): return EMBEDDING_TEST_FILE - def save_embedding_test_vectors(self): - """保存测试字符串的嵌入到本地(使用多线程优化)""" + async def save_embedding_test_vectors(self): + """保存测试字符串的嵌入到本地(异步单线程)""" logger.info("开始保存测试字符串的嵌入向量...") - # 使用多线程批量获取测试字符串的嵌入 - embedding_results = self._get_embeddings_batch_threaded( + embedding_results = await self._get_embeddings_batch_async( EMBEDDING_TEST_STRINGS, - chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)), - max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)), + chunk_size=self.chunk_size, + max_workers=self.max_workers, ) # 构建测试向量字典 @@ -288,8 +191,9 @@ class EmbeddingStore: test_vectors[str(idx)] = embedding else: logger.error(f"获取测试字符串嵌入失败: {s}") - # 使用原始单线程方法作为后备 - test_vectors[str(idx)] = self._get_embedding(s) + # Since _get_embedding is problematic, we just fail here + test_vectors[str(idx)] = [] + with open(self.get_test_file_path(), "w", encoding="utf-8") as f: f.write(orjson.dumps(test_vectors, option=orjson.OPT_INDENT_2).decode("utf-8")) @@ -304,28 +208,27 @@ class EmbeddingStore: with open(path, encoding="utf-8") as f: return orjson.loads(f.read()) - def check_embedding_model_consistency(self): - """校验当前模型与本地嵌入模型是否一致(使用多线程优化)""" + async def check_embedding_model_consistency(self): + """校验当前模型与本地嵌入模型是否一致(异步单线程)""" local_vectors = self.load_embedding_test_vectors() if local_vectors is None: logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。") - self.save_embedding_test_vectors() + await self.save_embedding_test_vectors() return True # 检查本地向量完整性 for idx in range(len(EMBEDDING_TEST_STRINGS)): if local_vectors.get(str(idx)) is None: logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。") - self.save_embedding_test_vectors() + await self.save_embedding_test_vectors() return True logger.info("开始检验嵌入模型一致性...") - # 使用多线程批量获取当前模型的嵌入 - embedding_results = self._get_embeddings_batch_threaded( + embedding_results = await self._get_embeddings_batch_async( EMBEDDING_TEST_STRINGS, - chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)), - max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)), + chunk_size=self.chunk_size, + max_workers=self.max_workers, ) # 检查一致性 @@ -343,8 +246,8 @@ class EmbeddingStore: logger.info("嵌入模型一致性校验通过。") return True - def batch_insert_strs(self, strs: list[str], times: int) -> None: - """向库中存入字符串(使用多线程优化)""" + async def batch_insert_strs(self, strs: list[str], times: int) -> None: + """向库中存入字符串(异步单线程)""" if not strs: return @@ -383,33 +286,18 @@ class EmbeddingStore: progress.update(task, advance=already_processed) if new_strs: - # 使用实例配置的参数,智能调整分块和线程数 - optimal_chunk_size = max( - MIN_CHUNK_SIZE, - min( - self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size - ), - ) - optimal_max_workers = min( - self.max_workers, - max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1), - ) - - logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}") - # 定义进度更新回调函数 def update_progress(count): progress.update(task, advance=count) - # 批量获取嵌入,并实时更新进度 - embedding_results = self._get_embeddings_batch_threaded( + embedding_results = await self._get_embeddings_batch_async( new_strs, - chunk_size=optimal_chunk_size, - max_workers=optimal_max_workers, + chunk_size=self.chunk_size, + max_workers=self.max_workers, progress_callback=update_progress, ) - # 存入结果(不再需要在这里更新进度,因为已经在回调中更新了) + # 存入结果 for s, embedding in embedding_results: item_hash = self.namespace + "-" + get_sha256(s) if embedding: # 只有成功获取到嵌入才存入 @@ -499,6 +387,13 @@ class EmbeddingStore: for key in self.store: array.append(self.store[key].embedding) self.idx2hash[str(len(array) - 1)] = key + + if not array: + logger.warning(f"在 {self.namespace} 中没有找到可用于构建Faiss索引的嵌入向量。") + embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension) or 1 + self.faiss_index = faiss.IndexFlatIP(embedding_dim) + return + embeddings = np.array(array, dtype=np.float32) # L2归一化 faiss.normalize_L2(embeddings) @@ -569,30 +464,30 @@ class EmbeddingManager: ) self.stored_pg_hashes = set() - def check_all_embedding_model_consistency(self): + async def check_all_embedding_model_consistency(self): """对所有嵌入库做模型一致性校验""" - return self.paragraphs_embedding_store.check_embedding_model_consistency() + return await self.paragraphs_embedding_store.check_embedding_model_consistency() - def _store_pg_into_embedding(self, raw_paragraphs: dict[str, str]): + async def _store_pg_into_embedding(self, raw_paragraphs: dict[str, str]): """将段落编码存入Embedding库""" - self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1) + await self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1) - def _store_ent_into_embedding(self, triple_list_data: dict[str, list[list[str]]]): + async def _store_ent_into_embedding(self, triple_list_data: dict[str, list[list[str]]]): """将实体编码存入Embedding库""" entities = set() for triple_list in triple_list_data.values(): for triple in triple_list: entities.add(triple[0]) entities.add(triple[2]) - self.entities_embedding_store.batch_insert_strs(list(entities), times=2) + await self.entities_embedding_store.batch_insert_strs(list(entities), times=2) - def _store_rel_into_embedding(self, triple_list_data: dict[str, list[list[str]]]): + async def _store_rel_into_embedding(self, triple_list_data: dict[str, list[list[str]]]): """将关系编码存入Embedding库""" graph_triples = [] # a list of unique relation triple (in tuple) from all chunks for triples in triple_list_data.values(): graph_triples.extend([tuple(t) for t in triples]) graph_triples = list(set(graph_triples)) - self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples], times=3) + await self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples], times=3) def load_from_file(self): """从文件加载""" @@ -602,17 +497,17 @@ class EmbeddingManager: # 从段落库中获取已存储的hash self.stored_pg_hashes = set(self.paragraphs_embedding_store.store.keys()) - def store_new_data_set( + async def store_new_data_set( self, raw_paragraphs: dict[str, str], triple_list_data: dict[str, list[list[str]]], ): - if not self.check_all_embedding_model_consistency(): + if not await self.check_all_embedding_model_consistency(): raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。") """存储新的数据集""" - self._store_pg_into_embedding(raw_paragraphs) - self._store_ent_into_embedding(triple_list_data) - self._store_rel_into_embedding(triple_list_data) + await self._store_pg_into_embedding(raw_paragraphs) + await self._store_ent_into_embedding(triple_list_data) + await self._store_rel_into_embedding(triple_list_data) self.stored_pg_hashes.update(raw_paragraphs.keys()) def save_to_file(self):