我对Tensorflow很新,并且一直在努力通过阅读tensorflow.org上的指南和文档来学习基础知识 .

我已经学习了如何使用 tf.datatf.estimator API的基础知识,并试图让他们在MNIST的基本分类模型上一起工作 .

我正在使用此脚本加载MNIST:https://github.com/tensorflow/models/blob/master/official/mnist/dataset.py

我对数据集函数进行了修改,以返回特征字典而不是向量:

def dataset(directory, images_file, labels_file):
  """Download and parse MNIST dataset."""

  images_file = download(directory, images_file)
  labels_file = download(directory, labels_file)

  check_image_file_header(images_file)
  check_labels_file_header(labels_file)

  def decode_image(image):
    # Normalize from [0, 255] to [0.0, 1.0]
    image = tf.decode_raw(image, tf.uint8)
    image = tf.cast(image, tf.float32)
    image = tf.reshape(image, [784])
    return image / 255.0

  def decode_label(label):
    label = tf.decode_raw(label, tf.uint8)  # tf.string -> [tf.uint8]
    label = tf.reshape(label, [])  # label is a scalar
    return tf.to_int32(label)

  images = tf.data.FixedLengthRecordDataset(
      images_file, 28 * 28, header_bytes=16).map(decode_image)
  labels = tf.data.FixedLengthRecordDataset(
      labels_file, 1, header_bytes=8).map(decode_label)
  return tf.data.Dataset.zip(({"image":images}, labels))

我在mf中使用premade估算器的MNIST分类器脚本如下:

import tensorflow as tf
import dataset

fc = [tf.feature_column.numeric_column("image", shape=784)]

mnist_classifier = tf.estimator.DNNClassifier(
    hidden_units=[512,512],
    feature_columns=fc,
    model_dir="models/mnist/dnn",
    n_classes=10)

def input_fn(train=False, batch_size=None):
    if train:
        ds = mnist.train("MNIST-data")
        ds = ds.shuffle(1000).repeat().batch(batch_size)
    else:
        ds = mnist.test("MNIST-data")
    return ds

mnist_classifier.train(
  input_fn=lambda:input_fn(True, 32),
  steps=10000)

eval_results = mnist_classifier.evaluate(input_fn=lambda:input_fn())

分类器在训练时不会崩溃,但在评估时,我面临以下追溯:

ValueError:无法使用输入形状重塑784个元素的张量,以形成[dnn / input_from_feature_columns / input_layer / image / Reshape'(op:'Reshape')的[784,784](614656个元素):[784,1],[2并且输入张量计算为部分形状:输入[1] = [784,784] .

什么可能导致这个问题?

我已经尝试打印列车和测试数据集的输出形状和类型,它们完全相同 .

我也尝试在tensorboard上查看模型,只有投影仪选项卡可用,没有标量或图形标签 .

谢谢!

PS:使用数据集和估算器API的TF教程的任何链接也都很棒 .