首页 文章

这是在训练期间重用tf.slim图进行验证的正确方法吗?

提问于
浏览
2

我想在训练期间的每个时期之后进行验证 .

我正在创建如下图:

import tensorflow as tf
from networks import densenet
from networks.densenet_utils import dense_arg_scope

with tf.variable_scope('scope') as scope:
    with slim.arg_scope(dense_arg_scope()):
        logits_train, _ = densenet(images, blocks=networks[
            'densenet_265'], num_classes=1000, data_name='imagenet', is_training=True, scope='densenet265',
                                 reuse=tf.AUTO_REUSE)
    scope.reuse_variables()
    with slim.arg_scope(dense_arg_scope()):
        logits_val, _ = densenet(images, blocks=networks[
            'densenet_265'], num_classes=1000, data_name='imagenet', is_training=False, scope='densenet265',
                                 reuse=tf.AUTO_REUSE)

为了在培训或验证期间获得 logits ,我会执行以下操作:

is_training = tf.Variable(True, trainable=False, dtype=tf.bool)
training_mode = tf.assign(is_training, True)
validation_mode = tf.assign(is_training, False)
logits = tf.cond(tf.equal(is_training, tf.constant(True, dtype=tf.bool)), lambda: logits_train,
                     lambda: logits_val)

但是,当我运行我的代码时,我收到OOM错误 . 我确信这不是因为批量大 . 这是因为,之前我犯了一个大错,并且在训练和验证过程中使用了相同的图表 . 当时批量大小为 32 且图像大小为 224x224x3 ,代码运行得非常好 .

我怀疑在使用 is_training=False 验证期间尝试重用图表时我犯了一些错误 .

densenet的代码取自以下两个文件:densenet_utils.py densenet.py

1 回答

  • 2

    您在logits_train和logits_val中创建了两个独立的网络,因此这会占用网络占用的内存的两倍 . (我假设它已正确设置并且变量正确共享,这可能是另一个问题,但这不会导致OOM,大数据是激活,而不是权重 . )

    没有必要这样做 . 使用相同的网络 logits_train 进行验证 . 原来参数 is_training 也可以采用布尔标量张量,因此您可以动态切换训练或推理模式 .

    因此,在您设置 images 占位符的位置,请将此行作为下一行:

    training_mode = tf.placeholder( shape = None, dtype = tf.bool )
    

    然后在上面的代码中,像这样设置你的网络:

    logits_train, _ = densenet(images, blocks=networks['densenet_265'],
        num_classes=1000, data_name='imagenet', is_training=training_mode,
        scope='densenet265', reuse=tf.AUTO_REUSE)
    

    请注意, is_training 参数的值填充了上面的张量 training_mode

    然后当你执行 sess.run( [ ... ] ) 命令(在上面的代码中不可见)时,你应该在_2512587中包含 training_mode ,如此(伪代码):

    result = sess.run( [ ??? ], feed_dict = { images : ???, training_mode : True / False } )
    

    请注意, training_mode 张量现在根据您是否正在进行培训填充False或True .

    这是基于我对 batch_normalizationdropout 层的研究 .

相关问题