首页 文章

转换不适用于数据集

提问于
浏览
2

我是pytorch的新手,想要了解一些东西 .

我正在加载MNIST如下:

transform_train = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize(size, interpolation=2),
     # transforms.Grayscale(num_output_channels=1),
     transforms.RandomHorizontalFlip(p=0.5),
     transforms.Normalize((mean), (std))])


trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

然而,当我探索数据集,即 trainloader.dataset.train_data[0] 时,我得到了一个范围[0,255]的张量与形状(28,28) .

我错过了什么?这是因为转换不是直接应用于数据加载器而是仅在运行时?如何以其他方式浏览我的数据?

1 回答

  • 2

    当调用 Dataset__getitem__ 方法时,将应用变换 . 例如,查看 MNIST 数据集类的 __getitem__ 方法:https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py#L62

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]
    
        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')
    
        if self.transform is not None:
            img = self.transform(img)
    
        if self.target_transform is not None:
            target = self.target_transform(target)
    
        return img, target
    

    当您为训练集索引 MNIST 实例时,会调用 __getitem__ 方法,例如:

    trainset[0]
    

    有关 __getitem__ 的更多信息:https://docs.python.org/3.6/reference/datamodel.html#object.getitem

    ResizeRandomHorizontalFlip 应该在 ToTensor 之前的原因是它们作用于PIL Images并且Pytorch中的所有数据集一致性首先将数据加载为 PIL Image . 事实上,你可以看到他们在这里通过以下方式强制执行:

    img = Image.fromarray(img.numpy(), mode='L')
    

    获得相应索引的 PIL Image 后,将应用变换

    if self.transform is not None:
        img = self.transform(img)
    

    ToTensorPIL Image 转换为 torch.TensorNormalize 减去均值并除以您提供的标准差 .

    最终,一些变换应用于标签

    if self.target_transform is not None:
        target = self.target_transform(target)
    

    最后返回处理过的图像和处理过的标签 . 所有这些都发生在一个 trainset[key] 电话中 .

    import torch
    from torchvision.transforms import *
    from torchvision.datasets import MNIST
    from torch.utils.data import DataLoader
    
    transform_train = Compose([Resize(28, interpolation=2),
                               RandomHorizontalFlip(p=0.5),
                               ToTensor(),
                               Normalize([0.], [1.])])
    
    trainset = MNIST(root='./data', train=True, download=True,
                     transform=transform_train)
    trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
    print(trainset[0][0].size(), trainset[0][0].min(), trainset[0][0].max())
    

    节目

    (torch.Size([1, 28, 28]), tensor(0.), tensor(1.))
    

相关问题