首页 文章

TensorFlow渐变:通过tf.gradients获得不必要的0.0渐变

提问于
浏览
3

我们假设我有以下变量

embeddings = tf.Variable(tf.random_uniform(dtype = tf.float32,shape = [self.vocab_size,self.embedding_dim],minval = -0.001,maxval = 0.001))sent_1 = construct_sentence(word_ids_1)sent_2 = construct_sentence(word_ids_2 )

construct_sentence 是一种基于占位符 word_ids_1word_ids_2 获取句子表示的方法

让我们假设我有一些损失:

loss = construct_loss(sent_1,sent_2,label)

现在,当我尝试使用渐变时:

gradients_wrt_w = tf.gradients(丢失,嵌入)

我没有获得关于 construct_sentenceconstruct_loss 中涉及的特定变量的渐变,而是获得变量 embeddings 中每个嵌入的渐变(对于那些不涉及丢失和句子表示的嵌入,渐变为0) .

How can I get the gradients wrt variables I am only interested in?

此外,由于涉及的偏导数,我得到一些变量(具有相同的值)的重复 . 由于嵌入是一个2D变量我不能像这样做一个简单的查找:

tf.gradients(loss,tf.nn.embedding_lookup(embeddings,word_ids))

这引入了巨大的性能减慢,因为我正在处理大量的字嵌入,并且我希望每次只使用一些字嵌入 .

此外,我得到了很多重复的渐变(因为偏导数),我尝试使用tf.AggregationMethod,但没有成功 .

1 回答

  • 0

    你不能做 tf.gradients(loss, tf.nn.embedding_lookup(embeddings, word_ids)) ,但你可以直接做 tf.gradients(loss, embeddings) ,这会给你一个只包含受影响的单词id的渐变的 tf.IndexedSlices 对象 .

    关于与重复单词id对应的渐变的聚合,这在调用 optimizer.apply_gradients 时自动完成,但您可以使用 tf.unsorted_segment_sumtf.unique 重现此操作,如下所示:

    embedding_table = tf.random_uniform((10, 5))
    word_ids = tf.placeholder(shape=(None), dtype=tf.int32)
    temp_emb = tf.nn.embedding_lookup(embedding_table, word_ids)
    loss = tf.reduce_sum(temp_emb, axis=0)
    
    g = tf.gradients(loss, embedding_table)[0].values
    repeating_indices = tf.gradients(loss, embedding_table)[0].indices # This is the same as word_ids.
    
    unique_indices, idx_in_repeating_indices = tf.unique(repeating_indices)
    
    agg_gradients = tf.unsorted_segment_sum(g,
                                            idx_in_repeating_indices,
                                            tf.shape(unique_indices)[0])
    
    sess = tf.Session()
    unique_indices_v, agg_gradients_v, _ = \
        sess.run([unique_indices, agg_gradients, loss],
             feed_dict={word_ids: np.array([6, 1, 5, 1, 1, 5])})
    
    
    print(unique_indices_v)
    print(agg_gradients_v)
    

    给出上面的例子:

    [6 1 5]
    
    [[1. 1. 1. 1. 1.]
     [3. 3. 3. 3. 3.]
     [2. 2. 2. 2. 2.]]
    

相关问题