tensor_split

paddle. tensor_split ( x, num_or_indices, axis=0, name=None ) [source]

Split the input tensor into multiple sub-Tensors along axis, allowing not being of equal size.

Parameters
  • x (Tensor) – A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.

  • num_or_indices (int|list|tuple) – If num_or_indices is an int n, x is split into n sections along axis. If x is divisible by n, each section will be x.shape[axis] / n. If x is not divisible by n, the first int(x.shape[axis] % n) sections will have size int(x.shape[axis] / n) + 1, and the rest will be int(x.shape[axis] / n). If ``num_or_indices is a list or tuple of integer indices, x is split along axis at each of the indices. For instance, num_or_indices=[2, 4] with axis=0 would split x into x[:2], x[2:4] and x[4:] along axis 0.

  • axis (int|Tensor, optional) – The axis along which to split, it can be a integer or a 0-D Tensor with shape [] and data type int32 or int64. If :math::axis < 0, the axis to split along is \(rank(x) + axis\). Default is 0.

  • name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name .

Returns

list[Tensor], The list of segmented Tensors.

Examples

>>> import paddle

>>> # x is a Tensor of shape [8]
>>> # evenly split
>>> x = paddle.rand([8])
>>> out0, out1 = paddle.tensor_split(x, num_or_indices=2)
>>> print(out0.shape)
[4]
>>> print(out1.shape)
[4]

>>> # not evenly split
>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=3)
>>> print(out0.shape)
[3]
>>> print(out1.shape)
[3]
>>> print(out2.shape)
[2]

>>> # split with indices
>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3])
>>> print(out0.shape)
[2]
>>> print(out1.shape)
[1]
>>> print(out2.shape)
[5]

>>> # split along axis
>>> # x is a Tensor of shape [7, 8]
>>> x = paddle.rand([7, 8])
>>> out0, out1 = paddle.tensor_split(x, num_or_indices=2, axis=1)
>>> print(out0.shape)
[7, 4]
>>> print(out1.shape)
[7, 4]

>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3], axis=1)
>>> print(out0.shape)
[7, 2]
>>> print(out1.shape)
[7, 1]
>>> print(out2.shape)
[7, 5]