首页 文章

PyTorch DataLoader

提问于
浏览 13
2

我正在尝试使用多个 torch.utils.data.DataLoader 来创建应用了不同变换的数据集 . 目前,我的代码是粗略的

d_transforms = [
    transforms.RandomHorizontalFlip(),
    # Some other transforms...
]
loaders = []
for i in range(len(d_transforms)):
    dataset = datasets.MNIST('./data', 
            train=train, 
            download=True, 
            transform=d_transforms[i]
    loaders.append(
        DataLoader(dataset, 
            shuffle=True, 
            pin_memory=True, 
            num_workers=1)
        )

这有效,但速度极慢 . kernprof表明我的代码中几乎所有时间都用在像这样的行上

x, y = next(iter(train_loaders[i]))

我怀疑这是因为我正在使用 DataLoader 的多个实例,每个实例都有自己的工作程序,它试图读取相同的数据文件 .

我的问题是,有什么更好的方法呢?理想情况下,我会继承 torch.utils.data.DataSet 并指定我想在采样时应用的变换,但由于 __getitem__ 无法接受参数,这似乎不可能 .

1 回答

  • 2

    __getitem__ 确实接受了一个参数,该参数是您要加载的内容的索引 . 例如 .

    transform = transforms.Compose(
        [transforms.ToTensor(),
         normalize])
    
    class CountDataset(Dataset):
    
    def __init__(self, file,transform=None):
    
        self.transform = transform
        #self.vocab = vocab
        with open(file,'rb') as f:
            self.data = pickle.load(f)
        self.y = self.data['answers']
        self.I = self.data['images']
    
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        img_name = self.I[idx]
        label = self.y[Idx]
        fname = '/'.join(img_name.split("/")[-2:]) #/train2014/xx.jpg
        DIR = '/hdd/manoj/VQA/Images/mscoco/'
        img_full_path = os.path.join(DIR,fname)
        img = Image.open(img_full_path).convert("RGB")
        img_tensor = self.transform(img.resize((224,224)))
        return img_tensor,label
    
    
    testset = CountDataset(file = 'testdat.pkl',
                            transform = transform)
    
    
    testloader = DataLoader(testset, batch_size=32,
                             shuffle=False, num_workers=4)
    

    你不要在循环中调用数据加载器 .

相关问题