首页 文章

如何将PyTorch的torch.inverse()函数应用于批处理中的每个样本?

提问于
浏览
5

这似乎是一个基本问题,但我无法解决这个问题 .

在我的神经网络的前向传递中,我有一个8x3x3形状的输出张量,其中8是我的批量大小 . 我们可以假设每个3x3张量是非奇异矩阵 . 我需要找到这些矩阵的逆 . PyTorch inverse()函数仅适用于方形矩阵 . 由于我现在有8x3x3,如何以可区分的方式将此功能应用于批处理中的每个矩阵?

如果我遍历样本并将反转追加到python列表,然后我将其转换为PyTorch张量,那么在backprop期间它是否会出现问题? (我问,因为将PyTorch张量转换为numpy以执行某些操作然后返回张量将不会在backprop期间为此类操作计算渐变)

当我尝试做类似的事情时,我也会收到以下错误 .

a = torch.arange(0,8).view(-1,2,2)
b = [m.inverse() for m in a]
c = torch.FloatTensor(b)

TypeError:'torch.FloatTensor'对象不支持索引

2 回答

  • 0

    EDIT:

    As of Pytorch version 1.0, torch.inverse now supports batches of tensors. See here. So you can simply use the built-in function torch.inverse

    老答复

    有计划很快实现批量反向 . 有关讨论,请参阅例如issue 7500issue 9102 . 但是,截至编写本文时,当前的稳定版本(0.4.1),没有批量反向操作可用 .

    话虽如此,最近又增加了对 torch.gesv 的批量支持 . 这可以(ab)用于沿着以下行定义您自己的批量逆操作:

    def b_inv(b_mat):
        eye = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat)
        b_inv, _ = torch.gesv(eye, b_mat)
        return b_inv
    

    我发现在GPU上运行时,这会在for循环中提供良好的加速 .

  • 4

    您可以使用 torch.functional.unbind() 拆分张量,对结果的每个元素应用inverse,然后堆叠回:

    a = torch.arange(0,8).view(-1,2,2)
    b = [t.inverse() for t in torch.functional.unbind(a)]
    c = torch.functional.stack(b)
    

相关问题