首页 文章

如何使用tfslim记录验证丢失和准确性

提问于
浏览
1

有什么方法可以在使用tf-slim时将验证丢失和准确性记录到tensorboard?当我使用keras时,以下代码可以为我执行此操作:

model.fit_generator(generator=train_gen(), validation_data=valid_gen(),...)

然后模型将评估每个时期后的验证损失和准确性,这非常方便 . 但是如何使用tf-slim实现这一目标?以下步骤使用原始张量流,这不是我想要的:

with tf.Session() as sess:
    for step in range(100000):
        sess.run(train_op, feed_dict={X: X_train, y: y_train})
        if n % batch_size * batches_per_epoch == 0:
            print(sess.run(train_op, feed_dict={X: X_train, y: y_train}))

现在,使用tf-slim训练模型的步骤是:

tf.contrib.slim.learning.train(
    train_op=train_op,
    logdir="logs",
    number_of_steps=10000,
    log_every_n_steps = 10,
    save_summaries_secs=1
)

那么如何使用上述细长的培训程序评估每个时代后的验证损失和准确性?

提前致谢!

1 回答

  • 1

    关于TF Slim回购(issue #5987)的问题仍在讨论中 . 该框架允许您轻松创建评估脚本,以便在培训之后/之后运行(下面的解决方案1),但有些人正在努力实现"classic cycle of batch training + validation"(解决方案2) .


    1.在另一个脚本中使用slim.evaluation

    TF Slim具有评估方法,例如 slim.evaluation.evaluation_loop() 您可以在另一个脚本(可以与您的训练并行运行)中使用,以定期加载模型的最新检查点并执行评估 . TF Slim页面包含一个很好的例子,这个脚本可能看起来如何:example .

    2.为slim.learning.train()提供自定义train_step_fn

    讨论的发起者提出的一个不完整的解决方案利用了您可以提供的自定义训练步骤功能 slim.learning.train()

    """
    Snippet from code by Kevin Malakoff @kmalakoff
    https://github.com/tensorflow/tensorflow/issues/5987#issue-192626454
    """
    # ...
    accuracy_validation = slim.metrics.accuracy(
        tf.argmax(predictions_validation, 1), 
        tf.argmax(labels_validation, 1)) # ... or whatever metrics needed
    
    def train_step_fn(session, *args, **kwargs):
      total_loss, should_stop = train_step(session, *args, **kwargs)
    
      if train_step_fn.step % FLAGS.validation_check == 0:
        accuracy = session.run(train_step_fn.accuracy_validation)
        print('Step %s - Loss: %.2f Accuracy: %.2f%%' % (str(train_step_fn.step).rjust(6, '0'), total_loss, accuracy * 100))
    
      # ...
    
      train_step_fn.step += 1
      return [total_loss, should_stop]
    
    train_step_fn.step = 0
    train_step_fn.accuracy_validation = accuracy_validation
    
    slim.learning.train(
      train_op,
      FLAGS.logs_dir,
      train_step_fn=train_step_fn,
      graph=graph,
      number_of_steps=FLAGS.max_steps
    )
    

相关问题