我已经使用tensorflow高级API tf.estimator.Estimator基于此Tutorial创建了张量流模型 . 我没有从互联网上下载MNIST数据,而是从本地CSV文件中读取它 . 它工作正常 . 但是,当我迁移代码以使用tf.contrib.learn.Estimator时,我遇到了预测数量的问题 . 我的test_data是28000张图片,但由于模型预测,我得到了28032个预测值 .
预测代码:
predict_input_fn = tf.contrib.learn.io.numpy_input_fn(
x={"x": test_data},
num_epochs=1,
shuffle=False)
predictions = list(mnist_classifier.predict(input_fn=predict_input_fn))
test_data是形状的数组(28000,28,28,1),但len(预测)= 28032任何人都知道可能是什么问题?