当从TensorFlow迭代器提供validation_data时,Keras似乎忽略了参数并且仍然使用训练数据 . 我的方法不正确,还是Keras中的错误?
导入tensorflow作为导入keras
def _parse_function_x(filename):
image = tf.random_uniform([tf.shape(filename)[0], 198,198,1]) # simulated image loading and manipulation
image = tf.Print(image, [filename, tf.shape(image)], " - returning image for file -")
return image
def _parse_function_y(label):
return tf.one_hot(label, 10)
def _parse_function(filename, label):
return _parse_function_x(filename), _parse_function_y(label)
flist = ["trimg1", "trimg2", "trimg3", "trimg4", "trimg5", "trimg6"]
filenames = tf.constant(flist)
labels = tf.constant([0, 5, 6, 1, 2, 3])
train_batch = 2
valid_batch = 3
train_x_dataset = tf.data.Dataset.from_tensor_slices((filenames))
train_x_dataset = train_x_dataset.repeat().batch(train_batch)
train_x_dataset = train_x_dataset.map(_parse_function_x)
it_train_x = train_x_dataset.make_one_shot_iterator()
train_y_dataset = tf.data.Dataset.from_tensor_slices((labels))
train_y_dataset = train_y_dataset.repeat().batch(train_batch)
train_y_dataset = train_y_dataset.map(_parse_function_y)
it_train_y = train_y_dataset.make_one_shot_iterator()
vlist = ["val1", "val2", "val3"]
valid_filenames = tf.constant(vlist)
valid_labels = tf.constant([3, 2, 5])
valid_dataset = tf.data.Dataset.from_tensor_slices((valid_filenames, valid_labels))
valid_dataset = valid_dataset.repeat().batch(valid_batch)
valid_dataset = valid_dataset.map(_parse_function)
it_valid = valid_dataset.make_one_shot_iterator()
model = keras.applications.resnet50.ResNet50(include_top=True, weights=None, input_tensor=it_train_x.get_next(), pooling=None, classes=10, input_shape=(198,198,1))
model.compile(optimizer='sgd',
loss='categorical_crossentropy',
metrics=['accuracy'],
target_tensors=[it_train_y.get_next()])
model.fit(steps_per_epoch=len(flist) // train_batch, epochs=5, validation_data=it_valid.get_next(),
validation_steps=len(vlist) // valid_batch, verbose=2)
结果:
2018-01-10 19:07:57.810850: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:58.372769: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg3 trimg4][2 198 198...]
2018-01-10 19:07:58.428576: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg5 trimg6][2 198 198...]
2018-01-10 19:07:58.744026: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:58.759726: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg3 trimg4][2 198 198...]
- 4s - loss: 8.7110 - acc: 0.0000e+00 - val_loss: 2.8315 - val_acc: 0.0000e+00
Epoch 2/5
2018-01-10 19:07:58.815126: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg5 trimg6][2 198 198...]
2018-01-10 19:07:58.869334: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:58.923224: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg3 trimg4][2 198 198...]
- 0s - loss: 12.0298 - acc: 0.0000e+00 - val_loss: 2.7070 - val_acc: 0.0000e+00
Epoch 3/5
2018-01-10 19:07:58.939015: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg5 trimg6][2 198 198...]
2018-01-10 19:07:59.005950: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:59.067022: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg3 trimg4][2 198 198...]
2018-01-10 19:07:59.120895: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg5 trimg6][2 198 198...]
- 0s - loss: 12.9786 - acc: 0.0000e+00 - val_loss: 3.7159 - val_acc: 0.0000e+00
Epoch 4/5
2018-01-10 19:07:59.136508: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:59.190424: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg3 trimg4][2 198 198...]
2018-01-10 19:07:59.259350: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg5 trimg6][2 198 198...]
2018-01-10 19:07:59.319021: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:59.334779: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg3 trimg4][2 198 198...]
- 0s - loss: 13.9429 - acc: 0.0000e+00 - val_loss: 5.9738 - val_acc: 0.0000e+00
Epoch 5/5
2018-01-10 19:07:59.388996: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg5 trimg6][2 198 198...]
2018-01-10 19:07:59.443311: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:59.507233: I tensorflow/core/kernels/logging_ops.cc:79] - returning image for file -[trimg3 trimg4][2 198 198...]
- 0s - loss: 11.9854 - acc: 0.1667 - val_loss: 1.3048 - val_acc: 0.5000
Process finished with exit code 0
val1,val2,...文件似乎被忽略,但是Keras以某种方式计算val_loss等 .
使用TensorFlow DataSet向Keras提供验证数据的正确方法是什么?