首页 文章

TensorFlow从文件保存到/加载图形

提问于
浏览
76

从我到目前为止收集的内容来看,有几种不同的方法可以将TensorFlow图转储到文件中,然后将其加载到另一个程序中,但我无法找到有关它们如何工作的明确示例/信息 . 我已经知道的是:

  • 使用 tf.train.Saver() 将模型的变量保存到检查点文件(.ckpt)中,稍后恢复(source

  • 将模型保存到.pb文件中并使用 tf.train.write_graph()tf.import_graph_def()source)将其加载回来

  • 从.pb文件加载模型,重新训练并使用Bazel将其转储到新的.pb文件中(source

  • 冻结图形以将图形和权重保存在一起(source

  • 使用 as_graph_def() 保存模型,对于权重/变量,将它们映射到常量(source

但是,我无法澄清有关这些不同方法的几个问题:

  • 关于检查点文件,它们是否只保存模型的训练权重?可以将检查点文件加载到新程序中,并用于运行模型,还是仅仅用作在特定时间/阶段将权重保存在模型中的方法?

  • 关于 tf.train.write_graph() ,是否也保存了权重/变量?

  • 关于Bazel,它是否只能从.pb文件中保存/加载以进行再培训?是否有一个简单的Bazel命令只是将图形转储到.pb?

  • 关于冻结,可以使用 tf.import_graph_def() 加载冻结图吗?

  • TensorFlow的Android演示从.pb文件加载Google的Inception模型 . 如果我想替换我自己的.pb文件,我该怎么做呢?我需要更改任何本机代码/方法吗?

  • 一般来说,所有这些方法之间究竟有什么区别?或者更广泛地说, as_graph_def() /.ckpt/.pb之间有什么区别?

简而言之,我正在寻找的方法是将图形(如,各种操作等)及其权重/变量保存到文件中,然后可以将其用于将图形和权重加载到另一个程序中,使用(不一定继续/再培训) .

关于这个主题的文档不是很简单,所以任何答案/信息将不胜感激 .

2 回答

  • 63

    有很多方法可以解决在TensorFlow中保存模型的问题,这可能会让它有点混乱 . 依次提出每个子问题:

    • 检查点文件(例如通过在tf.train.Saver对象上调用saver.save()生成)仅包含权重,以及在同一程序中定义的任何其他变量 . 要在另一个程序中使用它们,您必须重新创建关联的图形结构(例如,通过运行代码再次构建它,或调用tf.import_graph_def()),它告诉TensorFlow如何处理这些权重 . 请注意,调用 saver.save() 还会生成一个包含MetaGraphDef的文件,该文件包含一个图表以及如何将检查点的权重与该图表相关联的详细信息 . 有关详细信息,请参阅the tutorial .

    • tf.train.write_graph()只写图结构;不是重量 .

    • Bazel与读取或写入TensorFlow图无关 . (也许我误解了你的问题:随意在评论中澄清它 . )

    • 可以使用tf.import_graph_def()加载冻结图 . 在这种情况下,权重(通常)嵌入图表中,因此您无需加载单独的检查点 .

    • 主要的变化是更新输入到模型中的张量的名称,以及从模型中提取的张量的名称 . 在TensorFlow Android演示中,这将对应于传递给TensorFlowClassifier.initializeTensorFlow()inputNameoutputName 字符串 .

    • GraphDef 是程序结构,通常不会通过培训过程发生变化 . 检查点是培训过程状态的快照,通常在培训过程的每个步骤都会发生变化 . 因此,TensorFlow为这些类型的数据使用不同的存储格式,而低级API提供了不同的方法来保存和加载它们 . 更高级的库,例如MetaGraphDef库,Kerasskflow构建在这些机制上,以提供更方便的方法来保存和恢复整个模型 .

  • 1

    您可以尝试以下代码:

    with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        g_in = tf.import_graph_def(graph_def, name="")
    sess = tf.Session(graph=g_in)
    

相关问题