load_state_dict¶
- paddle.distributed. load_state_dict ( state_dict, path, process_group=None, coordinator_rank=0 ) None [source]
-
Load the state_dict inplace from a checkpoint path.
- Parameters
-
state_dict (Dict[str, paddle.Tensor]) – The state_dict to load. It will be modified inplace after loading.
path (str) – The directory to load checkpoint files.
process_group (paddle.distributed.collective.Group) – ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards.
coordinator_rank (int) – The rank used to coordinate the checkpoint. Rank0 is used by default.
Example
>>> >>> import paddle >>> import paddle.distributed as dist >>> ckpt_path = "./checkpoint" >>> w1 = paddle.arange(32).reshape([4, 8]) >>> mesh = dist.ProcessMesh([0, 1]) >>> sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0)]) >>> state_dict = {"w1": sharded_w1} >>> dist.save_state_dict(state_dict, ckpt_path) >>> w1_to_load = paddle.zeros_like(w1) >>> sharded_w1_to_load = dist.shard_tensor(w1, mesh, [dist.Replicate()]) >>> state_dict_to_load = {"w1": sharded_w1_to_load} >>> dist.load_state_dict(state_dict_to_load, ckpt_path) >>> print(f"state_dict_to_load:{state_dict_to_load}") state_dict_to_load:{'w1': Tensor(shape=[4, 8], dtype=int64, place=Place(gpu:0), stop_gradient=True, dist_attr={process_mesh: {shape: [2], process_ids: [0,1], dim_names: [d0]}, dims_mappings: [-1,-1], batch_dim: 0, dynamic_dims: [0,0], annotated: [dims_mapping: 1,process_mesh: 1], partial: [].}, GlobalDenseTensor= [[0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], [8 , 9 , 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31]])} >>>