首页 文章

正在覆盖Tensorflow检查点

提问于
浏览
0

我正在使用Tensorflow在输入集上训练模型(生成对抗网络),我想每50个时期保存模型的参数 .

假设我想训练1000个时期的模型,并且每50个时期保存模型的参数,最终会有20个不同的检查点文件 .

通过拥有一个Session和一个Saver对象,我只需使用以下代码即可 .

if num_epoch % 50 == 0:
    saver.save(sess=sess, path='RGAN-1/sv/' + type_exp, global_step=num_epoch)

问题是,检查点被覆盖了,在实验结束时,我只有最后6个检查点,而我应该有20个检查点 .

我不知道为什么会这样 .

1 回答

  • 2

    tf.train.Savermax_to_keep 参数默认设置为5 . 您可以传递0以保留所有检查点:

    saver = tf.train.Saver(..., max_to_keep=0)
    

    有关完整参数列表,请参见the docs .

相关问题