parallelize

paddle.distributed. parallelize ( model: paddle.nn.Layer, optimizer: paddle.optimizer.Optimizer | None = None, mesh: paddle.distributed.ProcessMesh | None = None, config: _ParallelizeConfig | None = None ) tuple[paddle.nn.Layer, paddle.optimizer.Optimizer] [source]

Parallelize the model and optimizer from a single card version to a distributed version.

Parameters
  • model (paddle.nn.Layer) – the model to be parallelized.

  • optimizer (paddle.optimizer.Optimizer, optional) – the optimizer to be parallelized. Could be None if no optimizer to be parallelized.

  • mesh (paddle.distributed.ProcessMesh, optional) – the process mesh for parallelize the model and the optimizer. Best practice: calling dist.auto_parallel.set_mesh to set the global mesh ahead of calling parallelize and keep the mesh parameter as None. If the `mesh is not None, the mesh passed to parallelize will overwrite the mesh set by set_mesh.

  • config (dict, optional) –

    a dict contains the parallel config. The keys of the dict can be chosen from dp_config, mp_config and pp_config which will be used to determine the parallel method for data parallel, tensor parallel and pipeline parallel separately. A valid config can be like this: {“dp_config”: for more information refer the dp_config section of this doc, “mp_config”: for more information refer the mp_config section of this doc, “pp_config”: for more information refer the pp_config section of this doc}.

    dp_config (dict): a dict specifying the data parallel config. The keys of dp_config is sharding_level.

    The value of sharding_level can be chosen from 0/1/2/3, which means pure data parallel, sharding parallel stage 1, sharding parallel stage 2 and sharding parallel stage 3 separately. A valid dp_config can be like this: {“sharding_level”: 2}.

    mp_config (dict): a dict specifying the tensor parallel config. The keys of mp_config is

    parallelize_plan. The value of parallelize_plan is another dict, mapping a layer name or a param name to a specific parallel plan. Note that the layer name could be written in regular format. If mapping a param name to a specific plan, the name of the param must be ended with weight or bias. And all valid parallel plan is ColWiseParallel, RowWiseParallel, SequenceParallelBegin, `SequenceParallelDisable, SequenceParallelEnable, SequenceParallelEnd, PrepareLayerInput and PrepareLayerOutput. A valid mp_config can be like this: {“llama.embed_tokens”: dist.ColWiseParallel(), “llama.norm”: dist.SequenceParallelEnable(), “lm_head.weight”: dist.ColWiseParallel()}.

    pp_config (dict): a dict specifying the pipeline parallel config. The keys of pp_config is split_spec

    and global_spec. The split_spec can be a dict or a string. If the split_spec is a dict, it maps a layer name to a SplitPoint, note that the layer name could be written in regular format. The pipeline parallel will exactly split the model at the point indicated by the map. If the split_spec is a string, it contains the prefix of a set of layers. The pipeline parallel will automatically split the model evenly at target layer. The global_spec is a string indicating a layer that contains global tensors, which will be duplicated through all stages of the pipeline parallel. Some valid pp_config can be list these: {“split_spec”: “llama.layers”, “global_spec”: “llama.global_layer”} or {“split_spec”: {“llama.layers.1”: SplitPoint.END}}.

Note

If the mesh is None or neither of dp_config, mp_config and pp_config is in the config, this api will do nothing but return the model and optimizer passed in.

Returns

the model and the optimizer after parallelize

Return type

model, optimizer

Examples

>>> import paddle
>>> import paddle.distributed as dist

>>> class ModelConfig:
...     def __init__(self):
...         self.vocab_size = 10
...         self.hidden_size = 20
...         self.intermediate_size = 20
...         self.num_layers = 2

>>> model_config = ModelConfig()

>>> class LlamaRMSNorm(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.weight = paddle.create_parameter(
...             shape=[model_config.hidden_size],
...             dtype=paddle.get_default_dtype(),
...         )
...
...     def forward(self, input):
...         pass

>>> class LlamaAttention(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...
...         self.qkv_proj = paddle.nn.Linear(
...             model_config.hidden_size,
...             model_config.hidden_size * 3,
...             bias_attr=False,
...         )
...
...         self.o_proj = paddle.nn.Linear(
...             model_config.hidden_size,
...             model_config.hidden_size,
...             bias_attr=False,
...         )
...
...     def forward(self, input):
...         pass

>>> class LlamaMLP(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.gate_up_proj = paddle.nn.Linear(
...             model_config.hidden_size,
...             model_config.intermediate_size * 2,
...             bias_attr=False
...         )
...
...         self.down_proj = paddle.nn.Linear(
...             model_config.intermediate_size, model_config.hidden_size, bias_attr=False
...         )
...
...     def forward(self, input):
...         pass

>>> class LlamaDecoderLayer(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.self_attn = LlamaAttention()
...         self.mlp = LlamaMLP()
...         self.input_layernorm = LlamaRMSNorm()
...         self.post_attention_layernorm = LlamaRMSNorm()
...
...     def forward(self, input):
...         pass

>>> class LlamaModel(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.embedding = paddle.nn.Embedding(model_config.vocab_size, model_config.hidden_size)
...         decoder_layers = []
...         for _ in range(model_config.num_layers):
...             decoder_layers.append(LlamaDecoderLayer())
...
...         self.layers = paddle.nn.LayerList(decoder_layers)
...         self.norm = LlamaRMSNorm()
...
...     def forward(self, input):
...         pass

>>> class LlamaLMHead(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.weight = self.create_parameter(
...             shape=[model_config.hidden_size, model_config.vocab_size],
...             dtype=paddle.get_default_dtype(),
...         )
...
...     def forward(self, input):
...         pass

>>> class LlamaForCausalLM(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.llama = LlamaModel()
...         self.lm_head = LlamaLMHead()
...
...     def forward(self, input):
...         pass

>>> mesh = dist.ProcessMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dim_names=["dp", "mp", "pp"])
>>> dist.auto_parallel.set_mesh(mesh)
>>> parallel_config = {
...     "dp_config": {'sharding_level': 1},
...     "mp_config": {
...         "parallelize_plan": {
...             "llama.embed_tokens": [
...                 dist.ColWiseParallel(),
...                 dist.SequenceParallelBegin(),
...             ],
...             "llama.position_embedding": [
...                 dist.ColWiseParallel(),
...                 dist.SequenceParallelBegin(),
...             ],
...             "llama.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(),
...             "llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(),
...             "llama.layers.*.self_attn": dist.SequenceParallelDisable(),
...             "llama.layers.*.mlp.gate_up_proj": dist.ColWiseParallel(),
...             "llama.layers.*.mlp.down_proj": dist.RowWiseParallel(),
...             "llama.layers.*.mlp": dist.SequenceParallelDisable(
...                 need_transpose=False
...             ),
...             "lm_head.weight": dist.ColWiseParallel(),
...             "lm_head": dist.SequenceParallelEnd(),
...         }
...     },
...     "pp_config": {'split_spec': "llama.layers"}
... }

>>> 
>>> model = LlamaForCausalLM()
>>> optimizer = paddle.optimizer.AdamW(parameters=model.parameters())
>>> dist_model, dist_optimizer = dist.parallelize(model, optimizer, config=parallel_config) # type: ignore[arg-type]
>>> # This case need to be executed in multi-card environment
>>> # python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 {test_case}.py