load_state_dict¶
- paddle.distributed. load_state_dict ( state_dict, path, process_group=None, coordinator_rank=0, unique_id=None, offload=False, mw_name_compatibility=True ) [源代码] ¶
将指定路径的 checkpoint 加载到指定 state_dict 中。
参数¶
state_dict (dict[str, paddle.Tensor]): 要加载的 state_dict,使用原地加载方式。
path (str): checkpoint 文件所在目录。
process_group (paddle.distributed.collective.Group,可选): 用于跨 rank 同步的 ProcessGroup。默认值为 None,表示使用包含所有卡的全局 process group。
coordinator_rank (int,可选): 用于协调检查点的 Rank。默认值为 0,表示使用 Rank 0 作为协调检查点。
unique_id (int,可选): checkpoint 的唯一 ID,用于区分不同版本的检查点。默认值为 None,使用指定路径最大值加载最新版本的检查点。
offload (bool,可选): 是否 offload checkpoint 到 CPU。默认值为 False,表示不进行 offload。
mw_name_compatibility (bool,可选): 是否兼容动态图与静态图半自动并行参数的命名。默认值为 True,表示兼容。
返回¶
None
代码示例¶
>>> import paddle
>>> import paddle.distributed as dist
>>> ckpt_path = "./checkpoint"
>>> w1 = paddle.arange(32).reshape([4, 8])
>>> mesh = dist.ProcessMesh([0, 1])
>>> sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0)])
>>> state_dict = {"w1": sharded_w1}
>>> dist.save_state_dict(state_dict, ckpt_path)
>>> w1_to_load = paddle.zeros_like(w1)
>>> sharded_w1_to_load = dist.shard_tensor(w1, mesh, [dist.Replicate()])
>>> state_dict_to_load = {"w1": sharded_w1_to_load}
>>> dist.load_state_dict(state_dict_to_load, ckpt_path)
>>> print(f"state_dict_to_load:{state_dict_to_load}")
state_dict_to_load:{'w1': Tensor(shape=[4, 8], dtype=int64, place=Place(gpu:0), stop_gradient=True, dist_attr={process_mesh: {shape: [2], process_ids: [0,1], dim_names: [d0]}, dims_mappings: [-1,-1], batch_dim: 0, dynamic_dims: [0,0], annotated: [dims_mapping: 1,process_mesh: 1], partial: [].}, GlobalDenseTensor=
[[0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ],
[8 , 9 , 10, 11, 12, 13, 14, 15],
[16, 17, 18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29, 30, 31]])}