random_split

paddle.io. random_split ( dataset, lengths, generator=None ) [source]

Randomly split a dataset into non-overlapping new datasets of given lengths. Optionally fix the generator for reproducible results, e.g.:

Parameters
  • dataset (Dataset) – Dataset to be split

  • lengths (sequence) – lengths of splits to be produced

  • generator (Generator, optional) – Generator used for the random permutation. Default is None then the DefaultGenerator is used in manual_seed().

Returns

A list of subset Datasets, which are the non-overlapping subsets of the original Dataset.

Return type

Datasets

Examples

>>> import paddle

>>> paddle.seed(2023)
>>> a_list = paddle.io.random_split(range(10), [3, 7])
>>> print(len(a_list))
2

>>> # output of the first subset
>>> for idx, v in enumerate(a_list[0]):
...     print(idx, v) 
0 7
1 6
2 5

>>> # output of the second subset
>>> for idx, v in enumerate(a_list[1]):
...     print(idx, v) 
0 1
1 9
2 4
3 2
4 0
5 3
6 8