shard_dataloader¶
- paddle.distributed. shard_dataloader ( dataloader: paddle.io.reader.DataLoader, meshes: Union[paddle.distributed.auto_parallel.process_mesh.ProcessMesh, List[paddle.distributed.auto_parallel.process_mesh.ProcessMesh], Tuple[paddle.distributed.auto_parallel.process_mesh.ProcessMesh]], input_keys: Optional[Union[List[str], Tuple[str]]] = None, shard_dims: Optional[Union[list, tuple, str, int]] = None, is_dataset_splitted: bool = False ) paddle.distributed.auto_parallel.api.ShardDataloader [source]
-
Convert the dataloader to a ShardDataloader which provided two capabilities: 1. split dataloader by shard_dim to do data parallel if it it not None. 2. reshard the output of dataloader to distributed tensor. if is_dataset_splitted is True, it means that the dataset has been split by users, and just need to do reshard. only if is_dataset_splitted is False and shard_dims is not None, it will do split.
- Parameters
-
dataloader (paddle.io.DataLoader) – The dataloader to be sharded. the output of dataloader must be a list or dict of paddle.Tensor with 2 elements, i.e. [input_data, label] or {“input_data”: input_data, “label”: label}, input_data and label can be a list to support multiple inputs.
meshes (ProcessMesh|list[ProcessMesh]|tuple[ProcessMesh]) – The mesh list of the dataloader. Identify which mesh the input is on. if len(meshes) == 1 or type(meshes) == ProcessMesh, all the inputs are on the same mesh.
input_keys (list[str]|tuple[str]) – if the iteration result of dataloader is a dict of tensors, input_keys is the keys of this dict, identify which tensor is located on which mesh, one-to-one correspondence with meshes. i.e. dict[input_keys[i]] is on meshes[i]. Default: None, which means the outputs is a list, and the i’th input is on meshes[i].
shard_dims (list(str)|tuple(str)|list(int)|tuple(int)|str|int]) – The mesh dimension to shard the dataloader. Users can specify the shard_dim of each mesh or specify a single shard_dim for all meshes. Default: None, which means the data loader will not be split, i.e. mp.
is_dataset_splitted (bool) – Whether the dataset has been splitted, Default: False.
- Returns
-
The sharded dataloader.
- Return type
-
ShardDataloader
Examples
>>> import paddle >>> import paddle.distributed as dist >>> from paddle.io import BatchSampler, DataLoader, Dataset >>> >>> mesh0 = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y']) >>> mesh1 = dist.ProcessMesh([[4, 5], [6, 7]], dim_names=['x', 'y']) >>> paddle.seed(1024) >>> np.random.seed(1024) >>> class RandomDataset(Dataset): >>> def __init__(self, seq_len, hidden, num_samples=8): ... super().__init__() ... self.seq_len = seq_len ... self.hidden = hidden ... self.num_samples = num_samples ... self.inputs = [np.random.uniform(size=[self.seq_len, self.hidden]).astype("float32") for _ in range(num_samples)] ... self.labels = [np.array(index, dtype="float32") for index in range(num_samples)] ... def __getitem__(self, index): ... return self.inputs[index], self.labels[index] ... def __len__(self): ... return self.num_samples >>> class MlpModel(paddle.nn.Layer): ... def __init__(self): ... super(MlpModel, self).__init__() ... self.w0 = dist.shard_tensor( ... self.create_parameter(shape=[HIDDLE_SIZE, HIDDLE_SIZE]), ... mesh0, [dist.Replicate(), dist.Shard(1)]) ... self.w1 = dist.shard_tensor( ... self.create_parameter(shape=[HIDDLE_SIZE, HIDDLE_SIZE]), ... mesh1, [dist.Replicate(), dist.Shard(0)]) ... def forward(self, x): ... y = paddle.matmul(x, self.w0) ... y = dist.reshard(y, mesh1, [dist.Shard(0), dist.Shard(2)]) ... z = paddle.matmul(y, self.w1) ... return z >>> model = MlpModel() >>> dataset = RandomDataset(4, 8) >>> sampler = BatchSampler( ... dataset, ... batch_size=2, ... ) >>> dataloader = DataLoader( ... dataset, ... batch_sampler=sampler, ... ) >>> dist_dataloader = dist.shard_dataloader( ... dataloader=dataloader, ... meshes=[mesh0, mesh1], ... shard_dims="x" ... ) >>> opt = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters()) >>> dist_opt = dist.shard_optimizer(opt) >>> def loss_fn(logits, label): ... # logits: [bs, seq_len, hidden], label: [bs] ... loss = paddle.nn.MSELoss(reduction="sum") ... logits = paddle.sum(logits, axis=[1, 2]) ... return loss(logits, label) >>> RUN_STATIC = eval(os.environ['RUN_STATIC']) >>> def run_dynamic(): ... for step, (input, label) in enumerate(dist_dataloader()): ... logits = model(input) ... loss = loss_fn(logits, label) ... print("step:{}, loss:{}".format(step, loss)) ... loss.backward() ... dist_opt.step() ... dist_opt.clear_grad() >>> def run_static(): ... dist_model = dist.to_static( ... model, dist_dataloader, loss_fn, opt ... ) ... dist_model.train() ... for step, (input, label) in enumerate(dist_dataloader()): ... print("label:", label) ... loss = dist_model(input, label) ... print("step:{}, loss:{}".format(step, loss)) >>> if RUN_STATIC == 0: ... run_dynamic() ... else: ... run_static() >>> # This case need to be executed in multi-card environment >>> # export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 >>> # RUN_STATIC=1 python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" {test_case}.py >>> # RUN_STATIC=0 python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" {test_case}.py
>>> import paddle >>> import paddle.distributed as dist >>> from paddle.io import BatchSampler, DataLoader, Dataset >>> import numpy as np >>> mesh0 = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['dp', 'mp']) >>> mesh1 = dist.ProcessMesh([[4, 5], [6, 7]], dim_names=['dp', 'mp']) >>> class RandomDataset(Dataset): ... def __init__(self, seq_len, hidden, num_samples=8): ... super().__init__() ... self.seq_len = seq_len ... self.hidden = hidden ... self.num_samples = num_samples ... self.inputs1 = [ ... np.random.uniform(size=[self.seq_len, self.hidden]).astype( ... "float32" ... ) ... for _ in range(num_samples) ... ] ... self.inputs2 = [ ... np.random.uniform(size=[self.seq_len, self.hidden]).astype( ... "float32" ... ) ... for _ in range(num_samples) ... ] ... self.labels = [ ... np.array(index, dtype="float32") for index in range(num_samples) ... ] ... def __getitem__(self, index): ... return { ... "inputs": [self.inputs1[index], self.inputs2[index]], ... "label": self.labels[index], ... } ... def __len__(self): ... return self.num_samples >>> dataset = RandomDataset(4, 8) >>> sampler = BatchSampler( ... dataset, ... batch_size=2, ... ) >>> dataloader = DataLoader( ... dataset, ... batch_sampler=sampler, ... ) >>> dist_dataloader = dist.shard_dataloader( ... dataloader=dataloader, ... meshes=[mesh0, mesh1], # or [[mesh0, mesh0], mesh1] ... shard_dims="dp", ... input_keys=["inputs", "label"], ... )