我想强制将矢量的所有元素归零,这些元素都低于某个阈值 . 而且我想这样做,以便我仍然可以通过非零传播渐变 .
例如,在theano我可以写:
B = theano.tensor.switch(A < .1, 0, A)
在pytorch中有解决方案吗?
从pytorch 0.4开始,您可以使用 torch.where 轻松完成(参见doc,Merged PR)
torch.where
它就像Theano一样容易 . 看一个例子:
import torch from torch.autograd import Variable x = Variable(torch.arange(0,4), requires_grad=True) # x = [0 1 2 3] zeros = Variable(torch.zeros(*x.shape)) # zeros = [0 0 0 0] y = x**2 # y = [0 1 4 9] z = torch.where(y < 5, zeros, y) # z = [0 0 0 9] # dz/dx = (dz/dy)(dy/dx) = (y < 5)(0) + (y ≥ 5)(2x) = 2x(x**2 ≥ 5) z.backward(torch.Tensor([1.0])) x.grad # (dz/dx) = [0 0 0 6]
我不认为在PyTorch中默认实现 switch . 但是,您可以通过extending the torch.autograd.Function在PyTorch中定义自己的函数
switch
所以,开关功能看起来像
class switchFunction(Function): @staticmethod def forward(ctx, flag, value, tensor): ctx.save_for_backward(flag) tensor[flag] = value return tensor @staticmethod def backward(ctx, grad_output): flag, = ctx.saved_variables grad_output[flag] = 0 return grad_output switch = switchFunction.apply
现在,您只需将 switch 称为 switch(A < 0.1, 0, A)
switch(A < 0.1, 0, A)
实际上有一个功能可以做到这一点 . 它被称为Threshold . 你可以像使用它一样
import torch.nn as nn m = nn.Threshold(0.1, 0) B = m(A)
2 回答
从pytorch 0.4开始,您可以使用
torch.where
轻松完成(参见doc,Merged PR)它就像Theano一样容易 . 看一个例子:
我不认为在PyTorch中默认实现
switch
. 但是,您可以通过extending the torch.autograd.Function在PyTorch中定义自己的函数所以,开关功能看起来像
现在,您只需将
switch
称为switch(A < 0.1, 0, A)
编辑
实际上有一个功能可以做到这一点 . 它被称为Threshold . 你可以像使用它一样