首页 文章

Tensorflow Estimator:缓存瓶颈

提问于
浏览
13

在遵循张量流图像分类教程时,首先它会缓存每个图像的瓶颈:

def: cache_bottlenecks())

我使用tensorflow的 Estimator 重写了训练 . 这真的简化了所有代码 . 但是我想在这里缓存瓶颈功能 .

这是我的 model_fn . 我想缓存 dense 层的结果,这样我就可以对实际的训练进行更改,而不必每次都计算瓶颈 .

我怎么能做到这一点?

def model_fn(features, labels, mode, params):
    is_training = mode == tf.estimator.ModeKeys.TRAIN

    num_classes = len(params['label_vocab'])

    module = hub.Module(params['module_spec'], trainable=is_training and params['train_module'])
    bottleneck_tensor = module(features['image'])

    with tf.name_scope('final_retrain_ops'):
        logits = tf.layers.dense(bottleneck_tensor, units=num_classes, trainable=is_training)  # save this?

    def train_op_fn(loss):
        optimizer = tf.train.AdamOptimizer()
        return optimizer.minimize(loss, global_step=tf.train.get_global_step())

    head = tf.contrib.estimator.multi_class_head(n_classes=num_classes, label_vocabulary=params['label_vocab'])

    return head.create_estimator_spec(
        features, mode, logits, labels, train_op_fn=train_op_fn
    )

1 回答

  • 0

    TF无法像你编码一样工作 . 你应该:

    • 从原始网络导出文件瓶颈 .

    • 使用瓶颈结果作为输入,使用另一个网络来训练您的数据 .

相关问题