BatchSampler¶
- class paddle.io. BatchSampler ( dataset=None, sampler=None, shuffle=False, batch_size=1, drop_last=False ) [源代码] ¶
批采样器的基础实现,用于 paddle.io.DataLoader
中迭代式获取 mini-batch 的样本下标数组,数组长度与 batch_size
一致。
所有用于 paddle.io.DataLoader
中的批采样器都必须是 paddle.io.BatchSampler
的子类并实现以下方法:
__iter__
:迭代式返回批样本下标数组。
__len__
:每 epoch 中 mini-batch 数。
参数¶
dataset (Dataset,可选) - 此参数必须是 Dataset 或 IterableDataset 的一个子类实例或实现了
__len__
的 Python 对象,用于生成样本下标。默认值为 None, 表示不使用此参数。sampler (Sampler,可选) - 此参数必须是 Sampler 的子类实例,用于迭代式获取样本下标。
dataset
和sampler
参数只能设置一个。默认值为 None, 表示不使用此参数。shuffle (bool,可选) - 是否需要在生成样本下标时打乱顺序。默认值为 False ,表示不打乱顺序。
batch_size (int,可选) - 每 mini-batch 中包含的样本数。默认值为 1 ,表示每 mini-batch 中包含 1 个样本数。
drop_last (bool,可选) - 是否需要丢弃最后无法凑整一个 mini-batch 的样本。默认值为 False ,表示不丢弃最后无法凑整一个 mini-batch 的样本。
见 DataLoader 。
返回¶
BatchSampler, 返回样本下标数组的迭代器。
代码示例¶
>>> import numpy as np
>>> from paddle.io import RandomSampler, BatchSampler, Dataset
>>> np.random.seed(2023)
>>> # init with dataset
>>> class RandomDataset(Dataset):
... def __init__(self, num_samples):
... self.num_samples = num_samples
...
... def __getitem__(self, idx):
... image = np.random.random([784]).astype('float32')
... label = np.random.randint(0, 9, (1, )).astype('int64')
... return image, label
...
... def __len__(self):
... return self.num_samples
...
>>> bs = BatchSampler(dataset=RandomDataset(100),
... shuffle=False,
... batch_size=16,
... drop_last=False)
...
>>> for batch_indices in bs:
... print(batch_indices)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
...
[96, 97, 98, 99]
>>> # init with sampler
>>> sampler = RandomSampler(RandomDataset(100))
>>> bs = BatchSampler(sampler=sampler,
... batch_size=8,
... drop_last=True)
...
>>> for batch_indices in bs:
... print(batch_indices)
[56, 12, 68, 0, 82, 66, 91, 44]
...
[53, 17, 22, 86, 52, 3, 92, 33]