首页 文章

如何在PyTorch中使用LSTM进行强化学习?

提问于
浏览
2

由于观察不能揭示整个状态,我需要使用循环神经网络进行强化,以便网络对过去发生的事情有某种记忆 . 为简单起见,我们假设我们使用LSTM .

现在内置的PyTorch LSTM需要你输入一个形状为 Time x MiniBatch x Input D 的输入,并输出一个形状张量 Time x MiniBatch x Output D .

然而,在强化学习中,为了知道 t+1 时的输入,我需要知道 t 时的输出,因为我在环境中进行操作 .

那么是否可以使用内置的PyTorch LSTM在强化学习环境中进行BPTT?如果是的话,我怎么能这样做?

1 回答

  • 2

    也许您可以将循环中的输入序列提供给LSTM . 像这样的东西:

    h, c = Variable(torch.zeros()), Variable(torch.zeros())
    for i in range(T):
        input = Variable(...)
        _, (h, c) = lstm(input, (h,c))
    

    每个时间步都可以使用(h,c)和输入来评估操作 . 只要你不破坏计算图,你可以反向传播,因为变量保留了所有的历史记录 .

相关问题