首页 文章

获取pytorch数据集的子集

提问于
浏览
1

我有一个网络,我想在一些数据集上训练(例如,说 CIFAR10 ) . 我可以通过创建数据加载器对象

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

我的问题如下:假设我想进行几次不同的训练迭代 . 假设我首先想要在奇数位置的所有图像上训练网络,然后在偶数位置的所有图像上训练网络,依此类推 . 为此,我需要能够访问这些图像 . 不幸的是,似乎 trainset 不允许这样的访问 . 也就是说,尝试 trainset[:1000] 或更多 trainset[mask] 会抛出错误 .

我可以做

trainset.train_data=trainset.train_data[mask]
trainset.train_labels=trainset.train_labels[mask]

然后

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              shuffle=True, num_workers=2)

但是,这将迫使我在每次迭代中创建完整数据集的新副本(因为我已经更改了 trainset.train_data 所以我需要重新定义 trainset ) . 有没有办法避免它?

理想情况下,我希望有一些“等同”的东西

trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4,
                                              shuffle=True, num_workers=2)

1 回答

  • 3

    您可以为数据集加载器定义自定义采样器,避免重新创建数据集(只需为每个不同的采样创建一个新的加载器) .

    class YourSampler(Sampler):
        def __init__(self, mask):
            self.mask = mask
    
        def __iter__(self):
            return (self.indices[i] for i in torch.nonzero(self.mask))
    
        def __len__(self):
            return len(self.mask)
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    
    sampler1 = YourSampler(your_mask)
    sampler2 = YourSampler(your_other_mask)
    trainloader_sampler1 = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              sampler = sampler1, shuffle=False, num_workers=2)
    trainloader_sampler2 = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              sampler = sampler2, shuffle=False, num_workers=2)
    

    PS:我没有检查代码 .

    PS2:你可以在这里找到更多信息:http://pytorch.org/docs/master/_modules/torch/utils/data/sampler.html#Sampler

相关问题