IterableDataset

class paddle.io. IterableDataset [源代码]

概述迭代式数据集的方法和行为的抽象类。

迭代式(iterable style)数据集需要继承这个基类,迭代式数据集为只能依次迭代式获取样本的数据集,类似 Python 中的迭代器,所有迭代式数据集须实现以下方法:

__iter__:依次返回数据赝本。

注解

迭代式数据集不需要实现 __getitem____len__,也不可以调用迭代式数据集的这两个方法。

paddle.io.DataLoader

代码示例 1

import numpy as np
from paddle.io import IterableDataset

# define a random dataset
class RandomDataset(IterableDataset):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __iter__(self):
        for i in range(self.num_samples):
            image = np.random.random([784]).astype('float32')
            label = np.random.randint(0, 9, (1, )).astype('int64')
            yield image, label

dataset = RandomDataset(10)
for img, lbl in dataset:
    print(img, lbl)

paddle.io.DataLoadernum_workers > 0 时,每个子进程都会遍历全量的数据集返回全量样本,所以数据集会重复 num_workers 次,如果需要数据集样本不会重复返回,可通过如下两种方法避免样本重复,两种方法中都需要通过 paddle.io.get_worker_info 获取各子进程的信息。

代码示例 2

通过 __iter__ 函数划分各子进程的数据

import math
import paddle
import numpy as np
from paddle.io import IterableDataset, DataLoader, get_worker_info

class SplitedIterableDataset(IterableDataset):
    def __init__(self, start, end):
        self.start = start
        self.end = end

    def __iter__(self):
        worker_info = get_worker_info()
        if worker_info is None:
            iter_start = self.start
            iter_end = self.end
        else:
            per_worker = int(
                math.ceil((self.end - self.start) / float(
                    worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)

        for i in range(iter_start, iter_end):
            yield np.array([i])

dataset = SplitedIterableDataset(start=2, end=9)
dataloader = DataLoader(
    dataset,
    num_workers=2,
    batch_size=1,
    drop_last=True)

for data in dataloader:
    print(data)
# outputs: [2, 5, 3, 6, 4, 7]

代码示例 3

通过各子进程初始化函数 worker_inif_fn 划分子进程数据

import math
import paddle
import numpy as np
from paddle.io import IterableDataset, DataLoader, get_worker_info

class RangeIterableDataset(IterableDataset):
    def __init__(self, start, end):
        self.start = start
        self.end = end

    def __iter__(self):
        for i in range(self.start, self.end):
            yield np.array([i])

dataset = RangeIterableDataset(start=2, end=9)

def worker_init_fn(worker_id):
    worker_info = get_worker_info()

    dataset = worker_info.dataset
    start = dataset.start
    end = dataset.end
    num_per_worker = int(
        math.ceil((end - start) / float(worker_info.num_workers)))

    worker_id = worker_info.id
    dataset.start = start + worker_id * num_per_worker
    dataset.end = min(dataset.start + num_per_worker, end)

dataloader = DataLoader(
    dataset,
    num_workers=2,
    batch_size=1,
    drop_last=True,
    worker_init_fn=worker_init_fn)

for data in dataloader:
    print(data)
# outputs: [2, 5, 3, 6, 4, 7]