ShardingStage2¶
- class paddle.distributed. ShardingStage2 ( mesh=None ) [source]
-
A builtin shard_fn for shard_optimizer interface, users can pass it to shard_optimizer to implement sharding optimization with stage 2.
- Parameters
-
mesh (None|paddle.distributed.ProcessMesh) – If mesh is not None, the ProcessMesh object describes the Cartesian topology of the used processes for dense type parameters. Note: Currently, only one mesh configuration is supported for all dense parameters. If there is a need for multiple mesh configurations, please configure them yourself in the upper layer networking code.
Examples
>>> import paddle >>> import paddle.distributed as dist >>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) >>> class MLP(paddle.nn.Layer): ... def __init__(self): ... super().__init__() ... self.fc1 = paddle.nn.Linear(8, 8) ... self.fc2 = paddle.nn.Linear(8, 8) ... ... def forward(self, input): ... return self.fc2(self.fc1(input)) >>> >>> layer = MLP() >>> batch = paddle.rand(shape=[8, 8]) >>> opt = paddle.optimizer.AdamW(parameters=layer.parameters()) >>> opt = dist.shard_optimizer(opt, dist.ShardingStage2(mesh)) >>> for _ in range(5): >>> loss = layer(batch) >>> loss.backward() >>> opt.step() >>> opt.clear_grad() >>> # This case need to be executed in multi-card environment >>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py