BatchSampler¶
- class paddle.io. BatchSampler ( dataset=None, sampler=None, shuffle=False, batch_size=1, drop_last=False ) [source]
-
A base implement of batch sampler used by paddle.io.DataLoader which yield mini-batch indices(a list/tuple with length as mini-batch size and holds sample indices) iterably.
Batch sampler used by
paddle.io.DataLoader
should be a subclass ofpaddle.io.BatchSampler
, BatchSampler subclasses should implement following methods:__iter__
: return mini-batch indices iterably.__len__
: get mini-batch number in an epoch.- Parameters
-
dataset (Dataset, optional) – this should be an instance of a subclass of Dataset or IterableDataset or other python object which implemented
__len__
for BatchSampler to get indices as the range ofdataset
length. Default None, disabled.sampler (Sampler, optional) – this should be a api_paddle_io_Sample instance which implemented
__iter__
to generate sample indices.sampler
anddataset
can not be set in the same time. Ifsampler
is set,dataset
should not be set. Default None, disabled.shuffle (bool, optional) – whether to shuffle indices order before generating batch indices. Default False, don’t shuffle indices before generating batch indices.
batch_size (int, optional) – sample indice number in a mini-batch indices. default 1, each mini-batch includes 1 sample.
drop_last (bool, optional) – whether drop the last incomplete (less than 1 mini-batch) batch dataset. Default False, keep it.
see DataLoader
- Returns
-
an iterable object for indices iterating
- Return type
-
BatchSampler
Examples
>>> import numpy as np >>> from paddle.io import RandomSampler, BatchSampler, Dataset >>> np.random.seed(2023) >>> # init with dataset >>> class RandomDataset(Dataset): ... def __init__(self, num_samples): ... self.num_samples = num_samples ... ... def __getitem__(self, idx): ... image = np.random.random([784]).astype('float32') ... label = np.random.randint(0, 9, (1, )).astype('int64') ... return image, label ... ... def __len__(self): ... return self.num_samples ... >>> bs = BatchSampler(dataset=RandomDataset(100), ... shuffle=False, ... batch_size=16, ... drop_last=False) ... >>> for batch_indices in bs: ... print(batch_indices) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] ... [96, 97, 98, 99] >>> # init with sampler >>> sampler = RandomSampler(RandomDataset(100)) >>> bs = BatchSampler(sampler=sampler, ... batch_size=8, ... drop_last=True) ... >>> for batch_indices in bs: ... print(batch_indices) [56, 12, 68, 0, 82, 66, 91, 44] ... [53, 17, 22, 86, 52, 3, 92, 33]