get_worker_info¶
- paddle.io. get_worker_info ( ) [source]
-
Get DataLoader worker process information function, this function is used to split data copy in worker process for IterableDataset (see
paddle.io.IterableDataset
), worker information contains following fields:num_workers
: total worker process number, see paddle.io.DataLoaderid
: the worker processs id, count from 0 tonum_workers - 1
dataset
: the dataset object in this worker process- Returns
-
an instance of WorkerInfo which contains fields above.
- Return type
-
WorkerInfo
Note
For more usage and examples, please see
paddle.io.IterableDataset
Example
import math import paddle import numpy as np from paddle.io import IterableDataset, DataLoader, get_worker_info class SplitedIterableDataset(IterableDataset): def __init__(self, start, end): self.start = start self.end = end def __iter__(self): worker_info = get_worker_info() if worker_info is None: iter_start = self.start iter_end = self.end else: per_worker = int( math.ceil((self.end - self.start) / float( worker_info.num_workers))) worker_id = worker_info.id iter_start = self.start + worker_id * per_worker iter_end = min(iter_start + per_worker, self.end) for i in range(iter_start, iter_end): yield np.array([i]) place = paddle.CPUPlace() dataset = SplitedIterableDataset(start=2, end=9) dataloader = DataLoader( dataset, places=place, num_workers=2, batch_size=1, drop_last=True) for data in dataloader: print(data) # outputs: [2, 5, 3, 6, 4, 7]