tensor_split¶
- paddle. tensor_split ( x, num_or_indices, axis=0, name=None ) [source]
-
Split the input tensor into multiple sub-Tensors along
axis
, allowing not being of equal size.- Parameters
-
x (Tensor) – A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.
num_or_indices (int|list|tuple) – If
num_or_indices
is an intn
,x
is split inton
sections alongaxis
. Ifx
is divisible byn
, each section will bex.shape[axis] / n
. Ifx
is not divisible byn
, the firstint(x.shape[axis] % n)
sections will have sizeint(x.shape[axis] / n) + 1
, and the rest will beint(x.shape[axis] / n). If ``num_or_indices
is a list or tuple of integer indices,x
is split alongaxis
at each of the indices. For instance,num_or_indices=[2, 4]
withaxis=0
would splitx
intox[:2]
,x[2:4]
andx[4:]
along axis 0.axis (int|Tensor, optional) – The axis along which to split, it can be a integer or a
0-D Tensor
with shape [] and data typeint32
orint64
. If :math::axis < 0, the axis to split along is \(rank(x) + axis\). Default is 0.name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name .
- Returns
-
list[Tensor], The list of segmented Tensors.
Examples
>>> import paddle >>> # x is a Tensor of shape [8] >>> # evenly split >>> x = paddle.rand([8]) >>> out0, out1 = paddle.tensor_split(x, num_or_indices=2) >>> print(out0.shape) [4] >>> print(out1.shape) [4] >>> # not evenly split >>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=3) >>> print(out0.shape) [3] >>> print(out1.shape) [3] >>> print(out2.shape) [2] >>> # split with indices >>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3]) >>> print(out0.shape) [2] >>> print(out1.shape) [1] >>> print(out2.shape) [5] >>> # split along axis >>> # x is a Tensor of shape [7, 8] >>> x = paddle.rand([7, 8]) >>> out0, out1 = paddle.tensor_split(x, num_or_indices=2, axis=1) >>> print(out0.shape) [7, 4] >>> print(out1.shape) [7, 4] >>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3], axis=1) >>> print(out0.shape) [7, 2] >>> print(out1.shape) [7, 1] >>> print(out2.shape) [7, 5]