[ torch 参数更多 ]torch.utils.data.DistributedSampler¶
torch.utils.data.DistributedSampler¶
torch.utils.data.DistributedSampler(dataset,
num_replicas=None,
rank=None,
shuffle=True,
seed=0,
drop_last=False)
paddle.io.DistributedBatchSampler¶
paddle.io.DistributedBatchSampler(dataset=None,
batch_size,
num_replicas=None,
rank=None,
shuffle=False,
drop_last=False)
PyTorch 参数更多,具体如下:
参数映射¶
| PyTorch | PaddlePaddle | 备注 | | —– | ———- | ———- | | dataset | dataset | 被采样的数据集。 | | - | batch_size | 每 mini-batch 中包含的样本数,PyTorch 无此参数,Paddle 需设置为 1。 | | num_replicas | num_replicas | 分布式训练时的进程个数。 | | rank | rank | num_replicas 个进程中的进程序号。 | | shuffle | shuffle | 是否需要在生成样本下标时打乱顺序。与 PyTorch 默认值不同, Paddle 应设置为 True
。 | | seed | - | 如果 shuffle=True,则使用随机种子对采样器进行随机排序,此数字在分布式组中的所有进程中应相同,Paddle 无此参数,一般对网络训练结果影响不大,可直接删除。 | | drop_last | drop_last | 是否需要丢弃最后无法凑整一个 mini-batch 的样本。 |