首页 文章

在同一个Python会话中保存和恢复Tensorflow图

提问于
浏览
0

有几个相关的问题,但似乎没有解决我的具体问题 .

我写了一些保存和恢复TensorFlow模型的代码 . 如果我保存模型并在后续的python运行中恢复模型,一切都还可以 . 但是,如果我尝试在同一个Python实例中保存和恢复模型,我会收到以下错误:

TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Foo/X:0", shape=(?, 4), dtype=float32) is not an element of this graph.

据我所知,恢复后的变量“Foo / X”位于图中:

[n.name for n in tf.get_default_graph().as_graph_def().node]

我的代码的基本思想是使用相同的TensorFlow API调用创建/重新创建图形,然后使用tf.train.Saver() . restore()来恢复训练状态 . 一个给出相同错误的简化示例(在函数Barfoo的最后一行):

import numpy as np
import tensorflow as tf

def Foobar():
    global R1
    with tf.variable_scope('Foo'):
        X = tf.placeholder("float", [None, 4], name = 'X')
        Y = tf.placeholder("float", [None], name = 'Y')
        W = tf.Variable(tf.ones([4, 1]), name = 'W')
        YH = tf.matmul(X, W, name = 'YH')
        L = tf.reduce_sum(tf.nn.l2_loss(YH - Y), name = 'L')
        O = tf.train.AdamOptimizer(learning_rate = 0.001, name = 'O').minimize(L)
        init = tf.global_variables_initializer()
    S.run(init)
    for i in range(32):
        l1, _ = S.run([L, O], feed_dict = {X: x, Y: y})
        print(str(l1))
    R1 = S.run(YH, feed_dict = {X: np.ones((1, 4))})
    saver = tf.train.Saver()
    saver.save(S, "TFModel/savemodel")

def Barfoo():  
    global R2
    with tf.variable_scope('Foo'):
        X = tf.placeholder("float", [None, 4], name = 'X')
        Y = tf.placeholder("float", [None], name = 'Y')
        W = tf.Variable(tf.ones([4, 1]), name = 'W')
        YH = tf.matmul(X, W, name = 'YH')
        L = tf.reduce_sum(tf.nn.l2_loss(YH - Y), name = 'L')
        O = tf.train.AdamOptimizer(learning_rate = 0.001, name = 'O').minimize(L)
    saver = tf.train.Saver()
    saver.restore(S, tf.train.latest_checkpoint('TFModel/'))
    print(str([n.name for n in tf.get_default_graph().as_graph_def().node]))
    R2 = S.run(YH, feed_dict = {X: np.ones((1, 4))})

x = np.random.rand(32, 4)
y = x.sum(axis = 1) + np.random.rand(32) / 10
S = tf.Session()
R1, R2 = None, None     
Foobar()
tf.reset_default_graph()
Barfoo()
print('R1: ' + str(R1))
print('R2: ' + str(R2)

Why does this code give an error on trying to use the variable X in Barfoo? 如果我先运行Foobar,终止程序,然后运行Barfoo,为什么会有效呢?

1 回答

  • 0

    以粗体回答问题,即:

    Why does this code give an error on trying to use the variable X in Barfoo?

    使用S.run,您需要计算要计算的张量的名称,以及带有张量名称作为键的feed_dict的字典 . 您试图将张量对象本身作为键传递,而不是将它们的名称传递给它 . 相比:

    错误)

    R2 = S.run(YH, feed_dict = {X: np.ones((1, 4))})
    

    对)

    R2 = S.run("Foo/YH:0", feed_dict={"Foo/X:0": np.ones((1, 4))})
    

    请注意,我通知了张量的名称,而不是张量器本身(正如您的版本中所做的那样) . 只需更改上面的行即可使代码正常工作 .

    关于第二个问题,请更清楚一点如何重现它,所以我可以检查发生了什么 .

相关问题