_ = importer.import_graph_def(input_graph_def, name='')
with session.Session() as sess:
if input_saver_def:
saver = saver_lib.Saver(saver_def=input_saver_def)
saver.restore(sess, input_checkpoint)
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.split(','),
variable_names_blacklist=variable_names_blacklist)
在上面的代码中,导入器用于将graphDef导入到当前默认图形,并且保护程序加载以前训练的值 . 问题是存储这些训练值的位置?在会话中,在input_graph_def中,在当前图形结构(tf.get_default_graph())中还是在保护程序中?
我检查方法 convert_variables_to_constants
的代码 . https://github.com/tensorflow/tensorflow/blob/235192d47cfb375c0cc93c1deefb9e440715bf35/tensorflow/python/framework/graph_util_impl.py
它使用sess.run(变量名)来获取加载的值 . 这个 sess.run
从哪里获取值?
1 回答
当我们定义保护程序时,我们应该将注释传递给它(默认情况下它是全局变量) .
这里我们首先初始化所有变量并将其保存为“./temp_model” . 要恢复变量:
我们可以将a和b保存到不同的地方:
并将它们恢复为图形: