save_state_dict

paddle.distributed. save_state_dict ( state_dict: dict[str, Tensor], path: str, process_group: Group | None = None, coordinator_rank: int = 0, unique_id: int | None = None, async_save: bool = False ) None [source]

Save the state_dict of model to path.

Parameters
  • state_dict (Dict[str, paddle.Tensor]) – The state_dict to save.

  • path (str) – The directory to save state_dict.

  • process_group (paddle.distributed.collective.Group) – ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards.

  • coordinator_rank (int) – The rank used to save non distributed values. Rank 0 is used by default.

  • unique_id (int) – The unique id of ckeckpoint, used to distinguish between different checkpoint versions. Default is None, in which case the id 0 when save for the first time and increased by 1 each time when calling save_state_dict in the same path. If unique_id is given and there is already checkpoint with the same unique_id, it will be overrited.

  • async_save (bool) – Async save the state_dict, default is False.

Examples

>>> 
>>> import paddle
>>> import paddle.distributed as dist
>>> w1 = paddle.arange(32).reshape([4, 8])
>>> mesh = dist.ProcessMesh([0, 1])
>>> sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0), dist.Replicate()])
>>> state_dict = {"w1": sharded_w1}
>>> dist.save_state_dict(state_dict, "./checkpoint")
>>>