py_func¶
- paddle.static. py_func ( func, x, out, backward_func=None, skip_vars_in_backward_input=None ) [源代码] ¶
PaddlePaddle 通过 py_func 在 Python 端注册 OP。py_func 的设计原理在于 Paddle 中的 Tensor 与 numpy 数组可以方便的互相转换,从而可使用 Python 中的 numpy API 来自定义一个 Python OP。
该自定义的 Python OP 的前向函数是 func
,反向函数是 backward_func
。 Paddle 将在前向部分调用 func
,并在反向部分调用 backward_func
(如果 backward_func
不是 None)。 x
为 func
的输入,必须为 Tensor 类型;out
为 func
的输出,既可以是 Tensor 类型,也可以是 numpy 数组。
反向函数 backward_func
的输入依次为:前向输入 x
、前向输出 out
、 out
的梯度。如果 out
的某些输出没有梯度,则 backward_func
的相关输入为 None。如果 x
的某些变量没有梯度,则用户应在 backward_func
中主动返回 None。
在调用该接口之前,还应正确设置 out
的数据类型和形状,而 out
和 x
对应梯度的数据类型和形状将自动推断而出。
此功能还可用于调试正在运行的网络,可以通过添加没有输出的 py_func
运算,并在 func
中打印输入 x
。
参数¶
func (callable) - 所注册的 Python OP 的前向函数,运行网络时,将根据该函数与前向输入
x
,计算前向输出out
。在func
建议先主动将 Tensor 转换为 numpy 数组,方便灵活的使用 numpy 相关的操作,如果未转换成 numpy,则可能某些操作无法兼容。x (Tensor|tuple(Tensor)|list[Tensor]) - 前向函数
func
的输入,多个 Tensor 以 tuple(Tensor)或 list[Tensor]的形式传入。out (T|tuple(T)|list[T]) - 前向函数
func
的输出,可以为 T|tuple(T)|list[T],其中 T 既可以为 Tensor,也可以为 numpy 数组。由于 Paddle 无法自动推断out
的形状和数据类型,必须应事先创建out
。backward_func (callable,可选) - 所注册的 Python OP 的反向函数。默认值为 None,意味着没有反向计算。若不为 None,则会在运行网络反向时调用
backward_func
计算x
的梯度。skip_vars_in_backward_input (Tensor,可选) -
backward_func
的输入中不需要的变量,可以是 Tensor|tuple(Tensor)|list[Tensor]。这些变量必须是x
和out
中的一个。默认值为 None,意味着没有变量需要从x
和out
中去除。若不为 None,则这些变量将不是backward_func
的输入。该参数仅在backward_func
不为 None 时有用。
返回¶
Tensor|tuple(Tensor)|list[Tensor],前向函数的输出 out
代码示例 1¶
# example 1:
import paddle
import numpy as np
paddle.enable_static()
# Creates a forward function, Tensor can be input directly without
# being converted into numpy array.
def tanh(x):
return np.tanh(x)
# Skip x in backward function and return the gradient of x
# Tensor must be actively converted to numpy array, otherwise,
# operations such as +/- can't be used.
def tanh_grad(y, dy):
return np.array(dy) * (1 - np.square(np.array(y)))
# Creates a forward function for debugging running networks(print value)
def debug_func(x):
print(x)
def create_tmp_var(name, dtype, shape):
return paddle.static.default_main_program().current_block().create_var(
name=name, dtype=dtype, shape=shape)
def simple_net(img, label):
hidden = img
for idx in range(4):
hidden = paddle.static.nn.fc(hidden, size=200)
new_hidden = create_tmp_var(name='hidden_{}'.format(idx),
dtype=hidden.dtype, shape=hidden.shape)
# User-defined forward and backward
hidden = paddle.static.py_func(func=tanh, x=hidden,
out=new_hidden, backward_func=tanh_grad,
skip_vars_in_backward_input=hidden)
# User-defined debug functions that print out the input Tensor
paddle.static.py_func(func=debug_func, x=hidden, out=None)
prediction = paddle.static.nn.fc(hidden, size=10, activation='softmax')
ce_loss = paddle.nn.loss.CrossEntropyLoss()
return ce_loss(prediction, label)
x = paddle.static.data(name='x', shape=[1,4], dtype='float32')
y = paddle.static.data(name='y', shape=[1], dtype='int64')
res = simple_net(x, y)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
input1 = np.random.random(size=[1,4]).astype('float32')
input2 = np.random.randint(1, 10, size=[1], dtype='int64')
out = exe.run(paddle.static.default_main_program(),
feed={'x':input1, 'y':input2},
fetch_list=[res.name])
print(out)
代码示例 2¶
# example 2:
# This example shows how to turn Tensor into numpy array and
# use numpy API to register an Python OP
import paddle
import numpy as np
paddle.enable_static()
def element_wise_add(x, y):
# Tensor must be actively converted to numpy array, otherwise,
# numpy.shape can't be used.
x = np.array(x)
y = np.array(y)
if x.shape != y.shape:
raise AssertionError("the shape of inputs must be the same!")
result = np.zeros(x.shape, dtype='int32')
for i in range(len(x)):
for j in range(len(x[0])):
result[i][j] = x[i][j] + y[i][j]
return result
def create_tmp_var(name, dtype, shape):
return paddle.static.default_main_program().current_block().create_var(
name=name, dtype=dtype, shape=shape)
def py_func_demo():
start_program = paddle.static.default_startup_program()
main_program = paddle.static.default_main_program()
# Input of the forward function
x = paddle.static.data(name='x', shape=[2,3], dtype='int32')
y = paddle.static.data(name='y', shape=[2,3], dtype='int32')
# Output of the forward function, name/dtype/shape must be specified
output = create_tmp_var('output','int32', [3,1])
# Multiple Tensor should be passed in the form of tuple(Tensor) or list[Tensor]
paddle.static.py_func(func=element_wise_add, x=[x,y], out=output)
exe=paddle.static.Executor(paddle.CPUPlace())
exe.run(start_program)
# Feed numpy array to main_program
input1 = np.random.randint(1, 10, size=[2,3], dtype='int32')
input2 = np.random.randint(1, 10, size=[2,3], dtype='int32')
out = exe.run(main_program,
feed={'x':input1, 'y':input2},
fetch_list=[output.name])
print("{0} + {1} = {2}".format(input1, input2, out))
py_func_demo()
# Reference output:
# [[5, 9, 9] + [[7, 8, 4] = [array([[12, 17, 13]
# [7, 5, 2]] [1, 3, 3]] [8, 8, 5]], dtype=int32)]