首页 文章

将tensorflow dynamic_rnn的输出馈送到后续层

提问于
浏览
1

我已经开始在tensorflow中使用RNN,我得到了一般原则,但实现的某些方面并不十分清楚 .

我的理解:假设我正在训练一个序列到序列的网络,其中输入的大小与输出的大小相同(可能类似于在每个时间步长预测一段文本中的下一个字符) . 我的重复层使用LSTM单元,之后我想要一个完全连接的层来为预测添加更多深度 .

在静态RNN中,按照TF约定,您应该在时间维度上将输入数据取消堆叠,并将其作为列表提供给 static_rnn 方法,如下所示:

import tensorflow as tf

num_input_features = 32
num_output_features = 32

lstm_size = 128
max_seq_len = 5

# input/output:
x = tf.placeholder(tf.float32, [None, max_seq_len, num_input_features])

x_series = tf.unstack(x, axis=1) # a list of length max_seq_len

# recurrent layer:
lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)
rnn_outputs, final_state = tf.nn.static_rnn(lstm_cell, x_series, dtype=tf.float32)

这将为您提供一个输出列表,每个时间步长一个 . 然后,如果您想在每一步对RNN的输出进行一些额外的计算,您可以对输出列表的每个元素执行此操作:

# output layer:

w = tf.Variable(tf.random_normal([lstm_size, num_output_features]))
b = tf.Variable(tf.random_normal([num_output_features]))

z_series = [tf.matmul(out, w) + b for out in rnn_outputs]
yhat_series = [tf.nn.tanh(z) for z in z_series]

然后我可以再次叠加 yhat_series 并将其与某些标签 y 进行比较,以获得我的成本函数 .

这里's what I don't get:在动态RNN中,输入到 dynamic_rnn 方法的输入是具有自己的时间维度的张量(默认情况下为1):

# input/output:
x = tf.placeholder(tf.float32, [None, max_seq_len, num_input_features])

# x_series = tf.unstack(x, axis=1) # dynamic RNN does not need this

# recurrent layer:
lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)
dyn_rnn_outputs, dyn_final_state = tf.nn.dynamic_rnn(lstm_cell, x, dtype=tf.float32)

那么 dyn_rnn_output 不是一个列表,而是一个形状的张量(?,max_seq_len,lstm_size) . 什么's the best way to handle feeding this tensor to a subsequent dense layer? I can't将RNN输出乘以我的权重矩阵,并且对RNN输出进行取消堆叠感觉就像dynamic_rnn API旨在避免的尴尬 .

我错过了一个很好的方法吗?

1 回答

  • 1

    任何想要解决这个问题的人的更新:

    有一个tensorflow函数, tf.contrib.rnn.OutputProjectionWrapper ,似乎专门用于将一个密集层附加到RNN单元的输出,但是将其作为RNN单元本身的一部分包装起来,然后您可以通过调用 tf.nn.dynamic_rnn 来展开:

    lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)
    proj = tf.contrib.rnn.OutputProjectionWrapper(lstm_cell, num_output_features)
    dyn_rnn_outputs, dyn_final_state = tf.nn.dynamic_rnn(proj, x, dtype=tf.float32)
    

    但更一般地说,如果你想对RNN的输出进行操作,通常的做法似乎是通过展开批次和时间维度来重塑 rnn_outputs ,在张量上执行操作,并将它们重新打开以进行最终操作 . 输出 .

相关问题