首页 文章

Tensorflow:如何将.meta,.data和.index模型文件转换为一个graph.pb文件

提问于
浏览
13

在tensorflow中,从头开始训练产生以下6个文件:

events.out.tfevents.1503494436.06L7-BRM738 model.ckpt-22480.meta checkpoint model.ckpt-22480.data-00000-of-00001 model.ckpt-22480.index graph.pbtxt

我想将它们(或只需要的)转换成一个文件 graph.pb ,以便能够将它转移到我的Android应用程序 .

我尝试了脚本 freeze_graph.py 但它需要作为输入已经 input.pb 文件,我没有 . (我之前只提到过这6个文件) . 如何获得这个 freezed_graph.pb 文件?我看到几个线程,但没有一个为我工作 .

4 回答

  • 0

    您可以使用此简单脚本来执行此操作 . 但是您必须指定输出节点的名称 .

    import tensorflow as tf
    
    meta_path = 'model.ckpt-22480.meta' # Your .meta file
    
    with tf.Session() as sess:
    
        # Restore the graph
        saver = tf.train.import_meta_graph(meta_path)
    
        # Load weights
        saver.restore(sess,tf.train.latest_checkpoint('.'))
    
        # Output nodes
        output_node_names =[n.name for n in tf.get_default_graph().as_graph_def().node]
    
        # Freeze the graph
        frozen_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            sess.graph_def,
            output_node_names)
    
        # Save the frozen graph
        with open('output_graph.pb', 'wb') as f:
          f.write(frozen_graph_def.SerializeToString())
    
  • 2

    因为它可能对其他人有帮助,我也在回答github后回答这里;-) . 我想你可以尝试这样的东西(使用tensorflow / python / tools中的freeze_graph脚本):

    python freeze_graph.py --input_graph=/path/to/graph.pbtxt --input_checkpoint=/path/to/model.ckpt-22480 --input_binary=false --output_graph=/path/to/frozen_graph.pb --output_node_names="the nodes that you want to output e.g. InceptionV3/Predictions/Reshape_1 for Inception V3 "
    

    这里的重要标志是--input_binary = false,因为文件graph.pbtxt是文本格式 . 我认为它对应于所需的graph.pb,它是二进制格式的等价物 .

    关于output_node_names,这对我来说真的很困惑,因为我在这部分仍然有一些问题,但你可以在tensorflow中使用summarize_graph脚本,它可以将pb或pbtxt作为输入 .

    问候,

    斯蒂芬

  • 15

    我尝试了freezed_graph.py脚本,但output_node_name参数完全令人困惑 . 工作失败了 .

    所以我尝试了另一个: export_inference_graph.py . 它按预期工作!

    python -u /tfPath/models/object_detection/export_inference_graph.py \
      --input_type=image_tensor \
      --pipeline_config_path=/your/config/path/ssd_mobilenet_v1_pets.config \
      --trained_checkpoint_prefix=/your/checkpoint/path/model.ckpt-50000 \
      --output_directory=/output/path
    

    我使用的tensorflow安装包来自这里:https://github.com/tensorflow/models

  • 6

    首先,使用以下代码生成graph.pb文件 . 使用tf.Session()作为sess:

    # Restore the graph
        _ = tf.train.import_meta_graph(args.input)
    
        # save graph file
        g = sess.graph
        gdef = g.as_graph_def()
        tf.train.write_graph(gdef, ".", args.output, True)
    

    然后,使用汇总图获取输出节点名称 . 最后,使用

    python freeze_graph.py --input_graph=/path/to/graph.pbtxt --input_checkpoint=/path/to/model.ckpt-22480 --input_binary=false --output_graph=/path/to/frozen_graph.pb --output_node_names="the nodes that you want to output e.g. InceptionV3/Predictions/Reshape_1 for Inception V3 "
    

    生成冻结图 .

相关问题