首页 文章

仅当模型显示TensorFlow的改进时才保存模型检查点

提问于
浏览
1

您是否知道在使用实验中包含的Estimator时是否有办法选择保存哪个模型?因为每个'save_checkpoints_steps',模型都会被保存,但这个模型不一定是最好的 .

def model_fn(features, labels, mode, params):
    predict = model_predict_()
    loss = model_loss()
    train_op = model_train_op(loss, mode)       
    predictions = {"predictions": predict}

    return tf.estimator.EstimatorSpec(
        mode = mode,
        predictions = predictions,
        loss = loss,
        train_op = train_op,
    )

def experiment_fn(run_config, hparams):
    estimator = tf.estimator.Estimator(
        model_fn = model_fn, 
        config = run_config,
        params = hparams
    )

    return learn.Experiment(
      estimator = estimator,
      train_input_fn = train_input_fn,
      eval_input_fn = eval_input_fn,
      eval_metrics = None,
      train_steps = 1000,
    )

ex = learn_runner.run(
        experiment_fn = experiment_fn,
        run_config = run_config,
        schedule = "train_and_evaluate",
        hparams =  hparams
)

输出如下:

INFO:tensorflow:将401的检查点保存到 . \ model.ckpt中 . INFO:tensorflow:global_step / sec:0.157117 INFO:tensorflow:step = 401,loss = 2.95048(636.468 sec)INFO:tensorflow:在2017-09-05-20:06:07开始评估INFO:tensorflow:从中恢复参数 . \ model.ckpt-401 INFO:tensorflow:Evaluation [1/1] INFO:tensorflow:在2017-09-05-20:06:09完成评估信息:tensorflow:为全局步骤401保存dict:global_step = 401,损失= 7.20411 INFO:tensorflow:验证(步骤401):global_step = 401,loss = 7.20411 INFO:tensorflow:training loss = 2.95048,step = 401(315.393 sec)INFO:tensorflow:将451的检查点保存到 . \ model.ckpt中 . 信息:tensorflow:在2017-09-05-20:11:32开始评估INFO:tensorflow:从 . \ model.ckpt-451恢复参数信息:tensorflow:评估[1/1]

你会看到每次它保存最后一个模型,这不一定是最好的 .

1 回答

  • 2

    为您的训练过程中断的事件保存检查点 . 如果您没有检查点,则需要从头开始重新启动 . 对于需要数周训练的大型车型来说,这是一个大问题 .

    一旦您的训练完成并且您对模型感到满意(用您的话说,"it is the best"),您可以使用https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#export_savedmodel明确保存它 . 调用此方法位于您用于创建 ExperiemntEstimator 上 . 请注意,此方法保存"model for inference",这意味着将从其中删除所有渐变操作并且不保存 .

    EDIT: In Reply to Nicolas's comment: 您可以使用创建估算器时传递的keep_checkpoint_every_n_hours选项来定期保存快照以及最新的快照 . 如果您发现您的模型在10小时前达到了最佳性能,那么您应该可以从大致那个时间找到快照 .

相关问题