ShardingStage2

class paddle.distributed. ShardingStage2 ( mesh=None ) [source]

A builtin shard_fn for shard_optimizer interface, users can pass it to shard_optimizer to implement sharding optimization with stage 2.

Parameters

mesh (None|paddle.distributed.ProcessMesh) – If mesh is not None, the ProcessMesh object describes the Cartesian topology of the used processes for dense type parameters. Note: Currently, only one mesh configuration is supported for all dense parameters. If there is a need for multiple mesh configurations, please configure them yourself in the upper layer networking code.

Examples

>>> import paddle
>>> import paddle.distributed as dist

>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

>>> class MLP(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.fc1 = paddle.nn.Linear(8, 8)
...         self.fc2 = paddle.nn.Linear(8, 8)
...
...     def forward(self, input):
...         return self.fc2(self.fc1(input))

>>> 
>>> layer = MLP()
>>> batch = paddle.rand(shape=[8, 8])
>>> opt = paddle.optimizer.AdamW(parameters=layer.parameters())
>>> opt = dist.shard_optimizer(opt, dist.ShardingStage2(mesh))
>>> for _ in range(5):
>>>     loss = layer(batch)
>>>     loss.backward()
>>>     opt.step()
>>>     opt.clear_grad()
>>> # This case need to be executed in multi-card environment
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py