py_func

paddle.static. py_func ( func, x, out, backward_func=None, skip_vars_in_backward_input=None ) [source]

This is used to register customized Python OP to Paddle. The design principe of py_func is that Tensor and numpy array can be converted to each other easily. So you can use Python and numpy API to register a python OP. The forward function of the registered OP is func and the backward function of that is backward_func. Paddle will call func at forward runtime and call backward_func at backward runtime(if backward_func is not None). x is the input of func, whose type must be Tensor; out is the output of func, whose type can be either Tensor or numpy array. The input of the backward function backward_func is x, out and the gradient of out. If out have no gradient, the relevant input of backward_func is None. If x do not have a gradient, the user should return None in backward_func. The data type and shape of out should also be set correctly before this API is called, and the data type and shape of the gradient of out and x will be inferred automatically. This API can also be used to debug the neural network by setting the func as a function that only print variables.

Parameters
  • func (callable) – The forward function of the registered OP. When the network is running, the forward output out will be calculated according to this function and the forward input x. In func , it’s suggested that we actively convert Tensor into a numpy array, so that we can use Python and numpy API arbitrarily. If not, some operations of numpy may not be compatible.

  • x (Tensor|tuple(Tensor)|list[Tensor]) – The input of the forward function func. It can be Tensor|tuple(Tensor)|list[Tensor]. In addition, Multiple Tensor should be passed in the form of tuple(Tensor) or list[Tensor].

  • out (T|tuple(T)|list[T]) – The output of the forward function func, it can be T|tuple(T)|list[T], where T can be either Tensor or numpy array. Since Paddle cannot automatically infer the shape and type of out, you must create out in advance.

  • backward_func (callable, optional) – The backward function of the registered OP. Its default value is None, which means there is no reverse calculation. If it is not None, backward_func is called to calculate the gradient of x when the network is at backward runtime.

  • skip_vars_in_backward_input (Tensor, optional) – It’s used to limit the input list of backward_func, and it can be Tensor|tuple(Tensor)|list[Tensor]. It must belong to either x or out. The default value is None, which means that no tensors need to be removed from x and out. If it is not None, these tensors will not be the input of backward_func. This parameter is only useful when backward_func is not None.

Returns

Tensor|tuple(Tensor)|list[Tensor], The output out of the forward function func.

Examples

# example 1:
import paddle
import numpy as np
paddle.enable_static()
# Creates a forward function, Tensor can be input directly without
# being converted into numpy array.
def tanh(x):
    return np.tanh(x)
# Skip x in backward function and return the gradient of x
# Tensor must be actively converted to numpy array, otherwise,
# operations such as +/- can't be used.
def tanh_grad(y, dy):
    return np.array(dy) * (1 - np.square(np.array(y)))
# Creates a forward function for debugging running networks(print value)
def debug_func(x):
    print(x)
def create_tmp_var(name, dtype, shape):
    return paddle.static.default_main_program().current_block().create_var(
        name=name, dtype=dtype, shape=shape)
def simple_net(img, label):
    hidden = img
    for idx in range(4):
        hidden = paddle.static.nn.fc(hidden, size=200)
        new_hidden = create_tmp_var(name='hidden_{}'.format(idx),
            dtype=hidden.dtype, shape=hidden.shape)
        # User-defined forward and backward
        hidden = paddle.static.py_func(func=tanh, x=hidden,
            out=new_hidden, backward_func=tanh_grad,
            skip_vars_in_backward_input=hidden)
        # User-defined debug functions that print out the input Tensor
        paddle.static.py_func(func=debug_func, x=hidden, out=None)
    prediction = paddle.static.nn.fc(hidden, size=10, activation='softmax')
    ce_loss = paddle.nn.loss.CrossEntropyLoss()
    return ce_loss(prediction, label)
x = paddle.static.data(name='x', shape=[1,4], dtype='float32')
y = paddle.static.data(name='y', shape=[1], dtype='int64')
res = simple_net(x, y)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
input1 = np.random.random(size=[1,4]).astype('float32')
input2 = np.random.randint(1, 10, size=[1], dtype='int64')
out = exe.run(paddle.static.default_main_program(),
              feed={'x':input1, 'y':input2},
              fetch_list=[res.name])
print(out)
# example 2:
# This example shows how to turn Tensor into numpy array and
# use numpy API to register an Python OP
import paddle
import numpy as np
paddle.enable_static()
def element_wise_add(x, y):
    # Tensor must be actively converted to numpy array, otherwise,
    # numpy.shape can't be used.
    x = np.array(x)
    y = np.array(y)
    if x.shape != y.shape:
        raise AssertionError("the shape of inputs must be the same!")
    result = np.zeros(x.shape, dtype='int32')
    for i in range(len(x)):
        for j in range(len(x[0])):
            result[i][j] = x[i][j] + y[i][j]
    return result
def create_tmp_var(name, dtype, shape):
    return paddle.static.default_main_program().current_block().create_var(
                name=name, dtype=dtype, shape=shape)
def py_func_demo():
    start_program = paddle.static.default_startup_program()
    main_program = paddle.static.default_main_program()
    # Input of the forward function
    x = paddle.static.data(name='x', shape=[2,3], dtype='int32')
    y = paddle.static.data(name='y', shape=[2,3], dtype='int32')
    # Output of the forward function, name/dtype/shape must be specified
    output = create_tmp_var('output','int32', [3,1])
    # Multiple Tensor should be passed in the form of tuple(Tensor) or list[Tensor]
    paddle.static.py_func(func=element_wise_add, x=[x,y], out=output)
    exe=paddle.static.Executor(paddle.CPUPlace())
    exe.run(start_program)
    # Feed numpy array to main_program
    input1 = np.random.randint(1, 10, size=[2,3], dtype='int32')
    input2 = np.random.randint(1, 10, size=[2,3], dtype='int32')
    out = exe.run(main_program,
                feed={'x':input1, 'y':input2},
                fetch_list=[output.name])
    print("{0} + {1} = {2}".format(input1, input2, out))
py_func_demo()
# Reference output:
# [[5, 9, 9]   + [[7, 8, 4]  =  [array([[12, 17, 13]
#  [7, 5, 2]]     [1, 3, 3]]            [8, 8, 5]], dtype=int32)]