嗨,我下载了cifar-10数据集 .
在我的代码中,它加载数据集如下 .
import cv2
import numpy as np
from keras.datasets import cifar10
from keras import backend as K
from keras.utils import np_utils
nb_train_samples = 3000 # 3000 training samples
nb_valid_samples = 100 # 100 validation samples
num_classes = 10
def load_cifar10_data(img_rows, img_cols):
# Load cifar10 training and validation sets
(X_train, Y_train), (X_valid, Y_valid) = cifar10.load_data()
# Resize trainging images
if K.image_dim_ordering() == 'th':
X_train = np.array([cv2.resize(img.transpose(1,2,0), (img_rows,img_cols)).transpose(2,0,1) for img in X_train[:nb_train_samples,:,:,:]])
X_valid = np.array([cv2.resize(img.transpose(1,2,0), (img_rows,img_cols)).transpose(2,0,1) for img in X_valid[:nb_valid_samples,:,:,:]])
else:
X_train = np.array([cv2.resize(img, (img_rows,img_cols)) for img in X_train[:nb_train_samples,:,:,:]])
X_valid = np.array([cv2.resize(img, (img_rows,img_cols)) for img in X_valid[:nb_valid_samples,:,:,:]])
# Transform targets to keras compatible format
Y_train = np_utils.to_categorical(Y_train[:nb_train_samples], num_classes)
Y_valid = np_utils.to_categorical(Y_valid[:nb_valid_samples], num_classes)
return X_train, Y_train, X_valid, Y_valid
但是下载数据集需要很长时间 . 相反,我手动下载了'cifar-10-python.tar.gz' . 那么如何将其加载到变量(X_train,Y_train),(X_valid,Y_valid)而不是使用cifar10.load_data()?
1 回答
请原谅我的英语 . 我也试图手动加载cifar-10数据集 . 在以下代码中,我将
cifar-10-python.tar.gz
解压缩到一个文件夹,并将文件夹中的文件data_batch_1
加载到4个数组中:x_train
,y_train
,x_test
,y_test
. 20%的data_batch_1
用于验证x_test
和y_test
,其余用于训练为x_train
和y_train
.