首页 文章

tensorflow存储已恢复检查点的位置

提问于
浏览
0
_ = importer.import_graph_def(input_graph_def, name='')

with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(saver_def=input_saver_def)
      saver.restore(sess, input_checkpoint)
output_graph_def = graph_util.convert_variables_to_constants(
        sess,
        input_graph_def,
        output_node_names.split(','),
        variable_names_blacklist=variable_names_blacklist)

在上面的代码中,导入器用于将graphDef导入到当前默认图形,并且保护程序加载以前训练的值 . 问题是存储这些训练值的位置?在会话中,在input_graph_def中,在当前图形结构(tf.get_default_graph())中还是在保护程序中?

我检查方法 convert_variables_to_constants 的代码 . https://github.com/tensorflow/tensorflow/blob/235192d47cfb375c0cc93c1deefb9e440715bf35/tensorflow/python/framework/graph_util_impl.py

它使用sess.run(变量名)来获取加载的值 . 这个 sess.run 从哪里获取值?

1 回答

  • 0

    当我们定义保护程序时,我们应该将注释传递给它(默认情况下它是全局变量) .

    In [2]: import tensorflow as tf
    In [3]: a = tf.get_variable("a", [])
    In [4]: b = tf.get_variable("b", [])
    In [5]: saver_a = tf.train.Saver({"my_a_in_ckpt": a}) # here "my_a_in_ckpt" can be any apt name you like, it is the variable name stored only in the checkpoint (1)
    In [6]: init = tf.global_variables_initializer()
    In [7]: sess = tf.Session()
    In [8]: sess.run(init)
    In [9]: sess.run(a)
    Out[9]: 0.43891537
    In [10]: sess.run(b)
    Out[10]: 1.5962805
    In [11]: saver_a.save(sess, "./temp_model")
    

    这里我们首先初始化所有变量并将其保存为“./temp_model” . 要恢复变量:

    In [2]: import tensorflow as tf
    In [3]: a = tf.get_variable("a", []) 
    In [5]: saver_a = tf.train.Saver({"my_a_in_ckpt": a})  # here "my_a_in_ckpt" should match that as you defined in step (1)
    In [7]: sess = tf.Session()
    In [9]: saver_a.restore(sess, tf.train.latest_checkpoint("./temp_model"))
    INFO:tensorflow:Restoring parameters from ./temp_model/temp
    In [10]: sess.run(a)
    Out[10]: 0.43891537
    In [11]: sess.run(b)
    Out[11]: 1.5962805
    

    我们可以将a和b保存到不同的地方:

    In [12]: saver_b = tf.train.Saver({"b": b})
    In [13]: saver_b.save(sess, "./temp_model_b/temp")
    Out[13]: './temp_model_b/temp'
    

    并将它们恢复为图形:

    In [3]: a = tf.get_variable("a", [])
    In [4]: b = tf.get_variable("b", [])
    In [5]: saver_b = tf.train.Saver({"b": b})
    In [6]: saver_a = tf.train.Saver({"my_a_in_ckpt": a})
    In [7]: saver_a.restore(sess, tf.train.latest_checkpoint("./temp_model"))                              
    INFO:tensorflow:Restoring parameters from ./temp_model/temp
    In [8]: saver_b.restore(sess, tf.train.latest_checkpoint("./temp_model_b"))
    INFO:tensorflow:Restoring parameters from ./temp_model_b/temp
    In [9]: sess.run(a)
    Out[9]: 0.43891537
    In [10]: sess.run(b)
    Out[10]: 1.5962805
    

相关问题