首页 文章

cifar10.load_data()需要很长时间才能下载数据

提问于
浏览
2

嗨,我下载了cifar-10数据集 .

在我的代码中,它加载数据集如下 .

import cv2
import numpy as np

from keras.datasets import cifar10
from keras import backend as K
from keras.utils import np_utils

nb_train_samples = 3000 # 3000 training samples
nb_valid_samples = 100 # 100 validation samples
num_classes = 10

def load_cifar10_data(img_rows, img_cols):

    # Load cifar10 training and validation sets
    (X_train, Y_train), (X_valid, Y_valid) = cifar10.load_data()

    # Resize trainging images
    if K.image_dim_ordering() == 'th':
        X_train = np.array([cv2.resize(img.transpose(1,2,0), (img_rows,img_cols)).transpose(2,0,1) for img in X_train[:nb_train_samples,:,:,:]])
        X_valid = np.array([cv2.resize(img.transpose(1,2,0), (img_rows,img_cols)).transpose(2,0,1) for img in X_valid[:nb_valid_samples,:,:,:]])
    else:
        X_train = np.array([cv2.resize(img, (img_rows,img_cols)) for img in X_train[:nb_train_samples,:,:,:]])
        X_valid = np.array([cv2.resize(img, (img_rows,img_cols)) for img in X_valid[:nb_valid_samples,:,:,:]])

    # Transform targets to keras compatible format
    Y_train = np_utils.to_categorical(Y_train[:nb_train_samples], num_classes)
    Y_valid = np_utils.to_categorical(Y_valid[:nb_valid_samples], num_classes)

    return X_train, Y_train, X_valid, Y_valid

但是下载数据集需要很长时间 . 相反,我手动下载了'cifar-10-python.tar.gz' . 那么如何将其加载到变量(X_train,Y_train),(X_valid,Y_valid)而不是使用cifar10.load_data()?

1 回答

  • 0

    请原谅我的英语 . 我也试图手动加载cifar-10数据集 . 在以下代码中,我将 cifar-10-python.tar.gz 解压缩到一个文件夹,并将文件夹中的文件 data_batch_1 加载到4个数组中: x_trainy_trainx_testy_test . 20%的 data_batch_1 用于验证 x_testy_test ,其余用于训练为 x_trainy_train .

    import pickle
    import numpy
    # load data
    with open('cifar-10-batches-py\\data_batch_1','rb') as f:
        dict1 = pickle.load(f,encoding='bytes')
    
    x = dict1[b'data']
    x = x.reshape(len(x), 3, 32, 32).astype('float32')
    
    y = numpy.asarray(dict1[b'labels'])
    
    x_test = x[0:int(0.2 * x.shape[0]), :, :, :]
    y_test = y[0:int(0.2 * y.shape[0])]
    x_train = x[int(0.2 * x.shape[0]):x.shape[0], :, :, :]
    y_train = y[int(0.2 * y.shape[0]):y.shape[0]]
    

相关问题