load_state_dict¶
- paddle.distributed. load_state_dict ( state_dict: dict[str, Tensor], path: str, process_group: Group | None = None, coordinator_rank: int = 0, unique_id: int | None = None, offload: bool = False, mw_name_compatibility: bool = True ) None [source]
-
Load the state_dict inplace from a checkpoint path.
- Parameters
-
state_dict (Dict[str, paddle.Tensor]) – The state_dict to load. It will be modified inplace after loading.
path (str) – The directory to load checkpoint files.
process_group (paddle.distributed.collective.Group) – ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards.
coordinator_rank (int) – The rank used to coordinate the checkpoint. Rank0 is used by default.
unique_id (int) – The unique id of ckeckpoint, used to distinguish between different checkpoint versions. Default is None, in which case the id the max id of given path, and the newest version checkpoint is loaded.
offload (bool) – Whether to offload the checkpoint data from GPU to CPU.
mw_name_compatibility (bool) – Enable name compatibility between dynamic and static graph semi-automatic parallel. Default is True.
Example
>>> >>> import paddle >>> import paddle.distributed as dist >>> ckpt_path = "./checkpoint" >>> w1 = paddle.arange(32).reshape([4, 8]) >>> mesh = dist.ProcessMesh([0, 1]) >>> sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0)]) >>> state_dict = {"w1": sharded_w1} >>> dist.save_state_dict(state_dict, ckpt_path) >>> w1_to_load = paddle.zeros_like(w1) >>> sharded_w1_to_load = dist.shard_tensor(w1, mesh, [dist.Replicate()]) >>> state_dict_to_load = {"w1": sharded_w1_to_load} >>> dist.load_state_dict(state_dict_to_load, ckpt_path) >>> print(f"state_dict_to_load:{state_dict_to_load}") state_dict_to_load:{'w1': Tensor(shape=[4, 8], dtype=int64, place=Place(gpu:0), stop_gradient=True, dist_attr={process_mesh: {shape: [2], process_ids: [0,1], dim_names: [d0]}, dims_mappings: [-1,-1], batch_dim: 0, dynamic_dims: [0,0], annotated: [dims_mapping: 1,process_mesh: 1], partial: [].}, GlobalDenseTensor= [[0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], [8 , 9 , 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31]])} >>>