shard_scaler¶
- paddle.distributed. shard_scaler ( scaler ) [source]
-
Warp the global view grad_scaler to distributed view.
- Parameters
-
scaler (paddle.amp.GradScaler) – The GradScaler to be sharded.
- Returns
-
A GradScaler with distributed view.
Examples
>>> import paddle >>> import paddle.distributed as dist >>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) >>> class MLP(paddle.nn.Layer): ... def __init__(self): ... super().__init__() ... self.fc1 = paddle.nn.Linear(8, 8) ... self.fc2 = paddle.nn.Linear(8, 8) ... ... def forward(self, input): ... return self.fc2(self.fc1(input)) >>> layer = MLP() >>> batch = paddle.rand(shape=[8, 8]) >>> opt = paddle.optimizer.AdamW(parameters=layer.parameters()) >>> layer, opt = paddle.amp.decorate(layer, opt, level='O2') >>> scaler = paddle.amp.GradScaler(init_loss_scaling=1024) >>> scaler = dist.shard_scaler(scaler) >>> opt = dist.shard_optimizer(opt) >>> for _ in range(5): >>> with paddle.amp.auto_cast(True): >>> loss = layer(batch) >>> scaled = scaler.scale(loss) >>> scaled.backward() >>> scaler.step(opt) >>> scaler.update() >>> opt.clear_grad() >>> # This case need to be executed in multi-card environment >>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py