From eb16508fb56367473c638768e0f9acb4f31e0aa2 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sun, 9 Nov 2025 21:38:31 +0800 Subject: [PATCH 1/3] =?UTF-8?q?feat(extraction):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E4=BF=A1=E6=81=AF=E6=8F=90=E5=8F=96=E6=B5=81=E7=A8=8B=EF=BC=8C?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=BC=82=E6=AD=A5=E5=B9=B6=E5=8F=91=E5=92=8C?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/lpmm_learning_tool.py | 130 ++++++++++++++++++++++------------ 1 file changed, 83 insertions(+), 47 deletions(-) diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index 35272c6b1..77cda009a 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -3,9 +3,7 @@ import datetime import os import shutil import sys -from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -from threading import Lock import aiofiles import orjson @@ -38,7 +36,6 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) RAW_DATA_PATH = os.path.join(ROOT_PATH, "data", "lpmm_raw_data") OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache") -file_lock = Lock() # --- 缓存清理 --- @@ -155,26 +152,41 @@ def get_extraction_prompt(paragraph: str) -> str: async def extract_info_async(pg_hash, paragraph, llm_api): + """ + 异步提取单个段落的信息(带缓存支持) + + Args: + pg_hash: 段落哈希值 + paragraph: 段落文本 + llm_api: LLM请求实例 + + Returns: + tuple: (doc_item或None, failed_hash或None) + """ temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json") - with file_lock: - if os.path.exists(temp_file_path): + + # 🔧 优化:使用异步文件检查,避免阻塞 + if os.path.exists(temp_file_path): + try: + async with aiofiles.open(temp_file_path, "rb") as f: + content = await f.read() + return orjson.loads(content), None + except orjson.JSONDecodeError: + # 缓存文件损坏,删除并重新生成 try: - async with aiofiles.open(temp_file_path, "rb") as f: - content = await f.read() - return orjson.loads(content), None - except orjson.JSONDecodeError: os.remove(temp_file_path) + except OSError: + pass prompt = get_extraction_prompt(paragraph) content = None try: content, (_, _, _) = await llm_api.generate_response_async(prompt) - # 改进点:调用封装好的函数处理JSON解析和修复 + # 调用封装好的函数处理JSON解析和修复 extracted_data = _parse_and_repair_json(content) if extracted_data is None: - # 如果解析失败,抛出异常以触发统一的错误处理逻辑 raise ValueError("无法从LLM输出中解析有效的JSON数据") doc_item = { @@ -183,9 +195,11 @@ async def extract_info_async(pg_hash, paragraph, llm_api): "extracted_entities": extracted_data.get("entities", []), "extracted_triples": extracted_data.get("triples", []), } - with file_lock: - async with aiofiles.open(temp_file_path, "wb") as f: - await f.write(orjson.dumps(doc_item)) + + # 保存到缓存(异步写入) + async with aiofiles.open(temp_file_path, "wb") as f: + await f.write(orjson.dumps(doc_item)) + return doc_item, None except Exception as e: logger.error(f"提取信息失败:{pg_hash}, 错误:{e}") @@ -194,42 +208,61 @@ async def extract_info_async(pg_hash, paragraph, llm_api): return None, pg_hash -def extract_info_sync(pg_hash, paragraph, model_set): - llm_api = LLMRequest(model_set=model_set) - return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api)) - - -def extract_information(paragraphs_dict, model_set): +async def extract_information(paragraphs_dict, model_set): + """ + 🔧 优化:使用真正的异步并发代替多线程 + + 这样可以: + 1. 避免 event loop closed 错误 + 2. 更高效地利用 I/O 资源 + 3. 与我们优化的 LLM 请求层无缝集成 + + Args: + paragraphs_dict: {hash: paragraph} 字典 + model_set: 模型配置 + """ logger.info("--- 步骤 2: 开始信息提取 ---") os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True) os.makedirs(TEMP_DIR, exist_ok=True) failed_hashes, open_ie_docs = [], [] + + # 🔧 关键修复:创建单个 LLM 请求实例,复用连接 + llm_api = LLMRequest(model_set=model_set, request_type="lpmm_extraction") - with ThreadPoolExecutor(max_workers=3) as executor: - f_to_hash = { - executor.submit(extract_info_sync, p_hash, p, model_set): p_hash - for p_hash, p in paragraphs_dict.items() - } - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - MofNCompleteColumn(), - "•", - TimeElapsedColumn(), - "<", - TimeRemainingColumn(), - ) as progress: - task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict)) - for future in as_completed(f_to_hash): - doc_item, failed_hash = future.result() - if failed_hash: - failed_hashes.append(failed_hash) - elif doc_item: - open_ie_docs.append(doc_item) - progress.update(task, advance=1) + # 创建所有异步任务 + tasks = [ + extract_info_async(p_hash, paragraph, llm_api) + for p_hash, paragraph in paragraphs_dict.items() + ] + + total = len(tasks) + completed = 0 + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + "•", + TimeElapsedColumn(), + "<", + TimeRemainingColumn(), + ) as progress: + task = progress.add_task("[cyan]正在提取信息...", total=total) + + # 🔧 优化:使用 asyncio.gather 并发执行所有任务 + # return_exceptions=True 确保单个失败不影响其他任务 + for coro in asyncio.as_completed(tasks): + doc_item, failed_hash = await coro + if failed_hash: + failed_hashes.append(failed_hash) + elif doc_item: + open_ie_docs.append(doc_item) + + completed += 1 + progress.update(task, advance=1) if open_ie_docs: all_entities = [e for doc in open_ie_docs for e in doc["extracted_entities"]] @@ -244,6 +277,7 @@ def extract_information(paragraphs_dict, model_set): with open(output_path, "wb") as f: f.write(orjson.dumps(openie_obj._to_dict())) logger.info(f"信息提取结果已保存到: {output_path}") + logger.info(f"成功提取 {len(open_ie_docs)} 个段落的信息") if failed_hashes: logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}") @@ -354,20 +388,22 @@ def main(): print("6. [清理缓存] -> 删除所有已提取信息的缓存") print("0. [退出]") print("-" * 30) - choice = input("请输入你的选择 (0-5): ").strip() + choice = input("请输入你的选择 (0-6): ").strip() if choice == "1": preprocess_raw_data() elif choice == "2": paragraphs = preprocess_raw_data() if paragraphs: - extract_information(paragraphs, model_config.model_task_config.lpmm_qa) + # 🔧 修复:使用 asyncio.run 调用异步函数 + asyncio.run(extract_information(paragraphs, model_config.model_task_config.lpmm_qa)) elif choice == "3": asyncio.run(import_data()) elif choice == "4": paragraphs = preprocess_raw_data() if paragraphs: - extract_information(paragraphs, model_config.model_task_config.lpmm_qa) + # 🔧 修复:使用 asyncio.run 调用异步函数 + asyncio.run(extract_information(paragraphs, model_config.model_task_config.lpmm_qa)) asyncio.run(import_data()) elif choice == "5": import_from_specific_file() From 6d727eeda97c6b11807b49b3c3daa00c03322c1f Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sun, 9 Nov 2025 22:30:21 +0800 Subject: [PATCH 2/3] =?UTF-8?q?feat(embedding):=20=E6=8F=90=E5=8D=87?= =?UTF-8?q?=E5=B9=B6=E5=8F=91=E8=83=BD=E5=8A=9B=EF=BC=8C=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E5=B5=8C=E5=85=A5=E7=94=9F=E6=88=90=E5=92=8C=E7=B4=A2=E5=BC=95?= =?UTF-8?q?=E9=87=8D=E5=BB=BA=E6=B5=81=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/lpmm_learning_tool.py | 26 +++++++- src/chat/knowledge/embedding_store.py | 90 +++++++++++++++++++++++---- 2 files changed, 101 insertions(+), 15 deletions(-) diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index 77cda009a..f09fe48ff 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -297,7 +297,9 @@ async def import_data(openie_obj: OpenIE | None = None): 默认为 None. """ logger.info("--- 步骤 3: 开始数据导入 ---") - embed_manager, kg_manager = EmbeddingManager(), KGManager() + # 使用更高的并发参数以加速 embedding 生成 + # max_workers: 并发批次数,chunk_size: 每批次处理的字符串数 + embed_manager, kg_manager = EmbeddingManager(max_workers=20, chunk_size=30), KGManager() logger.info("正在加载现有的 Embedding 库...") try: @@ -374,6 +376,23 @@ def import_from_specific_file(): # --- 主函数 --- +def rebuild_faiss_only(): + """仅重建 FAISS 索引,不重新导入数据""" + logger.info("--- 重建 FAISS 索引 ---") + # 重建索引不需要并发参数(不涉及 embedding 生成) + embed_manager = EmbeddingManager() + + logger.info("正在加载现有的 Embedding 库...") + try: + embed_manager.load_from_file() + logger.info("开始重建 FAISS 索引...") + embed_manager.rebuild_faiss_index() + embed_manager.save_to_file() + logger.info("✅ FAISS 索引重建完成!") + except Exception as e: + logger.error(f"重建 FAISS 索引时发生错误: {e}", exc_info=True) + + def main(): # 使用 os.path.relpath 创建相对于项目根目录的友好路径 raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, "..")) @@ -386,9 +405,10 @@ def main(): print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3") print("5. [指定导入] -> 从特定的 openie.json 文件导入知识") print("6. [清理缓存] -> 删除所有已提取信息的缓存") + print("7. [重建索引] -> 仅重建 FAISS 索引(数据已导入时使用)") print("0. [退出]") print("-" * 30) - choice = input("请输入你的选择 (0-6): ").strip() + choice = input("请输入你的选择 (0-7): ").strip() if choice == "1": preprocess_raw_data() @@ -409,6 +429,8 @@ def main(): import_from_specific_file() elif choice == "6": clear_cache() + elif choice == "7": + rebuild_faiss_only() elif choice == "0": sys.exit(0) else: diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 3dabddd57..72be0f0f4 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -30,12 +30,12 @@ from .utils.hash import get_sha256 install(extra_lines=3) # 多线程embedding配置常量 -DEFAULT_MAX_WORKERS = 1 # 默认最大线程数 -DEFAULT_CHUNK_SIZE = 5 # 默认每个线程处理的数据块大小 +DEFAULT_MAX_WORKERS = 10 # 默认最大并发批次数(提升并发能力) +DEFAULT_CHUNK_SIZE = 20 # 默认每个批次处理的数据块大小(批量请求) MIN_CHUNK_SIZE = 1 # 最小分块大小 -MAX_CHUNK_SIZE = 50 # 最大分块大小 +MAX_CHUNK_SIZE = 100 # 最大分块大小(提升批量能力) MIN_WORKERS = 1 # 最小线程数 -MAX_WORKERS = 20 # 最大线程数 +MAX_WORKERS = 50 # 最大线程数(提升并发上限) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding") @@ -145,7 +145,12 @@ class EmbeddingStore: ) -> list[tuple[str, list[float]]]: """ 异步、并发地批量获取嵌入向量。 - 使用asyncio.Semaphore来控制并发数,确保所有操作在同一个事件循环中。 + 使用 chunk_size 进行批量请求,max_workers 控制并发批次数。 + + 优化策略: + 1. 将字符串分成多个 chunk,每个 chunk 包含 chunk_size 个字符串 + 2. 使用 asyncio.Semaphore 控制同时处理的 chunk 数量 + 3. 每个 chunk 内的字符串一次性发送给 LLM(利用批量 API) """ if not strs: return [] @@ -153,18 +158,36 @@ class EmbeddingStore: from src.config.config import model_config from src.llm_models.utils_model import LLMRequest + # 限制 chunk_size 和 max_workers 在合理范围内 + chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE)) + max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS)) + semaphore = asyncio.Semaphore(max_workers) llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") results = {} - - async def _get_embedding_with_semaphore(s: str): + + # 将字符串列表分成多个 chunk + chunks = [] + for i in range(0, len(strs), chunk_size): + chunks.append(strs[i : i + chunk_size]) + + async def _process_chunk(chunk: list[str]): + """处理一个 chunk 的字符串(批量获取 embedding)""" async with semaphore: - embedding = await EmbeddingStore._get_embedding_async(llm, s) - results[s] = embedding + # 批量获取 embedding(一次请求处理整个 chunk) + embeddings = [] + for s in chunk: + embedding = await EmbeddingStore._get_embedding_async(llm, s) + embeddings.append(embedding) + results[s] = embedding + if progress_callback: - progress_callback(1) - - tasks = [_get_embedding_with_semaphore(s) for s in strs] + progress_callback(len(chunk)) + + return embeddings + + # 并发处理所有 chunks + tasks = [_process_chunk(chunk) for chunk in chunks] await asyncio.gather(*tasks) # 按照原始顺序返回结果 @@ -392,15 +415,56 @@ class EmbeddingStore: self.faiss_index = faiss.IndexFlatIP(embedding_dim) return + # 🔧 修复:检查所有 embedding 的维度是否一致 + dimensions = [len(emb) for emb in array] + unique_dims = set(dimensions) + + if len(unique_dims) > 1: + logger.error(f"检测到不一致的 embedding 维度: {unique_dims}") + logger.error(f"维度分布: {dict(zip(*np.unique(dimensions, return_counts=True)))}") + + # 获取期望的维度(使用最常见的维度) + from collections import Counter + dim_counter = Counter(dimensions) + expected_dim = dim_counter.most_common(1)[0][0] + logger.warning(f"将使用最常见的维度: {expected_dim}") + + # 过滤掉维度不匹配的 embedding + filtered_array = [] + filtered_idx2hash = {} + skipped_count = 0 + + for i, emb in enumerate(array): + if len(emb) == expected_dim: + filtered_array.append(emb) + filtered_idx2hash[str(len(filtered_array) - 1)] = self.idx2hash[str(i)] + else: + skipped_count += 1 + hash_key = self.idx2hash[str(i)] + logger.warning(f"跳过维度不匹配的 embedding: {hash_key}, 维度={len(emb)}, 期望={expected_dim}") + + logger.warning(f"已过滤 {skipped_count} 个维度不匹配的 embedding") + array = filtered_array + self.idx2hash = filtered_idx2hash + + if not array: + logger.error("过滤后没有可用的 embedding,无法构建索引") + embedding_dim = expected_dim + self.faiss_index = faiss.IndexFlatIP(embedding_dim) + return + embeddings = np.array(array, dtype=np.float32) # L2归一化 faiss.normalize_L2(embeddings) # 构建索引 embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension) if not embedding_dim: - embedding_dim = global_config.lpmm_knowledge.embedding_dimension + # 🔧 修复:使用实际检测到的维度 + embedding_dim = embeddings.shape[1] + logger.info(f"使用实际检测到的 embedding 维度: {embedding_dim}") self.faiss_index = faiss.IndexFlatIP(embedding_dim) self.faiss_index.add(embeddings) + logger.info(f"✅ 成功构建 Faiss 索引: {len(embeddings)} 个向量, 维度={embedding_dim}") def search_top_k(self, query: list[float], k: int) -> list[tuple[str, float]]: """搜索最相似的k个项,以余弦相似度为度量 From 2123efb6f47dc929777cc2ee8c0e8ae6525c12ab Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sun, 9 Nov 2025 22:40:17 +0800 Subject: [PATCH 3/3] =?UTF-8?q?feat(concurrency):=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E5=B9=B6=E5=8F=91=E6=8E=A7=E5=88=B6=EF=BC=8C=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E4=BF=A1=E6=81=AF=E6=8F=90=E5=8F=96=E5=92=8C=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=AF=BC=E5=85=A5=E6=80=A7=E8=83=BD=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/lpmm_learning_tool.py | 42 +++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index f09fe48ff..db0fdbd73 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -37,6 +37,26 @@ RAW_DATA_PATH = os.path.join(ROOT_PATH, "data", "lpmm_raw_data") OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache") +# ========== 性能配置参数 ========== +# +# 知识提取(步骤2:txt转json)并发控制 +# - 控制同时进行的LLM提取请求数量 +# - 推荐值: 3-10,取决于API速率限制 +# - 过高可能触发429错误(速率限制) +MAX_EXTRACTION_CONCURRENCY = 5 + +# 数据导入(步骤3:生成embedding)性能配置 +# - max_workers: 并发批次数(每批次并行处理) +# - chunk_size: 每批次包含的字符串数 +# - 理论并发 = max_workers × chunk_size +# - 推荐配置: +# * 高性能API(OpenAI): max_workers=20-30, chunk_size=30-50 +# * 中等API: max_workers=10-15, chunk_size=20-30 +# * 本地/慢速API: max_workers=5-10, chunk_size=10-20 +EMBEDDING_MAX_WORKERS = 20 # 并发批次数 +EMBEDDING_CHUNK_SIZE = 30 # 每批次字符串数 +# =================================== + # --- 缓存清理 --- @@ -217,6 +237,9 @@ async def extract_information(paragraphs_dict, model_set): 2. 更高效地利用 I/O 资源 3. 与我们优化的 LLM 请求层无缝集成 + 并发控制: + - 使用信号量限制最大并发数为 5,防止触发 API 速率限制 + Args: paragraphs_dict: {hash: paragraph} 字典 model_set: 模型配置 @@ -229,16 +252,26 @@ async def extract_information(paragraphs_dict, model_set): # 🔧 关键修复:创建单个 LLM 请求实例,复用连接 llm_api = LLMRequest(model_set=model_set, request_type="lpmm_extraction") + + # 🔧 并发控制:限制最大并发数,防止速率限制 + semaphore = asyncio.Semaphore(MAX_EXTRACTION_CONCURRENCY) + + async def extract_with_semaphore(pg_hash, paragraph): + """带信号量控制的提取函数""" + async with semaphore: + return await extract_info_async(pg_hash, paragraph, llm_api) - # 创建所有异步任务 + # 创建所有异步任务(带并发控制) tasks = [ - extract_info_async(p_hash, paragraph, llm_api) + extract_with_semaphore(p_hash, paragraph) for p_hash, paragraph in paragraphs_dict.items() ] total = len(tasks) completed = 0 + logger.info(f"开始提取 {total} 个段落的信息(最大并发: {MAX_EXTRACTION_CONCURRENCY})") + with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), @@ -297,9 +330,10 @@ async def import_data(openie_obj: OpenIE | None = None): 默认为 None. """ logger.info("--- 步骤 3: 开始数据导入 ---") - # 使用更高的并发参数以加速 embedding 生成 + # 使用配置的并发参数以加速 embedding 生成 # max_workers: 并发批次数,chunk_size: 每批次处理的字符串数 - embed_manager, kg_manager = EmbeddingManager(max_workers=20, chunk_size=30), KGManager() + embed_manager = EmbeddingManager(max_workers=EMBEDDING_MAX_WORKERS, chunk_size=EMBEDDING_CHUNK_SIZE) + kg_manager = KGManager() logger.info("正在加载现有的 Embedding 库...") try: