to_distributed

paddle.distributed. to_distributed ( model: paddle.nn.Layer, optimizer: paddle.optimizer.Optimizer, dataloader: paddle.io.DataLoader, device_num: int, node_num: int | None = 1, config: ToDistributedConfig | None = None ) tuple[paddle.nn.Layer, paddle.optimizer.Optimizer, paddle.distributed.auto_parallel.ShardDataloader] [source]

to_distributed can automatically convert neural networks, optimizer, and dataloader that do not contain any distributed code into neural networks, optimizers, and dataloader that are suitable for distributed training and ensure their correctness. At the same time, during the transformation process, the optimal distributed strategy will be automatically selected based on node_num and device_num to maximize performance.

Parameters
  • model (paddle.nn.Layer) – The model in dygraph mode, whose parameters are ordinary tensors, do not contain any distributed code. If one device has sufficient memory, it can train directly.

  • optimizer (paddle.optimizer.Optimizer) – The optimizer for training. one instance of a regular optimizer, e.g. paddle.optimizer.Adam etc.

  • dataloader (paddle.io.DataLoader) – The dataloader used in dygraph mode, It is instantiated through regular paddle.io.Dataset and paddle.io.Sampler, not paddle.io.DistributedBatchSampler.

  • device_num (int) – the number of devices on each node or machine.

  • node_num (int|None, optional) – the number of nodes or machines.

  • config (ToDistributedConfig| None = None) – Configs for input_spec and sequence_parallel. The custom input specs specify the most likely shape, dtype, and name information of each model inputs. If it is not None, the input specs and will be inferred from the custom input specs. If it is None, will use default with shape of [BATCH_SIZE=4, SEQ_LENGTH=1024], The custom input specs should be a list of paddle.static.InputSpec. Default: None. sequence_parallel indicates whether to use sequence parallel. Default: False.

Returns

model. The model in dygraph mode but contain distributed attributes.

optimizer. The optimizer for training and may be sharded states.

dataloader. The dataloader can be used in distributed training.

Examples

>>> 
>>> import math
>>> import numpy as np
>>> import paddle
>>> import paddle.nn.functional as F
>>> from paddle import nn
>>> from paddle.distributed import to_distributed
>>> from paddle.distributed.auto_parallel.high_level_api import ToDistributedConfig

>>> EPOCHS = 1
>>> VOCAB_SIZE = 8000
>>> BATCH_NUM = 2
>>> BATCH_SIZE = 4
>>> HIDDEN_SIZE = 2048
>>> INTERMEDIATE_SIZE = 4096
>>> SEQ_LENGTH = 1024
>>> N_HEAD = 32
>>> NUM_HIDDEN_LAYERS = 4
>>> class RandomDataset(paddle.io.Dataset): # type: ignore[type-arg]
...     def __init__(self, inputs, labels, num_samples):
...         self.inputs = inputs
...         self.labels = labels
...         self.num_samples = num_samples
...     def __getitem__(self, idx):
...         return self.inputs[idx], self.labels[idx]
...     def __len__(self):
...         return self.num_samples

>>> class RotaryEmbedding(nn.Layer):
...     def __init__(self, dim, max_position_embeddings=2048, base=10000):
...         super().__init__()
...         self.dim = dim
...         self.max_position_embeddings = max_position_embeddings
...         self.base = base
...         self.inv_freq = 1.0 / (
...             self.base ** (
...                 paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32")
...                 / self.dim
...             )
...         )
...         self._set_cos_sin_cache(seq_len=max_position_embeddings)

...     def _set_cos_sin_cache(self, seq_len):
...         self.max_seq_len_cached = seq_len
...         t = paddle.arange(seq_len, dtype="float32")
...         freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
...         emb = paddle.concat([freqs, freqs], axis=-1)
...         self.cos_cached = emb.cos()[None, :, None, :]
...         self.sin_cached = emb.sin()[None, :, None, :]

...     def forward(self, x, seq_len=None):
...         cos = self.cos_cached[:, :seq_len, :, :]
...         sin = self.sin_cached[:, :seq_len, :, :]
...         return (
...             cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
...             sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
...         )

>>> def rotate_half(x):
...     x1 = x[..., : x.shape[-1] // 2]
...     x2 = x[..., x.shape[-1] // 2 :]
...     return paddle.concat([-x2, x1], axis=-1)

>>> def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
...     if position_ids is None:
...         cos = cos[:, : q.shape[1], :, :]
...         sin = sin[:, : q.shape[1], :, :]
...     else:
...         cos = cos.squeeze(axis=[0, 2])
...         sin = sin.squeeze(axis=[0, 2])
...         cos = cos[position_ids].unsqueeze(2)
...         sin = sin[position_ids].unsqueeze(2)
...     q_embed = (q * cos) + (rotate_half(q) * sin)
...     k_embed = (k * cos) + (rotate_half(k) * sin)
...     return q_embed, k_embed

>>> def scaled_dot_product_attention(
...     query_states,
...     key_states,
...     value_states,
...     attention_mask,
... ):
...     bsz, q_len, num_heads, head_dim = query_states.shape
...     _, kv_seq_len, _, _ = value_states.shape
...     query_states = paddle.transpose(query_states, [0, 2, 1, 3])
...     key_states = paddle.transpose(key_states, [0, 2, 1, 3])
...     value_states = paddle.transpose(value_states, [0, 2, 1, 3])
...     attn_weights = paddle.matmul(
...         query_states / math.sqrt(head_dim), key_states.transpose([0, 1, 3, 2])
...     )
...     attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len])
...     attn_weights = attn_weights + attention_mask
...     if not paddle.in_dynamic_mode():
...         attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(
...             query_states.dtype
...         )
...     else:
...         with paddle.amp.auto_cast(False):
...             attn_weights = F.softmax(
...                 attn_weights, axis=-1, dtype="float32"
...             ).astype(query_states.dtype)
...     attn_output = paddle.matmul(attn_weights, value_states)
...     attn_output = attn_output.transpose([0, 2, 1, 3])
...     attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
...     return attn_output

>>> class Attention(nn.Layer):
...     def __init__(self, hidden_size=HIDDEN_SIZE, n_head=N_HEAD):
...         super().__init__()
...         self.hidden_size = hidden_size
...         self.num_heads = n_head
...         self.head_dim = hidden_size // n_head
...         self.q_proj = nn.Linear(
...             hidden_size, hidden_size, bias_attr=False
...         )
...         self.k_proj = nn.Linear(
...             hidden_size, hidden_size, bias_attr=False
...         )
...         self.v_proj = nn.Linear(
...             hidden_size, hidden_size, bias_attr=False
...         )
...         self.o_proj = nn.Linear(
...             hidden_size, hidden_size, bias_attr=False
...         )
...         self.rotary_emb = RotaryEmbedding(
...             self.head_dim, max_position_embeddings=SEQ_LENGTH, base=10000
...         )

...     def forward(
...         self,
...         hidden_states,
...         position_ids=None,
...         attention_mask=None,
...     ):
...         query_states = self.q_proj(hidden_states)
...         key_states = self.k_proj(hidden_states)
...         value_states = self.v_proj(hidden_states)
...         target_query_shape = [0, 0, self.num_heads, self.head_dim]
...         target_key_value_shape = [0, 0, self.num_heads, self.head_dim]
...         query_states = query_states.reshape(shape=target_query_shape)
...         key_states = key_states.reshape(shape=target_key_value_shape)
...         value_states = value_states.reshape(shape=target_key_value_shape)
...         kv_seq_len = key_states.shape[-3]
...         cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
...         query_states, key_states = apply_rotary_pos_emb(
...             query_states, key_states, cos, sin, position_ids
...         )
...         output = scaled_dot_product_attention(
...             query_states,
...             key_states,
...             value_states,
...             attention_mask,
...         )
...         attn_output = output
...         attn_output = self.o_proj(attn_output)
...         return attn_output

>>> class Mlp(nn.Layer):
...     def __init__(
...         self,
...         hidden_size=HIDDEN_SIZE,
...         intermediate_size=INTERMEDIATE_SIZE,
...     ):
...         super().__init__()
...         self.hidden_size = hidden_size
...         self.intermediate_size = intermediate_size
...         self.gate_proj = nn.Linear(
...             hidden_size, intermediate_size, bias_attr=False
...         )
...         self.up_proj = nn.Linear(
...             hidden_size, intermediate_size, bias_attr=False
...         )
...         self.down_proj = nn.Linear(
...             intermediate_size, hidden_size, bias_attr=False
...         )

...     def forward(self, x):
...         x = paddle.incubate.nn.functional.swiglu(
...             self.gate_proj(x), self.up_proj(x)
...         )
...         out = self.down_proj(x)
...         return out

>>> class RMSNorm(nn.Layer):
...     def __init__(self, hidden_size=HIDDEN_SIZE):
...         super().__init__()
...         self.hidden_size = hidden_size
...         self.weight = paddle.create_parameter(
...             shape=[self.hidden_size],
...             dtype=paddle.get_default_dtype(),
...             default_initializer=nn.initializer.Constant(1.0),
...         )
...         self.variance_epsilon = 1.0

...     def forward(self, hidden_states):
...         with paddle.amp.auto_cast(False):
...             hidden_states = hidden_states.astype("float32")
...             variance = hidden_states.pow(2).mean(-1, keepdim=True)
...             hidden_states = (
...                 paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
...             )
...         if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
...             hidden_states = paddle.cast(hidden_states, self.weight.dtype)
...         return hidden_states * self.weight

>>> class DecoderLayer(nn.Layer):
...     def __init__(
...         self,
...         hidden_size=HIDDEN_SIZE,
...         intermediate_size=INTERMEDIATE_SIZE,
...     ):
...         super().__init__()
...         self.hidden_size = hidden_size
...         self.intermediate_size = intermediate_size
...         self.self_attn = Attention(hidden_size)
...         self.mlp = Mlp()
...         self.input_layernorm = RMSNorm(hidden_size)
...         self.post_attn_layernorm = RMSNorm(hidden_size)

...     def forward(
...         self,
...         hidden_states,
...         position_ids=None,
...         attention_mask=None,
...     ):
...         residual = hidden_states
...         hidden_states = self.input_layernorm(hidden_states)
...         hidden_states = self.self_attn(
...             hidden_states, position_ids, attention_mask
...         )
...         hidden_states = residual + hidden_states
...         residual = hidden_states
...         hidden_states = self.post_attn_layernorm(hidden_states)
...         hidden_states = self.mlp(hidden_states)
...         hidden_states = residual + hidden_states
...         return hidden_states

>>> def _prepare_decoder_attention_mask(
...     attention_mask, input_shape, dtype
... ):
...     batch_size, src_length = attention_mask.shape[0], attention_mask.shape[-1]
...     batch_size, target_length = input_shape
...     attention_mask = attention_mask[:, None, None, :].astype("bool")
...     attention_mask.stop_gradient = True
...     expanded_attn_mask = attention_mask.expand([batch_size, 1, target_length, src_length])
...     mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool"))
...     combined_attention_mask = mask[None, None, :, :].expand(
...         [batch_size, 1, target_length, target_length]
...     )
...     expanded_attn_mask = (expanded_attn_mask & combined_attention_mask)
...     expanded_attn_mask = paddle.where(
...         expanded_attn_mask, 0.0, paddle.finfo(dtype).min
...     ).astype(dtype)
...     return expanded_attn_mask

>>> class Model(nn.Layer):
...     def __init__(
...         self,
...         vocab_size=VOCAB_SIZE,
...         hidden_size=HIDDEN_SIZE,
...         intermediate_size=INTERMEDIATE_SIZE,
...     ):
...         super().__init__()
...         self.vocab_size = vocab_size
...         self.hidden_size = hidden_size
...         self.intermediate_size = intermediate_size
...         self.embed_tokens = nn.Embedding(
...             vocab_size,
...             hidden_size,
...         )
...         self.layers = nn.LayerList(
...             [
...                 DecoderLayer()
...                 for i in range(NUM_HIDDEN_LAYERS)
...             ]
...         )
...         self.norm = RMSNorm(hidden_size)
...         self.weight = self.create_parameter(
...             shape=[hidden_size, vocab_size],
...             dtype=paddle.get_default_dtype(),
...         )
...         self.ignore_index = -100
...         self.loss_func = paddle.nn.CrossEntropyLoss(
...             reduction="none", ignore_index=self.ignore_index
...         )

...     def forward(
...         self,
...         input_ids=None,
...         position_ids=None,
...         attention_mask=None,
...         labels=None,
...     ):
...         batch_size, seq_length = input_ids.shape
...         inputs_embeds = self.embed_tokens(input_ids)
...         attention_mask = paddle.ones(
...             (batch_size, seq_length), dtype=paddle.bool
...         )
...         if position_ids is None:
...             position_ids = paddle.arange(seq_length, dtype="int64").expand(
...                 (batch_size, seq_length)
...             )
...         attention_mask = _prepare_decoder_attention_mask(
...             attention_mask,
...             (batch_size, seq_length),
...             inputs_embeds.dtype,
...         )
...         hidden_states = inputs_embeds
...         for idx, (decoder_layer) in enumerate(self.layers):
...             layer_outputs = decoder_layer(
...                 hidden_states,
...                 position_ids,
...                 attention_mask,
...             )
...             hidden_states = layer_outputs
...         hidden_states = self.norm(hidden_states)
...         logits = paddle.matmul(hidden_states, self.weight)
...         loss = None
...         if labels is not None:
...             masked_lm_loss = self.loss_func(
...                 logits.astype("float32"),
...                 labels.unsqueeze(2),
...             )
...             binary_sequence = paddle.where(
...                 masked_lm_loss > 0,
...                 paddle.ones_like(masked_lm_loss),
...                 paddle.zeros_like(masked_lm_loss),
...             )
...             count = paddle.sum(binary_sequence)
...             if count == 0:
...                 loss = paddle.sum(masked_lm_loss * binary_sequence)
...             else:
...                 loss = paddle.sum(masked_lm_loss * binary_sequence) / count
...         return (loss, logits)

>>> model = Model() # There is no distributed code or markup in Model
>>> input_seqs = np.random.randint(
...     low=0, high=1024, size=(BATCH_SIZE * BATCH_NUM, SEQ_LENGTH)
... ).astype("int64")
>>> labels = np.random.randint(
...     low=0, high=1024, size=(BATCH_SIZE * BATCH_NUM, SEQ_LENGTH)
... ).astype("int64")
>>> dataset = RandomDataset(
...     input_seqs, labels, BATCH_SIZE * BATCH_NUM
... )
>>> sampler = paddle.io.BatchSampler(
...     dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True
... )
>>> loader = paddle.io.DataLoader(
...     dataset, batch_sampler=sampler
... )
>>> opt = paddle.optimizer.SGD(
...     learning_rate=0.1, parameters=model.parameters()
... )
>>> input_seq_spec = paddle.static.InputSpec(
...     [BATCH_SIZE, SEQ_LENGTH], 'float32', 'input_seq', True
... )
>>> dist_config = ToDistributedConfig()
>>> dist_config.sequence_parallel = True

>>> # wrap model, opt, dataloader by using **to_distributed**
>>> dist_model, dist_opt, dist_loader = to_distributed(
...     model,
...     opt,
...     loader,
...     device_num=8,
...     node_num=1,
...     config=dist_config,
... )

>>> for epoch in range(EPOCHS):
...     dist_model.train()
...     for i, data in enumerate(dist_loader()):
...         inputs, labels = data
...         loss, _ = dist_model(inputs, labels=labels)
...         print(f"epoch {epoch}, step {i}: loss {loss}")
...         loss.backward()
...         dist_opt.step()
...         dist_opt.clear_grad()
>>> # This case need to be executed in multi-card environment
>>> # python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 {test_case}.py