首页 文章

PyTorch - 变量和张量之间的元素乘法?

提问于
浏览
7

截至PyTorch 0.4,这个问题已不再有效 . 在0.4 Tensor s和 Variable 被合并 .

如何在PyTorch中使用变量和张量执行逐元素乘法?有两个张量工作正常 . 变量和标量工作正常 . 但是当尝试使用变量和张量执行逐元素乘法时,我得到:

XXXXXXXXXXX in mul
    assert not torch.is_tensor(other)
AssertionError

例如,运行以下内容时:

import torch

x_tensor = torch.Tensor([[1, 2], [3, 4]])
y_tensor = torch.Tensor([[5, 6], [7, 8]])

x_variable = torch.autograd.Variable(x_tensor)

print(x_tensor * y_tensor)
print(x_variable * 2)
print(x_variable * y_tensor)

我希望第一个和最后一个打印语句显示类似的结果 . 前两个乘法按预期工作,误差在第三个中出现 . 我在PyTorch中尝试了 * 的别名(即 x_variable.mul(y_tensor)torch.mul(y_tensor, x_variable) 等) .

考虑到错误和产生它的代码,似乎不支持张量和变量之间的元素乘法 . 它是否正确?还是有什么我想念的?谢谢!

1 回答

  • 11

    是的,你是对的 . 元素乘法(与大多数其他操作一样)仅支持 Tensor * TensorVariable * Variable ,但 not 支持 Tensor * Variable .

    要执行上面的乘法运算,请将 Tensor 包装为 Variable ,它不需要渐变 . 额外的开销是微不足道的 .

    y_variable = torch.autograd.Variable(y_tensor, requires_grad=False)
    x_variable * y_variable # returns Variable
    

    但显然,只有使用 Variables ,如果你真的需要通过图表自动区分 . 另外,您可以像在问题中一样直接在 Tensors 上执行操作 .

相关问题