BatchSampler

class paddle.io. BatchSampler ( dataset=None, sampler=None, shuffle=False, batch_size=1, drop_last=False ) [source]

A base implement of batch sampler used by paddle.io.DataLoader which yield mini-batch indices(a list/tuple with length as mini-batch size and holds sample indices) iterably.

Batch sampler used by paddle.io.DataLoader should be a subclass of paddle.io.BatchSampler, BatchSampler subclasses should implement following methods:

__iter__: return mini-batch indices iterably.

__len__: get mini-batch number in an epoch.

Parameters
  • dataset (Dataset) – this could be a paddle.io.Dataset implement or other python object which implemented __len__ for BatchSampler to get indices as the range of dataset length. Default None.

  • sampler (Sampler) – this could be a paddle.io.Dataset instance which implemented __iter__ to yield sample indices. sampler and dataset can not be set in the same time. If sampler is set, shuffle should not be set. Default None.

  • shuffle (bool) – whether to shuffle indices order before genrating batch indices. Default False.

  • batch_size (int) – sample indice number in a mini-batch indices.

  • drop_last (bool) – whether drop the last incomplete batch dataset size is not divisible by the batch size. Default False

Returns

an iterable object for indices iterating

Return type

BatchSampler

Examples

from paddle.io import RandomSampler, BatchSampler, Dataset

# init with dataset
class RandomDataset(Dataset):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __getitem__(self, idx):
        image = np.random.random([784]).astype('float32')
        label = np.random.randint(0, 9, (1, )).astype('int64')
        return image, label

    def __len__(self):
        return self.num_samples

bs = BatchSampler(dataset=RandomDataset(100),
                  shuffle=False,
                  batch_size=16,
                  drop_last=False)

for batch_indices in bs:
    print(batch_indices)

# init with sampler
sampler = RandomSampler(RandomDataset(100))
bs = BatchSampler(sampler=sampler,
                  batch_size=8,
                  drop_last=True)

for batch_indices in bs:
    print(batch_indices)

see paddle.io.DataLoader