PyLayer

class paddle.autograd. PyLayer [source]

Build a custom Layer by creating subclasses. Subclasses need to follow the following rules: 1. Subclasses contain forward and backward function. Both forward and backward are @staticmethod. Their first argument should be a context and None can not be included in the returned result. 2. Input of backward contains a context as the first argument, and the rest arguments are the gradient of forward’s output tensors. so the number of backward’s input tensors equal to the number of forward output tensors. If you need the forward’s inputs or outputs in backward, you can use save_for_backward to store the required tensors, and then use them in the backward. 3. Output of backward function can only be Tensor or tuple/list of Tensor. Output tensors of backward are the gradient of forward’s input tensors, so the number of backward’s output tensors equal to the number of forward input tensors. After building the custom Layer, run it through the apply method.

Examples

import paddle
from paddle.autograd import PyLayer

# Inherit from PyLayer
class cus_tanh(PyLayer):
    @staticmethod
    def forward(ctx, x, func1, func2=paddle.square):
        # ctx is a context object that store some objects for backward.
        ctx.func = func2
        y = func1(x)
        # Pass tensors to backward.
        ctx.save_for_backward(y)
        return y

    @staticmethod
    # forward has only one output, so there is only one gradient in the input of backward.
    def backward(ctx, dy):
        # Get the tensors passed by forward.
        y, = ctx.saved_tensor()
        grad = dy * (1 - ctx.func(y))
        # forward has only one input, so only one gradient tensor is returned.
        return grad


data = paddle.randn([2, 3], dtype="float64")
data.stop_gradient = False
z = cus_tanh.apply(data, func1=paddle.tanh)
z.mean().backward()

print(data.grad)
static forward ( ctx, *args, **kwargs )

forward

It is to be overloaded by subclasses. It must accept a object of PyLayerContext as the first argument, followed by any number of arguments (tensors or other types). None can not be included in the returned result.

Parameters
  • *args (tuple) – input of PyLayer.

  • **kwargs (dict) – input of PyLayer.

Returns

output of PyLayer.

Return type

tensors or other types

Examples

import paddle
from paddle.autograd import PyLayer

class cus_tanh(PyLayer):
    @staticmethod
    def forward(ctx, x):
        y = paddle.tanh(x)
        # Pass tensors to backward.
        ctx.save_for_backward(y)
        return y

    @staticmethod
    def backward(ctx, dy):
        # Get the tensors passed by forward.
        y, = ctx.saved_tensor()
        grad = dy * (1 - paddle.square(y))
        return grad
static backward ( ctx, *args, **kwargs ) [source]

backward

This is a function to calculate the gradient. It is to be overloaded by subclasses. It must accept a object of PyLayerContext as the first argument, and the rest arguments are the gradient of forward’s output tensors. Output tensors of backward are the gradient of forward’s input tensors.

Parameters
  • *args (tuple) – The gradient of forward’s output tensor(s).

  • **kwargs (dict) – The gradient of forward’s output tensor(s).

Returns

The gradient of forward’s input tensor(s).

Return type

Tensor or list of Tensors

Examples

import paddle
from paddle.autograd import PyLayer

class cus_tanh(PyLayer):
    @staticmethod
    def forward(ctx, x):
        y = paddle.tanh(x)
        # Pass tensors to backward.
        ctx.save_for_backward(y)
        return y

    @staticmethod
    def backward(ctx, dy):
        # Get the tensors passed by forward.
        y, = ctx.saved_tensor()
        grad = dy * (1 - paddle.square(y))
        return grad
classmethod apply ( *args, **kwargs )

apply

After building the custom PyLayer, run it through the apply.

Parameters
  • *args (tuple) – input of PyLayer.

  • **kwargs (dict) – input of PyLayer.

Returns

output of PyLayer.

Return type

tensors or other types

Examples

import paddle
from paddle.autograd import PyLayer

class cus_tanh(PyLayer):
    @staticmethod
    def forward(ctx, x, func1, func2=paddle.square):
        ctx.func = func2
        y = func1(x)
        # Pass tensors to backward.
        ctx.save_for_backward(y)
        return y

    @staticmethod
    def backward(ctx, dy):
        # Get the tensors passed by forward.
        y, = ctx.saved_tensor()
        grad = dy * (1 - ctx.func(y))
        return grad


data = paddle.randn([2, 3], dtype="float64")
data.stop_gradient = False
# run custom Layer.
z = cus_tanh.apply(data, func1=paddle.tanh)