PrepareLayerOutput

class paddle.distributed. PrepareLayerOutput ( fn: Callable[[ProcessMesh], Callable[[Layer, tuple[Tensor], tuple[Tensor]], [tuple[Tensor]]]] | None = None ) [source]

Prepare the output of specific layer. User should provide one callable function.

Parameters

fn (callable) – A function that prepare the layer input. The function should take exactly one parameter named process_mesh and return the post hook.

Examples

>>> import paddle
>>> import paddle.distributed as dist

>>> class MLP(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.fc1 = paddle.nn.Linear(8, 8)
...         self.fc2 = paddle.nn.Linear(8, 8)
...
...     def forward(self, input):
...         return self.fc2(self.fc1(input))

>>> def layer_output_hook(process_mesh):
...     def hook(layer, input, output):
...         return output
...     return hook

>>> 
>>> layer = MLP()
>>> mp_config = {
...     'fc1': dist.PrepareLayerOutput(layer_output_hook)
... }