searchsorted

paddle. searchsorted ( sorted_sequence, values, out_int32=False, right=False, name=None ) [source]

Find the index of the corresponding sorted_sequence in the innermost dimension based on the given values.

Parameters
  • sorted_sequence (Tensor) – An input N-D or 1-D tensor with type int32, int64, float32, float64. The value of the tensor monotonically increases in the innermost dimension.

  • values (Tensor) – An input N-D tensor value with type int32, int64, float32, float64.

  • out_int32 (bool, optional) – Data type of the output tensor which can be int32, int64. The default value is False, and it indicates that the output data type is int64.

  • right (bool, optional) – Find the upper or lower bounds of the sorted_sequence range in the innermost dimension based on the given values. If the value of the sorted_sequence is nan or inf, return the size of the innermost dimension. The default value is False and it shows the lower bounds.

  • name (str, optional) – For details, please refer to Name. Generally, no setting is required. Default: None.

Returns

Tensor (the same sizes of the values), return the tensor of int32 if set out_int32 is True, otherwise return the tensor of int64.

Examples

import paddle

sorted_sequence = paddle.to_tensor([[1, 3, 5, 7, 9, 11],
                                    [2, 4, 6, 8, 10, 12]], dtype='int32')
values = paddle.to_tensor([[3, 6, 9, 10], [3, 6, 9, 10]], dtype='int32')
out1 = paddle.searchsorted(sorted_sequence, values)
print(out1)
# Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
#        [[1, 3, 4, 5],
#         [1, 2, 4, 4]])
out2 = paddle.searchsorted(sorted_sequence, values, right=True)
print(out2)
# Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
#        [[2, 3, 5, 5],
#         [1, 3, 4, 5]])
sorted_sequence_1d = paddle.to_tensor([1, 3, 5, 7, 9, 11, 13])
out3 = paddle.searchsorted(sorted_sequence_1d, values)
print(out3)
# Tensor(shape=[2, 4], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
#        [[1, 3, 4, 5],
#         [1, 3, 4, 5]])