首页 文章

如何在Tensorflow中使RNN单元的权重无法处理?

提问于
浏览
5

我正在尝试制作一个Tensorflow图,其中图的一部分已经预先训练并在预测模式下运行,而其余的训练 . 我已经定义了我预先训练好的细胞:

rnn_cell = tf.contrib.rnn.BasicLSTMCell(100)

state0 = tf.Variable(pretrained_state0,trainable=False)
state1 = tf.Variable(pretrained_state1,trainable=False)
pretrained_state = [state0, state1]

outputs, states = tf.contrib.rnn.static_rnn(rnn_cell, 
                                            data_input,
                                            dtype=tf.float32,
                                            initial_state = pretrained_state)

将初始变量设置为 trainable=False 没有帮助 . 这些仅用于初始化权重,因此权重仍然会发生变化 .

我仍然需要在训练步骤中运行优化器,因为我的模型的其余部分需要训练 . 但是,如何防止优化器更改此rnn单元格中的权重?

是否有一个rnn_cell相当于 trainable=False

1 回答

  • 2

    您可以使用 tf.stop_gradient() 来防止图表的 pretrained 部分更新其权重,或者您可以使用 optimiser() 来指定应该训练图表的哪些部分 . 第二种方法涉及:

    #Create variable scope for the trainable parts of the graph: tf.variable_scope('train').
    
     # get trainable variables
     t_vars = tf.trainable_variables()
     train_vars = [var for var in t_vars if var.name.startswith('train')]
     # train only the variables of a particular scope
     opt = optimizer.minimize(cost, var_list=train_vars)
    

相关问题