shard_scaler

paddle.distributed. shard_scaler ( scaler ) [source]

Warp the global view grad_scaler to distributed view.

Parameters

scaler (paddle.amp.GradScaler) – The GradScaler to be sharded.

Returns

A GradScaler with distributed view.

Examples

>>> import paddle
>>> import paddle.distributed as dist
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> class MLP(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.fc1 = paddle.nn.Linear(8, 8)
...         self.fc2 = paddle.nn.Linear(8, 8)
...
...     def forward(self, input):
...         return self.fc2(self.fc1(input))
>>> layer = MLP()
>>> batch = paddle.rand(shape=[8, 8])
>>> opt = paddle.optimizer.AdamW(parameters=layer.parameters())
>>> layer, opt = paddle.amp.decorate(layer, opt, level='O2')
>>> scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
>>> scaler = dist.shard_scaler(scaler)
>>> opt = dist.shard_optimizer(opt)
>>> for _ in range(5):
>>>     with paddle.amp.auto_cast(True):
>>>         loss = layer(batch)
>>>     scaled = scaler.scale(loss)
>>>     scaled.backward()
>>>     scaler.step(opt)
>>>     scaler.update()
>>>     opt.clear_grad()
>>> # This case need to be executed in multi-card environment
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py