Sampler

class paddle.io. Sampler ( data_source=None ) [source]

An abstract class to encapsulate methods and behaviors of samplers.

All sampler used by paddle.io.BatchSampler should be a subclass of paddle.io.Sampler, BatchSampler subclasses should implement following methods:

__iter__: return sample index iterably, which iterate over indices of dataset elements

__len__: the number of sample in data_source

Parameters

data_source (Dataset, optional) – this could be an instance of paddle.io.Dataset other Python object which implemented __len__ for Sampler to get indices as the range of dataset length. Default None.

Returns

an iterable object for sample indices iterating

Return type

Sampler

Examples

from paddle.io import Dataset, Sampler

class RandomDataset(Dataset):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __getitem__(self, idx):
        image = np.random.random([784]).astype('float32')
        label = np.random.randint(0, 9, (1, )).astype('int64')
        return image, label

    def __len__(self):
        return self.num_samples

class MySampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)

sampler = MySampler(data_source=RandomDataset(100))

for index in sampler:
    print(index)

see paddle.io.BatchSampler see paddle.io.DataLoader