parallelize¶
- paddle.distributed. parallelize ( model: paddle.nn.Layer, optimizer: paddle.optimizer.Optimizer | None = None, mesh: paddle.distributed.ProcessMesh | None = None, config: _ParallelizeConfig | None = None ) tuple[paddle.nn.Layer, paddle.optimizer.Optimizer] [source]
-
Parallelize the model and optimizer from a single card version to a distributed version.
- Parameters
-
model (paddle.nn.Layer) – the model to be parallelized.
optimizer (paddle.optimizer.Optimizer, optional) – the optimizer to be parallelized. Could be None if no optimizer to be parallelized.
mesh (paddle.distributed.ProcessMesh, optional) – the process mesh for parallelize the model and the optimizer. Best practice: calling dist.auto_parallel.set_mesh to set the global mesh ahead of calling parallelize and keep the mesh parameter as None. If the `mesh is not None, the mesh passed to parallelize will overwrite the mesh set by set_mesh.
config (dict, optional) –
a dict contains the parallel config. The keys of the dict can be chosen from dp_config, mp_config and pp_config which will be used to determine the parallel method for data parallel, tensor parallel and pipeline parallel separately. A valid config can be like this: {“dp_config”: for more information refer the dp_config section of this doc, “mp_config”: for more information refer the mp_config section of this doc, “pp_config”: for more information refer the pp_config section of this doc}.
- dp_config (dict): a dict specifying the data parallel config. The keys of dp_config is sharding_level.
-
The value of sharding_level can be chosen from 0/1/2/3, which means pure data parallel, sharding parallel stage 1, sharding parallel stage 2 and sharding parallel stage 3 separately. A valid dp_config can be like this: {“sharding_level”: 2}.
- mp_config (dict): a dict specifying the tensor parallel config. The keys of mp_config is
-
parallelize_plan. The value of parallelize_plan is another dict, mapping a layer name or a param name to a specific parallel plan. Note that the layer name could be written in regular format. If mapping a param name to a specific plan, the name of the param must be ended with weight or bias. And all valid parallel plan is ColWiseParallel, RowWiseParallel, SequenceParallelBegin, `SequenceParallelDisable, SequenceParallelEnable, SequenceParallelEnd, PrepareLayerInput and PrepareLayerOutput. A valid mp_config can be like this: {“llama.embed_tokens”: dist.ColWiseParallel(), “llama.norm”: dist.SequenceParallelEnable(), “lm_head.weight”: dist.ColWiseParallel()}.
- pp_config (dict): a dict specifying the pipeline parallel config. The keys of pp_config is split_spec
-
and global_spec. The split_spec can be a dict or a string. If the split_spec is a dict, it maps a layer name to a SplitPoint, note that the layer name could be written in regular format. The pipeline parallel will exactly split the model at the point indicated by the map. If the split_spec is a string, it contains the prefix of a set of layers. The pipeline parallel will automatically split the model evenly at target layer. The global_spec is a string indicating a layer that contains global tensors, which will be duplicated through all stages of the pipeline parallel. Some valid pp_config can be list these: {“split_spec”: “llama.layers”, “global_spec”: “llama.global_layer”} or {“split_spec”: {“llama.layers.1”: SplitPoint.END}}.
Note
If the mesh is None or neither of dp_config, mp_config and pp_config is in the config, this api will do nothing but return the model and optimizer passed in.
- Returns
-
the model and the optimizer after parallelize
- Return type
-
model, optimizer
Examples
>>> import paddle >>> import paddle.distributed as dist >>> class ModelConfig: ... def __init__(self): ... self.vocab_size = 10 ... self.hidden_size = 20 ... self.intermediate_size = 20 ... self.num_layers = 2 >>> model_config = ModelConfig() >>> class LlamaRMSNorm(paddle.nn.Layer): ... def __init__(self): ... super().__init__() ... self.weight = paddle.create_parameter( ... shape=[model_config.hidden_size], ... dtype=paddle.get_default_dtype(), ... ) ... ... def forward(self, input): ... pass >>> class LlamaAttention(paddle.nn.Layer): ... def __init__(self): ... super().__init__() ... ... self.qkv_proj = paddle.nn.Linear( ... model_config.hidden_size, ... model_config.hidden_size * 3, ... bias_attr=False, ... ) ... ... self.o_proj = paddle.nn.Linear( ... model_config.hidden_size, ... model_config.hidden_size, ... bias_attr=False, ... ) ... ... def forward(self, input): ... pass >>> class LlamaMLP(paddle.nn.Layer): ... def __init__(self): ... super().__init__() ... self.gate_up_proj = paddle.nn.Linear( ... model_config.hidden_size, ... model_config.intermediate_size * 2, ... bias_attr=False ... ) ... ... self.down_proj = paddle.nn.Linear( ... model_config.intermediate_size, model_config.hidden_size, bias_attr=False ... ) ... ... def forward(self, input): ... pass >>> class LlamaDecoderLayer(paddle.nn.Layer): ... def __init__(self): ... super().__init__() ... self.self_attn = LlamaAttention() ... self.mlp = LlamaMLP() ... self.input_layernorm = LlamaRMSNorm() ... self.post_attention_layernorm = LlamaRMSNorm() ... ... def forward(self, input): ... pass >>> class LlamaModel(paddle.nn.Layer): ... def __init__(self): ... super().__init__() ... self.embedding = paddle.nn.Embedding(model_config.vocab_size, model_config.hidden_size) ... decoder_layers = [] ... for _ in range(model_config.num_layers): ... decoder_layers.append(LlamaDecoderLayer()) ... ... self.layers = paddle.nn.LayerList(decoder_layers) ... self.norm = LlamaRMSNorm() ... ... def forward(self, input): ... pass >>> class LlamaLMHead(paddle.nn.Layer): ... def __init__(self): ... super().__init__() ... self.weight = self.create_parameter( ... shape=[model_config.hidden_size, model_config.vocab_size], ... dtype=paddle.get_default_dtype(), ... ) ... ... def forward(self, input): ... pass >>> class LlamaForCausalLM(paddle.nn.Layer): ... def __init__(self): ... super().__init__() ... self.llama = LlamaModel() ... self.lm_head = LlamaLMHead() ... ... def forward(self, input): ... pass >>> mesh = dist.ProcessMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dim_names=["dp", "mp", "pp"]) >>> dist.auto_parallel.set_mesh(mesh) >>> parallel_config = { ... "dp_config": {'sharding_level': 1}, ... "mp_config": { ... "parallelize_plan": { ... "llama.embed_tokens": [ ... dist.ColWiseParallel(), ... dist.SequenceParallelBegin(), ... ], ... "llama.position_embedding": [ ... dist.ColWiseParallel(), ... dist.SequenceParallelBegin(), ... ], ... "llama.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(), ... "llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(), ... "llama.layers.*.self_attn": dist.SequenceParallelDisable(), ... "llama.layers.*.mlp.gate_up_proj": dist.ColWiseParallel(), ... "llama.layers.*.mlp.down_proj": dist.RowWiseParallel(), ... "llama.layers.*.mlp": dist.SequenceParallelDisable( ... need_transpose=False ... ), ... "lm_head.weight": dist.ColWiseParallel(), ... "lm_head": dist.SequenceParallelEnd(), ... } ... }, ... "pp_config": {'split_spec': "llama.layers"} ... } >>> >>> model = LlamaForCausalLM() >>> optimizer = paddle.optimizer.AdamW(parameters=model.parameters()) >>> dist_model, dist_optimizer = dist.parallelize(model, optimizer, config=parallel_config) # type: ignore[arg-type] >>> # This case need to be executed in multi-card environment >>> # python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 {test_case}.py