首页 文章

如何有效地计算Pytorch中的张量?

提问于
浏览
4

我有一个张量 xx.shape=(batch_size,10) ,现在我想采取

x[i][0] = x[i][0]*x[i][1]*...*x[i][9] for i in range(batch_size)

这是我的代码:

for i in range(batch_size):
    for k in range(1, 10):
        x[i][0] = x[i][0] * x[i][k]

但是当我在 forward() 中实现它并调用 loss.backward() 时,反向传播的速度非常慢 . 为什么它很慢,有没有办法有效地实现它?

2 回答

  • 4

    它很慢,因为你使用两个for循环 .

    你可以使用 .prod 参见:https://pytorch.org/docs/stable/torch.html#torch.prod

    在你的情况下,

    x = torch.prod(x, dim=1)x = x.prod(dim=1)

    应该管用

  • 1

    当您使用两个循环来计算产品时,复杂性为n ^ 2 . 想象一下,在反向传播过程中多次这样做,你的代码变慢了 .

    向量运算加速了这些计算 .

相关问题