Sampler¶
概括数据集采样器行为和方法的基类。
所有数据集采样器必须继承这个基类,并实现以下方法:
__iter__
:迭代返回数据样本下标
__len__
: data_source
中的样本数
参数¶
data_source (Dataset) - 此参数必须是
paddle.io.Dataset
或paddle.io.IterableDataset
的一个子类实例或实现了__len__
的 Python 对象,用于生成样本下标。默认值为 None。
可见 paddle.io.BatchSampler
和 paddle.io.DataLoader
返回¶
Sampler,返回样本下标的迭代器。
代码示例¶
from paddle.io import Dataset, Sampler
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
class MySampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
sampler = MySampler(data_source=RandomDataset(100))
for index in sampler:
print(index)