首页 文章

使用张量流中的n / a标签停止渐变

提问于
浏览
0

我在以下场景中是'm implementing a Convolutional Neural Network in Tensorflow with python. I' m:我有一个标签张量 y (批量标签),如下所示:

y =   [[0,1,0]
       [0,0,1]
       [1,0,0]]

其中每一行都是一个 one-hot 向量,表示与对应示例相关的标签 . 现在在训练中,我希望使用该标签停止 loss gradient (设置为0)(第三个):

[1,0,0]

它表示不适用标签,而是计算批次中其他示例的丢失 . 对于我的损失计算,我使用这样的方法:

self.y_loss = kl_divergence(self.pred_y, self.y)

我发现这个function停止了渐变,但是如何有条理地将它应用于批处理元素呢?

1 回答

  • 2

    如果您不希望某些样本对渐变做出贡献,您可以在训练期间避免将它们送入网络 . 只需从训练集中删除带有该标签的样本即可 .

    或者,由于通过对每个样本的KL-发散进行求和来计算损失,因此如果应该考虑样本,则可以将每个样本的KL-发散乘以1,否则在对它们进行求和之前将其乘以0 . 您可以通过从1减去标签张量的第一列来获得所需的值向量乘以单个KL-发散: 1 - y[:,0]

    对于answer to your previous question中的 kl_divergence 函数,它可能如下所示:

    def kl_divergence(p, q) 
        return tf.reduce_sum(tf.reduce_sum(p * tf.log(p/q), axis=1)*(1-p[:,0]))
    

    其中p是groundtruth张量,q是预测

相关问题