首页 文章

Estimator预测无限循环

提问于
浏览
6

我不明白如何使用TensorFlow Estimator API进行单一预测 - 我的代码导致无限循环,不断预测相同的输入 .

根据documentation,当input_fn引发StopIteration异常时,预测应该停止:

input_fn:返回功能的输入函数,它是Tensor或SparseTensor的字符串功能名称字典 . 如果它返回一个元组,则第一个项目被提取为特征 . 预测将继续,直到input_fn引发输入结束异常(OutOfRangeError或StopIteration) .

这是我的代码中的相关部分:

classifier = tf.estimator.Estimator(model_fn=image_classifier, model_dir=output_dir,
                                    config=training_config, params=hparams)

def make_predict_input_fn(filename):
    queue = [ filename ]
    def _input_fn():
        if len(queue) == 0:
            raise StopIteration
        image = model.read_and_preprocess(queue.pop())
        return {'image': image}
    return _input_fn

predictions = classifier.predict(make_predict_input_fn('garden-rose-red-pink-56866.jpeg'))
for i, p in enumerate(predictions):
    print("Prediction %s: %s" % (i + 1, p["class"]))

我错过了什么?

1 回答

  • 0

    那是因为input_fn()需要是一个生成器 . 将您的函数更改为(yield而不是return):

    def make_predict_input_fn(filename):
        queue = [ filename ]
        def _input_fn():
            if len(queue) == 0:
                raise StopIteration
            image = model.read_and_preprocess(queue.pop())
            yield {'image': image}
        return _input_fn
    

相关问题