RandomSampler¶
- class paddle.io. RandomSampler ( data_source, replacement=False, num_samples=None, generator=None ) [源代码] ¶
随机迭代样本,产生重排下标,如果 replacement = False
,则会采样整个数据集;如果 replacement = True
,则会按照 num_samples
指定的样本数采集。
参数¶
data_source (Dataset) - 此参数必须是 Dataset 或 IterableDataset 的一个子类实例或实现了
__len__
的 Python 对象,用于生成样本下标。默认值为 None。replacement (bool,可选) - 如果为
False
则会采样整个数据集,如果为True
则会按num_samples
指定的样本数采集。默认值为False
。num_samples (int,可选) - 按此参数采集对应的样本数。默认值为 None,此时设为
data_source
的长度。generator (Generator,可选) - 指定采样
data_source
的采样器。默认值为 None,不启用。
返回¶
RandomSampler,返回随机采样下标的采样器
代码示例¶
>>> import numpy as np
>>> from paddle.io import Dataset, RandomSampler
>>> np.random.seed(2023)
>>> class RandomDataset(Dataset):
... def __init__(self, num_samples):
... self.num_samples = num_samples
...
... def __getitem__(self, idx):
... image = np.random.random([784]).astype('float32')
... label = np.random.randint(0, 9, (1, )).astype('int64')
... return image, label
...
... def __len__(self):
... return self.num_samples
...
>>> sampler = RandomSampler(data_source=RandomDataset(100))
>>> for index in sampler:
... print(index)
56
12
68
...
87