首页 文章

从Python导出Tensorflow图以便在C中使用

提问于
浏览
9

究竟应该如何导出python模型用于c?

我正在尝试做类似于本教程的事情:https://www.tensorflow.org/versions/r0.8/tutorials/image_recognition/index.html

我试图在c API中导入我自己的TF模型,而不是从一开始 . 我调整了输入大小和路径,但奇怪的错误不断出现 . 我花了一整天阅读堆栈溢出和其他论坛但无济于事 .

我尝试了两种导出图形的方法 .

方法1:元图 .

...loading inputs, setting up the model, etc....

sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())


for i in range(num_steps):  
  x_batch, y_batch = batch(50)  
  if i%10 == 0:
        train_accuracy = accuracy.eval(feed_dict={
        x:x_batch, y_: y_batch, keep_prob: 1.0})
        print("step %d, training accuracy %g"%(i, train_accuracy))
  train_step.run(feed_dict={x: x_batch, y_: y_batch, keep_prob: 0.5})

print("test accuracy %g"%accuracy.eval(feed_dict={
    x: features_test, y_: labels_test, keep_prob: 1.0}))

saver = tf.train.Saver(tf.all_variables())
checkpoint = 
   '/home/sander/tensorflow/tensorflow/examples/cat_face/data/model.ckpt'
    saver.save(sess, checkpoint)

   tf.train.export_meta_graph(filename=
   '/home/sander/tensorflow/tensorflow/examples/cat_face/data/cat_graph.pb',  
    meta_info_def=None,
    graph_def=sess.graph_def,
    saver_def=saver.restore(sess, checkpoint),
    collection_list=None, as_text=False)

尝试运行该程序时,方法1会产生以下错误:

[libprotobuf ERROR 
google/protobuf/src/google/protobuf/wire_format_lite.cc:532] String field 
'tensorflow.NodeDef.op' contains invalid UTF-8 data when parsing a protocol 
buffer. Use the 'bytes' type if you intend to send raw bytes. 
E tensorflow/examples/cat_face/main.cc:281] Not found: Failed to load 
compute graph at 'tensorflow/examples/cat_face/data/cat_graph.pb'

我还尝试了另一种导出图形的方法:

方法2:write_graph:

tf.train.write_graph(sess.graph_def, 
'/home/sander/tensorflow/tensorflow/examples/cat_face/data/', 
'cat_graph.pb', as_text=False)

这个版本实际上似乎加载了一些东西,但是我得到一个关于未被初始化的变量的错误:

Running model failed: Failed precondition: Attempting to use uninitialized  
value weight1
[[Node: weight1/read = Identity[T=DT_FLOAT, _class=["loc:@weight1"], 
_device="/job:localhost/replica:0/task:0/cpu:0"](weight1)]]

2 回答

  • 0

    首先,您需要使用以下命令将图形定义为文件

    with tf.Session() as sess:
    //Build network here 
    tf.train.write_graph(sess.graph.as_graph_def(), "C:\\output\\", "mymodel.pb")
    

    然后,使用saver保存模型

    saver = tf.train.Saver(tf.global_variables()) 
    saver.save(sess, "C:\\output\\mymodel.ckpt")
    

    然后,您的输出中将有2个文件,mymodel.ckpt,mymodel.pb

    here下载freeze_graph.py并在C:\ output \中运行以下命令 . 如果输出节点名称不同,请更改输出节点名称 .

    python freeze_graph.py --input_graph mymodel.pb --input_checkpoint mymodel.ckpt --output_node_names softmax / Reshape_1 --output_graph mymodelforc.pb

    您可以直接从C使用mymodelforc.pb

    您可以使用以下C代码加载proto文件

    #include "tensorflow/core/public/session.h"
    #include "tensorflow/core/platform/env.h"
    #include "tensorflow/cc/ops/image_ops.h"
    
    Session* session;
    NewSession(SessionOptions(), &session);
    
    GraphDef graph_def;
    ReadBinaryProto(Env::Default(), "C:\\output\\mymodelforc.pb", &graph_def);
    
    session->Create(graph_def);
    

    现在您可以使用会话进行推理 .

    您可以应用推理参数,如下所示:

    // Same dimension and type as input of your network
    tensorflow::Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({ 1, height, width, channel }));
    std::vector<tensorflow::Tensor> finalOutput;
    
    // Fill input tensor with your input data
    
    std::string InputName = "input"; // Your input placeholder's name
    std::string OutputName = "softmax/Reshape_1"; // Your output placeholder's name
    
    session->Run({ { InputName, input_tensor } }, { OutputName }, {}, &finalOutput);
    
    // finalOutput will contain the inference output that you search for
    
  • 0

    你可以尝试这个(修改输出图层的名称):

    import os
    import tensorflow as tf
    from tensorflow.python.framework import graph_util
    
    
    def load_graph_def(model_path, sess=None):
        sess = sess if sess is not None else tf.get_default_session()
        saver = tf.train.import_meta_graph(model_path + '.meta')
        saver.restore(sess, model_path)
    
    
    def freeze_graph(sess, output_layer_name, output_graph):
        graph = tf.get_default_graph()
        input_graph_def = graph.as_graph_def()
    
        # Exporting the graph
        print("Exporting graph...")
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,
            input_graph_def,
            output_layer_name.split(","))
    
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
    
    
    def freeze_from_checkpoint(checkpoint_file, output_layer_name):
    
        model_folder = os.path.basename(checkpoint_file)
        output_graph = os.path.join(model_folder, checkpoint_file + '.pb')
    
        with tf.Session() as sess:
    
            load_graph_def(checkpoint_file)
    
            freeze_graph(sess, output_layer_name, output_graph)
    
    
    if __name__ == '__main__':
        freeze_from_checkpoint(
            checkpoint_file='/home/sander/tensorflow/tensorflow/examples/cat_face/data/model.ckpt',
            output_layer_name='???')
    

相关问题