首页 文章

如何在PyTorch中正确实现批量输入LSTM网络?

提问于
浏览
13

这个release的PyTorch似乎为递归神经网络的可变长度输入提供了 PackedSequence . 但是,我发现正确使用它有点困难 .

使用 pad_packed_sequence 恢复由 pack_padded_sequence 馈送的RNN图层的输出,我们得到 T x B x N 张量 outputs ,其中 T 是最大时间步长, B 是批量大小, N 是隐藏大小 . 我发现对于批处理中的短序列,后续输出将全为零 .

这是我的问题 .

  • 对于单个输出任务,其中一个将需要所有序列的最后一个输出,简单的 outputs[-1] 将给出错误的结果,因为该张量包含许多短序列的零 . 需要按序列长度构建索引以获取所有序列的单个最后输出 . 有更简单的方法吗?

  • 对于多输出任务(例如seq2seq),通常会添加一个线性层 N x O 并将批输出 T x B x O 重新整形为 TB x O 并使用真实目标 TB (通常是语言模型中的整数)计算交叉熵损失 . 在这种情况下,批量输出中的这些零是否重要?

2 回答

  • 0

    Question 1 - Last Timestep

    这是我用来获取最后一个时间步的输出的代码 . 我不想知道这件事 . 我跟着这个discussion并 grab 了我的 last_timestep 方法的相关代码片段 . 这是我的前锋 .

    class BaselineRNN(nn.Module):
        def __init__(self, **kwargs):
            ...
    
        def last_timestep(self, unpacked, lengths):
            # Index of the last output for each sequence.
            idx = (lengths - 1).view(-1, 1).expand(unpacked.size(0),
                                                   unpacked.size(2)).unsqueeze(1)
            return unpacked.gather(1, idx).squeeze()
    
        def forward(self, x, lengths):
            embs = self.embedding(x)
    
            # pack the batch
            packed = pack_padded_sequence(embs, list(lengths.data),
                                          batch_first=True)
    
            out_packed, (h, c) = self.rnn(packed)
    
            out_unpacked, _ = pad_packed_sequence(out_packed, batch_first=True)
    
            # get the outputs from the last *non-masked* timestep for each sentence
            last_outputs = self.last_timestep(out_unpacked, lengths)
    
            # project to the classes using a linear layer
            logits = self.linear(last_outputs)
    
            return logits
    

    Question 2 - Masked Cross Entropy Loss

    是的,默认情况下,零填充时间步长(目标)很重要 . 但是,它很容易掩盖它们 . 您有两种选择,具体取决于您使用的PyTorch版本 .

    • PyTorch 0.2.0:现在pytorch支持使用 ignore_index 参数直接在CrossEntropyLoss中进行屏蔽 . 例如,在语言建模或seq2seq中,我添加零填充,我掩盖零填充的单词(目标),就像这样:

    loss_function = nn.CrossEntropyLoss(ignore_index = 0)

  • 7

    几天前,我发现this method使用索引来完成与单行相同的任务 .

    我有我的数据集批次( [batch size, sequence length, features] ),所以对我来说:

    unpacked_out = unpacked_out[np.arange(unpacked_out.shape[0]), lengths - 1, :]
    

    其中 unpacked_outtorch.nn.utils.rnn.pad_packed_sequence 的输出 .

    我将它与方法described here进行了比较,它看起来类似于Christos Baziotis在上面使用的 last_timestep() 方法(也推荐here),结果在我的情况下是相同的 .

相关问题