masked_scatter¶
- paddle. masked_scatter ( x: Tensor, mask: Tensor, value: Tensor, name: str | None = None ) Tensor [source]
-
Copies elements from value into x tensor at positions where the mask is True.
Elements from source are copied into x starting at position 0 of value and continuing in order one-by-one for each occurrence of mask being True. The shape of mask must be broadcastable with the shape of the underlying tensor. The value should have at least as many elements as the number of ones in mask.
The image illustrates a typical case of the masked_scatter operation.
Tensor
value
: Contains the data to be filled into the target tensor. Only the parts where the mask is True will take values from the value tensor, while the rest will be ignored;Tensor
mask
: Specifies which positions should extract values from the value tensor and update the target tensor. True indicates the corresponding position needs to be updated;Tensor
origin
: The input tensor, where only the parts satisfying the mask will be replaced, and the rest remains unchanged;
Result: After the
masked_scatter
operation, the parts of theorigin
tensor where themask
isTrue
are updated with the corresponding values from thevalue
tensor, while the parts where themask
isFalse
remain unchanged, forming the final updated tensor.- Parameters
-
x (Tensor) – An N-D Tensor. The data type is
float16
,float32
,float64
,int32
,int64
orbfloat16
.mask (Tensor) – The boolean tensor indicate the position to be filled. The data type of mask must be bool.
value (Tensor) – The value used to fill the target tensor. Supported data types are same as x.
name (str|None, optional) – Name for the operation (optional, default is None). For more information, please refer to Name.
- Returns
-
Tensor, A reshaped Tensor with the same data type as
x
.
Examples
>>> import paddle >>> paddle.seed(2048) >>> x = paddle.randn([2, 2]) >>> print(x) Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, [[-1.24725831, 0.03843464], [-0.31660911, 0.04793844]]) >>> mask = paddle.to_tensor([[True, True], [False, False]]) >>> value = paddle.to_tensor([1, 2, 3, 4, 5,], dtype="float32") >>> out = paddle.masked_scatter(x, mask, value) >>> print(out) Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, [[1, 2], [-0.31660911, 0.04793844]])