jvp¶
计算函数 func
在 xs
处的雅可比矩阵与向量 v
的乘积。
警告
该 API 目前为 Beta 版本,函数签名在未来版本可能发生变化。
参数¶
func (Callable) - Python 函数,输入参数为
xs
,输出为 Tensor 或 Tensor 序列。xs (Tensor|Sequence[Tensor]) - 函数
func
的输入参数,数据类型为 Tensor 或 Tensor 序列。v (Tensor|Sequence[Tensor]|None,可选) - 用于计算
jvp
的输入向量,形状要求 与xs
一致。默认值为None
,即相当于形状与xs
一致,值全为 1 的 Tensor 或 Tensor 序列。
返回¶
func_out (Tensor|tuple[Tensor]) - 函数
func(xs)
的输出。jvp (Tensor|tuple[Tensor]) -
jvp
计算结果。
代码示例¶
import paddle
def func(x):
return paddle.matmul(x, x)
x = paddle.ones(shape=[2, 2], dtype='float32')
_, jvp_result = paddle.incubate.autograd.jvp(func, x)
print(jvp_result)
# Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=False,
# [[4., 4.],
# [4., 4.]])
v = paddle.to_tensor([[1.0, 0.0], [0.0, 0.0]])
_, jvp_result = paddle.incubate.autograd.jvp(func, x, v)
print(jvp_result)
# Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=False,
# [[2., 1.],
# [1., 0.]])