PyLayerContext¶
- class paddle.autograd. PyLayerContext [source]
-
PyLayerContext
can assist the PyLayer in implementing certain functionalities.-
save_for_backward
(
*tensors
)
save_for_backward¶
-
Saves given tensors that backward need. Use
saved_tensor
in the backward to get the saved tensors.Note
This API should be called at most once, and only inside forward.
- Parameters
-
tensors (list of Tensors) – Tensors to be stored.
- Returns
-
None
Examples
import paddle from paddle.autograd import PyLayer class cus_tanh(PyLayer): @staticmethod def forward(ctx, x): # ctx is a context object that store some objects for backward. y = paddle.tanh(x) # Pass tensors to backward. ctx.save_for_backward(y) return y @staticmethod def backward(ctx, dy): # Get the tensors passed by forward. y, = ctx.saved_tensor() grad = dy * (1 - paddle.square(y)) return grad
-
saved_tensor
(
)
saved_tensor¶
-
Get the tensors stored by
save_for_backward
.- Returns
-
If context contains tensors stored by save_for_backward, then return these tensors, otherwise return None.
- Return type
-
list of Tensors or None
Examples
import paddle from paddle.autograd import PyLayer class cus_tanh(PyLayer): @staticmethod def forward(ctx, x): # ctx is a context object that store some objects for backward. y = paddle.tanh(x) # Pass tensors to backward. ctx.save_for_backward(y) return y @staticmethod def backward(ctx, dy): # Get the tensors passed by forward. y, = ctx.saved_tensor() grad = dy * (1 - paddle.square(y)) return grad
-
mark_not_inplace
(
*args
)
mark_not_inplace¶
-
Marks inputs as not inplace. This should be called at most once, only from inside the forward method, and all arguments should be Tensor inputs.
If the Tensor returned by forward method is the same as the Tensor input of forward, and this Tensor is marked as not_inplace, then Paddle will help the user create a new Tensor as output. Thereby preventing the auto grad information of the input Tensor from being overwritten.
Examples
import paddle class Exp(paddle.autograd.PyLayer): @staticmethod def forward(ctx, x): ctx.mark_not_inplace(x) return x @staticmethod def backward(ctx, grad_output): out = grad_output.exp() return out x = paddle.randn((1, 1)) x.stop_gradient = False attn_layers = [] for idx in range(0, 2): attn_layers.append(Exp()) for step in range(0, 2): a = x for j in range(0,2): a = attn_layers[j].apply(x) a.backward()
-
mark_non_differentiable
(
*args
)
mark_non_differentiable¶
-
Marks outputs as non-differentiable. This should be called at most once, only from inside the forward method, and all arguments should be tensor outputs.
This will mark outputs as not requiring gradients, increasing the efficiency of backward computation. You still need to accept a gradient for each output in backward, but it’s always going to be a zero tensor with the same shape as the shape of a corresponding output.
Examples
import paddle from paddle.autograd import PyLayer import numpy as np class Tanh(PyLayer): @staticmethod def forward(ctx, x): a = x + x b = x + x + x ctx.mark_non_differentiable(a) return a, b @staticmethod def backward(ctx, grad_a, grad_b): assert np.equal(grad_a.numpy(), paddle.zeros([1]).numpy()) assert np.equal(grad_b.numpy(), paddle.ones([1], dtype="float64").numpy()) return grad_b x = paddle.ones([1], dtype="float64") x.stop_gradient = False a, b = Tanh.apply(x) b.sum().backward()
-
set_materialize_grads
(
value: bool
)
set_materialize_grads¶
-
Sets whether to materialize output grad tensors. Default is True.
This should be called only from inside the forward method.
If True, undefined output grad tensors will be expanded to tensors full of zeros prior to calling the backward method.
If False, undefined output grad tensors will be None.
Examples
import paddle from paddle.autograd import PyLayer import numpy as np class Tanh(PyLayer): @staticmethod def forward(ctx, x): return x+x+x, x+x @staticmethod def backward(ctx, grad, grad2): assert np.equal(grad2.numpy(), paddle.zeros([1]).numpy()) return grad class Tanh2(PyLayer): @staticmethod def forward(ctx, x): ctx.set_materialize_grads(False) return x+x+x, x+x @staticmethod def backward(ctx, grad, grad2): assert grad2==None return grad x = paddle.ones([1], dtype="float64") x.stop_gradient = False Tanh.apply(x)[0].backward() x2 = paddle.ones([1], dtype="float64") x2.stop_gradient = False Tanh2.apply(x2)[0].backward()
-
save_for_backward
(
*tensors
)