Home Articles

在pytorch中有theano.tensor.switch的模拟吗?

Asked
Viewed 1054 times
0

我想强制将矢量的所有元素归零,这些元素都低于某个阈值 . 而且我想这样做,以便我仍然可以通过非零传播渐变 .

例如,在theano我可以写:

B = theano.tensor.switch(A < .1, 0, A)

在pytorch中有解决方案吗?

2 Answers

  • 1

    从pytorch 0.4开始,您可以使用 torch.where 轻松完成(参见docMerged PR

    它就像Theano一样容易 . 看一个例子:

    import torch
    from torch.autograd import Variable
    
    x = Variable(torch.arange(0,4), requires_grad=True) # x     = [0 1 2 3]
    zeros = Variable(torch.zeros(*x.shape))             # zeros = [0 0 0 0]
    
    y = x**2                         # y = [0 1 4 9]
    z = torch.where(y < 5, zeros, y) # z = [0 0 0 9]
    
    # dz/dx = (dz/dy)(dy/dx) = (y < 5)(0) + (y ≥ 5)(2x) = 2x(x**2 ≥ 5) 
    z.backward(torch.Tensor([1.0])) 
    x.grad # (dz/dx) = [0 0 0 6]
    
  • 1

    我不认为在PyTorch中默认实现 switch . 但是,您可以通过extending the torch.autograd.Function在PyTorch中定义自己的函数

    所以,开关功能看起来像

    class switchFunction(Function):
        @staticmethod
        def forward(ctx, flag, value, tensor):
            ctx.save_for_backward(flag)
            tensor[flag] = value
            return tensor
    
        @staticmethod
        def backward(ctx, grad_output):
            flag, = ctx.saved_variables
            grad_output[flag] = 0
            return grad_output
    switch = switchFunction.apply
    

    现在,您只需将 switch 称为 switch(A < 0.1, 0, A)

    编辑

    实际上有一个功能可以做到这一点 . 它被称为Threshold . 你可以像使用它一样

    import torch.nn as nn
    m = nn.Threshold(0.1, 0)
    B = m(A)
    

Related