首页 文章

如何在前进时从tensorflow slim模型VGG中提取特征?

提问于
浏览
1

我使用CASIA(人脸识别数据集)作为训练数据集,使用TensorFlow slim模型vgg训练了一个分类模型 . 我想通过使用LFW数据集来测试模型,它是一个面部匹配任务 . 所以我需要提取像fc7 / fc8这样的网络特征,而不是softmax图层,并比较特征之间的距离,以确定它们是否是同一个人 . 如何提取超薄型号的功能?

这是培训代码的一部分 .

import tensorflow as tf
from tensorflow.contrib.slim.python.slim.nets import vgg 
slim = tf.contrib.slim
FLAGS = tf.app.flags.FLAGS

def tower_loss(scope):
    images, labels = read_and_decode()
    with slim.arg_scope(vgg.vgg_arg_scope()):
        logits, end_points = vgg.vgg_16(images, num_classes=FLAGS.num_classes)
    _ = cal_loss(logits, labels)
    losses = tf.get_collection('losses', scope)
    total_loss = tf.add_n(losses, name='total_loss')
    return total_loss

2 回答

  • 1

    您可以尝试使用 tf.get_default_graph().get_tensor_by_name("VGG16/fc16:0") 或要提取的特定要素的张量名称 .

    要验证要提取的张量的名称,您可以尝试

    for operation in graph.get_operations():
        print operation.values()
    

    请记住将 :0 放在名称的末尾,因为它们表示您正在检索的项目是张量 .

  • 0

    获取苗条模型的 end_points 并提取该功能 .

相关问题