首页 文章

张量流量保护器可以在具有相同结构的不同图形中使用

提问于
浏览
1

网络结构已加载到默认全局图中 . 我想创建另一个具有相同结构的图形,并在此图中加载检查点 .

如果代码是这样的,它会抛出错误: ValueError: No variables to save 在最后一行 . 但是,第二行工作正常 . 为什么? as_graph_def() 返回的 GraphDef 是否包含变量定义/名称?

inference_graph_def = tf.get_default_graph().as_graph_def()
saver = tf.train.Saver()
with tf.Graph().as_default():
    tf.import_graph_def(inference_graph_def)
    saver1 = tf.train.Saver()

如果这样的代码,它会在最后一行抛出错误 Cannot interpret feed_dict key as Tensor: The name 'save/Const:0' refers to a Tensor which does not exist . 但是,第三行删除后,它可以正常工作 .

inference_graph_def = tf.get_default_graph().as_graph_def()
saver = tf.train.Saver()
with tf.Graph().as_default():
    tf.import_graph_def(inference_graph_def)    
    with session.Session() as sess:        
        saver.restore(sess, checkpoint_path)

那么,这是否意味着Saver无法在不同的图形中工作,即使它们具有相同的结构?

任何帮助将不胜感激〜

1 回答

  • 3

    下面是一个使用 MetaGraphDef 的示例,它与 GraphDef 保存变量集合不同,使用以前保存的图形初始化新图形 .

    import tensorflow as tf
    
    CHECKPOINT_PATH = "/tmp/first_graph_checkpoint"
    with tf.Graph().as_default():
      some_variable = tf.get_variable(
        name="some_variable",
        shape=[2],
        dtype=tf.float32)
      init_op = tf.global_variables_initializer()
      first_meta_graph = tf.train.export_meta_graph()
      first_graph_saver = tf.train.Saver()
      with tf.Session() as session:
        init_op.run()
        print("Initialized value in first graph", some_variable.eval())
        first_graph_saver.save(
            sess=session,
            save_path=CHECKPOINT_PATH)
    
    with tf.Graph().as_default():
      tf.train.import_meta_graph(first_meta_graph)
      second_graph_saver = tf.train.Saver()
      with tf.Session() as session:
        second_graph_saver.restore(
          sess=session,
          save_path=CHECKPOINT_PATH)
        print("Variable value after restore", tf.global_variables()[0].eval())
    

    打印类似于:

    Initialized value in first graph [-0.98926258 -0.09709156]
    Variable value after restore [-0.98926258 -0.09709156]
    

    请注意,检查点仍然很重要!加载 MetaGraph 不会恢复 Variables 的值(它不包含这些值),只是跟踪它们存在的记录(集合) . SavedModel format解决此问题,将 MetaGraph 与检查点和其他元数据捆绑在一起运行它们 .

    编辑:根据大众需求,这是一个用 GraphDef 做同样事情的例子 . 我不推荐它 . 由于在加载 GraphDef 时没有恢复任何集合,我们必须手动指定 Variables 我们希望 Saver 恢复; "import/"默认命名方案很容易修复 name=''name='' 参数,但如果你想让 Saver 工作"automatically",则需要手动填写变量集合 . 相反,我选择在创建 Saver 时手动指定映射 .

    import tensorflow as tf
    
    CHECKPOINT_PATH = "/tmp/first_graph_checkpoint"
    with tf.Graph().as_default():
      some_variable = tf.get_variable(
        name="some_variable",
        shape=[2],
        dtype=tf.float32)
      init_op = tf.global_variables_initializer()
      first_graph_def = tf.get_default_graph().as_graph_def()
      first_graph_saver = tf.train.Saver()
      with tf.Session() as session:
        init_op.run()
        print("Initialized value in first graph", some_variable.eval())
        first_graph_saver.save(
            sess=session,
            save_path=CHECKPOINT_PATH)
    
    with tf.Graph().as_default():
      tf.import_graph_def(first_graph_def)
      variable_to_restore = tf.get_default_graph().get_tensor_by_name(
          "import/some_variable:0")
      second_graph_saver = tf.train.Saver(var_list={
          "some_variable": variable_to_restore
      })
      with tf.Session() as session:
        second_graph_saver.restore(
          sess=session,
          save_path=CHECKPOINT_PATH)
        print("Variable value after restore", variable_to_restore.eval())
    

相关问题