[ torch 参数更多 ]torch.utils.data.DataLoader¶
torch.utils.data.DataLoader¶
torch.utils.data.DataLoader(dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor=2,
persistent_workers=False,
pin_memory_device='')
paddle.io.DataLoader¶
paddle.io.DataLoader(dataset,
feed_list=None,
places=None,
return_list=False,
batch_sampler=None,
batch_size=1,
shuffle=False,
drop_last=False,
collate_fn=None,
num_workers=0,
use_buffer_reader=True,
use_shared_memory=False,
timeout=0,
worker_init_fn=None)
参数映射¶
| PyTorch | PaddlePaddle | 备注 | | ————- | ———— | —————————————————— | | dataset | dataset | 表示数据集。 | | batch_size | batch_size | 每 mini-batch 中样本个数。 | | shuffle | shuffle | 生成 mini-batch 索引列表时是否对索引打乱顺序。 | | sampler | - | 表示数据集采集器,Paddle 无此参数,暂无转写方式。 | | batch_sampler | batch_sampler | mini-batch 索引列表。 | | num_workers | num_workers | 用于加载数据的子进程个数。 | | collate_fn | collate_fn | 用于指定如何将样本列表组合为 mini-batch 数据。 | | pin_memory | - | 表示数据最开始是属于锁页内存,Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。 | | drop_last | drop_last | 是否丢弃因数据集样本数不能被 batch_size 整除而产生的最后一个不完整的 mini-batch。 | | timeout | timeout | 从子进程输出队列获取 mini-batch 数据的超时时间。 | | worker_init_fn | worker_init_fn | 子进程初始化函数。 | | multiprocessing_context | - | 用于设置多进程的上下文,Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。 | | generator | - | 用于采样的伪随机数生成器,Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。 | | prefetch_factor | - | 表示每个 worker 预先加载的数据数量,Paddle 无此参数,暂无转写方式。 | | persistent_workers | - | 表示数据集使用一次后,数据加载器将会不会关闭工作进程,Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。 | | pin_memory_device | - | 数据加载器是否在返回 Tensor 之前将 Tensor 复制到设备固定存储器中,Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。 | | - | feed_list | 表示 feed 变量列表,PyTorch 无此参数,Paddle 保持默认即可 | | - | places | 数据需要放置到的 Place 列表,PyTorch 无此参数,Paddle 保持默认即可 | | - | return_list | 每个设备上的数据是否以 list 形式返回,PyTorch 无此参数,Paddle 保持默认即可 | | - | use_buffer_reader | 表示是否使用缓存读取器,PyTorch 无此参数,Paddle 保持默认即可 | | - | use_shared_memory | 表示是否使用共享内存来提升子进程将数据放入进程间队列的速度,PyTorch 无此参数,Paddle 保持默认即可 |