首页 文章

如何将numpy数组的迭代器提供给tensorflow Estimator / Evaluable

提问于
浏览
0

我有一个迭代器函数,它产生一批特性并标记为numpy数组的元组 .

def batch_iter():for ...:yield(np_features,np_labels)

然后我尝试像张量估算器一样

# the cnn_model_fn will print out shapes of various tensor when
# constructing the model
classifier = learn.Estimator(
    model_fn=cnn_model_fn, model_dir="/tmp/convnet_model")
for train_data, train_labels in batch_iter():
    classifier.fit(
        input_fn=lambda: (tf.constant(train_data), tf.constant(train_labels)),
        steps=1,
        monitors=[logging_hook])

(带注释的)日志看起来像

conv1 shape (100, 16, 20, 32)
pool1 shape (100, 8, 10, 32)
conv2 shape (100, 8, 10, 64)
pool2 shape (100, 4, 5, 64)
onehot label shape (100, 5)
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Saving checkpoints for 1 into /tmp/convnet_model/model.ckpt.    # checkpoint is saved in every iteration
INFO:tensorflow:step = 1, loss = 1618.76
INFO:tensorflow:Loss for final step: 1618.76.
conv1 shape (100, 16, 20, 32)    # the model_fn is called in every iteration
pool1 shape (100, 8, 10, 32)
conv2 shape (100, 8, 10, 64)
pool2 shape (100, 4, 5, 64)
onehot label shape (100, 5)
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from /tmp/convnet_model/model.ckpt-1  # checkpoint is restored in every iteration
INFO:tensorflow:Saving checkpoints for 2 into /tmp/convnet_model/model.ckpt.
INFO:tensorflow:step = 2, loss = 69370.6
INFO:tensorflow:Loss for final step: 69370.6.
conv1 shape (100, 16, 20, 32)
pool1 shape (100, 8, 10, 32)
conv2 shape (100, 8, 10, 64)
pool2 shape (100, 4, 5, 64)
onehot label shape (100, 5)
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from /tmp/convnet_model/model.ckpt-2
INFO:tensorflow:Saving checkpoints for 3 into /tmp/convnet_model/model.ckpt.
INFO:tensorflow:step = 3, loss = 289303.0
INFO:tensorflow:Loss for final step: 289303.0.
...

读取批次并且随着循环的迭代,损失会下降 . 但是,似乎在每次迭代中都会保存和恢复检查点,并在每次迭代中调用model_fn . 所以我觉得这不对 .

将迭代器提供给Estimator / Evaluable的正确方法是什么?

1 回答

  • 1

    在你的input_fn中你可以使用 tf.contrib.training.python_input

相关问题