feat(concurrency): 增加并发控制,优化信息提取和数据导入性能配置
This commit is contained in:
@@ -37,6 +37,26 @@ RAW_DATA_PATH = os.path.join(ROOT_PATH, "data", "lpmm_raw_data")
|
||||
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache")
|
||||
|
||||
# ========== 性能配置参数 ==========
|
||||
#
|
||||
# 知识提取(步骤2:txt转json)并发控制
|
||||
# - 控制同时进行的LLM提取请求数量
|
||||
# - 推荐值: 3-10,取决于API速率限制
|
||||
# - 过高可能触发429错误(速率限制)
|
||||
MAX_EXTRACTION_CONCURRENCY = 5
|
||||
|
||||
# 数据导入(步骤3:生成embedding)性能配置
|
||||
# - max_workers: 并发批次数(每批次并行处理)
|
||||
# - chunk_size: 每批次包含的字符串数
|
||||
# - 理论并发 = max_workers × chunk_size
|
||||
# - 推荐配置:
|
||||
# * 高性能API(OpenAI): max_workers=20-30, chunk_size=30-50
|
||||
# * 中等API: max_workers=10-15, chunk_size=20-30
|
||||
# * 本地/慢速API: max_workers=5-10, chunk_size=10-20
|
||||
EMBEDDING_MAX_WORKERS = 20 # 并发批次数
|
||||
EMBEDDING_CHUNK_SIZE = 30 # 每批次字符串数
|
||||
# ===================================
|
||||
|
||||
# --- 缓存清理 ---
|
||||
|
||||
|
||||
@@ -217,6 +237,9 @@ async def extract_information(paragraphs_dict, model_set):
|
||||
2. 更高效地利用 I/O 资源
|
||||
3. 与我们优化的 LLM 请求层无缝集成
|
||||
|
||||
并发控制:
|
||||
- 使用信号量限制最大并发数为 5,防止触发 API 速率限制
|
||||
|
||||
Args:
|
||||
paragraphs_dict: {hash: paragraph} 字典
|
||||
model_set: 模型配置
|
||||
@@ -229,16 +252,26 @@ async def extract_information(paragraphs_dict, model_set):
|
||||
|
||||
# 🔧 关键修复:创建单个 LLM 请求实例,复用连接
|
||||
llm_api = LLMRequest(model_set=model_set, request_type="lpmm_extraction")
|
||||
|
||||
# 🔧 并发控制:限制最大并发数,防止速率限制
|
||||
semaphore = asyncio.Semaphore(MAX_EXTRACTION_CONCURRENCY)
|
||||
|
||||
async def extract_with_semaphore(pg_hash, paragraph):
|
||||
"""带信号量控制的提取函数"""
|
||||
async with semaphore:
|
||||
return await extract_info_async(pg_hash, paragraph, llm_api)
|
||||
|
||||
# 创建所有异步任务
|
||||
# 创建所有异步任务(带并发控制)
|
||||
tasks = [
|
||||
extract_info_async(p_hash, paragraph, llm_api)
|
||||
extract_with_semaphore(p_hash, paragraph)
|
||||
for p_hash, paragraph in paragraphs_dict.items()
|
||||
]
|
||||
|
||||
total = len(tasks)
|
||||
completed = 0
|
||||
|
||||
logger.info(f"开始提取 {total} 个段落的信息(最大并发: {MAX_EXTRACTION_CONCURRENCY})")
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
@@ -297,9 +330,10 @@ async def import_data(openie_obj: OpenIE | None = None):
|
||||
默认为 None.
|
||||
"""
|
||||
logger.info("--- 步骤 3: 开始数据导入 ---")
|
||||
# 使用更高的并发参数以加速 embedding 生成
|
||||
# 使用配置的并发参数以加速 embedding 生成
|
||||
# max_workers: 并发批次数,chunk_size: 每批次处理的字符串数
|
||||
embed_manager, kg_manager = EmbeddingManager(max_workers=20, chunk_size=30), KGManager()
|
||||
embed_manager = EmbeddingManager(max_workers=EMBEDDING_MAX_WORKERS, chunk_size=EMBEDDING_CHUNK_SIZE)
|
||||
kg_manager = KGManager()
|
||||
|
||||
logger.info("正在加载现有的 Embedding 库...")
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user