我想保存Tensorflow(0.12.0)模型,包括图形和变量值,然后加载并执行它 . 我已经阅读了文档和其他帖子,但无法使基础工作 . 我正在使用this page in the Tensorflow docs中的技术 . 码:
保存一个简单的模型:
myVar = tf.Variable(7.1)
tf.add_to_collection('modelVariables', myVar) # why?
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
print sess.run(myVar)
saver0 = tf.train.Saver()
saver0.save(sess, './myModel.ckpt')
saver0.export_meta_graph('./myModel.meta')
稍后,加载并执行模型:
with tf.Session() as sess:
saver1 = tf.train.import_meta_graph('./myModel.meta')
saver1.restore(sess, './myModel.meta')
print sess.run(myVar)
Question 1: 保存代码似乎有效,但加载代码会产生此错误:
W tensorflow/core/util/tensor_slice_reader.cc:95] Could not open ./myModel.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
如何解决这个问题?
Question 2: 我将此行包含在TF文档中的模式中...
tf.add_to_collection('modelVariables', myVar)
......但为什么这条线是必要的?默认情况下 expert_meta_graph
不导出整个图表?如果没有,那么在保存之前是否需要将图中的每个变量添加到集合中?或者我们只是将那些将在恢复后访问的变量添加到集合中?
---------------------- 2017年1月12日更新------------------------ -----
部分成功基于Kashyap的建议,但仍然存在一个谜 . 下面的代码有效但仅当我包含包含 tf.add_to_collection
和 tf.get_collection
的行时 . 如果没有这些行,'load'模式会在最后一行中抛出错误: NameError: name 'myVar' is not defined
. 我的理解是,默认情况下 Saver.save
保存并恢复图中的所有变量,那么为什么需要指定将在集合中使用的变量的名称?我认为这与将Tensorflow的变量名称映射到Python名称有关,但这里的游戏规则是什么?对于哪些变量需要这样做?
mode = 'load' # or 'save'
if mode == 'save':
myVar = tf.Variable(7.1)
init_op = tf.global_variables_initializer()
saver0 = tf.train.Saver()
tf.add_to_collection('myVar', myVar) ### WHY NECESSARY?
with tf.Session() as sess:
sess.run(init_op)
print sess.run(myVar)
saver0.save(sess, './myModel')
if mode == 'load':
with tf.Session() as sess:
saver1 = tf.train.import_meta_graph('./myModel.meta')
saver1.restore(sess, tf.train.latest_checkpoint('./'))
myVar = tf.get_collection('myVar')[0] ### WHY NECESSARY?
print sess.run(myVar)
2 回答
我一直试图弄清楚同样的事情,并且能够通过使用
Supervisor
成功地做到这一点 . 它会自动加载所有变量和图形等 . 这是文档 - https://www.tensorflow.org/programmers_guide/supervisor . 以下是我的代码 -如你所见,这比使用
Saver
对象和处理单个变量等简单得多,只要图形保持不变(我的理解是当我们想要为不同的图重用预先训练的模型时Saver
会很方便) .Question1
这个问题已经得到了彻底的回答here . 您不必显式调用
export_meta_graph
. 拨打save method
. 这也将生成.meta
文件(因为save方法将在内部调用export_meta_graph
方法 . )例如
saver0.save(sess, './myModel.ckpt')
将生成
myModel.ckpt
文件以及myModel.ckpt.meta
文件 .然后您可以使用恢复模型
Question2
集合用于存储自定义信息,例如学习率,您使用的正则化因子以及其他信息,这些信息将在您导出图形时存储 . Tensorflow本身定义了一些像"TRAINABLE_VARIABLES"这样的集合,用于获取您构建的模型的所有可训练变量 . 您可以选择导出图表中的所有集合,也可以指定要在
export_meta_graph
函数中导出的集合 .是tensorflow将导出您定义的所有变量 . 但是,如果您需要任何其他需要导出到图表的信息,则可以将它们添加到集合中 .