首页 文章

将训练有素的Tensorflow模型保存到另一台机器上进行推理[重复]

提问于
浏览
0

这个问题在这里已有答案:

我对机器学习和Tensorflow框架比较陌生 . 我试图通过使用MNIST手写数字数据集来严格影响我所训练的模型,使用MNIST手写数字数据集并对我创建的测试示例进行推断 . 但是,我正在使用GPU进行远程计算机培训,并尝试将数据保存到目录中,以便我可以在本地计算机上传输数据和推理

似乎我能够用 tf.saved_model.simple_save 保存一些模型,但是,我不确定如何使用保存的数据进行推理并使用数据在给定新图像的情况下进行预测 . 似乎有多种方法可以保存模型,但我不确定使用Tensorflow框架的惯例或"correct way"是什么 .

到目前为止,这是我认为我需要的那条线,但我不确定它是否正确 .

tf.saved_model.simple_save(sess, 'mnist_model',                                                                                 
                inputs={'x': self.x},                                                                                                   
                outputs={'y_': self.y_, 'y_conv':self.y_conv})

如果有人能指出我如何正确保存经过训练的模型以及使用哪些变量来使用保存的模型进行推理,我真的很感激 .

1 回答

  • 1

    您可以这样做的方法是在图形定义中创建 tf.train.Saver() 对象,然后使用该对象将网络保存到指定目录 . 然后,可以将该目录中的权重从远程计算机下载到本地目录并在本地还原 . 这是一个小例子网络:

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    
    
    # >>>> Config. Vars <<<<
    
    TRAIN_STEPS = 1000
    
    SAVE_EVERY  = 100
    
    
    # >>>> Network <<<<
    
    inputs = tf.placeholder(tf.float32, shape=[None, 784])
    
    labels = tf.placeholder(tf.float32, shape=[None, 10])
    
    h1     = tf.layers.dense(inputs, 256, activation=tf.nn.relu, use_bias=True)
    
    logits = tf.layers.dense(h1, 10, use_bias=True)
    
    predictions = tf.nn.softmax(logits)
    
    prediction_ids = tf.argmax(predictions, axis=1)
    
    # >>>> Loss & Optimisation <<<<
    
    loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits)
    
    opt  = tf.train.AdamOptimizer().minimize(loss)
    
    # >>>> Utilities <<<<
    
    init  = tf.global_variables_initializer()
    
    saver = tf.train.Saver()
    
    
    with tf.Session() as sess:
    
        sess.run(init)
    
        # >>>> Training - run on remote, comment out locally <<<<
    
        for i in range(TRAIN_STEPS):
    
            print("Train step {}".format(i), end="\r")
    
            batch_data, batch_labels = mnist.train.next_batch(batch_size=128)
    
            feed_dict = {
                inputs: batch_data,
                labels: batch_labels
            }
    
            l, _ = sess.run([loss, opt], feed_dict=feed_dict)
    
            if i % SAVE_EVERY == 0:
    
                saver.save(sess, "saved_model/network_weights.ckpt")
    
    
        # >>>> Using the network - run locally to use the network <<<
    
        saver.restore(sess, "saved_model/network_weights.ckpt")
    
        test_data, test_labels = mnist.test.images, mnist.test.labels
    
        feed_dict = {
            inputs: test_data,
            labels: test_labels
        }
    
        preds = sess.run(prediction_ids, feed_dict=feed_dict)
    
        print(preds)
    

    因此,一旦在网络中定义了保护程序,就可以使用它将权重保存到指定的目录 - 在本例中是“saved_models”目录,在运行此特定代码之前,您需要创建该目录 .

    恢复模型就像调用 saver.restore() 然后传递会话以及存储权重的路径一样简单 . 因此,您可以在远程计算机上运行此代码,将"saved_models"目录下载到本地计算机,然后运行此代码并将训练部分注释掉以实际使用该模型 .

相关问题