我正在努力使用tf.estimator.Estimator训练方法训练我的cnn与整个TFRecord数据集 .

我试图在循环中运行火车如下:

estimator = tf.estimator.Estimator(
    model_fn=model_fn, model_dir=MODEL_FOLDER)
input_fn = generate_input_fn(path, [], batch_size=128,
                             shuffle=True, num_epochs=None)
while True:
    estimator.train(
        input_fn=input_fn, steps=1, hooks=[logging_hook])

我的input_fn看起来像这样:

def generate_input_fn(file_pattern, given_labels, batch_size=1,
                      num_epochs=None, shuffle=False):
    def _input_fn():
        print("_input_fn: file pattern: " + file_pattern)

        filenames_tensor = tf.train.match_filenames_once(file_pattern)
        filename_queue = tf.train.string_input_producer(
            filenames_tensor,
            num_epochs=num_epochs,
            shuffle=shuffle)

        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)

        features = tf.parse_single_example(
            serialized_example,
            features={
                'image/width': tf.FixedLenFeature([], tf.int64),
                'image/height': tf.FixedLenFeature([], tf.int64),
                'image/class/label': tf.FixedLenFeature([LABELS_SIZE], tf.int64),
                'image/encoded': tf.FixedLenFeature([], tf.string),
                'image/format': tf.FixedLenFeature([], tf.string),
                'image/name': tf.FixedLenFeature([], tf.string)
            })

        labels = features['image/class/label']
        filename = features['image/name']

        image = tf.image.decode_jpeg(
            features["image/encoded"], channels=IMAGE_CHANNELS)
        image.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS])

        image = tf.image.resize_image_with_crop_or_pad(
            image, IMAGE_HEIGHT, IMAGE_WIDTH)

        image_batch, batch_labels, filename_batch = tf.train.shuffle_batch(
            [image, labels, filename],
            batch_size,
            num_threads=8,
            capacity=5000,
            min_after_dequeue=1000
            # allow_smaller_final_batch=True
        )

        # so that the "center" of the image range is roughly 0.
        image_batch = tf.to_float(image_batch) / 255
        image_batch = (image_batch * 2) - 1

        features = {
            "image": image_batch,
            "filename": filename_batch
        }

        return features, batch_labels
    return _input_fn

在我的model_fn中,我有以下代码:

logits = tf.Print(logits, [logits], "Logits: ")
features['filename'] = tf.Print(features['filename'], [features['filename']], 'Filename: ')
tf.summary.text('filename', features['filename'])

但是当我在我的model_fn中打印文件名时,似乎每次运行都会得到相同的批处理 . 到目前为止,我已尝试:*更改步骤 - 但它不打印文件名而不是(只有logits)??? *尝试将读者移动到generate_input_fn的外部范围,但是它表示输入张量来自不同的图形

我知道自己做错了什么吗?谢谢你的帮助!