首页 文章

Tensorflow:如何保存/恢复模型?

提问于
浏览
405

在Tensorflow中训练模型后:

  • 如何保存训练有素的模型?

  • 您以后如何恢复此保存的模型?

18 回答

  • 6

    新的和更短的方式:simple_save

    许多好的答案,为了完整性,我将加上我的2美分: simple_save . 也是使用 tf.data.Dataset API的独立代码示例 .

    Python 3; Tensorflow 1.7

    import tensorflow as tf
    from tensorflow.python.saved_model import tag_constants
    
    with tf.Graph().as_default():
        with tf.Session as sess:
            ...
    
            # Saving
            inputs = {
                "batch_size_placeholder": batch_size_placeholder,
                "features_placeholder": features_placeholder,
                "labels_placeholder": labels_placeholder,
            }
            outputs = {"prediction": model_output}
            tf.saved_model.simple_save(
                sess, 'path/to/your/location/', inputs, outputs
            )
    

    恢复:

    graph = tf.Graph()
    with restored_graph.as_default():
        with tf.Session as sess:
            tf.saved_model.loader.load(
                sess,
                [tag_constants.SERVING],
            'path/to/your/location/',
            )
            batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0')
            features_placeholder = graph.get_tensor_by_name('features_placeholder:0')
            labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
            prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0')
    
            sess.run(prediction, feed_dict={
                batch_size_placeholder: some_value,
                features_placeholder: some_other_value,
                labels_placeholder: another_value
            })
    

    独立示例

    Original blog post

    以下代码为演示生成随机数据 .

    • 我们首先创建占位符 . 他们将在运行时保存数据 . 从他们,我们创建 Dataset 然后它 Iterator . 我们得到迭代器生成的张量,称为 input_tensor ,它将作为我们模型的输入 .

    • 模型本身是由 input_tensor 构建的:基于GRU的双向RNN,后跟密集分类器 . 因为为什么不呢 .

    • 损失是 softmax_cross_entropy_with_logits ,使用 Adam 优化 . 经过2个时期(每个2批),我们用 tf.saved_model.simple_save 保存"trained"模型 . 如果按原样运行代码,则模型将保存在当前工作目录中名为 simple/ 的文件夹中 .

    • 在新图表中,我们使用 tf.saved_model.loader.load 恢复保存的模型 . 我们使用 graph.get_tensor_by_name 获取占位符和logits,使用 graph.get_operation_by_name 获取 Iterator 初始化操作 .

    • 最后,我们对数据集中的两个批次进行推断,并检查保存和恢复的模型是否都产生相同的值 . 他们是这样!

    码:

    import os
    import shutil
    import numpy as np
    import tensorflow as tf
    from tensorflow.python.saved_model import tag_constants
    
    
    def model(graph, input_tensor):
        """Create the model which consists of
        a bidirectional rnn (GRU(10)) followed by a dense classifier
    
        Args:
            graph (tf.Graph): Tensors' graph
            input_tensor (tf.Tensor): Tensor fed as input to the model
    
        Returns:
            tf.Tensor: the model's output layer Tensor
        """
        cell = tf.nn.rnn_cell.GRUCell(10)
        with graph.as_default():
            ((fw_outputs, bw_outputs), (fw_state, bw_state)) = tf.nn.bidirectional_dynamic_rnn(
                cell_fw=cell,
                cell_bw=cell,
                inputs=input_tensor,
                sequence_length=[10] * 32,
                dtype=tf.float32,
                swap_memory=True,
                scope=None)
            outputs = tf.concat((fw_outputs, bw_outputs), 2)
            mean = tf.reduce_mean(outputs, axis=1)
            dense = tf.layers.dense(mean, 5, activation=None)
    
            return dense
    
    
    def get_opt_op(graph, logits, labels_tensor):
        """Create optimization operation from model's logits and labels
    
        Args:
            graph (tf.Graph): Tensors' graph
            logits (tf.Tensor): The model's output without activation
            labels_tensor (tf.Tensor): Target labels
    
        Returns:
            tf.Operation: the operation performing a stem of Adam optimizer
        """
        with graph.as_default():
            with tf.variable_scope('loss'):
                loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
                        logits=logits, labels=labels_tensor, name='xent'),
                        name="mean-xent"
                        )
            with tf.variable_scope('optimizer'):
                opt_op = tf.train.AdamOptimizer(1e-2).minimize(loss)
            return opt_op
    
    
    if __name__ == '__main__':
        # Set random seed for reproducibility
        # and create synthetic data
        np.random.seed(0)
        features = np.random.randn(64, 10, 30)
        labels = np.eye(5)[np.random.randint(0, 5, (64,))]
    
        graph1 = tf.Graph()
        with graph1.as_default():
            # Random seed for reproducibility
            tf.set_random_seed(0)
            # Placeholders
            batch_size_ph = tf.placeholder(tf.int64, name='batch_size_ph')
            features_data_ph = tf.placeholder(tf.float32, [None, None, 30], 'features_data_ph')
            labels_data_ph = tf.placeholder(tf.int32, [None, 5], 'labels_data_ph')
            # Dataset
            dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
            dataset = dataset.batch(batch_size_ph)
            iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
            dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')
            input_tensor, labels_tensor = iterator.get_next()
    
            # Model
            logits = model(graph1, input_tensor)
            # Optimization
            opt_op = get_opt_op(graph1, logits, labels_tensor)
    
            with tf.Session(graph=graph1) as sess:
                # Initialize variables
                tf.global_variables_initializer().run(session=sess)
                for epoch in range(3):
                    batch = 0
                    # Initialize dataset (could feed epochs in Dataset.repeat(epochs))
                    sess.run(
                        dataset_init_op,
                        feed_dict={
                            features_data_ph: features,
                            labels_data_ph: labels,
                            batch_size_ph: 32
                        })
                    values = []
                    while True:
                        try:
                            if epoch < 2:
                                # Training
                                _, value = sess.run([opt_op, logits])
                                print('Epoch {}, batch {} | Sample value: {}'.format(epoch, batch, value[0]))
                                batch += 1
                            else:
                                # Final inference
                                values.append(sess.run(logits))
                                print('Epoch {}, batch {} | Final inference | Sample value: {}'.format(epoch, batch, values[-1][0]))
                                batch += 1
                        except tf.errors.OutOfRangeError:
                            break
                # Save model state
                print('\nSaving...')
                cwd = os.getcwd()
                path = os.path.join(cwd, 'simple')
                shutil.rmtree(path, ignore_errors=True)
                inputs_dict = {
                    "batch_size_ph": batch_size_ph,
                    "features_data_ph": features_data_ph,
                    "labels_data_ph": labels_data_ph
                }
                outputs_dict = {
                    "logits": logits
                }
                tf.saved_model.simple_save(
                    sess, path, inputs_dict, outputs_dict
                )
                print('Ok')
        # Restoring
        graph2 = tf.Graph()
        with graph2.as_default():
            with tf.Session(graph=graph2) as sess:
                # Restore saved values
                print('\nRestoring...')
                tf.saved_model.loader.load(
                    sess,
                    [tag_constants.SERVING],
                    path
                )
                print('Ok')
                # Get restored placeholders
                labels_data_ph = graph2.get_tensor_by_name('labels_data_ph:0')
                features_data_ph = graph2.get_tensor_by_name('features_data_ph:0')
                batch_size_ph = graph2.get_tensor_by_name('batch_size_ph:0')
                # Get restored model output
                restored_logits = graph2.get_tensor_by_name('dense/BiasAdd:0')
                # Get dataset initializing operation
                dataset_init_op = graph2.get_operation_by_name('dataset_init')
    
                # Initialize restored dataset
                sess.run(
                    dataset_init_op,
                    feed_dict={
                        features_data_ph: features,
                        labels_data_ph: labels,
                        batch_size_ph: 32
                    }
    
                )
                # Compute inference for both batches in dataset
                restored_values = []
                for i in range(2):
                    restored_values.append(sess.run(restored_logits))
                    print('Restored values: ', restored_values[i][0])
    
        # Check if original inference and restored inference are equal
        valid = all((v == rv).all() for v, rv in zip(values, restored_values))
        print('\nInferences match: ', valid)
    

    这将打印:

    $ python3 save_and_restore.py
    
    Epoch 0, batch 0 | Sample value: [-0.13851789 -0.3087595   0.12804556  0.20013677 -0.08229901]
    Epoch 0, batch 1 | Sample value: [-0.00555491 -0.04339041 -0.05111827 -0.2480045  -0.00107776]
    Epoch 1, batch 0 | Sample value: [-0.19321944 -0.2104792  -0.00602257  0.07465433  0.11674127]
    Epoch 1, batch 1 | Sample value: [-0.05275984  0.05981954 -0.15913513 -0.3244143   0.10673307]
    Epoch 2, batch 0 | Final inference | Sample value: [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]
    Epoch 2, batch 1 | Final inference | Sample value: [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]
    
    Saving...
    INFO:tensorflow:Assets added to graph.
    INFO:tensorflow:No assets to write.
    INFO:tensorflow:SavedModel written to: b'/some/path/simple/saved_model.pb'
    Ok
    
    Restoring...
    INFO:tensorflow:Restoring parameters from b'/some/path/simple/variables/variables'
    Ok
    Restored values:  [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]
    Restored values:  [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]
    
    Inferences match:  True
    
  • 10

    我正在改进我的答案,添加更多有关保存和恢复模型的详细信息 .

    在(及之后) Tensorflow version 0.11

    Save the model:

    import tensorflow as tf
    
    #Prepare to feed input, i.e. feed_dict and placeholders
    w1 = tf.placeholder("float", name="w1")
    w2 = tf.placeholder("float", name="w2")
    b1= tf.Variable(2.0,name="bias")
    feed_dict ={w1:4,w2:8}
    
    #Define a test operation that we will restore
    w3 = tf.add(w1,w2)
    w4 = tf.multiply(w3,b1,name="op_to_restore")
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    #Create a saver object which will save all the variables
    saver = tf.train.Saver()
    
    #Run the operation by feeding input
    print sess.run(w4,feed_dict)
    #Prints 24 which is sum of (w1+w2)*b1 
    
    #Now, save the graph
    saver.save(sess, 'my_test_model',global_step=1000)
    

    Restore the model:

    import tensorflow as tf
    
    sess=tf.Session()    
    #First let's load meta graph and restore weights
    saver = tf.train.import_meta_graph('my_test_model-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./'))
    
    
    # Access saved Variables directly
    print(sess.run('bias:0'))
    # This will print 2, which is the value of bias that we saved
    
    
    # Now, let's access and create placeholders variables and
    # create feed-dict to feed new data
    
    graph = tf.get_default_graph()
    w1 = graph.get_tensor_by_name("w1:0")
    w2 = graph.get_tensor_by_name("w2:0")
    feed_dict ={w1:13.0,w2:17.0}
    
    #Now, access the op that you want to run. 
    op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
    
    print sess.run(op_to_restore,feed_dict)
    #This will print 60 which is calculated
    

    这里和一些更高级的用例已经在这里得到了很好的解释 .

    A quick complete tutorial to save and restore Tensorflow models

  • 1

    在(及之后)TensorFlow版本0.11.0RC1中,您可以根据https://www.tensorflow.org/programmers_guide/meta_graph调用 tf.train.export_meta_graphtf.train.import_meta_graph 来直接保存和恢复模型 .

    保存模型

    w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
    w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
    tf.add_to_collection('vars', w1)
    tf.add_to_collection('vars', w2)
    saver = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver.save(sess, 'my-model')
    # `save` method will call `export_meta_graph` implicitly.
    # you will get saved graph files:my-model.meta
    

    恢复模型

    sess = tf.Session()
    new_saver = tf.train.import_meta_graph('my-model.meta')
    new_saver.restore(sess, tf.train.latest_checkpoint('./'))
    all_vars = tf.get_collection('vars')
    for v in all_vars:
        v_ = sess.run(v)
        print(v_)
    
  • 123

    对于TensorFlow版本<0.11.0RC1:

    保存的检查点包含模型中 Variable 的值,而不是模型/图形本身,这意味着恢复检查点时图形应该相同 .

    以下是线性回归的示例,其中有一个训练循环可以保存变量检查点,还有一个评估部分可以恢复先前运行中保存的变量并计算预测 . 当然,如果您愿意,还可以恢复变量并继续训练 .

    x = tf.placeholder(tf.float32)
    y = tf.placeholder(tf.float32)
    
    w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
    b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
    y_hat = tf.add(b, tf.matmul(x, w))
    
    ...more setup for optimization and what not...
    
    saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
    
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        if FLAGS.train:
            for i in xrange(FLAGS.training_steps):
                ...training loop...
                if (i + 1) % FLAGS.checkpoint_steps == 0:
                    saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
                               global_step=i+1)
        else:
            # Here's where you're restoring the variables w and b.
            # Note that the graph is exactly as it was when the variables were
            # saved in a prior training run.
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                ...no checkpoint found...
    
            # Now you can run the model to get predictions
            batch_x = ...load some data...
            predictions = sess.run(y_hat, feed_dict={x: batch_x})
    

    以下是 Variabledocs,其中包括保存和恢复 . 以下是 Saverdocs .

  • 224

    我的环境:Python 3.6,Tensorflow 1.3.0

    虽然有很多解决方案,但大多数都基于 tf.train.Saver . 当我们加载由 Saver 保存的 .ckpt 时,我们必须重新定义张量流网络或使用一些奇怪且难以记住的名称,例如 'placehold_0:0''dense/Adam/Weight:0' . 在这里,我建议使用 tf.saved_model ,下面给出一个最简单的示例,您可以从Serving a TensorFlow Model了解更多信息:

    Save the model:

    import tensorflow as tf
    
    # define the tensorflow network and do some trains
    x = tf.placeholder("float", name="x")
    w = tf.Variable(2.0, name="w")
    b = tf.Variable(0.0, name="bias")
    
    h = tf.multiply(x, w)
    y = tf.add(h, b, name="y")
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    # save the model
    export_path =  './savedmodel'
    builder = tf.saved_model.builder.SavedModelBuilder(export_path)
    
    tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
    tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
    
    prediction_signature = (
      tf.saved_model.signature_def_utils.build_signature_def(
          inputs={'x_input': tensor_info_x},
          outputs={'y_output': tensor_info_y},
          method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
    
    builder.add_meta_graph_and_variables(
      sess, [tf.saved_model.tag_constants.SERVING],
      signature_def_map={
          tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
              prediction_signature 
      },
      )
    builder.save()
    

    Load the model:

    import tensorflow as tf
    sess=tf.Session() 
    signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
    input_key = 'x_input'
    output_key = 'y_output'
    
    export_path =  './savedmodel'
    meta_graph_def = tf.saved_model.loader.load(
               sess,
              [tf.saved_model.tag_constants.SERVING],
              export_path)
    signature = meta_graph_def.signature_def
    
    x_tensor_name = signature[signature_key].inputs[input_key].name
    y_tensor_name = signature[signature_key].outputs[output_key].name
    
    x = sess.graph.get_tensor_by_name(x_tensor_name)
    y = sess.graph.get_tensor_by_name(y_tensor_name)
    
    y_out = sess.run(y, {x: 3.0})
    
  • 36

    模型有两个部分,模型定义,由模型目录中的 Supervisor 保存为 graph.pbtxt ,张量的数值保存到检查点文件中,如 model.ckpt-1003418 .

    可以使用 tf.import_graph_def 恢复模型定义,并使用 Saver 恢复权重 .

    但是, Saver 使用特殊的集合保存变量列表,这些变量目前使用这两个变量(它在我们的路线图中进行修复) . 目前,您必须使用Ryan Sepassi的方法 - 手动构建具有相同节点名称的图形,并使用 Saver 将权重加载到其中 .

    (或者你可以通过使用 import_graph_def ,手动创建变量,并为每个变量使用 tf.add_to_collection(tf.GraphKeys.VARIABLES, variable) ,然后使用 Saver 来破解它)

  • -1

    你也可以采取这种更简单的方式 .

    步骤1:初始化所有变量

    W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
    B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")
    
    Similarly, W2, B2, W3, .....
    

    步骤2:将会话保存在模型保护程序中并保存

    model_saver = tf.train.Saver()
    
    # Train the model and save it in the end
    model_saver.save(session, "saved_models/CNN_New.ckpt")
    

    步骤3:恢复模型

    with tf.Session(graph=graph_cnn) as session:
        model_saver.restore(session, "saved_models/CNN_New.ckpt")
        print("Model restored.") 
        print('Initialized')
    

    第4步:检查你的变量

    W1 = session.run(W1)
    print(W1)
    

    在不同的python实例中运行时,请使用

    with tf.Session() as sess:
        # Restore latest checkpoint
        saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))
    
        # Initalize the variables
        sess.run(tf.global_variables_initializer())
    
        # Get default graph (supply your custom graph if you have one)
        graph = tf.get_default_graph()
    
        # It will give tensor object
        W1 = graph.get_tensor_by_name('W1:0')
    
        # To get the value (numpy array)
        W1_value = session.run(W1)
    
  • 172

    在大多数情况下,使用 tf.train.Saver 从磁盘保存和恢复是最佳选择:

    ... # build your model
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        ... # train the model
        saver.save(sess, "/tmp/my_great_model")
    
    with tf.Session() as sess:
        saver.restore(sess, "/tmp/my_great_model")
        ... # use the model
    

    您还可以保存/恢复图形结构本身(有关详细信息,请参阅MetaGraph documentation) . 默认情况下, Saver 将图形结构保存到 .meta 文件中 . 您可以调用 import_meta_graph() 来恢复它 . 它恢复图形结构并返回一个 Saver ,您可以使用它来恢复模型的状态:

    saver = tf.train.import_meta_graph("/tmp/my_great_model.meta")
    
    with tf.Session() as sess:
        saver.restore(sess, "/tmp/my_great_model")
        ... # use the model
    

    但是,有些情况下你需要更快的东西 . 例如,如果您实施提前停止,则希望每次模型在训练期间改进时保存检查点(在验证集上测量),然后如果一段时间没有进展,则需要回滚到最佳模型 . 如果你保存每次改进模型到磁盘,它将极大地减慢培训速度 . 诀窍是将变量状态保存到内存,然后稍后恢复它们:

    ... # build your model
    
    # get a handle on the graph nodes we need to save/restore the model
    graph = tf.get_default_graph()
    gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars]
    init_values = [assign_op.inputs[1] for assign_op in assign_ops]
    
    with tf.Session() as sess:
        ... # train the model
    
        # when needed, save the model state to memory
        gvars_state = sess.run(gvars)
    
        # when needed, restore the model state
        feed_dict = {init_value: val
                     for init_value, val in zip(init_values, gvars_state)}
        sess.run(assign_ops, feed_dict=feed_dict)
    

    快速解释:当您创建变量 X 时,TensorFlow会自动创建赋值操作 X/Assign 以设置变量的初始值 . 我们只使用这些现有的赋值操作,而不是创建占位符和额外的赋值操作(这会使图形变得混乱) . 每个赋值op的第一个输入是对它应该初始化的变量的引用,第二个输入( assign_op.inputs[1] )是初始值 . 因此,为了设置我们想要的任何值(而不是初始值),我们需要使用 feed_dict 并替换初始值 . 是的,TensorFlow允许您为任何操作提供值,而不仅仅是占位符,所以这很好 .

  • 37

    正如Yaroslav所说,你可以通过导入图形,手动创建变量,然后使用Saver来修复graph_def和checkpoint .

    我实现了这个用于个人用途,所以我虽然在这里分享代码 .

    链接:https://gist.github.com/nikitakit/6ef3b72be67b86cb7868

    (当然,这是一个黑客攻击,并且无法保证以这种方式保存的模型在未来的TensorFlow版本中仍然可读 . )

  • 18

    如果它是内部保存的模型,则只需为所有变量指定恢复器

    restorer = tf.train.Saver(tf.all_variables())
    

    并使用它来恢复当前会话中的变量:

    restorer.restore(self._sess, model_file)
    

    对于外部模型,您需要指定从其变量名到变量名的映射 . 您可以使用该命令查看模型变量名称

    python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt
    

    inspect_checkpoint.py脚本可以在Tensorflow源的“./tensorflow/python/tools”文件夹中找到 .

    要指定映射,可以使用我的Tensorflow-Worklab,它包含一组类和脚本来训练和重新训练不同的模型 . 它包括一个重新训练ResNet模型的例子,位于here

  • 16

    这是我对两个基本情况的简单解决方案,这两个基本情况不同于您是要从文件加载图还是在运行时构建它 .

    这个答案适用于Tensorflow 0.12(包括1.0) .

    在代码中重建图形

    保存

    graph = ... # build the graph
    saver = tf.train.Saver()  # create the saver after the graph
    with ... as sess:  # your session object
        saver.save(sess, 'my-model')
    

    正在加载

    graph = ... # build the graph
    saver = tf.train.Saver()  # create the saver after the graph
    with ... as sess:  # your session object
        saver.restore(sess, tf.train.latest_checkpoint('./'))
        # now you can use the graph, continue training or whatever
    

    也从文件中加载图表

    When using this technique, make sure all your layers/variables have explicitly set unique names. 否则Tensorflow将使名称本身唯一,并且它们在前一种技术中不是问题,因为在加载和保存时名称都是"mangled"相同 .

    保存

    graph = ... # build the graph
    
    for op in [ ... ]:  # operators you want to use after restoring the model
        tf.add_to_collection('ops_to_restore', op)
    
    saver = tf.train.Saver()  # create the saver after the graph
    with ... as sess:  # your session object
        saver.save(sess, 'my-model')
    

    正在加载

    with ... as sess:  # your session object
        saver = tf.train.import_meta_graph('my-model.meta')
        saver.restore(sess, tf.train.latest_checkpoint('./'))
        ops = tf.get_collection('ops_to_restore')  # here are your operators in the same order in which you saved them to the collection
    
  • 7

    您还可以在TensorFlow/skflow中查看examples,它提供了 saverestore 方法,可以帮助您轻松管理模型 . 它具有参数,您还可以控制备份模型的频率 .

  • 52

    如果您使用tf.train.MonitoredTrainingSession作为默认会话,那么您将使用会话挂钩来处理这些问题 .

  • 13

    如问题_866036中所述:

    use '**./**model_name.ckpt'
    saver.restore(sess,'./my_model_final.ckpt')
    

    代替

    saver.restore('my_model_final.ckpt')
    
  • 9

    这里的所有答案都很棒,但我想添加两件事 .

    首先,要详细说明@ user7505159的答案,“./”对于添加到要还原的文件名的开头非常重要 .

    例如,您可以在文件名中保存没有“./”的图形,如下所示:

    # Some graph defined up here with specific names
    
    saver = tf.train.Saver()
    save_file = 'model.ckpt'
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.save(sess, save_file)
    

    但是为了恢复图形,您可能需要在文件名前加上“./”:

    # Same graph defined up here
    
    saver = tf.train.Saver()
    save_file = './' + 'model.ckpt' # String addition used for emphasis
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, save_file)
    

    您不会总是需要“./”,但它可能会导致问题,具体取决于您的环境和TensorFlow版本 .

    它还想提一下,在恢复会话之前, sess.run(tf.global_variables_initializer()) 可能很重要 .

    如果在尝试还原已保存的会话时收到有关未初始化变量的错误,请确保在 saver.restore(sess, save_file) 行之前包含 sess.run(tf.global_variables_initializer()) . 它可以让你头疼 .

  • 53

    如果要减小模型大小,请使用tf.train.Saver保存模型,remerber,需要指定var_list . val_list可以是tf.trainable_variables或tf.global_variables .

  • 1

    根据新的Tensorflow版本, tf.train.Checkpoint 是保存和恢复模型的首选方式:

    Checkpoint.save和Checkpoint.restore写入和读取基于对象的检查点,而tf.train.Saver则写入和读取基于variable.name的检查点 . 基于对象的检查点保存了Python对象(图层,优化器,变量等)与命名边之间的依赖关系图,此图用于在恢复检查点时匹配变量 . 它对Python程序的更改更加健壮,并且有助于在急切执行时支持变量的创建恢复 . 首选tf.train.Checkpoint over tf.train.Saver以获取新代码 .

    这是一个例子:

    import tensorflow as tf
    import os
    
    tf.enable_eager_execution()
    
    checkpoint_directory = "/tmp/training_checkpoints"
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
    status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
    for _ in range(num_training_steps):
      optimizer.minimize( ... )  # Variables will be restored on creation.
    status.assert_consumed()  # Optional sanity checks.
    checkpoint.save(file_prefix=checkpoint_prefix)
    

    More information and example here.

  • 6

    渴望模式也是一个未解决的保存/恢复问题,尽管文档肯定会回答,但没有人回答甚至文档都没有 . 这是我编写的非工作代码,它试图将tensorflow.contrib.eager中的Saver类用作tfe . 我的代码肯定保存到了磁盘...保存了一些东西 . 问题是恢复 . 我甚至添加了显式代码,首先手动重新创建所有内容,然后加载学习的参数:

    optimizer = tf.train.AdamOptimizer() #ga
    global_step = tf.train.get_or_create_global_step()  # what is a global_step?
    model = tf.keras.Sequential([
      tf.keras.layers.Dense(10, activation=tf.nn.relu, input_shape=(4,)),  # input shape required
      tf.keras.layers.Dense(10, activation=tf.nn.relu, kernel_initializer='glorot_uniform'),
      tf.keras.layers.Dense(3)
    ])
    #s = tfe.Saver([optimizer, model, global_step])
    s = tfe.Saver([model])
    s.restore(file_prefix="/tmp/iris-1")
    

    它恢复了一些然后抛出一个ValueError:

    INFO:tensorflow:Restoring parameters from /tmp/iris-1
    
    ---------------------------------------------------------------------------
    ValueError...
    --> names, slices, dtypes = zip(*restore_specs)
    

相关问题