argmax

paddle. argmax ( x, axis=None, keepdim=False, dtype='int64', name=None ) [source]

This OP computes the indices of the max elements of the input tensor’s element along the provided axis.

Parameters
  • x (Tensor) – An input N-D Tensor with type float32, float64, int16, int32, int64, uint8.

  • axis (int, optional) – Axis to compute indices along. The effective range is [-R, R), where R is x.ndim. when axis < 0, it works the same way as axis + R. Default is None, the input x will be into the flatten tensor, and selecting the min value index.

  • keepdim (bool, optional) – Keep the axis that selecting max. The defalut value is False.

  • dtype (str|np.dtype, optional) – Data type of the output tensor which can be int32, int64. The default value is ‘int64’, and it will return the int64 indices.

  • name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.

Returns

Tensor, return the tensor of int32 if set dtype is int32, otherwise return the tensor of int64

Examples

import paddle

x =  paddle.to_tensor([[5,8,9,5],
                         [0,0,1,7],
                         [6,9,2,4]])
out1 = paddle.argmax(x)
print(out1) # 2
out2 = paddle.argmax(x, axis=1)
print(out2)
# [2 3 1]
out3 = paddle.argmax(x, axis=-1)
print(out3)
# [2 3 1]