首页 文章

MNIST Tensorflow:如何将形式[i]的张量操纵为形式[... 0,0,0,1,0,0 ...]的张量,其中1位于第i位置?

提问于
浏览
0

我想转换表单的张量(称之为logits)

int32 - [batch_size]

到形式的张量(称之为标签)

[batch_size, 10]

例如,对于batch_size = 3

logits=[1,6,9]
labels=[[0,1,0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,1,0,0,0],
        [0,0,0,0,0,0,0,0,0,1]]

出现这个问题是因为我想在tensorflow mnist示例中将成本函数更改为二次函数(https://github.com/tensorflow/tensorflow/tree/r0.9/tensorflow/examples/tutorials/mnist)我使用了fully_connected_feed.py和mnist.py . 在mnist.py中我想改变:

cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name='xentropy')
    loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')

loss= tf.reduce_sum(tf.squared_difference(logits,labels))

但问题在于:

Logits tensor, float - [batch_size, 10];  
Labels tensor, int64 - [batch_size].

所以我需要“矢量化”标签!?有谁知道如何做到这一点?

1 回答

  • 1

    标签“矢量化”称为单热编码 .

    你正在寻找tf.one_hot功能 .

    这个功能需要:

    • 索引列表(您的 logits 向量)

    • A depth 参数:这是单热矢量的深度(单热编码标签的长度)

    • on_valueoff_value 您可以根据需要进行更改(但默认值为1和0是您要查找的内容) .

    • dtype 那是张量输出类型 .

    因此,您可以使用以下方法对标签进行单热编码:

    one_hot_labels = tf.one_hot(logits, 10, dtype=tf.uint8)
    

    one_hot_labelstf.Tensor 对象 .

    如果您需要从python访问其内容,请记住eval(或运行它) .

    这是一个玩具示例:

    import tensorflow as tf.
    tf.InteractiveSession()
    logits=[1,6,9]
    one_hot_labels = tf.one_hot(logits, 10, dtype=tf.uint8)
    print(one_hot_labels.eval())
    

    输出:

    [[0 1 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 0 0 0]
     [0 0 0 0 0 0 0 0 0 1]]
    

相关问题