From 28035e18f1eb9434c4d53cd13b85a54c4962a6c8 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sun, 9 Nov 2025 22:40:17 +0800 Subject: [PATCH] =?UTF-8?q?feat(concurrency):=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E5=B9=B6=E5=8F=91=E6=8E=A7=E5=88=B6=EF=BC=8C=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E4=BF=A1=E6=81=AF=E6=8F=90=E5=8F=96=E5=92=8C=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=AF=BC=E5=85=A5=E6=80=A7=E8=83=BD=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/lpmm_learning_tool.py | 42 +++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/scripts/lpmm_learning_tool.py b/scripts/lpmm_learning_tool.py index f09fe48ff..db0fdbd73 100644 --- a/scripts/lpmm_learning_tool.py +++ b/scripts/lpmm_learning_tool.py @@ -37,6 +37,26 @@ RAW_DATA_PATH = os.path.join(ROOT_PATH, "data", "lpmm_raw_data") OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache") +# ========== 性能配置参数 ========== +# +# 知识提取(步骤2:txt转json)并发控制 +# - 控制同时进行的LLM提取请求数量 +# - 推荐值: 3-10,取决于API速率限制 +# - 过高可能触发429错误(速率限制) +MAX_EXTRACTION_CONCURRENCY = 5 + +# 数据导入(步骤3:生成embedding)性能配置 +# - max_workers: 并发批次数(每批次并行处理) +# - chunk_size: 每批次包含的字符串数 +# - 理论并发 = max_workers × chunk_size +# - 推荐配置: +# * 高性能API(OpenAI): max_workers=20-30, chunk_size=30-50 +# * 中等API: max_workers=10-15, chunk_size=20-30 +# * 本地/慢速API: max_workers=5-10, chunk_size=10-20 +EMBEDDING_MAX_WORKERS = 20 # 并发批次数 +EMBEDDING_CHUNK_SIZE = 30 # 每批次字符串数 +# =================================== + # --- 缓存清理 --- @@ -217,6 +237,9 @@ async def extract_information(paragraphs_dict, model_set): 2. 更高效地利用 I/O 资源 3. 与我们优化的 LLM 请求层无缝集成 + 并发控制: + - 使用信号量限制最大并发数为 5,防止触发 API 速率限制 + Args: paragraphs_dict: {hash: paragraph} 字典 model_set: 模型配置 @@ -229,16 +252,26 @@ async def extract_information(paragraphs_dict, model_set): # 🔧 关键修复:创建单个 LLM 请求实例,复用连接 llm_api = LLMRequest(model_set=model_set, request_type="lpmm_extraction") + + # 🔧 并发控制:限制最大并发数,防止速率限制 + semaphore = asyncio.Semaphore(MAX_EXTRACTION_CONCURRENCY) + + async def extract_with_semaphore(pg_hash, paragraph): + """带信号量控制的提取函数""" + async with semaphore: + return await extract_info_async(pg_hash, paragraph, llm_api) - # 创建所有异步任务 + # 创建所有异步任务(带并发控制) tasks = [ - extract_info_async(p_hash, paragraph, llm_api) + extract_with_semaphore(p_hash, paragraph) for p_hash, paragraph in paragraphs_dict.items() ] total = len(tasks) completed = 0 + logger.info(f"开始提取 {total} 个段落的信息(最大并发: {MAX_EXTRACTION_CONCURRENCY})") + with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), @@ -297,9 +330,10 @@ async def import_data(openie_obj: OpenIE | None = None): 默认为 None. """ logger.info("--- 步骤 3: 开始数据导入 ---") - # 使用更高的并发参数以加速 embedding 生成 + # 使用配置的并发参数以加速 embedding 生成 # max_workers: 并发批次数,chunk_size: 每批次处理的字符串数 - embed_manager, kg_manager = EmbeddingManager(max_workers=20, chunk_size=30), KGManager() + embed_manager = EmbeddingManager(max_workers=EMBEDDING_MAX_WORKERS, chunk_size=EMBEDDING_CHUNK_SIZE) + kg_manager = KGManager() logger.info("正在加载现有的 Embedding 库...") try: