take_along_axis¶
- paddle. take_along_axis ( arr, indices, axis, broadcast=True ) [source]
-
Take values from the input array by given indices matrix along the designated axis.
- Parameters
-
arr (Tensor) – The input Tensor. Supported data types are float32 and float64.
indices (Tensor) – Indices to take along each 1d slice of arr. This must match the dimension of arr, and need to broadcast against arr. Supported data type are int and int64.
axis (int) – The axis to take 1d slices along.
broadcast (bool, optional) – whether the indices broadcast.
- Returns
-
Tensor, The indexed element, same dtype with arr
Examples
>>> import paddle >>> x = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7,8,9]]) >>> index = paddle.to_tensor([[0]]) >>> axis = 0 >>> result = paddle.take_along_axis(x, index, axis) >>> print(result) Tensor(shape=[1, 3], dtype=int64, place=Place(cpu), stop_gradient=True, [[1, 2, 3]])