当从TensorFlow迭代器提供validation_data时,Keras似乎忽略了参数并且仍然使用训练数据 . 我的方法不正确,还是Keras中的错误?

导入tensorflow作为导入keras

def _parse_function_x(filename):
  image = tf.random_uniform([tf.shape(filename)[0], 198,198,1]) # simulated image loading and manipulation
  image = tf.Print(image, [filename, tf.shape(image)], " - returning image for file -")
  return image

def _parse_function_y(label):
    return tf.one_hot(label, 10)

def _parse_function(filename, label):
    return _parse_function_x(filename), _parse_function_y(label)

flist = ["trimg1", "trimg2", "trimg3", "trimg4", "trimg5", "trimg6"]

filenames = tf.constant(flist)
labels = tf.constant([0, 5, 6, 1, 2, 3])

train_batch = 2
valid_batch = 3

train_x_dataset = tf.data.Dataset.from_tensor_slices((filenames))
train_x_dataset = train_x_dataset.repeat().batch(train_batch)
train_x_dataset = train_x_dataset.map(_parse_function_x)
it_train_x = train_x_dataset.make_one_shot_iterator()

train_y_dataset = tf.data.Dataset.from_tensor_slices((labels))
train_y_dataset = train_y_dataset.repeat().batch(train_batch)
train_y_dataset = train_y_dataset.map(_parse_function_y)
it_train_y = train_y_dataset.make_one_shot_iterator()

vlist = ["val1", "val2", "val3"]

valid_filenames = tf.constant(vlist)
valid_labels = tf.constant([3, 2, 5])
valid_dataset = tf.data.Dataset.from_tensor_slices((valid_filenames, valid_labels))
valid_dataset = valid_dataset.repeat().batch(valid_batch)
valid_dataset = valid_dataset.map(_parse_function)
it_valid = valid_dataset.make_one_shot_iterator()

model = keras.applications.resnet50.ResNet50(include_top=True, weights=None, input_tensor=it_train_x.get_next(), pooling=None, classes=10, input_shape=(198,198,1))

model.compile(optimizer='sgd',
              loss='categorical_crossentropy',
              metrics=['accuracy'],
              target_tensors=[it_train_y.get_next()])


model.fit(steps_per_epoch=len(flist) // train_batch, epochs=5, validation_data=it_valid.get_next(),
          validation_steps=len(vlist)  // valid_batch, verbose=2)

结果:

2018-01-10 19:07:57.810850: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:58.372769: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg3 trimg4][2 198 198...]
2018-01-10 19:07:58.428576: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg5 trimg6][2 198 198...]
2018-01-10 19:07:58.744026: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:58.759726: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg3 trimg4][2 198 198...]
 - 4s - loss: 8.7110 - acc: 0.0000e+00 - val_loss: 2.8315 - val_acc: 0.0000e+00
Epoch 2/5
2018-01-10 19:07:58.815126: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg5 trimg6][2 198 198...]
2018-01-10 19:07:58.869334: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:58.923224: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg3 trimg4][2 198 198...]
 - 0s - loss: 12.0298 - acc: 0.0000e+00 - val_loss: 2.7070 - val_acc: 0.0000e+00
Epoch 3/5
2018-01-10 19:07:58.939015: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg5 trimg6][2 198 198...]
2018-01-10 19:07:59.005950: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:59.067022: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg3 trimg4][2 198 198...]
2018-01-10 19:07:59.120895: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg5 trimg6][2 198 198...]
 - 0s - loss: 12.9786 - acc: 0.0000e+00 - val_loss: 3.7159 - val_acc: 0.0000e+00
Epoch 4/5
2018-01-10 19:07:59.136508: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:59.190424: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg3 trimg4][2 198 198...]
2018-01-10 19:07:59.259350: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg5 trimg6][2 198 198...]
2018-01-10 19:07:59.319021: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:59.334779: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg3 trimg4][2 198 198...]
 - 0s - loss: 13.9429 - acc: 0.0000e+00 - val_loss: 5.9738 - val_acc: 0.0000e+00
Epoch 5/5
2018-01-10 19:07:59.388996: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg5 trimg6][2 198 198...]
2018-01-10 19:07:59.443311: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg1 trimg2][2 198 198...]
2018-01-10 19:07:59.507233: I tensorflow/core/kernels/logging_ops.cc:79]  - returning image for file -[trimg3 trimg4][2 198 198...]
 - 0s - loss: 11.9854 - acc: 0.1667 - val_loss: 1.3048 - val_acc: 0.5000

Process finished with exit code 0

val1,val2,...文件似乎被忽略,但是Keras以某种方式计算val_loss等 .

使用TensorFlow DataSet向Keras提供验证数据的正确方法是什么?