DistributedBatchSampler¶
- class paddle.io. DistributedBatchSampler ( dataset, batch_size, num_replicas=None, rank=None, shuffle=False, drop_last=False ) [源代码] ¶
分布式批采样器加载数据的一个子集。每个进程可以传递给 DataLoader 一个 DistributedBatchSampler 的实例,每个进程加载原始数据的一个子集。
注解
假定 Dataset 的大小是固定的。
参数¶
dataset (Dataset) - 此参数必须是 Dataset 的一个子类实例或实现了
__len__
的 Python 对象,用于生成样本下标。batch_size (int) - 每 mini-batch 中包含的样本数。
num_replicas (int,可选) - 分布式训练时的进程个数。如果是 None,会依据 ParallelEnv 获取值。默认是 None。
rank (int,可选) - num_replicas 个进程中的进程序号。如果是 None,会依据 ParallelEnv 获取值。默认是 None。
shuffle (bool,可选) - 是否需要在生成样本下标时打乱顺序。默认值为 False。
drop_last (bool,可选) - 是否需要丢弃最后无法凑整一个 mini-batch 的样本。默认值为 False。
返回¶
DistributedBatchSampler,返回样本下标数组的迭代器。
代码示例¶
>>> import numpy as np
>>> from paddle.io import Dataset, DistributedBatchSampler
>>> # init with dataset
>>> class RandomDataset(Dataset):
... def __init__(self, num_samples):
... self.num_samples = num_samples
...
... def __getitem__(self, idx):
... image = np.random.random([784]).astype('float32')
... label = np.random.randint(0, 9, (1, )).astype('int64')
... return image, label
...
... def __len__(self):
... return self.num_samples
...
>>> dataset = RandomDataset(100)
>>> sampler = DistributedBatchSampler(dataset, batch_size=64)
>>> for data in sampler:
... # do something
... break
方法¶
set_epoch(epoch)¶
设置 epoch 数。当设置``shuffle=True``时,此 epoch 被用作随机种子。默认情况下,用户可以不用此接口设置,每个 epoch 时,所有的进程(workers)使用不同的顺序。如果每个 epoch 设置相同的数字,每个 epoch 数据的读取顺序将会相同。
参数
epoch (int) - epoch 数。
代码示例
>>> import numpy as np
>>> from paddle.io import Dataset, DistributedBatchSampler
>>> # init with dataset
>>> class RandomDataset(Dataset):
... def __init__(self, num_samples):
... self.num_samples = num_samples
...
... def __getitem__(self, idx):
... image = np.random.random([784]).astype('float32')
... label = np.random.randint(0, 9, (1, )).astype('int64')
... return image, label
...
... def __len__(self):
... return self.num_samples
...
>>> dataset = RandomDataset(100)
>>> sampler = DistributedBatchSampler(dataset, batch_size=64)
>>> for epoch in range(10):
... sampler.set_epoch(epoch)