MNIST的TensorFlow文档推荐了多种不同的加载MNIST数据集的方法:
-
https://www.tensorflow.org/versions/r1.2/get_started/mnist/beginners
-
https://www.tensorflow.org/versions/r1.2/get_started/mnist/pros
All ways described in the documentation throw many deprecated warnings with TensorFlow 1.8.
我正在加载MNIST并创建培训批次的方式:
class MNIST:
def __init__(self, optimizer):
...
self.mnist_dataset = input_data.read_data_sets("/tmp/data/", one_hot=True)
self.test_data = self.mnist_dataset.test.images.reshape((-1, self.timesteps, self.num_input))
self.test_label = self.mnist_dataset.test.labels
...
def train_run(self, sess):
batch_input, batch_output = self.mnist_dataset.train.next_batch(self.batch_size, shuffle=True)
batch_input = batch_input.reshape((self.batch_size, self.timesteps, self.num_input))
_, loss = sess.run(fetches=[self.train_step, self.loss], feed_dict={self.input_placeholder: batch_input, self.output_placeholder: batch_output})
...
def test_run(self, sess):
loss = sess.run(fetches=[self.loss], feed_dict={self.input_placeholder: self.test_data, self.output_placeholder: self.test_label})
...
How could I do exactly the same thing, just with the current method of doing this?
我找不到任何关于此的文件 .
在我看来,新的方式是:
train, test = tf.keras.datasets.mnist.load_data()
self.mnist_train_ds = tf.data.Dataset.from_tensor_slices(train)
self.mnist_test_ds = tf.data.Dataset.from_tensor_slices(test)
但是如何在 train_run
和 test_run
方法中使用这些数据集?
1 回答
使用
TF dataset API
加载MNIST数据集的示例:Create a mnist dataset to load train, valid and test images:
您可以使用
Dataset.from_tensor_slices
或Dataset.from_generator
为numpy输入创建dataset
.Dataset.from_tensor_slices
将整个数据集添加到计算图中,因此我们将使用Dataset.from_generator
.A feedable iterator that can toggle between training and validation
A sample run: