分组切分并行

当模型参数达到百亿或者千亿时, 传统的数据并行训练可能会遇到显存瓶颈。 在数据并行训练中,每个 gpu worker 都有一份完整模型参数和优化器状态副本。 《ZeRO: Memory Optimizations Toward Training Trillion Parameter Models》 指出在每个 GPU 上都保存一份模型参数和优化器状态副本是冗余的。 我们可以通过将上述参数和副本划分到不同 GPU 中, 在每个 GPU 只保存部分副本,来减少每张 GPU 上显存的占用,从而可以支持更大模型的训练。

一、原理介绍

1.1 GroupSharded

GroupSharded 实现了类似 ZeRO-DP 的训练策略,将模型状态包括:模型参数(parameter)、参数梯度(gradient)、参数对应的优化器状态(以 Adam 为例 moment 和 varience)切分到每一张 GPU 上。让模型参数部分所占的显存随并行卡数的增加而减少。 通过 paddle.distributed.sharding.group_sharded_parallel 提供的简单易用接口, 用户只需要添加几行代码就可将策略加入到原有的训练中。

模型训练过程中的显存消耗主要由两大部分组成:模型参数及优化器状态、训练产生的中间变量(activations)。 GroupSharded 策略可以根据用户配置支持,分别切分模型参数、对应参数梯度和优化器状态,因此模型状态所消耗的显存可以随着并行 GPU 数量增加而线性减少; 但是每张 GPU 上仍然维护着模型完整的前向和反向,所以每张 GPU 依然需要存放模型的训练过程中的产生的全部的中间变量,这部分显存消耗 不会随着 GPU 数量的增加而减少。 用户可以通过结合 recompute 策略来减少 activation 这部分的显存消耗。

通过 GroupSharded 和增加并行 GPU 数量,用户可以在 A100-40G 设备下 8 卡训练 16.25B 参量的模型 (需要结合 recompute, amp 策略)。

1.2 GroupSharded-hybrid-dp

GroupSharded hybrid 数据并行策略,在 GroupSharded 并行的基础上再增加一层数据并行逻辑。 该策略的目的是通过 限制 GroupSharded 通信的节点数增加多路数据并行 来提高训练吞吐。 如果一个模型在普通 GroupSharded 训练时需要 M 张 GPU,则则开启 hybrid-dp 至少需要 N*M GPU (N>= 2)。

GroupSharded-hybrid-dp 适用的场景如下:

  • 当前有 4 个 8 卡 A100 节点

  • 目标模型 A 在 GroupSharded 训练时至少需要 8 卡 A100 (一个完整的 8 卡 A100 节点)

  • 希望利用全部的 4 个节点来加速训练

上述情况如果直接使用全部的 4 个节点 进行普通的 GroupSharded 训练, 那么全部的 32 gpus 之间组成一个完整 GroupSharded parallelism。这样会因为通信瓶颈造成训练速度非常慢:

  • GroupSharded 中的 allgather 通信 会涉及全部的 32 张卡,且为跨节点通信。

  • GroupSharded 中的 allreduce 通信 会涉及全部的 32 张卡,且为跨节点通信。

开启 hybrid-dp 并设置 GroupSharded_group_size = 8 后, 每个节点内的 8 张卡组成一个完整的 GroupSharded parallelism,4 个节点构成 4 路 hybrid data parallelism:

  • GroupSharded 中的 allgather 通信被限制在每个节点内的 8 张 GPU 之间, 没有跨节点通信。

  • GroupSharded 中的 allreduce 可以先进行机内通信,再跨节点通信,且每张 GPU 每次仅需要 allreduce 通信 1/8 的模型参数。

GroupSharded-hybrid-dp 通过上述措施,可以较大程度 减少 GroupSharded 训练 从 1 节点扩展到 4 节点时的(跨节点)通信量。提高节点增加时的加速比,提高训练吞吐。

P.S. hybrid dp 是因为 GroupSharded parallelism 本身内含一层 data parallelism 逻辑, hybrid dp 是在 GroupSharded parallelism 之上再增加新的一层 data parallelism 逻辑。

二、功能效果

下面表格将对比 GroupSharded 策略对显存的影响。

模型为 GPT(11.375B),试验环境为 A100 (40GB), recompute = ON, amp(O2) = ON, hybrid-dp = OFF。 模型不变,单卡 batch size 不变,当并行 GPU 数量增加时,显存的消耗将减小。 省下的显存可以用来增大模型。

setting

GPU Mem

Speed

GroupSharded—off

30474 MB

9344 tokens/s

GroupSharded—on N2C16

30160 MB

12663 tokens/s

GroupSharded—on N4C32

30190 MB

22235 tokens/s

GroupSharded 结合 amp (O2) + recompute,可以在 8 张 40GB A100 并行的情况下支持百亿参数(16.25B)GPT 训练。

三、使用方法

首先简单总结 GroupSharded stage1、stage2、stage3 分别实现减小参数规模的原理。stage1、stage2、stage3 分别在训练过程中对模型优化器状态、梯度+优化器状态、参数+梯度+优化器状态进行切分,通过减小训练的相关 Tensor(参数、梯度、优化器状态)达到同样计算资源下能够训练更大模型的效果。

以下是分别从 GroupSharded 的三种实现阶段的实现方式:
  • 使用 group_sharded_parallel 和 save_group_sharded_model 两个 API 可以进行训练和保存。使用 group_sharded_parallel 提供 stage1 的选项,内部使用 stage2 完成优化实现。参考 group_sharded_parallelsave_group_sharded_model

  • 此处需要注意,使用 save_group_sharded_model 保存模型,再次 load 时需要在调用 group_sharded_parallel 前对 model 和 optimizer 进行 set_state_dict。

  • 目前 stage2、3 已经适配 GPT 模型,可以参考请参考 示例代码

  • 其次解决组网中共享参数训练问题,stage3 需要额外在组网中加入外置参数注册逻辑,在组网中需要注册 self.extra_parameters = [self.gpt.embeddings.word_embeddings.weight],这部分可以参考 PaddleNLP 中 GPT-3 的组网。示例代码

import paddle
from paddle.vision.models import ResNet
from paddle.vision.models.resnet import BasicBlock
from paddle.distributed import fleet
from paddle.distributed.sharding import group_sharded_parallel, save_group_sharded_model

fleet.init(is_collective=True)
group = paddle.distributed.new_group([0, 1])
use_pure_fp16 = True

clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
model = ResNet(BasicBlock, 18)
optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters(), weight_decay=0.00001, grad_clip=clip)

scaler = None
if use_pure_fp16:
    scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
    # level O2 means converting the network to FP16
    model = paddle.amp.decorate(
        models=model,
        level='O2',
        save_dtype='float32')

# wrap GroupSharded model, optimizer and scaler
model, optimizer, scaler = group_sharded_parallel(model, optimizer, "os_g", scaler=scaler)

for step_id in range(1, 100):
    x = paddle.rand([1, 3, 224, 224])
    with paddle.amp.auto_cast(use_pure_fp16):
        out = model(x)
    loss = out.mean()

    if use_pure_fp16:
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    else:
        loss.backward()
        optimizer.step()
    optimizer.clear_grad()

    print("=== step_id : {}    loss : {}".format(step_id, loss.numpy()))

# save model and optimizer state_dict
save_group_sharded_model(model, output_dir, optimizer)

运行方式(需要保证当前机器有两张 GPU):

export CUDA_VISIBLE_DEVICES=0,1
python -m paddle.distributed.launch run_pretrain.py # run_pretrain.py 是用户运行动态图 GroupSharded 的 python 文件

控制台输出信息如下:

launch train in GPU mode!
INFO 2022-05-18 09:34:51,803 launch_utils.py:561] Local start 2 processes. First process distributed environment info (Only For Debug):
  +=======================================================================================+
  |                        Distributed Envs                      Value                    |
  +---------------------------------------------------------------------------------------+
  |                       PADDLE_TRAINER_ID                        0                      |
  |                 PADDLE_CURRENT_ENDPOINT                 127.0.0.1:12532               |
  |                     PADDLE_TRAINERS_NUM                        2                      |
  |                PADDLE_TRAINER_ENDPOINTS         127.0.0.1:12532,127.0.0.1:58759       |
  |                     PADDLE_RANK_IN_NODE                        0                      |
  |                 PADDLE_LOCAL_DEVICE_IDS                        6                      |
  |                 PADDLE_WORLD_DEVICE_IDS                       6,7                     |
  |                     FLAGS_selected_gpus                        6                      |
  |             FLAGS_selected_accelerators                        6                      |
  +=======================================================================================+

日志信息位于 log 目录下:

[2022-05-18 09:35:15,062] [    INFO] - global step 1, epoch: 0, batch: 0, loss: 11.059432030, avg_reader_cost: 0.15902 sec, avg_batch_cost: 0.61838 sec, speed: 1.62 step/s, ips: 13247 tokens/s, learning rate: 9.37500e-08
[2022-05-18 09:35:15,274] [    INFO] - global step 2, epoch: 0, batch: 1, loss: 11.050725937, avg_reader_cost: 0.00041 sec, avg_batch_cost: 0.21061 sec, speed: 4.75 step/s, ips: 38897 tokens/s, learning rate: 1.40625e-07
[2022-05-18 09:35:15,432] [    INFO] - global step 3, epoch: 0, batch: 2, loss: 11.051848412, avg_reader_cost: 0.00022 sec, avg_batch_cost: 0.15722 sec, speed: 6.36 step/s, ips: 52105 tokens/s, learning rate: 1.87500e-07
[2022-05-18 09:35:15,566] [    INFO] - global step 4, epoch: 0, batch: 3, loss: 11.052285194, avg_reader_cost: 0.00022 sec, avg_batch_cost: 0.13303 sec, speed: 7.52 step/s, ips: 61579 tokens/s, learning rate: 2.34375e-07
[2022-05-18 09:35:15,722] [    INFO] - global step 5, epoch: 0, batch: 4, loss: 11.028432846, avg_reader_cost: 0.00036 sec, avg_batch_cost: 0.15526 sec, speed: 6.44 step/s, ips: 52764 tokens/s, learning rate: 2.81250e-07
[2022-05-18 09:35:15,880] [    INFO] - global step 6, epoch: 0, batch: 5, loss: 11.032807350, avg_reader_cost: 0.00021 sec, avg_batch_cost: 0.15763 sec, speed: 6.34 step/s, ips: 51971 tokens/s, learning rate: 3.28125e-07