Home Articles

Pytorch:RNN模型的DataParallel错误

Asked
Viewed 1386 times
1

我正在尝试将torch.nn.DataParallel用于RNN模型 . 我的模型看起来像这样:

class EncoderRNN(nn.Module):
def __init__(self, vocal_size, hidden_size):
    super(EncoderRNN, self).__init__()
    self.hidden_size = hidden_size
    self.embedding = nn.Embedding(vocal_size, hidden_size)
    self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)

def forward(self, input_batch, input_batch_length, hidden):
    embedded = self.embedding(input_batch)
    packed_input = nn.utils.rnn.pack_padded_sequence(embedded, input_batch_length.cpu().numpy(), batch_first=True)
    output, hidden = self.gru(packed_input, hidden)
    return output, hidden


class DecoderRNN(nn.Module):
def __init__(self, hidden_size, vocab_size):
    super(DecoderRNN, self).__init__()
    self.hidden_size = hidden_size
    self.embedding = nn.Embedding(vocab_size, hidden_size)
    self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
    self.out = nn.Linear(hidden_size, vocab_size)
    self.softmax = nn.LogSoftmax(dim=1)

def forward(self, target_batch, target_batch_length, hidden, train=False):
    embedded = self.embedding(target_batch)
    output = F.relu(embedded)

    if train:
        # minus 1 to eliminate <EOS>
        packed_target = nn.utils.rnn.pack_padded_sequence(output, (target_batch_length - 1).cpu().numpy(),
                                                          batch_first=True)

    output, hidden = self.gru(packed_target, hidden)
    output = self.softmax(self.out(output[0]))
    return output, hidden

我在声明模型时实现了DataParallel:

encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)

代码在具有4个GPU的服务器上运行,我收到以下错误消息:

/home/cjunjie/NLP/DocSummarization/model.py:18: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters(). 
output, hidden = self.gru(packed_input, hidden)
Traceback (most recent call last):
  File "train.py", line 144, in <module>
    train_iteration(encoder, decoder, fileDataSet)
  File "train.py", line 110, in train_iteration
    target_indices, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
  File "train.py", line 41, in train
    encoder_output, encoder_hidden = encoder(input_batch, input_batch_length, encoder_hidden)
  File "/home/cjunjie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/cjunjie/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 74, in forward
return self.gather(outputs, self.output_device)
  File "/home/cjunjie/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 86, in gather
    return gather(outputs, output_device, dim=self.dim)
  File "/home/cjunjie/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 65, in gather
    return gather_map(outputs)
  File "/home/cjunjie/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 60, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
  File "/home/cjunjie/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 60, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
  File "/home/cjunjie/anaconda3/lib/python3.6/site-packages/torch/nn/utils/rnn.py", line 39, in __new__
    return super(PackedSequence, cls).__new__(cls, *args[0])
  File "/home/cjunjie/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 57, in gather_map
    return Gather.apply(target_device, dim, *outputs)
  File "/home/cjunjie/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 58, in forward
    assert all(map(lambda i: i.is_cuda, inputs))
AssertionError

我搜索了同样的问题,但没有一个有解决方案 . 有人可以帮忙吗?

1 Answer

  • 0

    要在GPU上运行代码,您需要将变量和模型权重复制到cuda . 我怀疑你没有将模型权重复制到cuda . 要做到这一点,你需要做

    encoder.cuda()
    decoder.cuda()
    

Related