DataLoader¶
- class paddle.io. DataLoader ( dataset, feed_list=None, places=None, return_list=True, batch_sampler=None, batch_size=1, shuffle=False, drop_last=False, collate_fn=None, num_workers=0, use_buffer_reader=True, use_shared_memory=True, timeout=0, worker_init_fn=None, persistent_workers=False ) [source]
-
DataLoader prodives an iterator which iterates given dataset once by the batch_sampler.
DataLoader supports single-process and multi-prcess data loading, multi-process workers will be used to load data asynchronously if
num_workers
is set as a positive number.DataLoader supports map-style dataset and iterable-style dataset.
For map-style datast(can get a sample from dataset with a given index), please see
paddle.io.Dataset
.For iterable-style datast(get samples from dataset iteratively, like a Python iterator), please see
paddle.io.IterableDataset
.For
batch_sampler
please seepaddle.io.BatchSampler
Note
GPU tensor operation is not supported in subprocess currently, please don’t use GPU tensor operations in pipeline which will be performed in subprocess, such as dataset transforms, collte_fn, etc. Numpy array and CPU tensor operation is supported.
Disable automatic batching
In certain cases such as some NLP tasks, instead of automatic batching, handling batching manually in dataset is needed by users. For these cases, automatic batching is disabled if both
batch_size
andbatch_sampler
is set as None, each data got fromdataset
should be batched data and will be processed with function define bycollate_fn
ordefault_collate_fn
.Note
When automatic batching is disabled,
default_collate_fn
will do nothing to data from dataset.- Parameters
-
dataset (Dataset) – the dataset to load data from, should be an instance of subclass of
paddle.io.Dataset
orpaddle.io.IterableDataset
.feed_list (list(Tensor)|tuple(Tensor)) – feed Tensor list. The Tensors should be created by
paddle.static.data()
.feed_list
must be set ifreturn_list
is False. Default None.places (list(Place)|tuple(Place)|list(str)|optional) – a list of Place, to put data onto,
places
can be None, ifplaces
is None, default place(CPUPlace or CUDAPlace(0)) will be used. Default None. Ifplaces
is list of string, the string in the list can becpu
,gpu:x
andgpu_pinned
, wherex
is the index of the GPUs.return_list (bool) – whether the return value on each device is presented as a list. If
return_list=False
, the return value on each device would be a dict of str -> Tensor, where the key of the dict is the name of each fed Tensors. Ifreturn_list=True
, the return value on each device would be a list(Tensor).return_list
can only be True in dynamic graph mode. Default True.batch_sampler (BatchSampler) – an instance of paddle.io.BatchSampler to generate batch indices to draw samples from
dataset
and combine a batch. Default None.batch_size (int|None) – sample number in a mini-batch, a substitution parameter for
batch_sampler
, ifbatch_sampler
is not set, a default paddle.io.BatchSampler will be used and initialize bybatch_size
,shuffle
anddrop_last
. Default 1.shuffle (bool) – whther to shuffle indices order before genrate batch indices, a substitution parameter for
batch_sampler
seebatch_size
. Default False.drop_last (bool) – whether drop the last incomplete batch dataset size is not divisible by the batch size, a substitution parameter for
batch_sampler
, seebatch_size
. Default Falsecollate_fn (callable) – function to generate mini-batch data by merging the sample list, None for only stack each fields of sample in axis 0(same as :attr::np.stack(…, axis=0)). Default None
num_workers (int) – the number of subprocess to load data, 0 for no subprocess used and loading data in main process. Default 0
use_buffer_reader (bool) – whether to use bufferred reader. If use_buffer_reader=True, the DataLoader would prefetch next batch data asynchronously, so it would speed up data feeding and occupies a little more CPU or GPU memory, i.e., the memory of one batch input data. Default True.
use_shared_memory (bool) – whether to use shared memory to speed up putting data into inter-process queue, set
use_shared_memory
as True only when the shared memory space on your machine(e.g. space of ‘/dev/shm’ on Linux operating sysytem) is large enough. Shared memory will only be enabled in multi-process mode(num_workers > 0). Default True.timeout (int) – the timeout value for getting data form output queue of subprocesses. Default 0.
worker_init_fn (callable) – init function which will be called with worker id on each subproces starting if not set as None. Default None.
- Returns
-
an iterable object for data iterating, each elemnet of the generated data is a Tensor.
- Return type
-
DataLoader
Examples
import numpy as np import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle.io import Dataset, BatchSampler, DataLoader BATCH_NUM = 20 BATCH_SIZE = 16 EPOCH_NUM = 4 IMAGE_SIZE = 784 CLASS_NUM = 10 # define a random dataset class RandomDataset(Dataset): def __init__(self, num_samples): self.num_samples = num_samples def __getitem__(self, idx): image = np.random.random([IMAGE_SIZE]).astype('float32') label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64') return image, label def __len__(self): return self.num_samples dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) class SimpleNet(nn.Layer): def __init__(self): super(SimpleNet, self).__init__() self.fc = nn.Linear(IMAGE_SIZE, CLASS_NUM) def forward(self, image, label=None): return self.fc(image) simple_net = SimpleNet() opt = paddle.optimizer.SGD(learning_rate=1e-3, parameters=simple_net.parameters()) loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2) for e in range(EPOCH_NUM): for i, (image, label) in enumerate(loader()): out = simple_net(image) loss = F.cross_entropy(out, label) avg_loss = paddle.mean(loss) avg_loss.backward() opt.minimize(avg_loss) simple_net.clear_gradients() print("Epoch {} batch {}: loss = {}".format(e, i, np.mean(loss.numpy())))
Note
For reading iterable dataset with multiprocess Dataloader, please see
paddle.io.IterableDataset
-
static
from_generator
(
feed_list=None,
capacity=None,
use_double_buffer=True,
iterable=True,
return_list=False,
use_multiprocess=False,
drop_last=True
)
from_generator¶
-
Warning
This API will be deprecated in the future, it is recommended to use
paddle.io.DataLoader
which supports multi-processes acceleration.Note
The framework ensures that the data loading order of DataLoader is exactly the same as the user-defined data source.
Create a DataLoader object for loading data from Python generator. Data would be prefetched using Python thread and be pushed into a queue asynchronously.
The created DataLoader object provides 3 methods to set the data source
set_sample_generator
,set_sample_list_generator
andset_batch_generator
. Please see the following example codes to know their usages.If iterable = True, the created DataLoader object is a Python generator object, which is iterable using for-range loop.
If iterable = False, the created DataLoader object provides
start()
andreset()
method to control the data reading process.- Parameters
-
feed_list (list(Tensor)|tuple(Tensor)) – feed Tensor list. The Tensors should be created by
fluid.data()
.capacity (int) – capacity of the queue maintained in DataLoader. The unit is batch number. Set larger capacity if your reader is fast.
use_double_buffer (bool) – whether to use double_buffer_reader. If use_double_buffer=True, the DataLoader would prefetch next batch data asynchronously, so it would speed up data feeding and occupies a little more CPU or GPU memory, i.e., the memory of one batch input data.
iterable (bool) – whether the created DataLoader is iterable.
return_list (bool) – whether the return value on each device is presented as a list. It is only valid when iterable=True. If return_list=False, the return value on each device would be a dict of str -> LoDTensor, where the key of the dict is the name of each fed Tensors. If return_list=True, the return value on each device would be a list(LoDTensor). It is recommended to use return_list=False in static graph mode and use return_list=True in dygraph mode.
use_multiprocess (bool) – whether to use multi-process to speed up the data loading process in dygraph. Note: this parameter only can be used in the dygraph mode. In the static graph mode, whether this parameter is set or not has no effect. The Default value is False.
drop_last (bool) – whether to drop the last batches whose number is less than the CPU core/GPU card number. The default value is True. In training phase, users should not set drop_last=False, because all CPU cores/GPU cards must read data from DataLoader. In inference phase, users can set drop_last=False, so that the last batches whose number is less than the CPU core/GPU card number can be tested.
- Returns
-
the created DataLoader object.
- Return type
-
loader (DataLoader)
Examples 1:
''' Example in static graph mode ''' import numpy as np import paddle import paddle.static as static import paddle.nn.functional as F BATCH_NUM = 10 BATCH_SIZE = 16 EPOCH_NUM = 4 CLASS_NUM = 10 ITERABLE = True # whether the created DataLoader object is iterable USE_GPU = False # whether to use GPU DATA_FORMAT = 'batch_generator' # data format of data source user provides paddle.enable_static() def simple_net(image, label): fc_tmp = static.nn.fc(image, size=CLASS_NUM) cross_entropy = F.softmax_with_cross_entropy(image, label) loss = paddle.mean(cross_entropy) sgd = paddle.optimizer.SGD(learning_rate=1e-3) sgd.minimize(loss) return loss def get_random_images_and_labels(image_shape, label_shape): image = np.random.random(size=image_shape).astype('float32') label = np.random.random(size=label_shape).astype('int64') return image, label # If the data generator yields one sample each time, # use DataLoader.set_sample_generator to set the data source. def sample_generator_creator(): def __reader__(): for _ in range(BATCH_NUM * BATCH_SIZE): image, label = get_random_images_and_labels([784], [1]) yield image, label return __reader__ # If the data generator yield list of samples each time, # use DataLoader.set_sample_list_generator to set the data source. def sample_list_generator_creator(): def __reader__(): for _ in range(BATCH_NUM): sample_list = [] for _ in range(BATCH_SIZE): image, label = get_random_images_and_labels([784], [1]) sample_list.append([image, label]) yield sample_list return __reader__ # If the data generator yields a batch each time, # use DataLoader.set_batch_generator to set the data source. def batch_generator_creator(): def __reader__(): for _ in range(BATCH_NUM): batch_image, batch_label = get_random_images_and_labels([BATCH_SIZE, 784], [BATCH_SIZE, 1]) yield batch_image, batch_label return __reader__ # If DataLoader is iterable, use for loop to train the network def train_iterable(exe, prog, loss, loader): for _ in range(EPOCH_NUM): for data in loader(): exe.run(prog, feed=data, fetch_list=[loss]) # If DataLoader is not iterable, use start() and reset() method to control the process def train_non_iterable(exe, prog, loss, loader): for _ in range(EPOCH_NUM): loader.start() # call DataLoader.start() before each epoch starts try: while True: exe.run(prog, fetch_list=[loss]) except paddle.core.EOFException: loader.reset() # call DataLoader.reset() after catching EOFException def set_data_source(loader, places): if DATA_FORMAT == 'sample_generator': loader.set_sample_generator(sample_generator_creator(), batch_size=BATCH_SIZE, drop_last=True, places=places) elif DATA_FORMAT == 'sample_list_generator': loader.set_sample_list_generator(sample_list_generator_creator(), places=places) elif DATA_FORMAT == 'batch_generator': loader.set_batch_generator(batch_generator_creator(), places=places) else: raise ValueError('Unsupported data format') image = static.data(name='image', shape=[None, 784], dtype='float32') label = static.data(name='label', shape=[None, 1], dtype='int64') # Define DataLoader loader = paddle.io.DataLoader.from_generator(feed_list=[image, label], capacity=16, iterable=ITERABLE) # Define network loss = simple_net(image, label) # Set data source of DataLoader # # If DataLoader is iterable, places must be given and the number of places must be the same with device number. # - If you are using GPU, call `paddle.static.cuda_places()` to get all GPU places. # - If you are using CPU, call `paddle.static.cpu_places()` to get all CPU places. # # If DataLoader is not iterable, places can be None. places = static.cuda_places() if USE_GPU else static.cpu_places() set_data_source(loader, places) exe = static.Executor(places[0]) exe.run(static.default_startup_program()) prog = static.CompiledProgram(static.default_main_program()).with_data_parallel(loss_name=loss.name) if loader.iterable: train_iterable(exe, prog, loss, loader) else: train_non_iterable(exe, prog, loss, loader)
Examples 2:
''' Example in dynamic graph mode. ''' import numpy as np import paddle import paddle.nn as nn import paddle.optimizer as opt import paddle.distributed as dist BATCH_SIZE = 16 BATCH_NUM = 4 EPOCH_NUM = 4 IMAGE_SIZE = 784 CLASS_NUM = 10 USE_GPU = False # whether to use GPU def _get_random_images_and_labels(image_shape, label_shape): image = np.random.random(size=image_shape).astype('float32') label = np.random.random(size=label_shape).astype('int64') return image, label def __reader__(): for _ in range(BATCH_NUM): batch_image, batch_label = _get_random_images_and_labels( [BATCH_SIZE, IMAGE_SIZE], [BATCH_SIZE, CLASS_NUM]) yield batch_image, batch_label def random_batch_reader(): return __reader__ class LinearNet(nn.Layer): def __init__(self): super(LinearNet, self).__init__() self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM) @paddle.jit.to_static def forward(self, x): return self._linear(x) # set device paddle.set_device('gpu' if USE_GPU else 'cpu') # create network layer = LinearNet() dp_layer = paddle.DataParallel(layer) loss_fn = nn.CrossEntropyLoss() adam = opt.Adam(learning_rate=0.001, parameters=dp_layer.parameters()) # create data loader loader = paddle.io.DataLoader.from_generator(capacity=5) loader.set_batch_generator(random_batch_reader()) for epoch_id in range(EPOCH_NUM): for batch_id, (image, label) in enumerate(loader()): out = layer(image) loss = loss_fn(out, label) loss.backward() adam.step() adam.clear_grad() print("Epoch {} batch {}: loss = {}".format( epoch_id, batch_id, np.mean(loss.numpy())))
Examples 3:
''' Example of `drop_last` using in static graph multi-cards mode ''' import paddle import paddle.static as static import numpy as np import os # We use 2 CPU cores to run inference network os.environ['CPU_NUM'] = '2' paddle.enable_static() # The data source has only 3 batches, which can not be # divided evenly to each CPU core def batch_generator(): for i in range(3): yield np.array([i+1]).astype('float32'), x = static.data(name='x', shape=[None], dtype='float32') y = x * x def run_inference(drop_last): loader = paddle.io.DataLoader.from_generator(feed_list=[x], capacity=8, drop_last=drop_last) loader.set_batch_generator(batch_generator, static.cpu_places()) exe = static.Executor(paddle.CPUPlace()) prog = static.CompiledProgram(static.default_main_program()) prog = prog.with_data_parallel() result = [] for data in loader(): each_ret, = exe.run(prog, feed=data, fetch_list=[y]) result.extend(each_ret) return result # Set drop_last to True, so that the last batch whose # number is less than CPU core number would be discarded. print(run_inference(drop_last=True)) # [1.0, 4.0] # Set drop_last to False, so that the last batch whose # number is less than CPU core number can be tested. print(run_inference(drop_last=False)) # [1.0, 4.0, 9.0]
-
static
from_dataset
(
dataset,
places,
drop_last=True
)
from_dataset¶
-
Warning
This API will be deprecated in the future, it is recommended to use
paddle.io.DataLoader
which supports multi-processes acceleration.Create an iterable DataLoader object for loading data from Dataset. Dataset is only supported in Linux system currently.
- Parameters
-
dataset (InMemoryDataset|QueueDataset) – the dataset object.
places (list(CUDAPlace)|list(CPUPlace)|list(str)) – places where the result data should be converted. If places is list of string, the string in the list can be
cpu
,gpu:x
andgpu_pinned
, where x is the index of the GPUs.drop_last (bool) – whether to drop the last batch whose sample number is less than batch size. If drop_last = True, they would be dropped. If drop_last = False, they would be kept.
- Returns
-
- the created DataLoader object, which can be
-
treated as a Python generator.
- Return type
-
loader (DataLoader)
Examples
import paddle import paddle.static as static paddle.enable_static() image = static.data(name='image', shape=[None, 784], dtype='float32') label = static.data(name='label', shape=[None, 1], dtype='int64') dataset = paddle.distributed.QueueDataset() dataset.init( batch_size=32, pipe_command='cat', use_var=[image, label]) dataset.set_filelist(['a.txt', 'b.txt', 'c.txt']) loader = paddle.io.DataLoader.from_dataset(dataset, static.cpu_places())