random_split

paddle.io. random_split ( dataset, lengths, generator=None ) [source]

Randomly split a dataset into non-overlapping new datasets of given lengths. Optionally fix the generator for reproducible results, e.g.:

Parameters
  • dataset (Dataset) – Dataset to be split

  • lengths (sequence) – lengths of splits to be produced

  • generator (Generator, optional) – Generator used for the random permutation. Default is None then the DefaultGenerator is used in manual_seed().

Returns

A list of subset Datasets, which are the non-overlapping subsets of the original Dataset.

Return type

Datasets

Examples

import paddle
from paddle.io import random_split

a_list = paddle.io.random_split(range(10), [3, 7])
print(len(a_list))
# 2

for idx, v in enumerate(a_list[0]):
    print(idx, v)

# output of the first subset
# 0 1
# 1 3
# 2 9

for idx, v in enumerate(a_list[1]):
    print(idx, v)
# output of the second subset
# 0 5
# 1 7
# 2 8
# 3 6
# 4 0
# 5 2
# 6 4