RowWiseParallel

class paddle.distributed. RowWiseParallel ( is_input_parallel: bool = True ) [source]

Row wise parallel plan for mp config. Will try to split weight on the first dim. This api is designed for paddle.nn.Linear or paddle.nn.Embedding. If any other instance of paddle.nn.Layer is passed, this plan will try to split layer.weight if it has.

Note

layer.weight should have two dims.

Parameters

is_input_parallel (bool) – Whether the input is a local tensor or a global tensor. If the input is a global tensor, an extra split will be called. The default value is True, which means the input is a local tensor.

Examples

>>> import paddle
>>> import paddle.distributed as dist

>>> class MLP(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.fc1 = paddle.nn.Linear(8, 8)
...         self.fc2 = paddle.nn.Linear(8, 8)
...
...     def forward(self, input):
...         return self.fc2(self.fc1(input))

>>> 
>>> layer = MLP()
>>> mp_config = {
...     'fc1': dist.RowWiseParallel()
... }