[ 参数不一致 ]torch.utils.data.random_split¶
torch.utils.data.random_split¶
torch.utils.data.random_split(dataset,
lengths,
generator=<torch._C.Generator object>)
paddle.io.random_split¶
paddle.io.random_split(dataset,
lengths,
generator=None)
两者参数除 lengths 外用法一致,具体如下:
参数差异¶
| PyTorch | PaddlePaddle | 备注 | | ————- | ———— |———————————————————————| | dataset | dataset | 表示可迭代数据集。 | | lengths | lengths | PyTorch:可为子集合长度列表,列表总和为原数组长度。也可为子集合所占比例列表,列表总和为 1.0。PaddlePaddle: 子集合长度列表,列表总和为原数组长度 | | generator | generator | 指定采样 data_source 的采样器。默认值为 None。 |
转写示例¶
lenghts: 子集合长度列表¶
# Pytorch 写法
lengths = [0.3, 0.3, 0.4]
datasets = torch.utils.data.random_split(dataset,
lengths,
generator=torch.manual_seed(0))
# Paddle 写法
lengths = [0.3, 0.3, 0.4]
lengths = [length * len(dataset) for length in lengths]
datasets = paddle.io.random_split(dataset,
lengths,
generator=paddle.seed(0))