首页 文章

在Tensorflow中添加整个列车/测试数据集的准确性摘要

提问于
浏览
3

我正在尝试使用Tensorboard来显示我的训练程序 . 我的目的是,当每个时代完成时,我想使用整个验证数据集测试网络的准确性,并将此准确性结果存储到摘要文件中,以便我可以在Tensorboard中将其可视化 .

我知道Tensorflow有 summary_op 这样做,但它似乎只适用于运行代码 sess.run(summary_op) 的一个批处理 . 我需要计算 whole 数据集的准确性 . 怎么样?

有没有任何例子可以做到这一点?

3 回答

  • 8

    定义一个接受占位符的 tf.scalar_summary

    accuracy_value_ = tf.placeholder(tf.float32, shape=())
    accuracy_summary = tf.scalar_summary('accuracy', accuracy_value_)
    

    然后计算整个数据集的精度(定义一个计算数据集中每个批次的准确度并提取平均值的例程)并将其保存到python变量中,我们称之为 va .

    获得 va 的值后,只需运行 accuracy_summary 操作,输入 accuracy_value_ 占位符:

    sess.run(accuracy_summary, feed_dict={accuracy_value_: va})
    
  • 0

    我实现了一个天真的单层模型作为示例来对MNIST数据集进行分类并在Tensorboard中可视化验证准确性,它对我有用 .

    import tensorflow as tf
    from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
    import os
    
    # number of epoch
    num_epoch = 1000
    model_dir = '/tmp/tf/onelayer_model/accu_info'
    # mnist dataset location, change if you need
    data_dir = '../data/mnist'
    
    # load MNIST dataset without one hot
    dataset = read_data_sets(data_dir, one_hot=False)
    
    # Create placeholder for input images X and labels y
    X = tf.placeholder(tf.float32, [None, 784])
    # one_hot = False
    y = tf.placeholder(tf.int32)
    
    # One layer model graph
    W = tf.Variable(tf.truncated_normal([784, 10], stddev=0.1))
    b = tf.Variable(tf.constant(0.1, shape=[10]))
    logits = tf.nn.relu(tf.matmul(X, W) + b)
    
    init = tf.initialize_all_variables()
    
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, y)
    # loss function
    loss = tf.reduce_mean(cross_entropy)
    train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
    
    _, top_1_op = tf.nn.top_k(logits)
    top_1 = tf.reshape(top_1_op, shape=[-1])
    correct_classification = tf.cast(tf.equal(top_1, y), tf.float32)
    # accuracy function
    acc = tf.reduce_mean(correct_classification)
    
    # define info that is used in SummaryWritter
    acc_summary = tf.scalar_summary('valid_accuracy', acc)
    valid_summary_op = tf.merge_summary([acc_summary])
    
    with tf.Session() as sess:
        # initialize all the variable
        sess.run(init)
    
        print("Writing Summaries to %s" % model_dir)
        train_summary_writer = tf.train.SummaryWriter(model_dir, sess.graph)
    
        # load validation dataset
        valid_x = dataset.validation.images
        valid_y = dataset.validation.labels
    
        for epoch in xrange(num_epoch):
            batch_x, batch_y = dataset.train.next_batch(100)
            feed_dict = {X: batch_x, y: batch_y}
            _, acc_value, loss_value = sess.run(
                [train_op, acc, loss], feed_dict=feed_dict)
            vsummary = sess.run(valid_summary_op,
                                feed_dict={X: valid_x,
                                           y: valid_y})
    
            # Write validation accuracy summary
            train_summary_writer.add_summary(vsummary, epoch)
    
  • 0

    如果您使用的是使用内部计数器的tf.metrics操作,则可以使用您的验证集进行批处理 . 这是一个简化的例子:

    model = create_model()
    tf.summary.scalar('cost', model.cost_op)
    acc_value_op, acc_update_op = tf.metrics.accuracy(labels,predictions)
    
    summary_common = tf.summary.merge_all()
    
    summary_valid = tf.summary.merge([
        tf.summary.scalar('accuracy', acc_value_op),
        # other metrics here...
    ])
    
    with tf.Session() as sess:
        train_writer = tf.summary.FileWriter(logs_path + '/train',
                                             sess.graph)
        valid_writer = tf.summary.FileWriter(logs_path + '/valid')
    

    虽然 training ,但只能使用您的列车编写者编写常见摘要:

    summary = sess.run(summary_common)
    train_writer.add_summary(summary, tf.train.global_step(sess, gstep_op))
    train_writer.flush()
    

    After every validation ,使用valid-writer编写两个摘要:

    gstep, summaryc, summaryv = sess.run([gstep_op, summary_common, summary_valid])
    valid_writer.add_summary(summaryc, gstep)
    valid_writer.add_summary(summaryv, gstep)
    valid_writer.flush()
    

    使用tf.metrics时,不要忘记在每个验证步骤之前重置内部计数器(局部变量) .

相关问题