py_func¶
- paddle.static. py_func ( func, x, out, backward_func=None, skip_vars_in_backward_input=None ) [source]
-
This is used to register customized Python OP to Paddle. The design principe of py_func is that Tensor and numpy array can be converted to each other easily. So you can use Python and numpy API to register a python OP. The forward function of the registered OP is
func
and the backward function of that isbackward_func
. Paddle will callfunc
at forward runtime and callbackward_func
at backward runtime(ifbackward_func
is not None).x
is the input offunc
, whose type must be Tensor;out
is the output offunc
, whose type can be either Tensor or numpy array. The input of the backward functionbackward_func
isx
,out
and the gradient ofout
. Ifout
have no gradient, the relevant input ofbackward_func
is None. Ifx
do not have a gradient, the user should return None inbackward_func
. The data type and shape ofout
should also be set correctly before this API is called, and the data type and shape of the gradient ofout
andx
will be inferred automatically. This API can also be used to debug the neural network by setting thefunc
as a function that only print variables.- Parameters
-
func (callable) – The forward function of the registered OP. When the network is running, the forward output
out
will be calculated according to this function and the forward inputx
. Infunc
, it’s suggested that we actively convert Tensor into a numpy array, so that we can use Python and numpy API arbitrarily. If not, some operations of numpy may not be compatible.x (Tensor|tuple(Tensor)|list[Tensor]) – The input of the forward function
func
. It can be Tensor|tuple(Tensor)|list[Tensor]. In addition, Multiple Tensor should be passed in the form of tuple(Tensor) or list[Tensor].out (T|tuple(T)|list[T]) – The output of the forward function
func
, it can be T|tuple(T)|list[T], where T can be either Tensor or numpy array. Since Paddle cannot automatically infer the shape and type ofout
, you must createout
in advance.backward_func (callable, optional) – The backward function of the registered OP. Its default value is None, which means there is no reverse calculation. If it is not None,
backward_func
is called to calculate the gradient ofx
when the network is at backward runtime.skip_vars_in_backward_input (Tensor, optional) – It’s used to limit the input list of
backward_func
, and it can be Tensor|tuple(Tensor)|list[Tensor]. It must belong to eitherx
orout
. The default value is None, which means that no tensors need to be removed fromx
andout
. If it is not None, these tensors will not be the input ofbackward_func
. This parameter is only useful whenbackward_func
is not None.
- Returns
-
Tensor|tuple(Tensor)|list[Tensor], The output
out
of the forward functionfunc
.
Examples
>>> import paddle >>> import numpy as np >>> np.random.seed(1107) >>> paddle.seed(1107) >>> paddle.enable_static() >>> # Creates a forward function, Tensor can be input directly without >>> # being converted into numpy array. >>> def tanh(x): ... return np.tanh(x) >>> # Skip x in backward function and return the gradient of x >>> # Tensor must be actively converted to numpy array, otherwise, >>> # operations such as +/- can't be used. >>> def tanh_grad(y, dy): ... return np.array(dy) * (1 - np.square(np.array(y))) >>> # Creates a forward function for debugging running networks(print value) >>> def debug_func(x): ... # print(x) ... pass >>> def create_tmp_var(name, dtype, shape): ... return paddle.static.default_main_program().current_block().create_var( ... name=name, dtype=dtype, shape=shape) >>> def simple_net(img, label): ... hidden = img ... for idx in range(4): ... hidden = paddle.static.nn.fc(hidden, size=200) ... new_hidden = create_tmp_var(name='hidden_{}'.format(idx), ... dtype=hidden.dtype, shape=hidden.shape) ... # User-defined forward and backward ... hidden = paddle.static.py_func(func=tanh, x=hidden, ... out=new_hidden, backward_func=tanh_grad, ... skip_vars_in_backward_input=hidden) ... # User-defined debug functions that print out the input Tensor ... paddle.static.py_func(func=debug_func, x=hidden, out=None) ... prediction = paddle.static.nn.fc(hidden, size=10, activation='softmax') ... ce_loss = paddle.nn.loss.CrossEntropyLoss() ... return ce_loss(prediction, label) >>> x = paddle.static.data(name='x', shape=[1,4], dtype='float32') >>> y = paddle.static.data(name='y', shape=[1], dtype='int64') >>> res = simple_net(x, y) >>> exe = paddle.static.Executor(paddle.CPUPlace()) >>> exe.run(paddle.static.default_startup_program()) >>> input1 = np.random.random(size=[1,4]).astype('float32') >>> input2 = np.random.randint(1, 10, size=[1], dtype='int64') >>> out = exe.run(paddle.static.default_main_program(), ... feed={'x':input1, 'y':input2}, ... fetch_list=[res.name]) >>> print(out[0].shape) ()
>>> # This example shows how to turn Tensor into numpy array and >>> # use numpy API to register an Python OP >>> import paddle >>> import numpy as np >>> np.random.seed(1107) >>> paddle.seed(1107) >>> paddle.enable_static() >>> def element_wise_add(x, y): ... # Tensor must be actively converted to numpy array, otherwise, ... # numpy.shape can't be used. ... x = np.array(x) ... y = np.array(y) ... if x.shape != y.shape: ... raise AssertionError("the shape of inputs must be the same!") ... result = np.zeros(x.shape, dtype='int32') ... for i in range(len(x)): ... for j in range(len(x[0])): ... result[i][j] = x[i][j] + y[i][j] ... return result >>> def create_tmp_var(name, dtype, shape): ... return paddle.static.default_main_program().current_block().create_var( ... name=name, dtype=dtype, shape=shape) >>> def py_func_demo(): ... start_program = paddle.static.default_startup_program() ... main_program = paddle.static.default_main_program() ... # Input of the forward function ... x = paddle.static.data(name='x', shape=[2, 3], dtype='int32') ... y = paddle.static.data(name='y', shape=[2, 3], dtype='int32') ... # Output of the forward function, name/dtype/shape must be specified ... output = create_tmp_var('output','int32', [3, 1]) ... # Multiple Tensor should be passed in the form of tuple(Tensor) or list[Tensor] ... paddle.static.py_func(func=element_wise_add, x=[x, y], out=output) ... exe=paddle.static.Executor(paddle.CPUPlace()) ... exe.run(start_program) ... # Feed numpy array to main_program ... input1 = np.random.randint(1, 10, size=[2, 3], dtype='int32') ... input2 = np.random.randint(1, 10, size=[2, 3], dtype='int32') ... out = exe.run(main_program, ... feed={'x':input1, 'y':input2}, ... fetch_list=[output.name]) ... print("{0} + {1} = {2}".format(input1, input2, out)) >>> py_func_demo() >>> # [[1 5 4] + [[3 7 7] = [array([[ 4, 12, 11] >>> # [9 4 8]] [2 3 9]] [11, 7, 17]], dtype=int32)]