tensor_split

paddle. tensor_split ( x: Tensor, num_or_indices: int | Sequence[int], axis: int | Tensor = 0, name: str | None = None ) list[Tensor] [source]

Split the input tensor into multiple sub-Tensors along axis, allowing not being of equal size.

In the following figure, the shape of Tenser x is [6], and after paddle.tensor_split(x, num_or_indices=4) transformation, we get four sub-Tensors out0, out1, out2, and out3 :

https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/tensor_split/tensor_split-1_en.png

since the length of x in axis = 0 direction 6 is not divisible by num_or_indices = 4, the size of the first int(6 % 4) part after splitting will be int(6 / 4) + 1 and the size of the remaining parts will be int(6 / 4).

Parameters
  • x (Tensor) – A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.

  • num_or_indices (int|list|tuple) – If num_or_indices is an int n, x is split into n sections along axis. If x is divisible by n, each section will be x.shape[axis] / n. If x is not divisible by n, the first int(x.shape[axis] % n) sections will have size int(x.shape[axis] / n) + 1, and the rest will be int(x.shape[axis] / n). If ``num_or_indices is a list or tuple of integer indices, x is split along axis at each of the indices. For instance, num_or_indices=[2, 4] with axis=0 would split x into x[:2], x[2:4] and x[4:] along axis 0.

  • axis (int|Tensor, optional) – The axis along which to split, it can be a integer or a 0-D Tensor with shape [] and data type int32 or int64. If :math::axis < 0, the axis to split along is \(rank(x) + axis\). Default is 0.

  • name (str|None, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name .

Returns

list[Tensor], The list of segmented Tensors.

Examples

>>> import paddle

>>> # evenly split
>>> # x is a Tensor of shape [8]
>>> x = paddle.rand([8])
>>> out0, out1 = paddle.tensor_split(x, num_or_indices=2)
>>> print(out0.shape)
[4]
>>> print(out1.shape)
[4]
https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/tensor_split/tensor_split-2.png
>>> import paddle

>>> # not evenly split
>>> # x is a Tensor of shape [8]
>>> x = paddle.rand([8])
>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=3)
>>> print(out0.shape)
[3]
>>> print(out1.shape)
[3]
>>> print(out2.shape)
[2]
https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/tensor_split/tensor_split-3_en.png
>>> import paddle

>>> # split with indices
>>> # x is a Tensor of shape [8]
>>> x = paddle.rand([8])
>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3])
>>> print(out0.shape)
[2]
>>> print(out1.shape)
[1]
>>> print(out2.shape)
[5]
https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/tensor_split/tensor_split-4.png
>>> import paddle

>>> # split along axis
>>> # x is a Tensor of shape [7, 8]
>>> x = paddle.rand([7, 8])
>>> out0, out1 = paddle.tensor_split(x, num_or_indices=2, axis=1)
>>> print(out0.shape)
[7, 4]
>>> print(out1.shape)
[7, 4]
https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/tensor_split/tensor_split-5.png
>>> import paddle

>>> # split along axis with indices
>>> # x is a Tensor of shape [7, 8]
>>> x = paddle.rand([7, 8])
>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3], axis=1)
>>> print(out0.shape)
[7, 2]
>>> print(out1.shape)
[7, 1]
>>> print(out2.shape)
[7, 5]
https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/tensor_split/tensor_split-6.png