multiplex

paddle. multiplex ( inputs, index, name=None ) [source]

Based on the given index parameter, the OP selects a specific row from each input Tensor to construct the output Tensor.

If the input of this OP contains \(m\) Tensors, where \(I_{i}\) means the i-th input Tensor, \(i\) between \([0,m)\) .

And \(O\) means the output, where \(O[i]\) means the i-th row of the output, then the output satisfies that \(O[i] = I_{index[i]}[i]\) .

For Example:

Given:

inputs = [[[0,0,3,4], [0,1,3,4], [0,2,4,4], [0,3,3,4]],
          [[1,0,3,4], [1,1,7,8], [1,2,4,2], [1,3,3,4]],
          [[2,0,3,4], [2,1,7,8], [2,2,4,2], [2,3,3,4]],
          [[3,0,3,4], [3,1,7,8], [3,2,4,2], [3,3,3,4]]]

index = [[3],[0],[1],[2]]

out = [[3,0,3,4],    # out[0] = inputs[index[0]][0] = inputs[3][0] = [3,0,3,4]
       [0,1,3,4],    # out[1] = inputs[index[1]][1] = inputs[0][1] = [0,1,3,4]
       [1,2,4,2],    # out[2] = inputs[index[2]][2] = inputs[1][2] = [1,2,4,2]
       [2,3,3,4]]    # out[3] = inputs[index[3]][3] = inputs[2][3] = [2,3,3,4]
Parameters
  • inputs (list) – The input Tensor list. The list elements are N-D Tensors of data types float32, float64, int32, int64, complex64, complex128. All input Tensor shapes should be the same and rank must be at least 2.

  • index (Tensor) – Used to select some rows in the input Tensor to construct an index of the output Tensor. It is a 2-D Tensor with data type int32 or int64 and shape [M, 1], where M is the number of input Tensors.

  • name (str, optional) – Name for the operation (optional, default is None). For more information, please refer to Name.

Returns

Output of multiplex OP, with data type being float32, float64, int32, int64.

Return type

Tensor

Examples

>>> import paddle

>>> img1 = paddle.to_tensor([[1, 2], [3, 4]], dtype=paddle.float32)
>>> img2 = paddle.to_tensor([[5, 6], [7, 8]], dtype=paddle.float32)
>>> inputs = [img1, img2]
>>> index = paddle.to_tensor([[1], [0]], dtype=paddle.int32)
>>> res = paddle.multiplex(inputs, index)
>>> print(res)
Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
[[5., 6.],
 [3., 4.]])