我正在尝试使用多个 torch.utils.data.DataLoader
来创建应用了不同变换的数据集 . 目前,我的代码是粗略的
d_transforms = [
transforms.RandomHorizontalFlip(),
# Some other transforms...
]
loaders = []
for i in range(len(d_transforms)):
dataset = datasets.MNIST('./data',
train=train,
download=True,
transform=d_transforms[i]
loaders.append(
DataLoader(dataset,
shuffle=True,
pin_memory=True,
num_workers=1)
)
这有效,但速度极慢 . kernprof表明我的代码中几乎所有时间都用在像这样的行上
x, y = next(iter(train_loaders[i]))
我怀疑这是因为我正在使用 DataLoader
的多个实例,每个实例都有自己的工作程序,它试图读取相同的数据文件 .
我的问题是,有什么更好的方法呢?理想情况下,我会继承 torch.utils.data.DataSet
并指定我想在采样时应用的变换,但由于 __getitem__
无法接受参数,这似乎不可能 .
1 回答
__getitem__
确实接受了一个参数,该参数是您要加载的内容的索引 . 例如 .你不要在循环中调用数据加载器 .