我遇到了以下示例,我不知道可以按如下方式提供RNN状态 .
self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=True)
self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.initial_state = cell.zero_state(args.batch_size, tf.float32)
在这段代码中,初始状态被声明为归零状态 . 据我所知,这不是占位符 . 它只是零张量的一个元素 .
然后在使用RNN模型生成初始状态的函数中,在session.run中输入 .
def sample(self, sess, chars, vocab, num=200, prime='The ', sampling_type=1):
state = sess.run(self.cell.zero_state(1, tf.float32))
for char in prime[:-1]:
x = np.zeros((1, 1))
x[0, 0] = vocab[char]
feed = {self.input_data: x, self.initial_state:state}
[state] = sess.run([self.final_state], feed)
由于self.initial_state不是占位符,如何才能获得win session.run?
这是我正在查看的代码的link .
2 回答
请注意,您可以输入任何变量,而不仅仅是占位符 . 因此,在这种情况下,您可以手动输入元组的每个组件:
在阅读类似的RNN代码时,我遇到了与您相同的问题 .
根据我的理解,
rnn_cell.zero_state
实际上会返回一个可以输入的张量元组 . 您的占位符也是张量 .所以,如果你这样做:
并且饲料字典允许您只要是张量或张量阵列就可以进食 .