首页 文章

Tensorflow模型占位符恢复

提问于
浏览
1

我想在Tensorflow中恢复,修改和重用(相当)复杂的模型,但是在使用占位符时,如何正确地传递feed_dict有一些困难 . 代码如下:

input_dir = "parallel_win_10_40_conv_3l_rnn"
input_file = "parallel_win_10_40_conv_3l_rnn"
saver = tf.train.import_meta_graph("./result/cnn_rnn_parallel/tune_rnn_layer/"+input_dir+"/model_"+input_file+".meta")

# # Method 1
# all_placeholders = [x for x in tf.get_default_graph().get_operations() if x.type == "Placeholder"]
# cnn_in, rnn_in, Y = all_placeholders[0], all_placeholders[1], all_placeholders[2]
# keep_prob, phase_train = all_placeholders[3], all_placeholders[4]

# Method 2
cnn_in = tf.placeholder(tf.float32, shape=[None, input_height, input_width, input_channel_num], name='cnn_in')
rnn_in = tf.placeholder(tf.float32, shape=[None, n_time_step, n_input_ele], name='rnn_in')
Y = tf.placeholder(tf.float32, shape=[None, n_labels], name = 'Y')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
phase_train = tf.placeholder(tf.bool, name='phase_train')

with tf.Session() as session:
    saver.restore(session, "./result/cnn_rnn_parallel/tune_rnn_layer/"+input_dir+"/model_"+input_file)

    test_cnn_batch = np.zeros(shape=[accuracy_batch_size], dtype=float)
    test_rnn_batch = np.zeros(shape=[accuracy_batch_size], dtype=float)

    offset = (accuracy_batch_size) % (test_y.shape[0] - accuracy_batch_size)
    test_cnn_batch = cnn_test_x[offset:(offset + accuracy_batch_size), :, :, :, :]
    test_cnn_batch = test_cnn_batch.reshape(len(test_cnn_batch) * window_size, input_height, input_width, 1)
    test_rnn_batch = rnn_test_x[offset:(offset + accuracy_batch_size), :, :]
    test_batch_y = test_y[offset:(offset + accuracy_batch_size), :]

    print(session.run('fin_m:0', feed_dict={cnn_in: test_cnn_batch, rnn_in: test_rnn_batch,
                                        Y: test_batch_y, keep_prob: 1.0, phase_train: False}))

当我使用方法1时,我收到一个错误:

TypeError:无法将feed_dict键解释为Tensor:无法将操作转换为Tensor .

当我使用方法2时,我得到一个不同的错误:

InvalidArgumentError(请参见上面的回溯):您必须使用dtype float为占位符张量'cnn_in'提供值

这两个错误都让我感到困惑,因为在保存模型之前占位符的定义完全相同,所以它们不应该具有相同的类型(Operation或Tensor)?对于第二种方法,test_cnn_batch是一个带有浮点值的ndarray . 我认为这可能是因为模型中的cnn_in是在saver = tf.train.import_meta_graph行中定义的(根据错误信息) . 我认为重新定义之后可能有所帮助,但没有骰子 .

这里发生了什么?这样做的正确方法是什么?我已经阅读了许多相关的问题,但它们没有直接解决这些问题 .

任何帮助表示赞赏 .

1 回答

  • 0

    您的方法1是错误的,因为您需要获取 Tensors ,而不是定义占位符的操作 . 由于你循环了 get_operations() 的结果,你得到的是操作,而不是张量 .

    我们的方法2也是错误的,因为你没有得到图中的占位符,你定义的新的那些没有连接到你的计算图的其余部分 .

    您必须做的是找到占位符的 names ,然后按名称从图中获取它们 . 错误代码已经显示了您需要的名称之一:

    InvalidArgumentError(请参见上面的回溯):您必须使用dtype float为占位符张量'cnn_in'提供值

    然后你可以这样做:

    cnn_in = tf.get_default_graph().get_tensor_by_name('cnn_in')
    

    Note :您可能需要在张量名称后附加 :0

    从图表中为所需的所有占位符重复相同的过程 .

相关问题