在pytorch中,我们可以给出一个打包序列作为RNN的输入 . 从official doc开始,RNN的输入可以如下 .
input(seq_len,batch,input_size):包含输入序列特征的张量 . 输入也可以是打包的可变长度序列 .
例
packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
outputs, hidden = self.rnn(packed, hidden)
outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs)
这里, embedded
是批输入的嵌入表示 .
我的问题是,如何对RNN中的打包序列进行计算?如何通过打包表示计算批量填充序列的隐藏状态?
1 回答
基于this relevent question的matthew_zeng的答案:未计算填充元素的输出,隐藏将是最后一次有效输入后的隐藏状态 .