shard_optimizer¶
- paddle.distributed. shard_optimizer ( optimizer, shard_fn=None ) [source]
-
Warp the global view optimizer to distributed view.
Note
- The shard_fn should have the following signature:
-
def shard_fn(accumulator_name, param, accumulator) -> sharded_accumulator
- Parameters
-
optimizer (paddle.optimizer.Optimizer) – The optimizer to be sharded.
shard_fn (Callable, optional) – The function to shard accumulators. If not specified, we simply pass down the dist attr of the params.
- Returns
-
An optimizer with distributed view.
Examples
>>> import paddle >>> import paddle.distributed as dist >>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) >>> class MLP(paddle.nn.Layer): ... def __init__(self): ... super().__init__() ... self.fc1 = paddle.nn.Linear(8, 8) ... self.fc2 = paddle.nn.Linear(8, 8) ... ... def forward(self, input): ... return self.fc2(self.fc1(input)) >>> layer = MLP() >>> batch = paddle.rand(shape=[8, 8]) >>> opt = paddle.optimizer.AdamW(parameters=layer.parameters()) >>> opt = dist.shard_optimizer(opt) >>> for _ in range(5): >>> loss = layer(batch) >>> loss.backward() >>> opt.step() >>> opt.clear_grad() >>> # This case need to be executed in multi-card environment >>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py