DistributedBatchSampler¶
- class paddle.io. DistributedBatchSampler ( dataset, batch_size, num_replicas=None, rank=None, shuffle=False, drop_last=False ) [source]
-
Sampler that restricts data loading to a subset of the dataset.
In such case, each process can pass a DistributedBatchSampler instance as a DataLoader sampler, and load a subset of the original dataset that is exclusive to it.
Note
Dataset is assumed to be of constant size.
- Parameters
-
dataset (Dataset) – this could be an instance of subclass of Dataset or other python object which implemented __len__ for BatchSampler to get indices of samples.
batch_size (int) – sample size of each mini-batch.
num_replicas (int, optional) – porcess number in distributed training. If
num_replicas
is None,num_replicas
will be retrieved from ParallelEnv . Default None.rank (int, optional) – the rank of the current process among
num_replicas
processes. Ifrank
is None,rank
is retrieved from ParallelEnv. Default None.shuffle (bool, optional) – whther to shuffle indices order before genrating batch indices. Default False.
drop_last (bool, optional) – whether drop the last incomplete(less than a mini-batch) batch dataset size. Default False.
- Returns
-
DistributedBatchSampler, return an iterable object for indices iterating.
Examples
>>> import numpy as np >>> from paddle.io import Dataset, DistributedBatchSampler >>> # init with dataset >>> class RandomDataset(Dataset): ... def __init__(self, num_samples): ... self.num_samples = num_samples ... ... def __getitem__(self, idx): ... image = np.random.random([784]).astype('float32') ... label = np.random.randint(0, 9, (1, )).astype('int64') ... return image, label ... ... def __len__(self): ... return self.num_samples ... >>> dataset = RandomDataset(100) >>> sampler = DistributedBatchSampler(dataset, batch_size=64) >>> for data in sampler: ... # do something ... break
-
set_epoch
(
epoch
)
set_epoch¶
-
Sets the epoch number. When
shuffle=True
, this number is used as seeds of random numbers. By default, users may not set this, all replicas (workers) use a different random ordering for each epoch. If set same number at each epoch, this sampler will yield the same ordering at all epoches.- Parameters
-
epoch (int) – Epoch number.
Examples
>>> import numpy as np >>> from paddle.io import Dataset, DistributedBatchSampler >>> # init with dataset >>> class RandomDataset(Dataset): ... def __init__(self, num_samples): ... self.num_samples = num_samples ... ... def __getitem__(self, idx): ... image = np.random.random([784]).astype('float32') ... label = np.random.randint(0, 9, (1, )).astype('int64') ... return image, label ... ... def __len__(self): ... return self.num_samples ... >>> dataset = RandomDataset(100) >>> sampler = DistributedBatchSampler(dataset, batch_size=64) >>> for epoch in range(10): ... sampler.set_epoch(epoch)