save_state_dict¶
- paddle.distributed. save_state_dict ( state_dict, path, process_group=None, coordinator_rank=0 ) None [source]
-
Save the state_dict of model to path.
- Parameters
-
state_dict (Dict[str, paddle.Tensor]) – The state_dict to save.
path (str) – The directory to save state_dict.
process_group (paddle.distributed.collective.Group) – ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards.
coordinator_rank (int) – The rank used to save non distributed values. Rank0 is used by default.
Examples
>>> >>> import paddle >>> import paddle.distributed as dist >>> w1 = paddle.arange(32).reshape([4, 8]) >>> mesh = dist.ProcessMesh([0, 1]) >>> sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0), dist.Replicate()]) >>> state_dict = {"w1": sharded_w1} >>> dist.save_state_dict(state_dict, "./checkpoint") >>>