masked_scatter

paddle. masked_scatter ( x: Tensor, mask: Tensor, value: Tensor, name: str | None = None ) Tensor [source]

Copies elements from value into x tensor at positions where the mask is True.

Elements from source are copied into x starting at position 0 of value and continuing in order one-by-one for each occurrence of mask being True. The shape of mask must be broadcastable with the shape of the underlying tensor. The value should have at least as many elements as the number of ones in mask.

The image illustrates a typical case of the masked_scatter operation.

  1. Tensor value: Contains the data to be filled into the target tensor. Only the parts where the mask is True will take values from the value tensor, while the rest will be ignored;

  2. Tensor mask: Specifies which positions should extract values from the value tensor and update the target tensor. True indicates the corresponding position needs to be updated;

  3. Tensor origin: The input tensor, where only the parts satisfying the mask will be replaced, and the rest remains unchanged;

Result: After the masked_scatter operation, the parts of the origin tensor where the mask is True are updated with the corresponding values from the value tensor, while the parts where the mask is False remain unchanged, forming the final updated tensor.

legend of masked_scatter API
Parameters
  • x (Tensor) – An N-D Tensor. The data type is float16, float32, float64, int32, int64 or bfloat16.

  • mask (Tensor) – The boolean tensor indicate the position to be filled. The data type of mask must be bool.

  • value (Tensor) – The value used to fill the target tensor. Supported data types are same as x.

  • name (str|None, optional) – Name for the operation (optional, default is None). For more information, please refer to Name.

Returns

Tensor, A reshaped Tensor with the same data type as x.

Examples

>>> import paddle
>>> paddle.seed(2048)
>>> x = paddle.randn([2, 2])
>>> print(x)
Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
    [[-1.24725831,  0.03843464],
    [-0.31660911,  0.04793844]])

>>> mask = paddle.to_tensor([[True, True], [False, False]])
>>> value = paddle.to_tensor([1, 2, 3, 4, 5,], dtype="float32")

>>> out = paddle.masked_scatter(x, mask, value)
>>> print(out)
Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
    [[1,  2],
    [-0.31660911,  0.04793844]])