首页 文章

保存并加载Tensorflow模型

提问于
浏览
1

我想保存Tensorflow(0.12.0)模型,包括图形和变量值,然后加载并执行它 . 我已经阅读了文档和其他帖子,但无法使基础工作 . 我正在使用this page in the Tensorflow docs中的技术 . 码:

保存一个简单的模型:

myVar = tf.Variable(7.1)
tf.add_to_collection('modelVariables', myVar) # why?
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    print sess.run(myVar)
    saver0 = tf.train.Saver()
    saver0.save(sess, './myModel.ckpt')
    saver0.export_meta_graph('./myModel.meta')

稍后,加载并执行模型:

with tf.Session() as sess:
    saver1 = tf.train.import_meta_graph('./myModel.meta')
    saver1.restore(sess, './myModel.meta')
    print sess.run(myVar)

Question 1: 保存代码似乎有效,但加载代码会产生此错误:

W tensorflow/core/util/tensor_slice_reader.cc:95] Could not open ./myModel.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?

如何解决这个问题?

Question 2: 我将此行包含在TF文档中的模式中...

tf.add_to_collection('modelVariables', myVar)

......但为什么这条线是必要的?默认情况下 expert_meta_graph 不导出整个图表?如果没有,那么在保存之前是否需要将图中的每个变量添加到集合中?或者我们只是将那些将在恢复后访问的变量添加到集合中?

---------------------- 2017年1月12日更新------------------------ -----

部分成功基于Kashyap的建议,但仍然存在一个谜 . 下面的代码有效但仅当我包含包含 tf.add_to_collectiontf.get_collection 的行时 . 如果没有这些行,'load'模式会在最后一行中抛出错误: NameError: name 'myVar' is not defined . 我的理解是,默认情况下 Saver.save 保存并恢复图中的所有变量,那么为什么需要指定将在集合中使用的变量的名称?我认为这与将Tensorflow的变量名称映射到Python名称有关,但这里的游戏规则是什么?对于哪些变量需要这样做?

mode = 'load' # or 'save'
if mode == 'save':
    myVar = tf.Variable(7.1)
    init_op = tf.global_variables_initializer()
    saver0 = tf.train.Saver()
    tf.add_to_collection('myVar', myVar) ### WHY NECESSARY?
    with tf.Session() as sess:
        sess.run(init_op)
        print sess.run(myVar)
        saver0.save(sess, './myModel')
if mode == 'load':
    with tf.Session() as sess:
        saver1 = tf.train.import_meta_graph('./myModel.meta')
        saver1.restore(sess, tf.train.latest_checkpoint('./'))
        myVar = tf.get_collection('myVar')[0]  ### WHY NECESSARY?
        print sess.run(myVar)

2 回答

  • 1

    我一直试图弄清楚同样的事情,并且能够通过使用 Supervisor 成功地做到这一点 . 它会自动加载所有变量和图形等 . 这是文档 - https://www.tensorflow.org/programmers_guide/supervisor . 以下是我的代码 -

    sv = tf.train.Supervisor(logdir="/checkpoint', save_model_secs=60)
        with sv.managed_session() as sess:
            if not sv.should_stop(): 
                #Do run/eval/train ops on sess as needed. Above works for both saving and loading
    

    如你所见,这比使用 Saver 对象和处理单个变量等简单得多,只要图形保持不变(我的理解是当我们想要为不同的图重用预先训练的模型时 Saver 会很方便) .

  • 1

    Question1

    这个问题已经得到了彻底的回答here . 您不必显式调用 export_meta_graph . 拨打 save method . 这也将生成 .meta 文件(因为save方法将在内部调用 export_meta_graph 方法 . )

    例如

    saver0.save(sess, './myModel.ckpt')

    将生成 myModel.ckpt 文件以及 myModel.ckpt.meta 文件 .

    然后您可以使用恢复模型

    with tf.Session() as sess:
        saver1 = tf.train.import_meta_graph('./myModel.ckpt.meta')
        saver1.restore(sess, './myModel')
        print sess.run(myVar)
    

    Question2

    集合用于存储自定义信息,例如学习率,您使用的正则化因子以及其他信息,这些信息将在您导出图形时存储 . Tensorflow本身定义了一些像"TRAINABLE_VARIABLES"这样的集合,用于获取您构建的模型的所有可训练变量 . 您可以选择导出图表中的所有集合,也可以指定要在 export_meta_graph 函数中导出的集合 .

    是tensorflow将导出您定义的所有变量 . 但是,如果您需要任何其他需要导出到图表的信息,则可以将它们添加到集合中 .

相关问题