gather_nd

paddle. gather_nd ( x, index, name=None ) [源代码]

gather 的高维推广,并且支持多轴同时索引。index 是一个 K 维度的张量,它可以认为是从 x 中取 K-1 维张量,每一个元素是一个切片:

\[output[(i_0, ..., i_{K-2})] = x[index[(i_0, ..., i_{K-2})]]\]

显然,index.shape[-1] <= x.rank 并且输出张量的维度是 index.shape[:-1] + x.shape[index.shape[-1]:]

示例:

给定:
    x = [[[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]],
             [[12, 13, 14, 15],
              [16, 17, 18, 19],
              [20, 21, 22, 23]]]
    x.shape = (2, 3, 4)

- 案例 1:
    index = [[1]]

    gather_nd(x, index)
             = [x[1, :, :]]
             = [[12, 13, 14, 15],
                [16, 17, 18, 19],
                [20, 21, 22, 23]]

- 案例 2:

    index = [[0,2]]
    gather_nd(x, index)
             = [x[0, 2, :]]
             = [8, 9, 10, 11]

- 案例 3:

    index = [[1, 2, 3]]
    gather_nd(x, index)
             = [x[1, 2, 3]]
             = [23]

参数

  • x (Tensor) - 输入 Tensor,数据类型可以是 int32、int64、float32、float64、bool。

  • index (Tensor) - 输入的索引 Tensor,其数据类型 int32 或者 int64。它的维度 index.rank 必须大于 1,并且 index.shape[-1] <= x.rank

  • name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。

返回

shape 为 index.shape[:-1] + x.shape[index.shape[-1]:]的 Tensor,数据类型与 x 一致。

代码示例

import paddle

x = paddle.to_tensor([[[1, 2], [3, 4], [5, 6]],
                      [[7, 8], [9, 10], [11, 12]]])
index = paddle.to_tensor([[0, 1]])

output = paddle.gather_nd(x, index) #[[3, 4]]