shard_layer¶
- paddle.distributed. shard_layer ( layer: paddle.nn.layer.layers.Layer, process_mesh: paddle.distributed.auto_parallel.process_mesh.ProcessMesh, shard_fn: Optional[Callable] = None, input_fn: Optional[Callable] = None, output_fn: Optional[Callable] = None ) paddle.nn.layer.layers.Layer [source]
-
Converts all layer’s parameters to DistTensor parameters according to the shard_fn specified. It could also control the conversion of input or output of the layer by specifying the input_fn and output_fn. (i.e. convert the input to paddle.Tensor with DistTensor, convert output back to paddle.Tensor with DenseTensor.)
The shard_fn should have the following signature:
def shard_fn(layer_name, layer, process_mesh) -> None
The input_fn should have the following signature:
def input_fn(inputs, process_mesh) -> list(paddle.Tensor)
In general, the type of input_fn return value is paddle.Tensor with DistTensor.
The output_fn should have the following signature:
def output_fn(outputs, process_mesh) -> list(paddle.Tensor)
In general, the type of output_fn return value is paddle.Tensor with DenseTensor.
- Parameters
-
layer (paddle.nn.Layer) – The Layer object to be shard.
process_mesh (paddle.distributed.ProcessMesh) – The ProcessMesh information to be place the input layer.
shard_fn (Callable) – The function to shard layer parameters across the process_mesh. If not specified, by default we replicate all parameters of the layer across the process_mesh.
input_fn (Callable) – Specify how the input of the layer is sharded. The input_fn will be registered for the Layer as a forward pre-hook. By default we do not shard the input.
output_fn (Callable) – Specify how the output of the layer is sharded or convert it back to paddle.Tensor with DenseTensor. The output_fn will be registered for the Layer as forward post-hook. By default we do not shard or convert the output.
- Returns
-
- A layer that contains parameters/buffers
-
that are all paddle.Tensor with DistTensor
- Return type
-
Layer
Examples
>>> import paddle >>> import paddle.distributed as dist >>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) >>> class MLP(paddle.nn.Layer): ... def __init__(self): ... super().__init__() ... self.fc1 = paddle.nn.Linear(8, 8) ... self.fc2 = paddle.nn.Linear(8, 8) ... ... def forward(self, input): ... return self.fc2(self.fc1(input)) >>> def shard_fn(layer_name, layer, process_mesh): ... if layer_name == 'fc1': ... layer.weight = dist.shard_tensor(layer.weight, process_mesh, [dist.Shard(0)]) >>> layer = MLP() >>> layer = dist.shard_layer(layer, mesh, shard_fn) >>> print(layer) >>> # This case need to be excuted in multi-card environment >>> # export CUDA_VISIBLE_DEVICES=0,1 >>> # python -m paddle.distributed.launch {test_case}.py