ChainDataset

class paddle.io. ChainDataset ( datasets ) [源代码]

将多个流式数据集级联的数据集。

用于级联的数据集须都是 IterableDataset 数据集,将各流式数据集按顺序级联为一个数据集。

参数

  • datasets (list of IterableDatasets) - 待级联的多个数据集。

返回

paddle.io.IterableDataset,级联后的流式数据集

代码示例

>>> import numpy as np
>>> import paddle
>>> from paddle.io import IterableDataset, ChainDataset


>>> # define a random dataset
>>> class RandomDataset(IterableDataset):
...     def __init__(self, num_samples):
...         self.num_samples = num_samples
...
...     def __iter__(self):
...         for i in range(10):
...             image = np.random.random([32]).astype('float32')
...             label = np.random.randint(0, 9, (1, )).astype('int64')
...             yield image, label
...
>>> dataset = ChainDataset([RandomDataset(10), RandomDataset(10)])
>>> for image, label in iter(dataset):
...     # do something
...     ...