[ torch 参数更多 ]torch.utils.data.DataLoader¶
torch.utils.data.DataLoader¶
torch.utils.data.DataLoader(dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor=2,
persistent_workers=False,
pin_memory_device='')
paddle.io.DataLoader¶
paddle.io.DataLoader(dataset,
feed_list=None,
places=None,
return_list=False,
batch_sampler=None,
batch_size=1,
shuffle=False,
drop_last=False,
collate_fn=None,
num_workers=0,
use_buffer_reader=True,
use_shared_memory=False,
timeout=0,
worker_init_fn=None)
参数映射¶
| PyTorch | PaddlePaddle | 备注 | | ————- | ———— | —————————————————— | | sampler | - | 表示数据集采集器,Paddle 无此参数。 | | pin_memory | - | 表示数据最开始是属于锁页内存,Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。 | | multiprocessing_context | - | ?,Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。 | | generator | - | 用于采样的伪随机数生成器,Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。 | | prefetch_factor | - | 表示每个 worker 预先加载的数据数量,Paddle 无此参数。 | | persistent_workers | - | 表示数据集使用一次后,数据加载器将会不会关闭工作进程,Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。 | | pin_memory_device | - | 数据加载器是否在返回 Tensor 之前将 Tensor 复制到设备固定存储器中,Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。 | | - | feed_list | 表示 feed 变量列表,PyTorch 无此参数。 | | - | places | 数据需要放置到的 Place 列表,PyTorch 无此参数。 | | - | return_list | 每个设备上的数据是否以 list 形式返回,PyTorch 无此参数。 | | - | use_buffer_reader | 表示是否使用缓存读取器,PyTorch 无此参数。 | | - | use_shared_memory | 表示是否使用共享内存来提升子进程将数据放入进程间队列的速度,PyTorch 无此参数。 |
功能差异¶
自定义数据采集器¶
PyTorch:可通过设置sampler
自定义数据采集器。 PaddlePaddle:PaddlePaddle 无此功能,可使用如下代码自定义一个 DataLoader 实现该功能。
class DataLoader(paddle.io.DataLoader):
def __init__(self,
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
generator=None):
if isinstance(dataset[0], (tuple, list)):
return_list = True
else:
return_list = False
super().__init__(
dataset,
feed_list=None,
places=None,
return_list=return_list,
batch_sampler=batch_sampler,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
collate_fn=collate_fn,
num_workers=num_workers,
use_buffer_reader=True,
use_shared_memory=False,
timeout=timeout,
worker_init_fn=worker_init_fn)
if sampler is not None:
self.batch_sampler.sampler = sampler