ConcatDataset¶
将多个数据集拼接为一个。
此 API 可用于集成多个不同的数据集。
参数¶
datasets (sequence) - 待拼接的数据集序列。
返回¶
Dataset,由 datasets
拼接而成的数据集。
代码示例¶
>>> import numpy as np
>>> import paddle
>>> from paddle.io import Dataset, ConcatDataset
>>> # define a random dataset
>>> class RandomDataset(Dataset):
... def __init__(self, num_samples):
... self.num_samples = num_samples
...
... def __getitem__(self, idx):
... image = np.random.random([32]).astype('float32')
... label = np.random.randint(0, 9, (1, )).astype('int64')
... return image, label
...
... def __len__(self):
... return self.num_samples
...
>>> dataset = ConcatDataset([RandomDataset(10), RandomDataset(10)])
>>> for i in range(len(dataset)):
... image, label = dataset[i]
... # do something