masked_matmul¶
- paddle.sparse. masked_matmul ( x, y, mask, name=None ) [source]
-
Note
This API is only supported from
CUDA 11.3
.Applies matrix multiplication of two Dense Tensors.
The supported input/output Tensor layout are as follows:
Note
x[DenseTensor] @ y[DenseTensor] * mask[SparseCooTensor] -> out[SparseCooTensor] x[DenseTensor] @ y[DenseTensor] * mask[SparseCsrTensor] -> out[SparseCsrTensor]
It supports backward propagation.
Dimensions x and y must be >= 2D. Automatic broadcasting of Tensor is not supported. the shape of x should be [*, M, K] , and the shape of y should be [*, K, N] , and the shape of mask should be [*, M, N] , where * is zero or more batch dimensions.
- Parameters
-
x (DenseTensor) – The input tensor. It is DenseTensor. The data type can be float32 or float64.
y (DenseTensor) – The input tensor. It is DenseTensor. The data type can be float32 or float64.
mask (SparseTensor) – The mask tensor, which can be SparseCooTensor/SparseCsrTensor. It specify sparse coordinates. The data type can be float32 or float64.
name (str, optional) – Name for the operation (optional, default is None). For more information, please refer to Name.
- Returns
-
SparseCooTensor or SparseCsrTensor, which is same with mask .
- Return type
-
SparseTensor
Examples
>>> >>> import paddle >>> paddle.device.set_device('gpu') >>> paddle.seed(100) >>> # dense @ dense * csr_mask -> csr >>> crows = [0, 2, 3, 5] >>> cols = [1, 3, 2, 0, 1] >>> values = [1., 2., 3., 4., 5.] >>> dense_shape = [3, 4] >>> mask = paddle.sparse.sparse_csr_tensor(crows, cols, values, dense_shape) >>> print(mask) Tensor(shape=[3, 4], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True, crows=[0, 2, 3, 5], cols=[1, 3, 2, 0, 1], values=[1., 2., 3., 4., 5.]) >>> x = paddle.rand([3, 5]) >>> y = paddle.rand([5, 4]) >>> out = paddle.sparse.masked_matmul(x, y, mask) >>> print(out) Tensor(shape=[3, 4], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True, crows=[0, 2, 3, 5], cols=[1, 3, 2, 0, 1], values=[0.98986477, 0.97800624, 1.14591956, 0.68561077, 0.94714981])