BatchSampler

class paddle.io. BatchSampler ( dataset=None, sampler=None, shuffle=False, batch_size=1, drop_last=False ) [source]

A base implement of batch sampler used by paddle.io.DataLoader which yield mini-batch indices(a list/tuple with length as mini-batch size and holds sample indices) iterably.

Batch sampler used by paddle.io.DataLoader should be a subclass of paddle.io.BatchSampler, BatchSampler subclasses should implement following methods:

__iter__: return mini-batch indices iterably.

__len__: get mini-batch number in an epoch.

Parameters
  • dataset (Dataset, optional) – this should be an instance of a subclass of Dataset or IterableDataset or other python object which implemented __len__ for BatchSampler to get indices as the range of dataset length. Default None, disabled.

  • sampler (Sampler, optional) – this should be a api_paddle_io_Sample instance which implemented __iter__ to generate sample indices. sampler and dataset can not be set in the same time. If sampler is set, dataset should not be set. Default None, disabled.

  • shuffle (bool, optional) – whether to shuffle indices order before generating batch indices. Default False, don’t shuffle indices before generating batch indices.

  • batch_size (int, optional) – sample indice number in a mini-batch indices. default 1, each mini-batch includes 1 sample.

  • drop_last (bool, optional) – whether drop the last incomplete (less than 1 mini-batch) batch dataset. Default False, keep it.

see DataLoader

Returns

an iterable object for indices iterating

Return type

BatchSampler

Examples

>>> import numpy as np
>>> from paddle.io import RandomSampler, BatchSampler, Dataset

>>> np.random.seed(2023)
>>> # init with dataset
>>> class RandomDataset(Dataset):
...     def __init__(self, num_samples):
...         self.num_samples = num_samples
...
...     def __getitem__(self, idx):
...         image = np.random.random([784]).astype('float32')
...         label = np.random.randint(0, 9, (1, )).astype('int64')
...         return image, label
...
...     def __len__(self):
...         return self.num_samples
...
>>> bs = BatchSampler(dataset=RandomDataset(100),
...                     shuffle=False,
...                     batch_size=16,
...                     drop_last=False)
...
>>> for batch_indices in bs:
...     print(batch_indices)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
...
[96, 97, 98, 99]
>>> # init with sampler
>>> sampler = RandomSampler(RandomDataset(100))
>>> bs = BatchSampler(sampler=sampler,
...                     batch_size=8,
...                     drop_last=True)
...
>>> for batch_indices in bs:
...     print(batch_indices)
[56, 12, 68, 0, 82, 66, 91, 44]
...
[53, 17, 22, 86, 52, 3, 92, 33]